Release/v0.2.3 (#355)
* feat(web): add PageEmpty component
* feat(web): add PageTabs component
* feat(web): add PageEmpty component
* feat(web): add PageTabs component
* feat(prompt): add history tracking for prompt releases
* feat(web): add prompt menu
* refactor: The PageScrollList component supports two generic parameters
* feat(web): BodyWrapper compoent update PageLoading
* feat(web): add Ontology menu
* feat(web): memory management add scene
* feat(tasks): add celery task configuration for periodic jobs
- Add ignore_result=True to prevent storing results for periodic tasks
- Set max_retries=0 to skip failed periodic tasks without retry attempts
- Configure acks_late=False for immediate acknowledgment in beat tasks
- Add time_limit and soft_time_limit to regenerate_memory_cache task (3600s/3300s)
- Add time_limit and soft_time_limit to workspace_reflection_task (300s/240s)
- Add time_limit and soft_time_limit to run_forgetting_cycle_task (7200s/7000s)
- Improve task reliability and resource management for scheduled jobs
* feat(sandbox): add Node.js code execution support to sandbox
* Release/v0.2.2 (#260)
* [modify] migration script
* [add] migration script
* fix(web): change form message
* fix(web): the memoryContent field is compatible with numbers and strings
* feat(web): code node hidden
* fix(model):
1. create a basic model to check if the name and provider are duplicated.
2. The result shows error models because the provider created API Keys for all matching models.
---------
Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
* Feature/ontology class clean (#249)
* [add] Complete ontology engineering feature implementation
* [add] Add ontology feature integration and validation utilities
* [add] Add OWL validator and validation utilities
* [fix] Add missing render_ontology_extraction_prompt function
* [fix]Add dependencies, fix functionality
* [add] migration script
* feat(celery): add dedicated periodic tasks worker and queue (#261)
* fix(web): conflict resolve
* Fix/v022 bug (#263)
* [fix]Fix the issue of inconsistent language in explicit and episodic memory.
* [fix]Fix the issue of inconsistent language in explicit and episodic memory.
* [add]Add scene_id
* [fix]Based on the AI review to fix the code
* Fix/develop memory reflex (#265)
* 遗漏的历史映射
* 遗漏的历史映射
* 反思后台报错处理
* [add] migration script
* fix: chat conversation_id add node_start
* feat(web): show code node
* fix(web): Restructure the CustomSelect component, repair the interface that is called multiple times when the form is updated
* feat(web): RadioGroupCard support block mode
* feat(web): create space add icon
* feat(app and model): token consumption statistics
* Add/develop memory (#264)
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 遗漏的历史映射
* 新增长期记忆功能
* 新增长期记忆功能
* 新增长期记忆功能
* 知识库检索多余字段
* 长期
* feat(app and model): token consumption statistics of the cluster
* memory_BUG_fix
* fix(web): prompt history remove pageLoading
* fix(prompt): remove hard-coded import of prompt file paths (#279)
* Fix/develop memory bug (#274)
* 遗漏的历史映射
* 遗漏的历史映射
* fix_timeline_memories
* fix(web): update retrieve_type key
* Fix/develop memory bug (#276)
* 遗漏的历史映射
* 遗漏的历史映射
* fix_timeline_memories
* fix_timeline_memories
* write_gragp/bug_fix
* write_gragp/bug_fix
* write_gragp/bug_fix
* chore(celery): disable periodic task scheduling
* fix(prompt): remove hard-coded import of prompt file paths
---------
Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Ke Sun <kesun5@illinois.edu>
* fix(web): remove delete confirm content
* refactor(workflow): relocate template directory into workflow
* feat(memory): add long-term storage task routing and batching
* fix(web): PageScrollList loading update
* fix(web): PageScrollList loading update
* Ontology v1 bug (#291)
* [changes]Add 'id' as the secondary sorting key, and 'scene_id' now returns a UUID object
* [fix]Fix the "end_user" return to be sorted by update time.
* [fix]Set the default values of the memory configuration model based on the spatial model.
* [fix]Remove the entity extraction check combination model, read the configuration list, and add the return of scene_id
* [fix]Fix the "end_user" return to be sorted by update time.
* [fix]
* fix(memory): add Redis session validation
- Add macOS fork() safety configuration in celery_app.py to prevent initialization issues
- Add null/False checks for Redis session queries in term_memory_save to handle missing sessions gracefully
- Add null/False checks in memory_long_term_storage to prevent processing empty Redis results
- Add null/False checks in aggregate_judgment before format_parsing to avoid errors on missing data
- Initialize redis_messages variable in window_dialogue for consistency
- Add debug logging when no existing session found in Redis for better troubleshooting
- Add TODO comments for magic numbers (scope=6, time=5) to be extracted as constants
- Improve error handling when Redis returns False or empty results instead of crashing
* fix(web): PageScrollList style update
* fix(workflow): fix argument passing in code execution nodes
* fix(web): prompt add disabled
* fix(web): space icon required
* feat(app): modify the key of the token
* fix(fix the key of the app's token):
* fix(workflow): switch code input encoding to base64+URL encoding
* [add]The main project adds multi-API Key load balancing.
* [changes]Attribute security access, secure numerical conversion, unified use of local variables
* fix(web): save add session update
* fix(web): language editor support paste
* [changes]Active status filtering logic, API Key selection strategy
* memory_BUG
* memory_BUG_long_term
* [changes]
* memory_BUG_long_term
* memory_BUG_long_term
* Fix/release memory bug (#306)
* memory_BUG_fix
* memory_BUG
* memory_BUG_long_term
* memory_BUG_long_term
* memory_BUG_long_term
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios
* fix(web): replace code editor
* [changes]Modify the description of the time for the recent event
* [changes]Modify the code based on the AI review
* feat(web): update memory config ontology api
* fix(web): ui update
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* knowledge_retrieval/bug/fix
* feat(workflow): add token usage statistics for question classifier and parameter extraction
* feat(web): move prompt menu
* Multiple independent transactions - single transaction
* Multiple independent transactions - single transaction
* Multiple independent transactions - single transaction
* Multiple independent transactions - single transaction
* Write Missing None (#321)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Fix/release memory bug (#324)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
* redis update
* redis update
* redis update
* redis update
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Fix/writer memory bug (#326)
* [fix]Fix the bug
* [fix]Fix the bug
* [fix]Correct the direction indication.
* fix(web): markdown table ui update
* Fix/release memory bug (#332)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
* redis update
* redis update
* redis update
* redis update
* writer_dup_bug/fix
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Fix/fact summary (#333)
* [fix]Disable the contents related to fact_summary
* [fix]Disable the contents related to fact_summary
* [fix]Modify the code based on the AI review
* Fix/release memory bug (#335)
* Write Missing None
* Write Missing None
* Write Missing None
* Apply suggestion from @sourcery-ai[bot]
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Write Missing None
* redis update
* redis update
* redis update
* redis update
* writer_dup_bug/fix
* writer_graph_bug/fix
* writer_graph_bug/fix
---------
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
* Revert "feat(web): move prompt menu"
This reverts commit 9e6e8f50f8.
* fix(web): ui update
* fix(web): update text
* fix(web): ui update
* fix(model): change the "vl" model type of dashscope to "chat"
* fix(model): change the "vl" model type of dashscope to "chat"
---------
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: Eternity <1533512157@qq.com>
Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
Co-authored-by: 乐力齐 <162269739+lanceyq@users.noreply.github.com>
Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: lixinyue <2569494688@qq.com>
Co-authored-by: Eternity <61316157+myhMARS@users.noreply.github.com>
Co-authored-by: lanceyq <1982376970@qq.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,238 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
json_schema = WriteAggregateModel.model_json_schema()
|
||||
template_service = TemplateService(template_root)
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='write_aggregate_judgment.jinja2',
|
||||
operation_name='aggregate_judgment',
|
||||
history=history,
|
||||
sentence=ori_messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
factory = MemoryClientFactory(db_session)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": system_prompt
|
||||
}
|
||||
]
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=WriteAggregateModel
|
||||
)
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
"is_same_event": structured.is_same_event,
|
||||
"output": output_value
|
||||
}
|
||||
if not structured.is_same_event:
|
||||
logger.info(result_dict)
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
for history_messag in hstory_messages.strip().splitlines():
|
||||
history_messag = json.loads(history_messag)
|
||||
for content in history_messag:
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
return result
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
for data in Query:
|
||||
role = data['role']
|
||||
if role == "human":
|
||||
user.append(data['content'])
|
||||
if role == "ai":
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{user_content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{ai_content}"
|
||||
}
|
||||
|
||||
]
|
||||
return messages
|
||||
@@ -1,22 +1,20 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -34,14 +32,6 @@ async def make_write_graph():
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
# workflow = StateGraph(WriteState)
|
||||
# workflow.add_node("content_input", content_input_write)
|
||||
# workflow.add_node("save_neo4j", write_node)
|
||||
# workflow.add_edge(START, "content_input")
|
||||
# workflow.add_edge("content_input", "save_neo4j")
|
||||
# workflow.add_edge("save_neo4j", END)
|
||||
#
|
||||
# graph = workflow.compile()
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
@@ -51,43 +41,63 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
end_user_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j'==node_name:
|
||||
massages=node_data
|
||||
massages=massages.get('write_result')['status']
|
||||
print(massages) # | 更新数据: {node_data}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Pydantic models for write aggregate judgment operations."""
|
||||
|
||||
from typing import List, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Individual message item in conversation."""
|
||||
|
||||
role: str = Field(..., description="角色:user 或 assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
|
||||
|
||||
class WriteAggregateResponse(BaseModel):
|
||||
"""Response model for aggregate judgment containing judgment result and output."""
|
||||
|
||||
is_same_event: bool = Field(
|
||||
...,
|
||||
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
||||
)
|
||||
output: Union[List[MessageItem], bool] = Field(
|
||||
...,
|
||||
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
||||
)
|
||||
|
||||
|
||||
# 为了保持向后兼容,保留旧的类名作为别名
|
||||
WriteAggregateModel = WriteAggregateResponse
|
||||
@@ -0,0 +1,57 @@
|
||||
输入句子:{{sentence}}
|
||||
历史消息:{{history}}
|
||||
|
||||
# 你的角色
|
||||
你是一个擅长事件聚合与语义判断的专家。
|
||||
|
||||
# 你的任务
|
||||
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
||||
|
||||
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
||||
- 描述的是同一个具体事件或事实
|
||||
- 存在明显的因果关系、前后发展关系
|
||||
- 是对同一事件的补充、解释、追问或延展
|
||||
- 逻辑上属于同一语境下的连续讨论
|
||||
|
||||
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
||||
- 话题不同,事件主体不同
|
||||
- 时间、地点、对象明显不同
|
||||
- 只是语义相似,但并非同一具体事件
|
||||
- 无直接事件、因果或逻辑关联
|
||||
|
||||
# 输出规则(非常重要)
|
||||
你必须按照以下JSON格式输出:
|
||||
|
||||
**如果是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": true,
|
||||
"output": false
|
||||
}
|
||||
```
|
||||
|
||||
**如果不是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": false,
|
||||
"output": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "输入句子的内容"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "对应的回复内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
# JSON Schema
|
||||
{{json_schema}}
|
||||
|
||||
# 注意事项
|
||||
- 必须严格按照上述格式输出
|
||||
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
||||
- 消息列表必须包含 role 和 content 字段
|
||||
- 不要输出任何解释、分析或多余内容
|
||||
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from typing import Any, List, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def serialize_messages(messages: Any) -> str:
|
||||
"""
|
||||
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
||||
|
||||
Args:
|
||||
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
||||
|
||||
Returns:
|
||||
str: JSON 字符串
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
|
||||
if isinstance(messages, (list, tuple)):
|
||||
# 检查是否是 LangChain 消息对象列表
|
||||
serialized_list = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# LangChain 消息对象
|
||||
serialized_list.append({
|
||||
'type': msg.type,
|
||||
'content': msg.content,
|
||||
'role': getattr(msg, 'role', msg.type)
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
serialized_list.append(msg)
|
||||
else:
|
||||
serialized_list.append(str(msg))
|
||||
return json.dumps(serialized_list, ensure_ascii=False)
|
||||
|
||||
if isinstance(messages, dict):
|
||||
return json.dumps(messages, ensure_ascii=False)
|
||||
|
||||
# 其他类型转为字符串
|
||||
return str(messages)
|
||||
|
||||
|
||||
def deserialize_messages(messages_str: str) -> Any:
|
||||
"""
|
||||
将 JSON 字符串反序列化为原始格式
|
||||
|
||||
Args:
|
||||
messages_str: JSON 字符串
|
||||
|
||||
Returns:
|
||||
反序列化后的对象(list、dict 或 string)
|
||||
"""
|
||||
if not messages_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
return json.loads(messages_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return messages_str
|
||||
|
||||
|
||||
def fix_encoding(text: str) -> str:
|
||||
"""
|
||||
修复错误编码的文本
|
||||
|
||||
Args:
|
||||
text: 需要修复的文本
|
||||
|
||||
Returns:
|
||||
str: 修复后的文本
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
|
||||
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化会话数据为统一的输出格式
|
||||
|
||||
Args:
|
||||
data: 原始会话数据
|
||||
include_time: 是否包含时间字段
|
||||
|
||||
Returns:
|
||||
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
||||
"""
|
||||
result = {
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": fix_encoding(data.get('aimessages', ''))
|
||||
}
|
||||
|
||||
if include_time:
|
||||
result["starttime"] = data.get('starttime', '')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
||||
"""
|
||||
根据时间范围过滤数据
|
||||
|
||||
Args:
|
||||
items: 包含 starttime 字段的数据列表
|
||||
minutes: 时间范围(分钟)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 过滤后的数据列表
|
||||
"""
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
starttime = item.get('starttime', '')
|
||||
if starttime and starttime >= time_threshold_str:
|
||||
filtered_items.append(item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
|
||||
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
||||
remove_time: bool = True) -> List[Dict]:
|
||||
"""
|
||||
对结果进行排序、限制数量并移除时间字段
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
limit: 最大返回数量
|
||||
remove_time: 是否移除 starttime 字段
|
||||
|
||||
Returns:
|
||||
List[Dict]: 处理后的数据列表
|
||||
"""
|
||||
# 按时间降序排序(最新的在前)
|
||||
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
# 限制数量
|
||||
result_items = items[:limit]
|
||||
|
||||
# 移除 starttime 字段
|
||||
if remove_time:
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于1条,返回空列表
|
||||
if len(result_items) < 1:
|
||||
return []
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
||||
"""
|
||||
生成 Redis key
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
key_type: key 类型 ("session", "read", "write", "count")
|
||||
|
||||
Returns:
|
||||
str: Redis key
|
||||
"""
|
||||
if key_type == "count":
|
||||
return f"session:count:{session_id}"
|
||||
elif key_type == "write":
|
||||
return f"session:write:{session_id}"
|
||||
elif key_type == "session" or key_type == "read":
|
||||
return f"session:{session_id}"
|
||||
else:
|
||||
return f"session:{session_id}"
|
||||
|
||||
|
||||
def get_current_timestamp() -> str:
|
||||
"""
|
||||
获取当前时间戳字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,11 +1,36 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
filter_by_time_range,
|
||||
sort_and_limit_results,
|
||||
generate_session_key,
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -16,32 +41,437 @@ class RedisSessionStore:
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def _fix_encoding(self, text):
|
||||
"""修复错误编码的文本"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||
def save_session_write(self, userid: str, messages: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
messages = serialize_messages(messages)
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="write")
|
||||
|
||||
# 使用 pipeline 批量写入,减少网络往返
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"messages": messages,
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
"""
|
||||
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
results.append({
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{
|
||||
"session_id": "uuid",
|
||||
"id": "...",
|
||||
"sessionid": "end_user_id",
|
||||
"messages": "...",
|
||||
"starttime": "timestamp"
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 end_user_id 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"id": data.get('id', ''),
|
||||
"sessionid": data.get('sessionid', ''),
|
||||
"messages": fix_encoding(data.get('messages', '')),
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
minutes: 查询最近几分钟的数据,默认5分钟
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
matched_items.append({
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
def delete_all_write_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 write 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:write:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
保存用户访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
count: 访问次数
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
new_count: 新的 count 值
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回 True,未找到记录返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 count 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:count:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
aimessages: AI回复消息
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="read")
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
@@ -49,177 +479,195 @@ class RedisSessionStore:
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
|
||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||
|
||||
# 执行批量操作
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
"""
|
||||
try:
|
||||
session_ids = []
|
||||
pipe = self.r.pipeline()
|
||||
|
||||
for session in sessions_data:
|
||||
session_id = str(uuid.uuid4())
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}"
|
||||
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"end_user_id": session.get('end_user_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
})
|
||||
|
||||
session_ids.append(session_id)
|
||||
|
||||
# 一次性执行所有写入操作
|
||||
results = pipe.execute()
|
||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||
return session_ids
|
||||
except Exception as e:
|
||||
print(f"批量保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
Dict 或 None: 会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key in self.r.keys('session:*'):
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
"""
|
||||
获取所有会话数据
|
||||
获取所有会话数据(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
Dict: 所有会话数据,key 为 session_id
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
# 排除 count 和 write 类型的 key
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
Args:
|
||||
sessionid: 会话ID(支持模糊匹配)
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
优化版本:使用 pipeline 减少网络往返
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
field: 字段名
|
||||
value: 字段值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0]) # 返回 key 是否存在
|
||||
return bool(results[0])
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
def delete_all_sessions(self) -> int:
|
||||
"""
|
||||
删除所有会话
|
||||
删除所有会话(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
# 过滤掉 count 和 write 类型
|
||||
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
||||
if keys_to_delete:
|
||||
return self.r.delete(*keys_to_delete)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
def delete_duplicate_sessions(self) -> int:
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
删除重复会话数据(不包括 count 和 write 类型)
|
||||
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 第一步:使用 pipeline 批量获取所有 key
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 第二步:使用 pipeline 批量获取所有数据
|
||||
# 批量获取所有数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 第三步:在内存中识别重复数据
|
||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||
keys_to_delete = [] # 需要删除的 key 列表
|
||||
# 识别重复数据
|
||||
seen = {}
|
||||
keys_to_delete = []
|
||||
|
||||
for key, data in zip(keys, all_data, strict=False):
|
||||
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
end_user_id = data.get('end_user_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||
identifier = (
|
||||
data.get('sessionid', ''),
|
||||
data.get('id', ''),
|
||||
data.get('end_user_id', ''),
|
||||
data.get('messages', ''),
|
||||
data.get('aimessages', '')
|
||||
)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
# 第一次出现,记录
|
||||
seen[identifier] = key
|
||||
|
||||
# 第四步:使用 pipeline 批量删除重复的 key
|
||||
# 批量删除重复的 key
|
||||
deleted_count = 0
|
||||
if keys_to_delete:
|
||||
# 分批删除,避免单次操作过大
|
||||
batch_size = 1000
|
||||
for i in range(0, len(keys_to_delete), batch_size):
|
||||
batch = keys_to_delete[i:i + batch_size]
|
||||
@@ -233,79 +681,28 @@ class RedisSessionStore:
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self, sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
# 使用 pipeline 批量获取数据,提高性能
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 使用 pipeline 批量获取所有 hash 数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 解析并筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
"Query": self._fix_encoding(data.get('messages')),
|
||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
# 按时间降序排序(最新的在前)
|
||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
# 只保留最新的6条
|
||||
result_items = matched_items[:6]
|
||||
# # 移除 starttime 字段
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
# 全局实例
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
|
||||
write_store = RedisWriteStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
count_store = RedisCountStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
@@ -123,23 +124,48 @@ async def write(
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 检查是否是死锁错误
|
||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
else:
|
||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||
raise
|
||||
else:
|
||||
# 非死锁错误,直接抛出
|
||||
raise
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Neo4j connector: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology models
|
||||
from app.core.memory.models.ontology_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
# Variable configuration models
|
||||
from app.core.memory.models.variate_config import (
|
||||
StatementExtractionConfig,
|
||||
@@ -105,6 +111,9 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
# Variable configuration
|
||||
"StatementExtractionConfig",
|
||||
"ForgettingEngineConfig",
|
||||
|
||||
@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
135
api/app/core/memory/models/ontology_models.py
Normal file
135
api/app/core/memory/models/ontology_models.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Models for ontology classes and extraction responses.
|
||||
|
||||
This module contains Pydantic models for representing extracted ontology classes
|
||||
from scenario descriptions, following OWL ontology engineering standards.
|
||||
|
||||
Classes:
|
||||
OntologyClass: Represents an extracted ontology class
|
||||
OntologyExtractionResponse: Response model containing extracted ontology classes
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class OntologyClass(BaseModel):
|
||||
"""Represents an extracted ontology class from scenario description.
|
||||
|
||||
An ontology class represents an abstract category or concept in a domain,
|
||||
following OWL ontology engineering standards and naming conventions.
|
||||
|
||||
Attributes:
|
||||
id: Unique string identifier for the ontology class
|
||||
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
|
||||
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
|
||||
description: Textual description of the class
|
||||
examples: List of concrete instance examples of this class
|
||||
parent_class: Optional name of the parent class in the hierarchy
|
||||
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
|
||||
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid4().hex,
|
||||
description="Unique identifier for the ontology class"
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the class in PascalCase format"
|
||||
)
|
||||
name_chinese: Optional[str] = Field(
|
||||
None,
|
||||
description="Chinese translation of the class name"
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description="Description of the class"
|
||||
)
|
||||
examples: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of concrete instance examples"
|
||||
)
|
||||
parent_class: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the parent class in the hierarchy"
|
||||
)
|
||||
entity_type: str = Field(
|
||||
...,
|
||||
description="Type/category of the entity"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain this class belongs to"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_pascal_case(cls, v: str) -> str:
|
||||
"""Validate that the class name follows PascalCase convention.
|
||||
|
||||
PascalCase rules:
|
||||
- Must start with an uppercase letter
|
||||
- Cannot contain spaces
|
||||
- Should not contain special characters except underscores
|
||||
|
||||
Args:
|
||||
v: The class name to validate
|
||||
|
||||
Returns:
|
||||
The validated class name
|
||||
|
||||
Raises:
|
||||
ValueError: If the name doesn't follow PascalCase convention
|
||||
"""
|
||||
if not v:
|
||||
raise ValueError("Class name cannot be empty")
|
||||
|
||||
if not v[0].isupper():
|
||||
raise ValueError(
|
||||
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
|
||||
)
|
||||
|
||||
if ' ' in v:
|
||||
raise ValueError(
|
||||
f"Class name '{v}' cannot contain spaces (PascalCase)"
|
||||
)
|
||||
|
||||
# Check for invalid characters (allow alphanumeric and underscore only)
|
||||
if not all(c.isalnum() or c == '_' for c in v):
|
||||
raise ValueError(
|
||||
f"Class name '{v}' contains invalid characters. "
|
||||
"Only alphanumeric characters and underscores are allowed"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class OntologyExtractionResponse(BaseModel):
|
||||
"""Response model for ontology extraction from LLM.
|
||||
|
||||
This model represents the structured output from the LLM when
|
||||
extracting ontology classes from scenario descriptions.
|
||||
|
||||
Attributes:
|
||||
classes: List of extracted ontology classes
|
||||
domain: Domain/field the scenario belongs to
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
classes: List[OntologyClass] = Field(
|
||||
default_factory=list,
|
||||
description="List of extracted ontology classes"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain/field the scenario belongs to"
|
||||
)
|
||||
@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
if len(desc_b) > len(desc_a):
|
||||
canonical.description = desc_b
|
||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||
fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
def _extract_sources(txt: str) -> List[str]:
|
||||
sources: List[str] = []
|
||||
if not txt:
|
||||
return sources
|
||||
for line in str(txt).splitlines():
|
||||
ln = line.strip()
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
# fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
# def _extract_sources(txt: str) -> List[str]:
|
||||
# sources: List[str] = []
|
||||
# if not txt:
|
||||
# return sources
|
||||
# for line in str(txt).splitlines():
|
||||
# ln = line.strip()
|
||||
# 支持“来源:”或“来源:”前缀
|
||||
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
if m:
|
||||
content = m.group(1).strip()
|
||||
if content:
|
||||
sources.append(content)
|
||||
# m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
# if m:
|
||||
# content = m.group(1).strip()
|
||||
# if content:
|
||||
# sources.append(content)
|
||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||
if not sources and txt.strip():
|
||||
sources.append(txt.strip())
|
||||
return sources
|
||||
# if not sources and txt.strip():
|
||||
# sources.append(txt.strip())
|
||||
# return sources
|
||||
try:
|
||||
src_a = _extract_sources(fact_a)
|
||||
src_b = _extract_sources(fact_b)
|
||||
seen = set()
|
||||
merged_sources: List[str] = []
|
||||
for s in src_a + src_b:
|
||||
if s and s not in seen:
|
||||
seen.add(s)
|
||||
merged_sources.append(s)
|
||||
if merged_sources:
|
||||
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
elif fact_b and not fact_a:
|
||||
canonical.fact_summary = fact_b
|
||||
# src_a = _extract_sources(fact_a)
|
||||
# src_b = _extract_sources(fact_b)
|
||||
# seen = set()
|
||||
# merged_sources: List[str] = []
|
||||
# for s in src_a + src_b:
|
||||
# if s and s not in seen:
|
||||
# seen.add(s)
|
||||
# merged_sources.append(s)
|
||||
# if merged_sources:
|
||||
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
# elif fact_b and not fact_a:
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
# 兜底:若解析失败,保留较长文本
|
||||
if len(fact_b) > len(fact_a):
|
||||
canonical.fact_summary = fact_b
|
||||
# if len(fact_b) > len(fact_a):
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
|
||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||
desc_a = (getattr(a, "description", "") or "")
|
||||
desc_b = (getattr(b, "description", "") or "")
|
||||
fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
score_a = len(desc_a) + len(fact_a)
|
||||
score_b = len(desc_b) + len(fact_b)
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
# fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
# score_a = len(desc_a) + len(fact_a)
|
||||
# score_b = len(desc_b) + len(fact_b)
|
||||
score_a = len(desc_a)
|
||||
score_b = len(desc_b)
|
||||
if score_a != score_b:
|
||||
return 0 if score_a >= score_b else 1
|
||||
return 0
|
||||
@@ -189,7 +192,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -197,7 +201,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
prompt = render_entity_dedup_prompt(
|
||||
|
||||
@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
description=row.get("description") or "",
|
||||
aliases=row.get("aliases") or [],
|
||||
name_embedding=row.get("name_embedding") or [],
|
||||
fact_summary=row.get("fact_summary") or "",
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=row.get("fact_summary") or "",
|
||||
connect_strength=row.get("connect_strength") or "",
|
||||
)
|
||||
|
||||
|
||||
@@ -1085,7 +1085,8 @@ class ExtractionOrchestrator:
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
|
||||
@@ -8,4 +8,5 @@
|
||||
- TemporalExtractor: 时间信息提取
|
||||
- EmbeddingGenerator: 嵌入向量生成
|
||||
- MemorySummaryGenerator: 记忆摘要生成
|
||||
- OntologyExtractor: 本体类提取
|
||||
"""
|
||||
|
||||
@@ -14,6 +14,34 @@ from pydantic import Field
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
# 支持的语言列表和默认回退值
|
||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||
FALLBACK_LANGUAGE = "en"
|
||||
|
||||
|
||||
def validate_language(language: Optional[str]) -> str:
|
||||
"""
|
||||
校验语言参数,确保其为有效值。
|
||||
|
||||
Args:
|
||||
language: 待校验的语言代码
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
"""
|
||||
if language is None:
|
||||
return FALLBACK_LANGUAGE
|
||||
|
||||
lang = str(language).lower().strip()
|
||||
if lang in SUPPORTED_LANGUAGES:
|
||||
return lang
|
||||
|
||||
logger.warning(
|
||||
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'。"
|
||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||
)
|
||||
return FALLBACK_LANGUAGE
|
||||
|
||||
|
||||
class MemorySummaryResponse(RobustLLMResponse):
|
||||
"""Structured response for summary generation per chunk.
|
||||
@@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse):
|
||||
|
||||
async def generate_title_and_type_for_summary(
|
||||
content: str,
|
||||
llm_client
|
||||
llm_client,
|
||||
language: str = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
为MemorySummary生成标题和类型
|
||||
@@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary(
|
||||
Args:
|
||||
content: Summary的内容文本
|
||||
llm_client: LLM客户端实例
|
||||
language: 生成标题使用的语言 ("zh" 中文, "en" 英文),如果为None则从配置读取
|
||||
|
||||
Returns:
|
||||
(标题, 类型)元组
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||
from app.core.config import settings
|
||||
|
||||
# 如果没有指定语言,从配置中读取,并校验有效性
|
||||
if language is None:
|
||||
language = settings.DEFAULT_LANGUAGE
|
||||
language = validate_language(language)
|
||||
|
||||
# 定义有效的类型集合
|
||||
VALID_TYPES = {
|
||||
@@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary(
|
||||
}
|
||||
DEFAULT_TYPE = "conversation" # 默认类型
|
||||
|
||||
# 根据语言设置默认标题
|
||||
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
|
||||
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
|
||||
ERROR_TITLE = "错误" if language == "zh" else "Error"
|
||||
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
|
||||
|
||||
try:
|
||||
if not content:
|
||||
logger.warning("content为空,无法生成标题和类型")
|
||||
return ("空内容", DEFAULT_TYPE)
|
||||
logger.warning(f"content为空,无法生成标题和类型 (language={language})")
|
||||
return (DEFAULT_TITLE, DEFAULT_TYPE)
|
||||
|
||||
# 1. 渲染Jinja2提示词模板
|
||||
prompt = await render_episodic_title_and_type_prompt(content)
|
||||
# 1. 渲染Jinja2提示词模板,传递语言参数
|
||||
prompt = await render_episodic_title_and_type_prompt(content, language=language)
|
||||
|
||||
# 2. 调用LLM生成标题和类型
|
||||
messages = [
|
||||
@@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary(
|
||||
json_str = json_str.strip()
|
||||
|
||||
result_data = json.loads(json_str)
|
||||
title = result_data.get("title", "未知标题")
|
||||
title = result_data.get("title", UNKNOWN_TITLE)
|
||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||
|
||||
# 5. 校验和归一化类型
|
||||
@@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary(
|
||||
f"已归一化为 '{episodic_type}'"
|
||||
)
|
||||
|
||||
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
||||
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
|
||||
return (title, episodic_type)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
||||
return ("解析失败", DEFAULT_TYPE)
|
||||
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
|
||||
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
||||
return ("错误", DEFAULT_TYPE)
|
||||
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
|
||||
return (ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
async def _process_chunk_summary(
|
||||
dialog: DialogData,
|
||||
@@ -153,11 +195,16 @@ async def _process_chunk_summary(
|
||||
return None
|
||||
|
||||
try:
|
||||
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
|
||||
from app.core.config import settings
|
||||
language = validate_language(settings.DEFAULT_LANGUAGE)
|
||||
|
||||
# Render prompt via Jinja2 for a single chunk
|
||||
prompt_content = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk.content,
|
||||
json_schema=MemorySummaryResponse.model_json_schema(),
|
||||
max_words=200,
|
||||
language=language,
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -178,9 +225,10 @@ async def _process_chunk_summary(
|
||||
try:
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
llm_client=llm_client
|
||||
llm_client=llm_client,
|
||||
language=language
|
||||
)
|
||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||
# Continue without title and type
|
||||
|
||||
@@ -0,0 +1,482 @@
|
||||
"""Ontology class extraction from scenario descriptions using LLM.
|
||||
|
||||
This module provides the OntologyExtractor class for extracting ontology classes
|
||||
from natural language scenario descriptions. It uses LLM-driven extraction combined
|
||||
with two-layer validation (string validation + OWL semantic validation).
|
||||
|
||||
Classes:
|
||||
OntologyExtractor: Extracts ontology classes from scenario descriptions
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.ontology_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
|
||||
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyExtractor:
|
||||
"""Extractor for ontology classes from scenario descriptions.
|
||||
|
||||
This extractor uses LLM to identify abstract classes and concepts from
|
||||
natural language scenario descriptions, following OWL ontology engineering
|
||||
standards. It performs two-layer validation:
|
||||
1. String validation (naming conventions, reserved words, duplicates)
|
||||
2. OWL semantic validation (consistency checking, circular inheritance)
|
||||
|
||||
Attributes:
|
||||
llm_client: OpenAI client for LLM calls
|
||||
validator: String validator for class names and descriptions
|
||||
owl_validator: OWL validator for semantic validation
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient):
|
||||
"""Initialize the OntologyExtractor.
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for LLM processing
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.validator = OntologyValidator()
|
||||
self.owl_validator = OWLValidator()
|
||||
|
||||
logger.info("OntologyExtractor initialized")
|
||||
|
||||
async def extract_ontology_classes(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str] = None,
|
||||
max_classes: int = 15,
|
||||
min_classes: int = 5,
|
||||
enable_owl_validation: bool = True,
|
||||
llm_temperature: float = 0.3,
|
||||
llm_max_tokens: int = 2000,
|
||||
max_description_length: int = 500,
|
||||
timeout: Optional[float] = None,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Extract ontology classes from a scenario description.
|
||||
|
||||
This is the main extraction method that orchestrates the entire process:
|
||||
1. Call LLM to extract ontology classes
|
||||
2. Perform first-layer validation (string validation and cleaning)
|
||||
3. Perform second-layer validation (OWL semantic validation)
|
||||
4. Filter invalid classes based on validation errors
|
||||
5. Return validated ontology classes
|
||||
|
||||
Args:
|
||||
scenario: Natural language scenario description
|
||||
domain: Optional domain hint (e.g., "Healthcare", "Education")
|
||||
max_classes: Maximum number of classes to extract (default: 15)
|
||||
min_classes: Minimum number of classes to extract (default: 5)
|
||||
enable_owl_validation: Whether to enable OWL validation (default: True)
|
||||
llm_temperature: LLM temperature parameter (default: 0.3)
|
||||
llm_max_tokens: LLM max tokens parameter (default: 2000)
|
||||
max_description_length: Maximum description length (default: 500)
|
||||
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse containing validated ontology classes
|
||||
|
||||
Raises:
|
||||
ValueError: If scenario is empty or invalid
|
||||
asyncio.TimeoutError: If extraction times out
|
||||
|
||||
Examples:
|
||||
>>> extractor = OntologyExtractor(llm_client)
|
||||
>>> response = await extractor.extract_ontology_classes(
|
||||
... scenario="A hospital manages patient records...",
|
||||
... domain="Healthcare",
|
||||
... max_classes=10,
|
||||
... timeout=30.0
|
||||
... )
|
||||
>>> len(response.classes)
|
||||
7
|
||||
"""
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Validate input
|
||||
if not scenario or not scenario.strip():
|
||||
logger.error("Scenario description is empty")
|
||||
raise ValueError("Scenario description cannot be empty")
|
||||
|
||||
scenario = scenario.strip()
|
||||
|
||||
logger.info(
|
||||
f"Starting ontology extraction - scenario_length={len(scenario)}, "
|
||||
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
|
||||
f"timeout={timeout}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Call LLM for extraction with timeout
|
||||
logger.info("Step 1: Calling LLM for ontology extraction")
|
||||
llm_start_time = time.time()
|
||||
|
||||
if timeout is not None:
|
||||
# Wrap LLM call with timeout
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.error(
|
||||
f"LLM extraction timed out after {timeout} seconds "
|
||||
f"(actual duration: {llm_duration:.2f}s)"
|
||||
)
|
||||
# Return empty response on timeout
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
else:
|
||||
# No timeout specified, call directly
|
||||
response = await self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
)
|
||||
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.info(
|
||||
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
# Step 2: First-layer validation (string validation and cleaning)
|
||||
logger.info("Step 2: Performing first-layer validation (string validation)")
|
||||
validation_start_time = time.time()
|
||||
|
||||
response = self._validate_and_clean(
|
||||
response=response,
|
||||
max_description_length=max_description_length,
|
||||
)
|
||||
|
||||
validation_duration = time.time() - validation_start_time
|
||||
logger.info(
|
||||
f"After first-layer validation: {len(response.classes)} classes remain "
|
||||
f"(validation took {validation_duration:.2f}s)"
|
||||
)
|
||||
|
||||
# Check if we have enough classes after first-layer validation
|
||||
if len(response.classes) < min_classes:
|
||||
logger.warning(
|
||||
f"Only {len(response.classes)} classes remain after validation, "
|
||||
f"which is below minimum of {min_classes}"
|
||||
)
|
||||
|
||||
# Step 3: Second-layer validation (OWL semantic validation)
|
||||
if enable_owl_validation and response.classes:
|
||||
logger.info("Step 3: Performing second-layer validation (OWL validation)")
|
||||
owl_start_time = time.time()
|
||||
|
||||
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
|
||||
classes=response.classes,
|
||||
)
|
||||
|
||||
owl_duration = time.time() - owl_start_time
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
|
||||
)
|
||||
|
||||
# Filter invalid classes based on errors
|
||||
response = self._filter_invalid_classes(
|
||||
response=response,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"After second-layer validation: {len(response.classes)} classes remain"
|
||||
)
|
||||
else:
|
||||
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
|
||||
else:
|
||||
if not enable_owl_validation:
|
||||
logger.info("Step 3: OWL validation disabled, skipping")
|
||||
else:
|
||||
logger.info("Step 3: No classes to validate, skipping OWL validation")
|
||||
|
||||
# Calculate total duration
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
# Log extraction statistics
|
||||
logger.info(
|
||||
f"Ontology extraction completed - "
|
||||
f"final_class_count={len(response.classes)}, "
|
||||
f"domain={response.domain}, "
|
||||
f"total_duration={total_duration:.2f}s, "
|
||||
f"llm_duration={llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Re-raise timeout errors
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction timed out after {timeout} seconds "
|
||||
f"(total duration: {total_duration:.2f}s)",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty response on failure
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
|
||||
async def _call_llm_for_extraction(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str],
|
||||
max_classes: int,
|
||||
llm_temperature: float,
|
||||
llm_max_tokens: int,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Call LLM to extract ontology classes from scenario.
|
||||
|
||||
This method renders the extraction prompt using the Jinja2 template
|
||||
and calls the LLM with structured output to get ontology classes.
|
||||
|
||||
Args:
|
||||
scenario: Scenario description text
|
||||
domain: Optional domain hint
|
||||
max_classes: Maximum number of classes to extract
|
||||
llm_temperature: LLM temperature parameter
|
||||
llm_max_tokens: LLM max tokens parameter
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse from LLM
|
||||
|
||||
Raises:
|
||||
Exception: If LLM call fails
|
||||
"""
|
||||
try:
|
||||
# Render prompt using template
|
||||
prompt_content = await render_ontology_extraction_prompt(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
json_schema=OntologyExtractionResponse.model_json_schema(),
|
||||
)
|
||||
|
||||
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert ontology engineer specializing in knowledge "
|
||||
"representation and OWL standards. Extract ontology classes from "
|
||||
"scenario descriptions following the provided instructions. "
|
||||
"Return valid JSON conforming to the schema."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt_content,
|
||||
},
|
||||
]
|
||||
|
||||
# Call LLM with structured output
|
||||
logger.debug(
|
||||
f"Calling LLM with temperature={llm_temperature}, "
|
||||
f"max_tokens={llm_max_tokens}"
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"LLM extraction successful - extracted {len(response.classes)} classes"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM extraction failed: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def _validate_and_clean(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
max_description_length: int,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Perform first-layer validation: string validation and cleaning.
|
||||
|
||||
This method validates and cleans the extracted ontology classes:
|
||||
1. Validate class names (PascalCase, no reserved words)
|
||||
2. Sanitize invalid class names
|
||||
3. Truncate long descriptions
|
||||
4. Remove duplicate classes
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse from LLM
|
||||
max_description_length: Maximum description length
|
||||
|
||||
Returns:
|
||||
Cleaned OntologyExtractionResponse
|
||||
"""
|
||||
if not response.classes:
|
||||
logger.debug("No classes to validate")
|
||||
return response
|
||||
|
||||
logger.debug(f"Validating {len(response.classes)} classes")
|
||||
|
||||
validated_classes = []
|
||||
|
||||
for ontology_class in response.classes:
|
||||
# Validate class name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"Invalid class name '{ontology_class.name}': {error_msg}"
|
||||
)
|
||||
|
||||
# Attempt to sanitize
|
||||
sanitized_name = self.validator.sanitize_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
|
||||
)
|
||||
|
||||
# Update class name
|
||||
ontology_class.name = sanitized_name
|
||||
|
||||
# Re-validate sanitized name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
sanitized_name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.error(
|
||||
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
|
||||
"Skipping this class."
|
||||
)
|
||||
continue
|
||||
|
||||
# Truncate description if too long
|
||||
if ontology_class.description:
|
||||
original_length = len(ontology_class.description)
|
||||
ontology_class.description = self.validator.truncate_description(
|
||||
ontology_class.description,
|
||||
max_length=max_description_length,
|
||||
)
|
||||
|
||||
if len(ontology_class.description) < original_length:
|
||||
logger.debug(
|
||||
f"Truncated description for '{ontology_class.name}': "
|
||||
f"{original_length} -> {len(ontology_class.description)} chars"
|
||||
)
|
||||
|
||||
validated_classes.append(ontology_class)
|
||||
|
||||
# Remove duplicates (case-insensitive)
|
||||
original_count = len(validated_classes)
|
||||
validated_classes = self.validator.remove_duplicates(validated_classes)
|
||||
|
||||
if len(validated_classes) < original_count:
|
||||
logger.info(
|
||||
f"Removed {original_count - len(validated_classes)} duplicate classes"
|
||||
)
|
||||
|
||||
# Return cleaned response
|
||||
return OntologyExtractionResponse(
|
||||
classes=validated_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
def _filter_invalid_classes(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
errors: List[str],
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Filter invalid classes based on OWL validation errors.
|
||||
|
||||
This method analyzes OWL validation errors and removes classes
|
||||
that caused validation failures (e.g., circular inheritance,
|
||||
inconsistencies).
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse to filter
|
||||
errors: List of error messages from OWL validation
|
||||
|
||||
Returns:
|
||||
Filtered OntologyExtractionResponse
|
||||
"""
|
||||
if not errors:
|
||||
return response
|
||||
|
||||
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
|
||||
|
||||
# Extract class names mentioned in errors
|
||||
invalid_class_names = set()
|
||||
|
||||
for error in errors:
|
||||
# Look for class names in error messages
|
||||
for ontology_class in response.classes:
|
||||
if ontology_class.name in error:
|
||||
invalid_class_names.add(ontology_class.name)
|
||||
logger.debug(
|
||||
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
|
||||
)
|
||||
|
||||
# Filter out invalid classes
|
||||
if invalid_class_names:
|
||||
original_count = len(response.classes)
|
||||
|
||||
filtered_classes = [
|
||||
c for c in response.classes
|
||||
if c.name not in invalid_class_names
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
|
||||
f"{invalid_class_names}"
|
||||
)
|
||||
|
||||
return OntologyExtractionResponse(
|
||||
classes=filtered_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -25,6 +25,15 @@ class TripletExtractor:
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
|
||||
def _get_language(self) -> str:
|
||||
"""Get the configured language for entity descriptions
|
||||
|
||||
Returns:
|
||||
Language code ("zh" or "en")
|
||||
"""
|
||||
from app.core.config import settings
|
||||
return settings.DEFAULT_LANGUAGE
|
||||
|
||||
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
||||
"""Process a single statement and return extracted triplets and entities"""
|
||||
# Render the prompt using helper function
|
||||
@@ -40,7 +49,8 @@ class TripletExtractor:
|
||||
statement=statement.statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||
predicate_instructions=PREDICATE_DEFINITIONS
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language()
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
|
||||
@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
|
||||
key=lambda eid: (
|
||||
_strength_rank(eid),
|
||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
0 # 临时占位
|
||||
),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
@@ -177,7 +177,7 @@ def render_entity_dedup_prompt(
|
||||
|
||||
# Args:
|
||||
# entity_a: Dict of entity A attributes
|
||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
|
||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
|
||||
@@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
chunk_content: The content of the chunk to process
|
||||
json_schema: JSON schema for the expected output format
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
statement=statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=json_schema,
|
||||
predicate_instructions=predicate_instructions
|
||||
predicate_instructions=predicate_instructions,
|
||||
language=language
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
@@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
'statement': 'str',
|
||||
'chunk_content': 'str',
|
||||
'json_schema': 'TripletExtractionResponse.schema',
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS'
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS',
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -213,6 +216,7 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts: str,
|
||||
json_schema: dict,
|
||||
max_words: int = 200,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
||||
@@ -221,6 +225,7 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts: Concatenated text of conversation chunks
|
||||
json_schema: JSON schema for the expected output format
|
||||
max_words: Maximum words for the summary
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string.
|
||||
@@ -230,12 +235,14 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts=chunk_texts,
|
||||
json_schema=json_schema,
|
||||
max_words=max_words,
|
||||
language=language,
|
||||
)
|
||||
log_prompt_rendering('memory summary', rendered_prompt)
|
||||
log_template_rendering('memory_summary.jinja2', {
|
||||
'chunk_texts_len': len(chunk_texts or ""),
|
||||
'max_words': max_words,
|
||||
'json_schema': 'MemorySummaryResponse.schema'
|
||||
'json_schema': 'MemorySummaryResponse.schema',
|
||||
'language': language
|
||||
})
|
||||
return rendered_prompt
|
||||
|
||||
@@ -388,24 +395,65 @@ async def render_memory_insight_prompt(
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_episodic_title_and_type_prompt(content: str) -> str:
|
||||
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
||||
|
||||
Args:
|
||||
content: The content of the episodic memory summary to analyze
|
||||
language: The language to use for title generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
||||
rendered_prompt = template.render(content=content)
|
||||
rendered_prompt = template.render(content=content, language=language)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('episodic_type_classification.jinja2', {
|
||||
'content_len': len(content) if content else 0
|
||||
'content_len': len(content) if content else 0,
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_ontology_extraction_prompt(
|
||||
scenario: str,
|
||||
domain: str | None = None,
|
||||
max_classes: int = 15,
|
||||
json_schema: dict | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
|
||||
|
||||
Args:
|
||||
scenario: The scenario description text to extract ontology classes from
|
||||
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
|
||||
max_classes: Maximum number of classes to extract (default: 15)
|
||||
json_schema: JSON schema for the expected output format
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_ontology.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('ontology extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_ontology.jinja2', {
|
||||
'scenario_len': len(scenario) if scenario else 0,
|
||||
'domain': domain,
|
||||
'max_classes': max_classes,
|
||||
'json_schema': 'OntologyExtractionResponse.schema'
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_a.description | default('') }}"
|
||||
- 别名: {{ entity_a.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_a.fact_summary | default('') }}"
|
||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
|
||||
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
||||
|
||||
实体B:
|
||||
@@ -17,7 +18,8 @@
|
||||
- 类型: "{{ entity_b.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_b.description | default('') }}"
|
||||
- 别名: {{ entity_b.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_b.fact_summary | default('') }}"
|
||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
|
||||
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
||||
|
||||
上下文:
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
=== Task ===
|
||||
Generate a concise title and classify the episodic memory into the most appropriate category.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成标题和分类。**
|
||||
{% else %}
|
||||
**Important: Please generate the title and classification in English.**
|
||||
{% endif %}
|
||||
|
||||
=== Requirements ===
|
||||
- Extract a clear, concise title (10-20 characters) that captures the core content
|
||||
{% if language == "zh" %}
|
||||
- 标题必须使用中文
|
||||
{% else %}
|
||||
- Title must be in English
|
||||
{% endif %}
|
||||
- Classify into exactly one category based on the primary theme
|
||||
- Be specific and avoid ambiguity
|
||||
- Output must be valid JSON conforming to the schema below
|
||||
|
||||
210
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
210
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
@@ -0,0 +1,210 @@
|
||||
===Task===
|
||||
Extract ontology classes from the given scenario description following ontology engineering standards.
|
||||
|
||||
===Role===
|
||||
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
|
||||
|
||||
===Scenario Description===
|
||||
{{ scenario }}
|
||||
|
||||
{% if domain -%}
|
||||
===Domain Hint===
|
||||
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
|
||||
{%- endif %}
|
||||
|
||||
===Extraction Rules===
|
||||
|
||||
**1. Abstract Classes, Not Instances:**
|
||||
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
|
||||
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
|
||||
- Think in terms of "types of things" rather than "specific things"
|
||||
|
||||
**2. Naming Convention (PascalCase):**
|
||||
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
|
||||
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
|
||||
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
|
||||
- Use clear, descriptive names in English
|
||||
- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA")
|
||||
- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试")
|
||||
|
||||
**3. Domain Relevance:**
|
||||
- Focus on classes that are central to the scenario's domain
|
||||
- Prioritize classes that represent key concepts, entities, or relationships
|
||||
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
|
||||
|
||||
**4. Class Quantity:**
|
||||
- Extract between 5 and {{ max_classes }} classes
|
||||
- Aim for a balanced set covering the main concepts in the scenario
|
||||
- Quality over quantity: prefer well-defined classes over exhaustive lists
|
||||
|
||||
**5. Clear Descriptions:**
|
||||
- Provide concise, informative descriptions in Chinese (max 500 characters)
|
||||
- Describe what the class represents, not specific instances
|
||||
- Use clear, natural Chinese language that explains the class's role in the domain
|
||||
|
||||
**6. Concrete Examples:**
|
||||
- Provide 2-5 concrete instance examples in Chinese for each class
|
||||
- Examples should be specific, realistic instances of the class
|
||||
- Examples help clarify the class's scope and meaning
|
||||
- Use natural Chinese language for examples
|
||||
- Example format: ["示例1", "示例2", "示例3"]
|
||||
|
||||
**7. Class Hierarchy:**
|
||||
- Identify parent-child relationships where applicable
|
||||
- Use the parent_class field to specify inheritance
|
||||
- Parent class must be one of the extracted classes or a standard OWL class
|
||||
- Leave parent_class as null for top-level classes
|
||||
|
||||
**8. Entity Types:**
|
||||
- Classify each class with an appropriate entity_type
|
||||
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
|
||||
- Choose the most specific type that applies
|
||||
|
||||
**9. OWL Reserved Words:**
|
||||
- Do NOT use OWL reserved words as class names
|
||||
- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal"
|
||||
- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class")
|
||||
|
||||
**10. Language Consistency:**
|
||||
- Extract all class names in English (PascalCase format) for the "name" field
|
||||
- Provide Chinese translation for class names in the "name_chinese" field
|
||||
- Descriptions MUST be in Chinese (中文)
|
||||
- Examples MUST be in Chinese (中文)
|
||||
- Use clear, natural Chinese language for descriptions and examples
|
||||
|
||||
===Examples===
|
||||
|
||||
**Example 1 (Healthcare Domain):**
|
||||
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Patient",
|
||||
"name_chinese": "患者",
|
||||
"description": "在医疗机构接受医疗护理或治疗的人",
|
||||
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Person",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "MedicalProcedure",
|
||||
"name_chinese": "医疗程序",
|
||||
"description": "为医疗诊断或治疗而执行的系统性操作流程",
|
||||
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Diagnosis",
|
||||
"name_chinese": "诊断",
|
||||
"description": "基于症状和检查结果对疾病或状况的识别",
|
||||
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Doctor",
|
||||
"name_chinese": "医生",
|
||||
"description": "诊断和治疗患者的持证医疗专业人员",
|
||||
"examples": ["全科医生", "外科医生", "心脏病专家"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Treatment",
|
||||
"name_chinese": "治疗",
|
||||
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
|
||||
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
}
|
||||
],
|
||||
"domain": "Healthcare",
|
||||
"namespace": "http://example.org/healthcare#"
|
||||
}
|
||||
|
||||
**Example 2 (Education Domain):**
|
||||
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Student",
|
||||
"name_chinese": "学生",
|
||||
"description": "在教育机构注册学习的人",
|
||||
"examples": ["本科生", "研究生", "在职学生"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Course",
|
||||
"name_chinese": "课程",
|
||||
"description": "涵盖特定学科或主题的结构化教育课程",
|
||||
"examples": ["计算机科学导论", "微积分I", "世界历史"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Professor",
|
||||
"name_chinese": "教授",
|
||||
"description": "教授课程并进行研究的学术教师",
|
||||
"examples": ["助理教授", "副教授", "正教授"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "AcademicProgram",
|
||||
"name_chinese": "学术项目",
|
||||
"description": "通向学位或证书的结构化课程体系",
|
||||
"examples": ["理学学士", "文学硕士", "博士项目"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Assignment",
|
||||
"name_chinese": "作业",
|
||||
"description": "分配给学生以评估学习成果的任务或项目",
|
||||
"examples": ["论文", "习题集", "研究报告", "实验报告"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Object",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Lecture",
|
||||
"name_chinese": "讲座",
|
||||
"description": "由教师进行的教育性演讲或讲座",
|
||||
"examples": ["入门讲座", "客座讲座", "在线讲座"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Event",
|
||||
"domain": "Education"
|
||||
}
|
||||
],
|
||||
"domain": "Education",
|
||||
"namespace": "http://example.org/education#"
|
||||
}
|
||||
|
||||
===Output Format===
|
||||
|
||||
**JSON Requirements:**
|
||||
- Use only ASCII double quotes (") for JSON structure
|
||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- All class names must be in PascalCase format
|
||||
- All class names must be unique (case-insensitive)
|
||||
- Extract between 5 and {{ max_classes }} classes
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -5,6 +5,12 @@
|
||||
===Task===
|
||||
Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成实体描述(description)和示例(example)。**
|
||||
{% else %}
|
||||
**Important: Please generate entity descriptions and examples in English.**
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
@@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
**Entity Extraction:**
|
||||
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
||||
{% if language == "zh" %}
|
||||
- **实体描述(description)必须使用中文**
|
||||
- **示例(example)必须使用中文**
|
||||
{% else %}
|
||||
- **Entity descriptions must be in English**
|
||||
- **Examples must be in English**
|
||||
{% endif %}
|
||||
- **Semantic Memory Classification (is_explicit_memory):**
|
||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
||||
@@ -334,9 +347,11 @@ Output:
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- The output language should ALWAYS match the input language
|
||||
- If input is in English, extract statements in English
|
||||
- If input is in Chinese, extract statements in Chinese
|
||||
{% if language == "zh" %}
|
||||
- **语言要求:实体描述(description)和示例(example)必须使用中文**
|
||||
{% else %}
|
||||
- **Language Requirement: Entity descriptions and examples must be in English**
|
||||
{% endif %}
|
||||
- Preserve the original language and do not translate
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -5,10 +5,21 @@
|
||||
=== Task ===
|
||||
Summarize the provided conversation chunks into a concise Memory summary.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成摘要内容。**
|
||||
{% else %}
|
||||
**Important: Please generate the summary content in English.**
|
||||
{% endif %}
|
||||
|
||||
=== Requirements ===
|
||||
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
||||
- Avoid repetition and filler; be specific.
|
||||
- Keep it under {{ max_words or 200 }} words.
|
||||
{% if language == "zh" %}
|
||||
- 摘要内容必须使用中文
|
||||
{% else %}
|
||||
- Summary content must be in English
|
||||
{% endif %}
|
||||
- Output must be valid JSON conforming to the schema below.
|
||||
|
||||
=== Input ===
|
||||
@@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary.
|
||||
4. Do not include line breaks within JSON string values
|
||||
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
||||
|
||||
The output language should always be the same as the input language.
|
||||
{% if language == "zh" %}
|
||||
**语言要求:输出内容必须使用中文。**
|
||||
{% else %}
|
||||
**Language Requirement: The output content must be in English.**
|
||||
{% endif %}
|
||||
|
||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||
{{ json_schema }}
|
||||
10
api/app/core/memory/utils/validation/__init__.py
Normal file
10
api/app/core/memory/utils/validation/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Validation utilities for ontology extraction.
|
||||
|
||||
This module provides validation classes for ontology class names,
|
||||
descriptions, and OWL compliance checking.
|
||||
"""
|
||||
|
||||
from .ontology_validator import OntologyValidator
|
||||
from .owl_validator import OWLValidator
|
||||
|
||||
__all__ = ['OntologyValidator', 'OWLValidator']
|
||||
268
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
268
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""String validation for ontology class names and descriptions.
|
||||
|
||||
This module provides the OntologyValidator class for validating and sanitizing
|
||||
ontology class names according to OWL standards and naming conventions.
|
||||
|
||||
Classes:
|
||||
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyValidator:
|
||||
"""Validator for ontology class names and descriptions.
|
||||
|
||||
This validator performs string-level validation including:
|
||||
- PascalCase naming convention validation
|
||||
- OWL reserved word checking
|
||||
- Duplicate class name removal
|
||||
- Description length truncation
|
||||
|
||||
Attributes:
|
||||
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
|
||||
"""
|
||||
|
||||
# OWL reserved words that cannot be used as class names
|
||||
OWL_RESERVED_WORDS = {
|
||||
'Thing', 'Nothing', 'Class', 'Property',
|
||||
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
|
||||
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
|
||||
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
|
||||
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
|
||||
'Annotation', 'AnnotationProperty', 'Axiom',
|
||||
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
|
||||
'Datatype', 'DataRange', 'Literal',
|
||||
'DeprecatedClass', 'DeprecatedProperty',
|
||||
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
|
||||
'BackwardCompatibleWith', 'OntologyProperty',
|
||||
}
|
||||
|
||||
def validate_class_name(self, name: str) -> Tuple[bool, str]:
|
||||
"""Validate that a class name follows OWL naming conventions.
|
||||
|
||||
Validation rules:
|
||||
1. Must not be empty
|
||||
2. Must start with an uppercase letter (PascalCase)
|
||||
3. Cannot contain spaces
|
||||
4. Can only contain alphanumeric characters and underscores
|
||||
5. Cannot be an OWL reserved word
|
||||
|
||||
Args:
|
||||
name: The class name to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
- is_valid: True if the name is valid, False otherwise
|
||||
- error_message: Empty string if valid, error description if invalid
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> validator.validate_class_name("MedicalProcedure")
|
||||
(True, "")
|
||||
>>> validator.validate_class_name("medical procedure")
|
||||
(False, "Class name 'medical procedure' cannot contain spaces")
|
||||
>>> validator.validate_class_name("Thing")
|
||||
(False, "Class name 'Thing' is an OWL reserved word")
|
||||
"""
|
||||
logger.debug(f"Validating class name: '{name}'")
|
||||
|
||||
# Check if empty
|
||||
if not name or not name.strip():
|
||||
error_msg = "Class name cannot be empty"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# Check if it's an OWL reserved word
|
||||
if name in self.OWL_RESERVED_WORDS:
|
||||
error_msg = f"Class name '{name}' is an OWL reserved word"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check if starts with uppercase letter
|
||||
if not name[0].isupper():
|
||||
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check for spaces
|
||||
if ' ' in name:
|
||||
error_msg = f"Class name '{name}' cannot contain spaces"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check for invalid characters (only alphanumeric and underscore allowed)
|
||||
if not re.match(r'^[A-Za-z0-9_]+$', name):
|
||||
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
logger.debug(f"Class name '{name}' is valid")
|
||||
return True, ""
|
||||
|
||||
def sanitize_class_name(self, name: str) -> str:
|
||||
"""Attempt to sanitize an invalid class name into a valid format.
|
||||
|
||||
Sanitization steps:
|
||||
1. Strip whitespace
|
||||
2. Remove invalid characters
|
||||
3. Replace spaces with empty string (PascalCase)
|
||||
4. Capitalize first letter of each word
|
||||
5. If result is empty or starts with number, prefix with 'Class'
|
||||
|
||||
Args:
|
||||
name: The class name to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized class name that should pass validation
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> validator.sanitize_class_name("medical procedure")
|
||||
'MedicalProcedure'
|
||||
>>> validator.sanitize_class_name("patient-record")
|
||||
'PatientRecord'
|
||||
>>> validator.sanitize_class_name("123invalid")
|
||||
'Class123Invalid'
|
||||
"""
|
||||
logger.debug(f"Sanitizing class name: '{name}'")
|
||||
|
||||
if not name or not name.strip():
|
||||
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
|
||||
return "UnnamedClass"
|
||||
|
||||
# Strip whitespace
|
||||
name = name.strip()
|
||||
original_name = name
|
||||
|
||||
# Split on spaces, hyphens, and underscores, then capitalize each word
|
||||
words = re.split(r'[\s\-_]+', name)
|
||||
|
||||
# Capitalize first letter of each word and keep rest as is
|
||||
sanitized_words = []
|
||||
for word in words:
|
||||
if word:
|
||||
# Remove non-alphanumeric characters except underscore
|
||||
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
|
||||
if clean_word:
|
||||
# Capitalize first letter
|
||||
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
|
||||
|
||||
# Join words
|
||||
sanitized = ''.join(sanitized_words)
|
||||
|
||||
# If empty or starts with number, prefix with 'Class'
|
||||
if not sanitized or sanitized[0].isdigit():
|
||||
sanitized = 'Class' + sanitized
|
||||
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
|
||||
|
||||
# If it's a reserved word, append 'Class' suffix
|
||||
if sanitized in self.OWL_RESERVED_WORDS:
|
||||
sanitized = sanitized + 'Class'
|
||||
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
|
||||
|
||||
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
|
||||
return sanitized
|
||||
|
||||
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
|
||||
"""Remove duplicate ontology classes based on case-insensitive name comparison.
|
||||
|
||||
When duplicates are found, keeps the first occurrence and discards subsequent ones.
|
||||
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects
|
||||
|
||||
Returns:
|
||||
List of OntologyClass objects with duplicates removed
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> classes = [
|
||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||
... ]
|
||||
>>> unique = validator.remove_duplicates(classes)
|
||||
>>> len(unique)
|
||||
2
|
||||
>>> [c.name for c in unique]
|
||||
['Patient', 'Doctor']
|
||||
"""
|
||||
if not classes:
|
||||
logger.debug("No classes to check for duplicates")
|
||||
return classes
|
||||
|
||||
logger.debug(f"Checking {len(classes)} classes for duplicates")
|
||||
|
||||
seen_names = set()
|
||||
unique_classes = []
|
||||
duplicates_found = []
|
||||
|
||||
for ontology_class in classes:
|
||||
# Use lowercase for comparison
|
||||
name_lower = ontology_class.name.lower()
|
||||
|
||||
if name_lower not in seen_names:
|
||||
seen_names.add(name_lower)
|
||||
unique_classes.append(ontology_class)
|
||||
else:
|
||||
duplicates_found.append(ontology_class.name)
|
||||
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
|
||||
|
||||
if duplicates_found:
|
||||
logger.info(
|
||||
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
|
||||
)
|
||||
else:
|
||||
logger.debug("No duplicate classes found")
|
||||
|
||||
return unique_classes
|
||||
|
||||
def truncate_description(self, description: str, max_length: int = 500) -> str:
|
||||
"""Truncate a description to a maximum length.
|
||||
|
||||
If the description exceeds max_length, it will be truncated and
|
||||
an ellipsis (...) will be appended to indicate truncation.
|
||||
|
||||
Args:
|
||||
description: The description text to truncate
|
||||
max_length: Maximum allowed length (default: 500)
|
||||
|
||||
Returns:
|
||||
Truncated description string
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> long_desc = "A" * 600
|
||||
>>> truncated = validator.truncate_description(long_desc, max_length=500)
|
||||
>>> len(truncated)
|
||||
500
|
||||
>>> truncated.endswith("...")
|
||||
True
|
||||
"""
|
||||
if not description:
|
||||
return ""
|
||||
|
||||
if len(description) <= max_length:
|
||||
return description
|
||||
|
||||
# Truncate and add ellipsis
|
||||
# Reserve 3 characters for "..."
|
||||
truncate_at = max_length - 3
|
||||
truncated = description[:truncate_at] + "..."
|
||||
|
||||
logger.debug(
|
||||
f"Truncated description from {len(description)} to {len(truncated)} characters"
|
||||
)
|
||||
|
||||
return truncated
|
||||
585
api/app/core/memory/utils/validation/owl_validator.py
Normal file
585
api/app/core/memory/utils/validation/owl_validator.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""OWL semantic validation for ontology classes using Owlready2.
|
||||
|
||||
This module provides the OWLValidator class for validating ontology classes
|
||||
against OWL standards using the Owlready2 library. It performs semantic
|
||||
validation including consistency checking, circular inheritance detection,
|
||||
and OWL file export.
|
||||
|
||||
Classes:
|
||||
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from owlready2 import (
|
||||
World,
|
||||
Thing,
|
||||
get_ontology,
|
||||
sync_reasoner_pellet,
|
||||
OwlReadyInconsistentOntologyError,
|
||||
)
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OWLValidator:
|
||||
"""Validator for OWL semantic validation of ontology classes.
|
||||
|
||||
This validator performs semantic-level validation using Owlready2 including:
|
||||
- Creating OWL classes from ontology class definitions
|
||||
- Running consistency checking with Pellet reasoner
|
||||
- Detecting circular inheritance
|
||||
- Validating Protégé compatibility
|
||||
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
|
||||
|
||||
Attributes:
|
||||
base_namespace: Base URI for the ontology namespace
|
||||
"""
|
||||
|
||||
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
|
||||
"""Initialize the OWL validator.
|
||||
|
||||
Args:
|
||||
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
|
||||
"""
|
||||
self.base_namespace = base_namespace
|
||||
|
||||
def validate_ontology_classes(
|
||||
self,
|
||||
classes: List[OntologyClass],
|
||||
) -> Tuple[bool, List[str], Optional[World]]:
|
||||
"""Validate extracted ontology classes against OWL standards.
|
||||
|
||||
This method creates an OWL ontology from the provided classes using Owlready2,
|
||||
runs consistency checking with the Pellet reasoner, and detects common issues
|
||||
like circular inheritance.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_messages, world):
|
||||
- is_valid: True if ontology is valid and consistent, False otherwise
|
||||
- error_messages: List of error/warning messages
|
||||
- world: Owlready2 World object containing the ontology (None if validation failed)
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = [
|
||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||
... ]
|
||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||
>>> is_valid
|
||||
True
|
||||
>>> len(errors)
|
||||
0
|
||||
"""
|
||||
if not classes:
|
||||
return False, ["No classes provided for validation"], None
|
||||
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# Create a new world (isolated ontology environment)
|
||||
world = World()
|
||||
|
||||
# Use a proper ontology IRI
|
||||
# Owlready2 expects the IRI to end with .owl or similar
|
||||
onto_iri = self.base_namespace.rstrip('#/')
|
||||
if not onto_iri.endswith('.owl'):
|
||||
onto_iri = onto_iri + '.owl'
|
||||
|
||||
# Create ontology
|
||||
onto = world.get_ontology(onto_iri)
|
||||
|
||||
with onto:
|
||||
# Dictionary to store created OWL classes for parent reference
|
||||
owl_classes = {}
|
||||
|
||||
# First pass: Create all classes without parent relationships
|
||||
for ontology_class in classes:
|
||||
try:
|
||||
# Create OWL class dynamically using type() with Thing as base
|
||||
# The key is to NOT set namespace in the dict, let Owlready2 handle it
|
||||
owl_class = type(
|
||||
ontology_class.name, # Class name
|
||||
(Thing,), # Base classes
|
||||
{} # Class dict (empty, let Owlready2 manage)
|
||||
)
|
||||
|
||||
# Add label (rdfs:label) - include both English and Chinese names
|
||||
labels = [ontology_class.name]
|
||||
if ontology_class.name_chinese:
|
||||
labels.append(ontology_class.name_chinese)
|
||||
owl_class.label = labels
|
||||
|
||||
# Add comment (rdfs:comment) with description
|
||||
if ontology_class.description:
|
||||
owl_class.comment = [ontology_class.description]
|
||||
|
||||
# Store for parent relationship setup
|
||||
owl_classes[ontology_class.name] = owl_class
|
||||
|
||||
logger.debug(
|
||||
f"Created OWL class: {ontology_class.name} "
|
||||
f"(Chinese: {ontology_class.name_chinese}) "
|
||||
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
# Second pass: Set up parent relationships
|
||||
for ontology_class in classes:
|
||||
if ontology_class.parent_class and ontology_class.name in owl_classes:
|
||||
parent_name = ontology_class.parent_class
|
||||
|
||||
# Check if parent exists
|
||||
if parent_name in owl_classes:
|
||||
try:
|
||||
child_class = owl_classes[ontology_class.name]
|
||||
parent_class = owl_classes[parent_name]
|
||||
|
||||
# Set parent by modifying is_a
|
||||
child_class.is_a = [parent_class]
|
||||
|
||||
logger.debug(
|
||||
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Failed to set parent relationship "
|
||||
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
|
||||
)
|
||||
errors.append(error_msg)
|
||||
logger.warning(error_msg)
|
||||
else:
|
||||
warning_msg = (
|
||||
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
|
||||
)
|
||||
errors.append(warning_msg)
|
||||
logger.warning(warning_msg)
|
||||
|
||||
# Check for circular inheritance
|
||||
for class_name, owl_class in owl_classes.items():
|
||||
if self._has_circular_inheritance(owl_class):
|
||||
error_msg = f"Circular inheritance detected for class '{class_name}'"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Run consistency checking with Pellet reasoner
|
||||
try:
|
||||
logger.info("Running Pellet reasoner for consistency checking...")
|
||||
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
|
||||
logger.info("Consistency check passed")
|
||||
|
||||
except OwlReadyInconsistentOntologyError as e:
|
||||
error_msg = f"Ontology is inconsistent: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
return False, errors, world
|
||||
|
||||
except Exception as e:
|
||||
# Reasoner errors are often due to Java not being installed or configured
|
||||
# Log as warning but don't fail validation - ontology structure is still valid
|
||||
warning_msg = f"Reasoner check skipped: {str(e)}"
|
||||
if str(e).strip(): # Only log if there's an actual error message
|
||||
logger.warning(warning_msg)
|
||||
else:
|
||||
logger.warning("Reasoner check skipped: Java may not be installed or configured")
|
||||
# Continue - ontology structure is valid even without reasoner check
|
||||
|
||||
# If we have errors (excluding warnings), validation failed
|
||||
is_valid = len(errors) == 0
|
||||
|
||||
return is_valid, errors, world
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OWL validation failed: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, errors, None
|
||||
|
||||
def _has_circular_inheritance(self, owl_class) -> bool:
|
||||
"""Check if an OWL class has circular inheritance.
|
||||
|
||||
Circular inheritance occurs when a class inherits from itself through
|
||||
a chain of parent relationships (e.g., A -> B -> C -> A).
|
||||
|
||||
Args:
|
||||
owl_class: Owlready2 class object to check
|
||||
|
||||
Returns:
|
||||
True if circular inheritance is detected, False otherwise
|
||||
"""
|
||||
visited = set()
|
||||
current = owl_class
|
||||
|
||||
while current:
|
||||
# Get class IRI or name as identifier
|
||||
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
|
||||
|
||||
if class_id in visited:
|
||||
# Found a cycle
|
||||
return True
|
||||
|
||||
visited.add(class_id)
|
||||
|
||||
# Get parent classes (is_a relationship)
|
||||
parents = getattr(current, 'is_a', [])
|
||||
|
||||
# Filter out Thing and other base classes
|
||||
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
|
||||
|
||||
if not parent_classes:
|
||||
# No more parents, no cycle
|
||||
break
|
||||
|
||||
# Check first parent (in single inheritance)
|
||||
current = parent_classes[0] if parent_classes else None
|
||||
|
||||
return False
|
||||
|
||||
def export_to_owl(
|
||||
self,
|
||||
world: World,
|
||||
output_path: Optional[str] = None,
|
||||
format: str = "rdfxml",
|
||||
classes: Optional[List] = None
|
||||
) -> str:
|
||||
"""Export ontology to OWL file in specified format.
|
||||
|
||||
Supported formats:
|
||||
- rdfxml: RDF/XML format (default, most compatible)
|
||||
- turtle: Turtle format (more readable)
|
||||
- ntriples: N-Triples format (simplest)
|
||||
- json: JSON format (simplified, human-readable)
|
||||
|
||||
Args:
|
||||
world: Owlready2 World object containing the ontology
|
||||
output_path: Optional file path to save the ontology (if None, returns string)
|
||||
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
|
||||
classes: Optional list of OntologyClass objects (required for json format)
|
||||
|
||||
Returns:
|
||||
String representation of the exported ontology
|
||||
|
||||
Raises:
|
||||
ValueError: If format is not supported
|
||||
RuntimeError: If export fails
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
|
||||
"""
|
||||
# Validate format
|
||||
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
|
||||
if format not in valid_formats:
|
||||
raise ValueError(
|
||||
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
|
||||
)
|
||||
|
||||
# JSON format doesn't need OWL processing
|
||||
if format == "json":
|
||||
if not classes:
|
||||
raise ValueError("Classes list is required for JSON format export")
|
||||
return self._export_to_json(classes)
|
||||
|
||||
# For OWL formats, world is required
|
||||
if not world:
|
||||
raise ValueError("World object is None. Cannot export ontology.")
|
||||
|
||||
# Note: Owlready2 has issues with turtle format export
|
||||
# We'll handle it specially by converting from rdfxml
|
||||
use_conversion = (format == "turtle")
|
||||
|
||||
try:
|
||||
# Get all ontologies in the world
|
||||
ontologies = list(world.ontologies.values())
|
||||
|
||||
if not ontologies:
|
||||
raise RuntimeError("No ontologies found in world")
|
||||
|
||||
# Find the ontology with classes (skip anonymous/empty ontologies)
|
||||
onto = None
|
||||
for ont in ontologies:
|
||||
classes_count = len(list(ont.classes()))
|
||||
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
|
||||
if classes_count > 0:
|
||||
onto = ont
|
||||
break
|
||||
|
||||
# If no ontology with classes found, use the last non-anonymous one
|
||||
if onto is None:
|
||||
for ont in reversed(ontologies):
|
||||
if ont.base_iri != "http://anonymous/":
|
||||
onto = ont
|
||||
break
|
||||
|
||||
# If still no ontology, use the first one
|
||||
if onto is None:
|
||||
onto = ontologies[0]
|
||||
|
||||
# Log ontology contents for debugging
|
||||
logger.info(f"Ontology IRI: {onto.base_iri}")
|
||||
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
|
||||
|
||||
# List all classes in the ontology
|
||||
all_classes = list(onto.classes())
|
||||
for cls in all_classes:
|
||||
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
|
||||
if hasattr(cls, 'label'):
|
||||
logger.debug(f" Labels: {cls.label}")
|
||||
if hasattr(cls, 'comment'):
|
||||
logger.debug(f" Comments: {cls.comment}")
|
||||
|
||||
if len(all_classes) == 0:
|
||||
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
|
||||
|
||||
if output_path:
|
||||
# Save to file
|
||||
export_format = "rdfxml" if use_conversion else format
|
||||
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
|
||||
onto.save(file=output_path, format=export_format)
|
||||
|
||||
# Read back the file content to return
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Convert to turtle if needed
|
||||
if use_conversion:
|
||||
content = self._convert_to_turtle(content)
|
||||
|
||||
logger.info(f"Successfully exported ontology to {output_path}")
|
||||
|
||||
# Format the content for better readability
|
||||
content = self._format_owl_content(content, format)
|
||||
|
||||
return content
|
||||
else:
|
||||
# Export to string (save to temporary location and read)
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
export_format = "rdfxml" if use_conversion else format
|
||||
onto.save(file=tmp_path, format=export_format)
|
||||
|
||||
with open(tmp_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Convert to turtle if needed
|
||||
if use_conversion:
|
||||
content = self._convert_to_turtle(content)
|
||||
|
||||
# Format the content for better readability
|
||||
content = self._format_owl_content(content, format)
|
||||
|
||||
return content
|
||||
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to export ontology: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _export_to_json(self, classes: List) -> str:
|
||||
"""Export ontology classes to simplified JSON format.
|
||||
|
||||
This format is more compact and easier to parse than OWL XML.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects
|
||||
|
||||
Returns:
|
||||
JSON string representation (compact format)
|
||||
"""
|
||||
import json
|
||||
|
||||
result = {
|
||||
"ontology": {
|
||||
"namespace": self.base_namespace,
|
||||
"classes": []
|
||||
}
|
||||
}
|
||||
|
||||
for cls in classes:
|
||||
class_data = {
|
||||
"name": cls.name,
|
||||
"name_chinese": cls.name_chinese,
|
||||
"description": cls.description,
|
||||
"entity_type": cls.entity_type,
|
||||
"domain": cls.domain,
|
||||
"parent_class": cls.parent_class,
|
||||
"examples": cls.examples if hasattr(cls, 'examples') else []
|
||||
}
|
||||
result["ontology"]["classes"].append(class_data)
|
||||
|
||||
# 使用紧凑格式:无缩进,使用分隔符减少空格
|
||||
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
|
||||
|
||||
def _convert_to_turtle(self, rdfxml_content: str) -> str:
|
||||
"""Convert RDF/XML content to Turtle format using rdflib.
|
||||
|
||||
Args:
|
||||
rdfxml_content: RDF/XML format content
|
||||
|
||||
Returns:
|
||||
Turtle format content
|
||||
"""
|
||||
try:
|
||||
from rdflib import Graph
|
||||
|
||||
# Parse RDF/XML
|
||||
g = Graph()
|
||||
g.parse(data=rdfxml_content, format="xml")
|
||||
|
||||
# Serialize to Turtle
|
||||
turtle_content = g.serialize(format="turtle")
|
||||
|
||||
# Handle bytes vs string
|
||||
if isinstance(turtle_content, bytes):
|
||||
turtle_content = turtle_content.decode('utf-8')
|
||||
|
||||
return turtle_content
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"rdflib is not installed. Cannot convert to Turtle format. "
|
||||
"Install with: pip install rdflib"
|
||||
)
|
||||
return rdfxml_content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert to Turtle format: {e}")
|
||||
return rdfxml_content
|
||||
|
||||
def _format_owl_content(self, content: str, format: str) -> str:
|
||||
"""Format OWL content for better readability.
|
||||
|
||||
Args:
|
||||
content: Raw OWL content string
|
||||
format: Format type (rdfxml, turtle, ntriples)
|
||||
|
||||
Returns:
|
||||
Formatted OWL content string
|
||||
"""
|
||||
if format == "rdfxml":
|
||||
# Format XML with proper indentation
|
||||
try:
|
||||
import xml.dom.minidom as minidom
|
||||
dom = minidom.parseString(content)
|
||||
# Pretty print with 2-space indentation
|
||||
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
|
||||
|
||||
# Remove extra blank lines
|
||||
lines = []
|
||||
prev_blank = False
|
||||
for line in formatted.split('\n'):
|
||||
is_blank = not line.strip()
|
||||
if not (is_blank and prev_blank): # Skip consecutive blank lines
|
||||
lines.append(line)
|
||||
prev_blank = is_blank
|
||||
|
||||
formatted = '\n'.join(lines)
|
||||
|
||||
return formatted
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to format XML content: {e}")
|
||||
return content
|
||||
|
||||
elif format == "turtle":
|
||||
# Turtle format is already relatively readable
|
||||
# Just ensure consistent line endings and not empty
|
||||
if not content or content.strip() == "":
|
||||
logger.warning("Turtle content is empty, this may indicate an export issue")
|
||||
return content.strip() + '\n' if content.strip() else content
|
||||
|
||||
elif format == "ntriples":
|
||||
# N-Triples format is line-based, ensure proper line endings
|
||||
return content.strip() + '\n' if content.strip() else content
|
||||
|
||||
return content
|
||||
|
||||
def validate_with_protege_compatibility(
|
||||
self,
|
||||
classes: List[OntologyClass]
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Validate that ontology classes are compatible with Protégé editor.
|
||||
|
||||
Protégé compatibility checks:
|
||||
- Class names are valid OWL identifiers
|
||||
- No special characters that Protégé cannot handle
|
||||
- Namespace is properly formatted
|
||||
- Labels and comments are properly encoded
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_compatible, warnings):
|
||||
- is_compatible: True if compatible with Protégé, False otherwise
|
||||
- warnings: List of compatibility warning messages
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
|
||||
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
|
||||
>>> is_compatible
|
||||
True
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# Check namespace format
|
||||
if not self.base_namespace.startswith(('http://', 'https://')):
|
||||
warnings.append(
|
||||
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
|
||||
"for Protégé compatibility"
|
||||
)
|
||||
|
||||
if not self.base_namespace.endswith(('#', '/')):
|
||||
warnings.append(
|
||||
f"Namespace '{self.base_namespace}' should end with # or / "
|
||||
"for Protégé compatibility"
|
||||
)
|
||||
|
||||
# Check each class
|
||||
for ontology_class in classes:
|
||||
# Check for special characters that might cause issues
|
||||
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
|
||||
warnings.append(
|
||||
f"Class name '{ontology_class.name}' contains special characters "
|
||||
"that may cause issues in Protégé"
|
||||
)
|
||||
|
||||
# Check description length (Protégé can handle long descriptions but may display poorly)
|
||||
if ontology_class.description and len(ontology_class.description) > 1000:
|
||||
warnings.append(
|
||||
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
|
||||
"which may display poorly in Protégé"
|
||||
)
|
||||
|
||||
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
|
||||
if not ontology_class.name.isascii():
|
||||
warnings.append(
|
||||
f"Class name '{ontology_class.name}' contains non-ASCII characters "
|
||||
"which may cause encoding issues in some Protégé versions"
|
||||
)
|
||||
|
||||
# If no warnings, it's compatible
|
||||
is_compatible = len(warnings) == 0
|
||||
|
||||
return is_compatible, warnings
|
||||
Reference in New Issue
Block a user