Release/v0.2.3 (#355)

* feat(web): add PageEmpty component

* feat(web): add PageTabs component

* feat(web): add PageEmpty component

* feat(web): add PageTabs component

* feat(prompt): add history tracking for prompt releases

* feat(web): add prompt menu

* refactor: The PageScrollList component supports two generic parameters

* feat(web): BodyWrapper compoent update PageLoading

* feat(web): add Ontology menu

* feat(web): memory management add scene

* feat(tasks): add celery task configuration for periodic jobs

- Add ignore_result=True to prevent storing results for periodic tasks
- Set max_retries=0 to skip failed periodic tasks without retry attempts
- Configure acks_late=False for immediate acknowledgment in beat tasks
- Add time_limit and soft_time_limit to regenerate_memory_cache task (3600s/3300s)
- Add time_limit and soft_time_limit to workspace_reflection_task (300s/240s)
- Add time_limit and soft_time_limit to run_forgetting_cycle_task (7200s/7000s)
- Improve task reliability and resource management for scheduled jobs

* feat(sandbox): add Node.js code execution support to sandbox

* Release/v0.2.2 (#260)

* [modify] migration script

* [add] migration script

* fix(web): change form message

* fix(web): the memoryContent field is compatible with numbers and strings

* feat(web): code node hidden

* fix(model):
1. create a basic model to check if the name and provider are duplicated.
2. The result shows error models because the provider created API Keys for all matching models.

---------

Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>

* Feature/ontology class clean (#249)

* [add] Complete ontology engineering feature implementation

* [add] Add ontology feature integration and validation utilities

* [add] Add OWL validator and validation utilities

* [fix] Add missing render_ontology_extraction_prompt function

* [fix]Add dependencies, fix functionality

* [add] migration script

* feat(celery): add dedicated periodic tasks worker and queue (#261)

* fix(web): conflict resolve

* Fix/v022 bug (#263)

* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [fix]Fix the issue of inconsistent language in explicit and episodic memory.

* [add]Add scene_id

* [fix]Based on the AI review to fix the code

* Fix/develop memory reflex (#265)

* 遗漏的历史映射

* 遗漏的历史映射

* 反思后台报错处理

* [add] migration script

* fix: chat conversation_id add node_start

* feat(web): show code node

* fix(web): Restructure the CustomSelect component, repair the interface that is called multiple times when the form is updated

* feat(web): RadioGroupCard support block mode

* feat(web): create space add icon

* feat(app and model): token consumption statistics

* Add/develop memory (#264)

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 遗漏的历史映射

* 新增长期记忆功能

* 新增长期记忆功能

* 新增长期记忆功能

* 知识库检索多余字段

* 长期

* feat(app and model): token consumption statistics of the cluster

* memory_BUG_fix

* fix(web): prompt history remove pageLoading

* fix(prompt): remove hard-coded import of prompt file paths (#279)

* Fix/develop memory bug (#274)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix(web): update retrieve_type key

* Fix/develop memory bug (#276)

* 遗漏的历史映射

* 遗漏的历史映射

* fix_timeline_memories

* fix_timeline_memories

* write_gragp/bug_fix

* write_gragp/bug_fix

* write_gragp/bug_fix

* chore(celery): disable periodic task scheduling

* fix(prompt): remove hard-coded import of prompt file paths

---------

Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Ke Sun <kesun5@illinois.edu>

* fix(web): remove delete confirm content

* refactor(workflow): relocate template directory into workflow

* feat(memory): add long-term storage task routing and batching

* fix(web): PageScrollList loading update

* fix(web): PageScrollList loading update

* Ontology v1 bug (#291)

* [changes]Add 'id' as the secondary sorting key, and 'scene_id' now returns a UUID object

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]Set the default values of the memory configuration model based on the spatial model.

* [fix]Remove the entity extraction check combination model, read the configuration list, and add the return of scene_id

* [fix]Fix the "end_user" return to be sorted by update time.

* [fix]

* fix(memory): add Redis session validation

- Add macOS fork() safety configuration in celery_app.py to prevent initialization issues
- Add null/False checks for Redis session queries in term_memory_save to handle missing sessions gracefully
- Add null/False checks in memory_long_term_storage to prevent processing empty Redis results
- Add null/False checks in aggregate_judgment before format_parsing to avoid errors on missing data
- Initialize redis_messages variable in window_dialogue for consistency
- Add debug logging when no existing session found in Redis for better troubleshooting
- Add TODO comments for magic numbers (scope=6, time=5) to be extracted as constants
- Improve error handling when Redis returns False or empty results instead of crashing

* fix(web): PageScrollList style update

* fix(workflow): fix argument passing in code execution nodes

* fix(web): prompt add disabled

* fix(web): space icon required

* feat(app): modify the key of the token

* fix(fix the key of the app's token):

* fix(workflow): switch code input encoding to base64+URL encoding

* [add]The main project adds multi-API Key load balancing.

* [changes]Attribute security access, secure numerical conversion, unified use of local variables

* fix(web): save add session update

* fix(web): language editor support paste

* [changes]Active status filtering logic, API Key selection strategy

* memory_BUG

* memory_BUG_long_term

* [changes]

* memory_BUG_long_term

* memory_BUG_long_term

* Fix/release memory bug (#306)

* memory_BUG_fix

* memory_BUG

* memory_BUG_long_term

* memory_BUG_long_term

* memory_BUG_long_term

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* [fix]1.The "read_all_config" interface returns "scene_name";2.Memory configuration for lightweight query ontology scenarios

* fix(web): replace code editor

* [changes]Modify the description of the time for the recent event

* [changes]Modify the code based on the AI review

* feat(web): update memory config ontology api

* fix(web): ui update

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* knowledge_retrieval/bug/fix

* feat(workflow): add token usage statistics for question classifier and parameter extraction

* feat(web): move prompt menu

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Multiple independent transactions - single transaction

* Write Missing None (#321)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/release memory bug (#324)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/writer memory bug (#326)

* [fix]Fix the bug

* [fix]Fix the bug

* [fix]Correct the direction indication.

* fix(web): markdown table ui update

* Fix/release memory bug (#332)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fix/fact summary (#333)

* [fix]Disable the contents related to fact_summary

* [fix]Disable the contents related to fact_summary

* [fix]Modify the code based on the AI review

* Fix/release memory bug (#335)

* Write Missing None

* Write Missing None

* Write Missing None

* Apply suggestion from @sourcery-ai[bot]

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Write Missing None

* redis update

* redis update

* redis update

* redis update

* writer_dup_bug/fix

* writer_graph_bug/fix

* writer_graph_bug/fix

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Revert "feat(web): move prompt menu"

This reverts commit 9e6e8f50f8.

* fix(web): ui update

* fix(web): update text

* fix(web): ui update

* fix(model): change the "vl" model type of dashscope to "chat"

* fix(model): change the "vl" model type of dashscope to "chat"

---------

Co-authored-by: zhaoying <yzhao96@best-inc.com>
Co-authored-by: Eternity <1533512157@qq.com>
Co-authored-by: Mark <zhuwenhui5566@163.com>
Co-authored-by: yingzhao <zhaoyingyz@126.com>
Co-authored-by: Timebomb2018 <18868801967@163.com>
Co-authored-by: 乐力齐 <162269739+lanceyq@users.noreply.github.com>
Co-authored-by: lixinyue11 <94037597+lixinyue11@users.noreply.github.com>
Co-authored-by: lixinyue <2569494688@qq.com>
Co-authored-by: Eternity <61316157+myhMARS@users.noreply.github.com>
Co-authored-by: lanceyq <1982376970@qq.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
This commit is contained in:
Ke Sun
2026-02-06 19:01:57 +08:00
committed by GitHub
parent eab7225d83
commit 79ab929fb0
187 changed files with 12252 additions and 1656 deletions

View File

@@ -0,0 +1,238 @@
import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context, get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages使用它替代 structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
根据窗口获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
scope窗口大小
'''
scope=scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# 获取 config_id如果 memory_config 是对象,提取 config_id否则直接使用
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
聚合判断函数:判断输入句子和历史消息是否描述同一事件
Args:
end_user_id: 终端用户ID
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
json_schema = WriteAggregateModel.model_json_schema()
template_service = TemplateService(template_root)
system_prompt = await template_service.render_template(
template_name='write_aggregate_judgment.jinja2',
operation_name='aggregate_judgment',
history=history,
sentence=ori_messages,
json_schema=json_schema
)
with get_db_context() as db_session:
factory = MemoryClientFactory(db_session)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
messages = [
{
"role": "user",
"content": system_prompt
}
]
structured = await llm_client.response_structured(
messages=messages,
response_model=WriteAggregateModel
)
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
for msg in output_value
]
result_dict = {
"is_same_event": structured.is_same_event,
"output": output_value
}
if not structured.is_same_event:
logger.info(result_dict)
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}

View File

@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
清理后的数据
"""
# 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
}
if isinstance(data, dict):

View File

@@ -0,0 +1,72 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
"""
格式化解析消息列表
Args:
messages: 消息列表
type: 返回类型 ('string''dict')
Returns:
格式化后的消息列表
"""
result = []
user=[]
ai=[]
for message in messages:
hstory_messages = message['messages']
for history_messag in hstory_messages.strip().splitlines():
history_messag = json.loads(history_messag)
for content in history_messag:
role = content['role']
content = content['content']
if type == "string":
if role == 'human' or role=="user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict" :
if role == 'human' or role=="user":
user.append( content)
else:
ai.append(content)
if type == "dict":
for key,values in zip(user,ai):
result.append({key:values})
return result
async def messages_parse(messages: list | dict):
user=[]
ai=[]
database=[]
for message in messages:
Query = message['Query']
Query = json.loads(Query)
for data in Query:
role = data['role']
if role == "human":
user.append(data['content'])
if role == "ai":
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
async def agent_chat_messages(user_content,ai_content):
messages = [
{
"role": "user",
"content": f"{user_content}"
},
{
"role": "assistant",
"content": f"{ai_content}"
}
]
return messages

View File

@@ -1,22 +1,20 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langchain_core.messages import HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -34,14 +32,6 @@ async def make_write_graph():
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
@@ -51,43 +41,63 @@ async def make_write_graph():
yield graph
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
end_user_id = 'new_2025test1103' # 组ID
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j'==node_name:
massages=node_data
massages=massages.get('write_result')['status']
print(massages) # | 更新数据: {node_data}
except Exception as e:
import traceback
traceback.print_exc()
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -0,0 +1,28 @@
"""Pydantic models for write aggregate judgment operations."""
from typing import List, Union
from pydantic import BaseModel, Field
class MessageItem(BaseModel):
"""Individual message item in conversation."""
role: str = Field(..., description="角色user 或 assistant")
content: str = Field(..., description="消息内容")
class WriteAggregateResponse(BaseModel):
"""Response model for aggregate judgment containing judgment result and output."""
is_same_event: bool = Field(
...,
description="是否是同一事件。True表示是同一事件False表示不同事件"
)
output: Union[List[MessageItem], bool] = Field(
...,
description="如果is_same_event为True返回False如果is_same_event为False返回消息列表"
)
# 为了保持向后兼容,保留旧的类名作为别名
WriteAggregateModel = WriteAggregateResponse

View File

@@ -0,0 +1,57 @@
输入句子:{{sentence}}
历史消息:{{history}}
# 你的角色
你是一个擅长事件聚合与语义判断的专家。
# 你的任务
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False
- 描述的是同一个具体事件或事实
- 存在明显的因果关系、前后发展关系
- 是对同一事件的补充、解释、追问或延展
- 逻辑上属于同一语境下的连续讨论
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
- 话题不同,事件主体不同
- 时间、地点、对象明显不同
- 只是语义相似,但并非同一具体事件
- 无直接事件、因果或逻辑关联
# 输出规则(非常重要)
你必须按照以下JSON格式输出
**如果是同一事件:**
```json
{
"is_same_event": true,
"output": false
}
```
**如果不是同一事件:**
```json
{
"is_same_event": false,
"output": [
{
"role": "user",
"content": "输入句子的内容"
},
{
"role": "assistant",
"content": "对应的回复内容"
}
]
}
```
# JSON Schema
{{json_schema}}
# 注意事项
- 必须严格按照上述格式输出
- output 字段:如果是同一事件返回 false如果不是同一事件返回完整的消息列表
- 消息列表必须包含 role 和 content 字段
- 不要输出任何解释、分析或多余内容

View File

@@ -0,0 +1,186 @@
import json
from typing import Any, List, Dict, Optional
from datetime import datetime, timedelta
def serialize_messages(messages: Any) -> str:
"""
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
Args:
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
Returns:
str: JSON 字符串
"""
if isinstance(messages, str):
return messages
if isinstance(messages, (list, tuple)):
# 检查是否是 LangChain 消息对象列表
serialized_list = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# LangChain 消息对象
serialized_list.append({
'type': msg.type,
'content': msg.content,
'role': getattr(msg, 'role', msg.type)
})
elif isinstance(msg, dict):
serialized_list.append(msg)
else:
serialized_list.append(str(msg))
return json.dumps(serialized_list, ensure_ascii=False)
if isinstance(messages, dict):
return json.dumps(messages, ensure_ascii=False)
# 其他类型转为字符串
return str(messages)
def deserialize_messages(messages_str: str) -> Any:
"""
将 JSON 字符串反序列化为原始格式
Args:
messages_str: JSON 字符串
Returns:
反序列化后的对象list、dict 或 string
"""
if not messages_str:
return []
try:
return json.loads(messages_str)
except (json.JSONDecodeError, TypeError):
return messages_str
def fix_encoding(text: str) -> str:
"""
修复错误编码的文本
Args:
text: 需要修复的文本
Returns:
str: 修复后的文本
"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
"""
格式化会话数据为统一的输出格式
Args:
data: 原始会话数据
include_time: 是否包含时间字段
Returns:
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
"""
result = {
"Query": fix_encoding(data.get('messages', '')),
"Answer": fix_encoding(data.get('aimessages', ''))
}
if include_time:
result["starttime"] = data.get('starttime', '')
return result
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
"""
根据时间范围过滤数据
Args:
items: 包含 starttime 字段的数据列表
minutes: 时间范围(分钟)
Returns:
List[Dict]: 过滤后的数据列表
"""
time_threshold = datetime.now() - timedelta(minutes=minutes)
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
filtered_items = []
for item in items:
starttime = item.get('starttime', '')
if starttime and starttime >= time_threshold_str:
filtered_items.append(item)
return filtered_items
def sort_and_limit_results(items: List[Dict], limit: int = 6,
remove_time: bool = True) -> List[Dict]:
"""
对结果进行排序、限制数量并移除时间字段
Args:
items: 数据列表
limit: 最大返回数量
remove_time: 是否移除 starttime 字段
Returns:
List[Dict]: 处理后的数据列表
"""
# 按时间降序排序(最新的在前)
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 限制数量
result_items = items[:limit]
# 移除 starttime 字段
if remove_time:
for item in result_items:
item.pop('starttime', None)
# 如果结果少于1条返回空列表
if len(result_items) < 1:
return []
return result_items
def generate_session_key(session_id: str, key_type: str = "session") -> str:
"""
生成 Redis key
Args:
session_id: 会话ID
key_type: key 类型 ("session", "read", "write", "count")
Returns:
str: Redis key
"""
if key_type == "count":
return f"session:count:{session_id}"
elif key_type == "write":
return f"session:write:{session_id}"
elif key_type == "session" or key_type == "read":
return f"session:{session_id}"
else:
return f"session:{session_id}"
def get_current_timestamp() -> str:
"""
获取当前时间戳字符串
Returns:
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
"""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -1,11 +1,36 @@
import redis
import uuid
from datetime import datetime
from app.core.config import settings
from typing import List, Dict, Any, Optional, Union
from app.core.memory.agent.utils.redis_base import (
serialize_messages,
deserialize_messages,
fix_encoding,
format_session_data,
filter_by_time_range,
sort_and_limit_results,
generate_session_key,
get_current_timestamp
)
class RedisSessionStore:
class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
@@ -16,32 +41,437 @@ class RedisSessionStore:
)
self.uudi = session_id
def _fix_encoding(self, text):
"""修复错误编码的文本"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
def save_session_write(self, userid: str, messages: str) -> str:
"""
写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒
Args:
userid: 用户ID
messages: 用户消息
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
messages = serialize_messages(messages)
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="write")
# 使用 pipeline 批量写入,减少网络往返
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
"messages": messages,
"starttime": get_current_timestamp()
})
result = pipe.execute()
# 直接写入数据decode_responses=True 已经处理了编码
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}")
raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
"""
通过 save_session_write 的 userid 获取 sessionid 和 messages
Args:
userid: 用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
results.append({
"sessionid": session_id,
"messages": fix_encoding(data.get('messages', ''))
})
if not results:
return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}")
return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
"""
通过 end_user_id 获取所有 write 类型的会话数据
Args:
end_user_id: 终端用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
返回格式:
[
{
"session_id": "uuid",
"id": "...",
"sessionid": "end_user_id",
"messages": "...",
"starttime": "timestamp"
},
...
]
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 end_user_id 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
# 构建完整的会话信息
session_info = {
"session_id": session_id,
"id": data.get('id', ''),
"sessionid": data.get('sessionid', ''),
"messages": fix_encoding(data.get('messages', '')),
"starttime": data.get('starttime', '')
}
results.append(session_info)
if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False
# 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
import traceback
traceback.print_exc()
return False
def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]:
"""
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
Args:
userid: 用户ID (对应 sessionid 字段)
minutes: 查询最近几分钟的数据默认5分钟
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
matched_items = []
for data in all_data:
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空
matched_items.append({
"Query": fix_encoding(data.get('messages', '')),
"Answer": "",
"starttime": data.get('starttime', '')
})
# 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None)
print(result_items)
elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
def delete_all_write_sessions(self) -> int:
"""
删除所有 write 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:write:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
"""
保存用户访问次数统计
Args:
end_user_id: 终端用户ID
count: 访问次数
messages: 消息内容
Returns:
str: 新生成的 session_id
"""
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"end_user_id": end_user_id,
"count": int(count),
"messages": serialize_messages(messages),
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
"""
通过 end_user_id 查询访问次数统计
Args:
end_user_id: 终端用户ID
Returns:
list 或 False: 如果找到返回 [count, messages],否则返回 False
"""
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
self.r.delete(index_key)
return False
except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
return False
# 直接获取数据
key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key)
if not data:
# 索引存在但数据不存在,清理索引
self.r.delete(index_key)
return False
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
"""
通过 end_user_id 修改访问次数统计(优化版:使用索引)
Args:
end_user_id: 终端用户ID
new_count: 新的 count 值
messages: 消息内容
Returns:
bool: 更新成功返回 True未找到记录返回 False
"""
try:
# 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误
try:
key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key)
if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
# 直接更新数据
key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages)
pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count))
pipe.hset(key, 'messages', messages_str)
result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
return False
def delete_all_count_sessions(self) -> int:
"""
删除所有 count 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:count:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
# ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
"""
写入一条会话数据,返回 session_id
Args:
userid: 用户ID
messages: 用户消息
aimessages: AI回复消息
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="read")
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
@@ -49,177 +479,195 @@ class RedisSessionStore:
"end_user_id": end_user_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
"starttime": get_current_timestamp()
})
# 可选设置过期时间例如30天避免数据无限增长
# pipe.expire(key, 30 * 24 * 60 * 60)
# 执行批量操作
result = pipe.execute()
print(f"保存结果: {result[0]}, session_id: {session_id}")
return session_id # 返回新生成的 session_id
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"保存会话失败: {e}")
print(f"[save_session] 保存会话失败: {e}")
raise e
def save_sessions_batch(self, sessions_data):
"""
批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能
"""
try:
session_ids = []
pipe = self.r.pipeline()
for session in sessions_data:
session_id = str(uuid.uuid4())
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}"
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": session.get('userid'),
"apply_id": session.get('apply_id'),
"end_user_id": session.get('end_user_id'),
"messages": session.get('messages'),
"aimessages": session.get('aimessages'),
"starttime": starttime
})
session_ids.append(session_id)
# 一次性执行所有写入操作
results = pipe.execute()
print(f"批量保存完成: {len(session_ids)} 条记录")
return session_ids
except Exception as e:
print(f"批量保存会话失败: {e}")
raise e
# ---------------- 读取 ----------------
def get_session(self, session_id):
# ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
读取一条会话数据
Args:
session_id: 会话ID
Returns:
Dict 或 None: 会话数据
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
data = self.r.hgetall(key)
return data if data else None
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key in self.r.keys('session:*'):
data = self.r.hgetall(key)
if not data:
continue
# 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
result_items.append(data)
return result_items
def get_all_sessions(self):
"""
获取所有会话数据
获取所有会话数据(不包括 count 和 write 类型)
Returns:
Dict: 所有会话数据key 为 session_id
"""
sessions = {}
for key in self.r.keys('session:*'):
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
# 排除 count 和 write 类型的 key
if ':count:' not in key and ':write:' not in key:
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
return sessions
# ---------------- 更新 ----------------
def update_session(self, session_id, field, value):
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
"""
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
Args:
sessionid: 会话ID支持模糊匹配
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
keys = self.r.keys('session:*')
if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool:
"""
更新单个字段
优化版本:使用 pipeline 减少网络往返
Args:
session_id: 会话ID
field: 字段名
value: 字段值
Returns:
bool: 是否更新成功
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
pipe = self.r.pipeline()
pipe.exists(key)
pipe.hset(key, field, value)
results = pipe.execute()
return bool(results[0]) # 返回 key 是否存在
return bool(results[0])
# ---------------- 删除 ----------------
def delete_session(self, session_id):
# ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int:
"""
删除单条会话
Args:
session_id: 会话ID
Returns:
int: 删除的数量
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
return self.r.delete(key)
def delete_all_sessions(self):
def delete_all_sessions(self) -> int:
"""
删除所有会话
删除所有会话(不包括 count 和 write 类型)
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:*')
if keys:
return self.r.delete(*keys)
# 过滤掉 count 和 write 类型
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
if keys_to_delete:
return self.r.delete(*keys_to_delete)
return 0
def delete_duplicate_sessions(self):
def delete_duplicate_sessions(self) -> int:
"""
删除重复会话数据,条件:
"sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成
删除重复会话数据(不包括 count 和 write 类型)
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
Returns:
int: 删除的数量
"""
import time
start_time = time.time()
# 第一步:使用 pipeline 批量获取所有 key
keys = self.r.keys('session:*')
if not keys:
print("[delete_duplicate_sessions] 没有会话数据")
return 0
# 第二步:使用 pipeline 批量获取所有数据
# 批量获取所有数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 第三步:在内存中识别重复数据
seen = {} # 用字典记录identifier -> key保留第一个出现的 key
keys_to_delete = [] # 需要删除的 key 列表
# 识别重复数据
seen = {}
keys_to_delete = []
for key, data in zip(keys, all_data, strict=False):
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
if not data:
continue
# 获取五个字段的值
sessionid = data.get('sessionid', '')
user_id = data.get('id', '')
end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '')
aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
identifier = (
data.get('sessionid', ''),
data.get('id', ''),
data.get('end_user_id', ''),
data.get('messages', ''),
data.get('aimessages', '')
)
if identifier in seen:
# 重复,标记为待删除
keys_to_delete.append(key)
else:
# 第一次出现,记录
seen[identifier] = key
# 第四步:使用 pipeline 批量删除重复的 key
# 批量删除重复的 key
deleted_count = 0
if keys_to_delete:
# 分批删除,避免单次操作过大
batch_size = 1000
for i in range(0, len(keys_to_delete), batch_size):
batch = keys_to_delete[i:i + batch_size]
@@ -233,79 +681,28 @@ class RedisSessionStore:
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count
def find_user_session(self, sessionid):
user_id = sessionid
result_items = []
for key, values in store.get_all_sessions().items():
history = {}
if user_id == str(values['sessionid']):
history["Query"] = values['messages']
history["Answer"] = values['aimessages']
result_items.append(history)
if len(result_items) <= 1:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
"""
import time
start_time = time.time()
# 使用 pipeline 批量获取数据,提高性能
keys = self.r.keys('session:*')
if not keys:
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 使用 pipeline 批量获取所有 hash 数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 解析并筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
# 检查是否符合三个条件
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({
"Query": self._fix_encoding(data.get('messages')),
"Answer": self._fix_encoding(data.get('aimessages')),
"starttime": data.get('starttime', '')
})
# 按时间降序排序(最新的在前)
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 只保留最新的6条
result_items = matched_items[:6]
# # 移除 starttime 字段
for item in result_items:
item.pop('starttime', None)
# 如果结果少于等于1条返回空列表
if len(result_items) <= 1:
result_items = []
elapsed_time = time.time() - start_time
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# 全局实例
store = RedisSessionStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
)
write_store = RedisWriteStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
count_store = RedisCountStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)

View File

@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally.
"""
import asyncio
import time
from datetime import datetime
@@ -123,23 +124,48 @@ async def write(
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒
for attempt in range(max_retries):
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
break
else:
logger.warning("Failed to save some data to Neo4j")
if attempt < max_retries - 1:
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
except Exception as e:
error_msg = str(e)
# 检查是否是死锁错误
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
if attempt < max_retries - 1:
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
else:
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
raise
else:
# 非死锁错误,直接抛出
raise
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close()
except Exception as e:
logger.error(f"Error closing Neo4j connector: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)

View File

@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
# Ontology models
from app.core.memory.models.ontology_models import (
OntologyClass,
OntologyExtractionResponse,
)
# Variable configuration models
from app.core.memory.models.variate_config import (
StatementExtractionConfig,
@@ -105,6 +111,9 @@ __all__ = [
"Entity",
"Triplet",
"TripletExtractionResponse",
# Ontology models
"OntologyClass",
"OntologyExtractionResponse",
# Variable configuration
"StatementExtractionConfig",
"ForgettingEngineConfig",

View File

@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity"
)
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -0,0 +1,135 @@
"""Models for ontology classes and extraction responses.
This module contains Pydantic models for representing extracted ontology classes
from scenario descriptions, following OWL ontology engineering standards.
Classes:
OntologyClass: Represents an extracted ontology class
OntologyExtractionResponse: Response model containing extracted ontology classes
"""
from typing import List, Optional
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
class OntologyClass(BaseModel):
"""Represents an extracted ontology class from scenario description.
An ontology class represents an abstract category or concept in a domain,
following OWL ontology engineering standards and naming conventions.
Attributes:
id: Unique string identifier for the ontology class
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
description: Textual description of the class
examples: List of concrete instance examples of this class
parent_class: Optional name of the parent class in the hierarchy
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
id: str = Field(
default_factory=lambda: uuid4().hex,
description="Unique identifier for the ontology class"
)
name: str = Field(
...,
description="Name of the class in PascalCase format"
)
name_chinese: Optional[str] = Field(
None,
description="Chinese translation of the class name"
)
description: str = Field(
...,
description="Description of the class"
)
examples: List[str] = Field(
default_factory=list,
description="List of concrete instance examples"
)
parent_class: Optional[str] = Field(
None,
description="Name of the parent class in the hierarchy"
)
entity_type: str = Field(
...,
description="Type/category of the entity"
)
domain: str = Field(
...,
description="Domain this class belongs to"
)
@field_validator('name')
@classmethod
def validate_pascal_case(cls, v: str) -> str:
"""Validate that the class name follows PascalCase convention.
PascalCase rules:
- Must start with an uppercase letter
- Cannot contain spaces
- Should not contain special characters except underscores
Args:
v: The class name to validate
Returns:
The validated class name
Raises:
ValueError: If the name doesn't follow PascalCase convention
"""
if not v:
raise ValueError("Class name cannot be empty")
if not v[0].isupper():
raise ValueError(
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
)
if ' ' in v:
raise ValueError(
f"Class name '{v}' cannot contain spaces (PascalCase)"
)
# Check for invalid characters (allow alphanumeric and underscore only)
if not all(c.isalnum() or c == '_' for c in v):
raise ValueError(
f"Class name '{v}' contains invalid characters. "
"Only alphanumeric characters and underscores are allowed"
)
return v
class OntologyExtractionResponse(BaseModel):
"""Response model for ontology extraction from LLM.
This model represents the structured output from the LLM when
extracting ontology classes from scenario descriptions.
Attributes:
classes: List of extracted ontology classes
domain: Domain/field the scenario belongs to
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
classes: List[OntologyClass] = Field(
default_factory=list,
description="List of extracted ontology classes"
)
domain: str = Field(
...,
description="Domain/field the scenario belongs to"
)

View File

@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
if len(desc_b) > len(desc_a):
canonical.description = desc_b
# 合并事实摘要:统一保留一个“实体: name”行来源行去重保序
fact_a = getattr(canonical, "fact_summary", "") or ""
fact_b = getattr(ent, "fact_summary", "") or ""
def _extract_sources(txt: str) -> List[str]:
sources: List[str] = []
if not txt:
return sources
for line in str(txt).splitlines():
ln = line.strip()
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = getattr(canonical, "fact_summary", "") or ""
# fact_b = getattr(ent, "fact_summary", "") or ""
# def _extract_sources(txt: str) -> List[str]:
# sources: List[str] = []
# if not txt:
# return sources
# for line in str(txt).splitlines():
# ln = line.strip()
# 支持“来源:”或“来源:”前缀
m = re.match(r"^来源[:]\s*(.+)$", ln)
if m:
content = m.group(1).strip()
if content:
sources.append(content)
# m = re.match(r"^来源[:]\s*(.+)$", ln)
# if m:
# content = m.group(1).strip()
# if content:
# sources.append(content)
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
if not sources and txt.strip():
sources.append(txt.strip())
return sources
# if not sources and txt.strip():
# sources.append(txt.strip())
# return sources
try:
src_a = _extract_sources(fact_a)
src_b = _extract_sources(fact_b)
seen = set()
merged_sources: List[str] = []
for s in src_a + src_b:
if s and s not in seen:
seen.add(s)
merged_sources.append(s)
if merged_sources:
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
elif fact_b and not fact_a:
canonical.fact_summary = fact_b
# src_a = _extract_sources(fact_a)
# src_b = _extract_sources(fact_b)
# seen = set()
# merged_sources: List[str] = []
# for s in src_a + src_b:
# if s and s not in seen:
# seen.add(s)
# merged_sources.append(s)
# if merged_sources:
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
# elif fact_b and not fact_a:
# canonical.fact_summary = fact_b
pass
except Exception:
# 兜底:若解析失败,保留较长文本
if len(fact_b) > len(fact_a):
canonical.fact_summary = fact_b
# if len(fact_b) > len(fact_a):
# canonical.fact_summary = fact_b
pass
except Exception:
pass

View File

@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
desc_a = (getattr(a, "description", "") or "")
desc_b = (getattr(b, "description", "") or "")
fact_a = (getattr(a, "fact_summary", "") or "")
fact_b = (getattr(b, "fact_summary", "") or "")
score_a = len(desc_a) + len(fact_a)
score_b = len(desc_b) + len(fact_b)
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_a = (getattr(a, "fact_summary", "") or "")
# fact_b = (getattr(b, "fact_summary", "") or "")
# score_a = len(desc_a) + len(fact_a)
# score_b = len(desc_b) + len(fact_b)
score_a = len(desc_a)
score_b = len(desc_b)
if score_a != score_b:
return 0 if score_a >= score_b else 1
return 0
@@ -189,7 +192,8 @@ async def _judge_pair(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -197,7 +201,8 @@ async def _judge_pair(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
# 5. 渲染LLM提示词用工具函数填充模板包含实体信息、上下文、输出格式
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(a, "entity_type", None),
"description": getattr(a, "description", None),
"aliases": getattr(a, "aliases", None) or [],
"fact_summary": getattr(a, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(a, "fact_summary", None),
"connect_strength": getattr(a, "connect_strength", None),
}
entity_b = {
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
"entity_type": getattr(b, "entity_type", None),
"description": getattr(b, "description", None),
"aliases": getattr(b, "aliases", None) or [],
"fact_summary": getattr(b, "fact_summary", None),
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# "fact_summary": getattr(b, "fact_summary", None),
"connect_strength": getattr(b, "connect_strength", None),
}
prompt = render_entity_dedup_prompt(

View File

@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
description=row.get("description") or "",
aliases=row.get("aliases") or [],
name_embedding=row.get("name_embedding") or [],
fact_summary=row.get("fact_summary") or "",
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=row.get("fact_summary") or "",
connect_strength=row.get("connect_strength") or "",
)

View File

@@ -1085,7 +1085,8 @@ class ExtractionOrchestrator:
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),

View File

@@ -8,4 +8,5 @@
- TemporalExtractor: 时间信息提取
- EmbeddingGenerator: 嵌入向量生成
- MemorySummaryGenerator: 记忆摘要生成
- OntologyExtractor: 本体类提取
"""

View File

@@ -14,6 +14,34 @@ from pydantic import Field
logger = get_memory_logger(__name__)
# 支持的语言列表和默认回退值
SUPPORTED_LANGUAGES = {"zh", "en"}
FALLBACK_LANGUAGE = "en"
def validate_language(language: Optional[str]) -> str:
"""
校验语言参数,确保其为有效值。
Args:
language: 待校验的语言代码
Returns:
有效的语言代码("zh""en"
"""
if language is None:
return FALLBACK_LANGUAGE
lang = str(language).lower().strip()
if lang in SUPPORTED_LANGUAGES:
return lang
logger.warning(
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'"
f"支持的语言: {SUPPORTED_LANGUAGES}"
)
return FALLBACK_LANGUAGE
class MemorySummaryResponse(RobustLLMResponse):
"""Structured response for summary generation per chunk.
@@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse):
async def generate_title_and_type_for_summary(
content: str,
llm_client
llm_client,
language: str = None
) -> Tuple[str, str]:
"""
为MemorySummary生成标题和类型
@@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary(
Args:
content: Summary的内容文本
llm_client: LLM客户端实例
language: 生成标题使用的语言 ("zh" 中文, "en" 英文)如果为None则从配置读取
Returns:
(标题, 类型)元组
"""
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
from app.core.config import settings
# 如果没有指定语言,从配置中读取,并校验有效性
if language is None:
language = settings.DEFAULT_LANGUAGE
language = validate_language(language)
# 定义有效的类型集合
VALID_TYPES = {
@@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary(
}
DEFAULT_TYPE = "conversation" # 默认类型
# 根据语言设置默认标题
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
ERROR_TITLE = "错误" if language == "zh" else "Error"
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
try:
if not content:
logger.warning("content为空无法生成标题和类型")
return ("空内容", DEFAULT_TYPE)
logger.warning(f"content为空无法生成标题和类型 (language={language})")
return (DEFAULT_TITLE, DEFAULT_TYPE)
# 1. 渲染Jinja2提示词模板
prompt = await render_episodic_title_and_type_prompt(content)
# 1. 渲染Jinja2提示词模板,传递语言参数
prompt = await render_episodic_title_and_type_prompt(content, language=language)
# 2. 调用LLM生成标题和类型
messages = [
@@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary(
json_str = json_str.strip()
result_data = json.loads(json_str)
title = result_data.get("title", "未知标题")
title = result_data.get("title", UNKNOWN_TITLE)
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
# 5. 校验和归一化类型
@@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary(
f"已归一化为 '{episodic_type}'"
)
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
return (title, episodic_type)
except json.JSONDecodeError:
logger.error(f"无法解析LLM响应为JSON: {full_response}")
return ("解析失败", DEFAULT_TYPE)
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
except Exception as e:
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
return ("错误", DEFAULT_TYPE)
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
return (ERROR_TITLE, DEFAULT_TYPE)
async def _process_chunk_summary(
dialog: DialogData,
@@ -153,11 +195,16 @@ async def _process_chunk_summary(
return None
try:
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
from app.core.config import settings
language = validate_language(settings.DEFAULT_LANGUAGE)
# Render prompt via Jinja2 for a single chunk
prompt_content = await render_memory_summary_prompt(
chunk_texts=chunk.content,
json_schema=MemorySummaryResponse.model_json_schema(),
max_words=200,
language=language,
)
messages = [
@@ -178,9 +225,10 @@ async def _process_chunk_summary(
try:
title, episodic_type = await generate_title_and_type_for_summary(
content=summary_text,
llm_client=llm_client
llm_client=llm_client,
language=language
)
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
except Exception as e:
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
# Continue without title and type

View File

@@ -0,0 +1,482 @@
"""Ontology class extraction from scenario descriptions using LLM.
This module provides the OntologyExtractor class for extracting ontology classes
from natural language scenario descriptions. It uses LLM-driven extraction combined
with two-layer validation (string validation + OWL semantic validation).
Classes:
OntologyExtractor: Extracts ontology classes from scenario descriptions
"""
import asyncio
import logging
import time
from typing import List, Optional
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.ontology_models import (
OntologyClass,
OntologyExtractionResponse,
)
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
from app.core.memory.utils.validation.owl_validator import OWLValidator
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
logger = logging.getLogger(__name__)
class OntologyExtractor:
"""Extractor for ontology classes from scenario descriptions.
This extractor uses LLM to identify abstract classes and concepts from
natural language scenario descriptions, following OWL ontology engineering
standards. It performs two-layer validation:
1. String validation (naming conventions, reserved words, duplicates)
2. OWL semantic validation (consistency checking, circular inheritance)
Attributes:
llm_client: OpenAI client for LLM calls
validator: String validator for class names and descriptions
owl_validator: OWL validator for semantic validation
"""
def __init__(self, llm_client: OpenAIClient):
"""Initialize the OntologyExtractor.
Args:
llm_client: OpenAIClient instance for LLM processing
"""
self.llm_client = llm_client
self.validator = OntologyValidator()
self.owl_validator = OWLValidator()
logger.info("OntologyExtractor initialized")
async def extract_ontology_classes(
self,
scenario: str,
domain: Optional[str] = None,
max_classes: int = 15,
min_classes: int = 5,
enable_owl_validation: bool = True,
llm_temperature: float = 0.3,
llm_max_tokens: int = 2000,
max_description_length: int = 500,
timeout: Optional[float] = None,
) -> OntologyExtractionResponse:
"""Extract ontology classes from a scenario description.
This is the main extraction method that orchestrates the entire process:
1. Call LLM to extract ontology classes
2. Perform first-layer validation (string validation and cleaning)
3. Perform second-layer validation (OWL semantic validation)
4. Filter invalid classes based on validation errors
5. Return validated ontology classes
Args:
scenario: Natural language scenario description
domain: Optional domain hint (e.g., "Healthcare", "Education")
max_classes: Maximum number of classes to extract (default: 15)
min_classes: Minimum number of classes to extract (default: 5)
enable_owl_validation: Whether to enable OWL validation (default: True)
llm_temperature: LLM temperature parameter (default: 0.3)
llm_max_tokens: LLM max tokens parameter (default: 2000)
max_description_length: Maximum description length (default: 500)
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
Returns:
OntologyExtractionResponse containing validated ontology classes
Raises:
ValueError: If scenario is empty or invalid
asyncio.TimeoutError: If extraction times out
Examples:
>>> extractor = OntologyExtractor(llm_client)
>>> response = await extractor.extract_ontology_classes(
... scenario="A hospital manages patient records...",
... domain="Healthcare",
... max_classes=10,
... timeout=30.0
... )
>>> len(response.classes)
7
"""
# Start timing
start_time = time.time()
# Validate input
if not scenario or not scenario.strip():
logger.error("Scenario description is empty")
raise ValueError("Scenario description cannot be empty")
scenario = scenario.strip()
logger.info(
f"Starting ontology extraction - scenario_length={len(scenario)}, "
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
f"timeout={timeout}"
)
try:
# Step 1: Call LLM for extraction with timeout
logger.info("Step 1: Calling LLM for ontology extraction")
llm_start_time = time.time()
if timeout is not None:
# Wrap LLM call with timeout
try:
response = await asyncio.wait_for(
self._call_llm_for_extraction(
scenario=scenario,
domain=domain,
max_classes=max_classes,
llm_temperature=llm_temperature,
llm_max_tokens=llm_max_tokens,
),
timeout=timeout
)
except asyncio.TimeoutError:
llm_duration = time.time() - llm_start_time
logger.error(
f"LLM extraction timed out after {timeout} seconds "
f"(actual duration: {llm_duration:.2f}s)"
)
# Return empty response on timeout
return OntologyExtractionResponse(
classes=[],
domain=domain or "Unknown",
)
else:
# No timeout specified, call directly
response = await self._call_llm_for_extraction(
scenario=scenario,
domain=domain,
max_classes=max_classes,
llm_temperature=llm_temperature,
llm_max_tokens=llm_max_tokens,
)
llm_duration = time.time() - llm_start_time
logger.info(
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
)
# Step 2: First-layer validation (string validation and cleaning)
logger.info("Step 2: Performing first-layer validation (string validation)")
validation_start_time = time.time()
response = self._validate_and_clean(
response=response,
max_description_length=max_description_length,
)
validation_duration = time.time() - validation_start_time
logger.info(
f"After first-layer validation: {len(response.classes)} classes remain "
f"(validation took {validation_duration:.2f}s)"
)
# Check if we have enough classes after first-layer validation
if len(response.classes) < min_classes:
logger.warning(
f"Only {len(response.classes)} classes remain after validation, "
f"which is below minimum of {min_classes}"
)
# Step 3: Second-layer validation (OWL semantic validation)
if enable_owl_validation and response.classes:
logger.info("Step 3: Performing second-layer validation (OWL validation)")
owl_start_time = time.time()
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
classes=response.classes,
)
owl_duration = time.time() - owl_start_time
if not is_valid:
logger.warning(
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
)
# Filter invalid classes based on errors
response = self._filter_invalid_classes(
response=response,
errors=errors,
)
logger.info(
f"After second-layer validation: {len(response.classes)} classes remain"
)
else:
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
else:
if not enable_owl_validation:
logger.info("Step 3: OWL validation disabled, skipping")
else:
logger.info("Step 3: No classes to validate, skipping OWL validation")
# Calculate total duration
total_duration = time.time() - start_time
# Log extraction statistics
logger.info(
f"Ontology extraction completed - "
f"final_class_count={len(response.classes)}, "
f"domain={response.domain}, "
f"total_duration={total_duration:.2f}s, "
f"llm_duration={llm_duration:.2f}s"
)
return response
except asyncio.TimeoutError:
# Re-raise timeout errors
total_duration = time.time() - start_time
logger.error(
f"Ontology extraction timed out after {timeout} seconds "
f"(total duration: {total_duration:.2f}s)",
exc_info=True
)
raise
except Exception as e:
total_duration = time.time() - start_time
logger.error(
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
exc_info=True
)
# Return empty response on failure
return OntologyExtractionResponse(
classes=[],
domain=domain or "Unknown",
)
async def _call_llm_for_extraction(
self,
scenario: str,
domain: Optional[str],
max_classes: int,
llm_temperature: float,
llm_max_tokens: int,
) -> OntologyExtractionResponse:
"""Call LLM to extract ontology classes from scenario.
This method renders the extraction prompt using the Jinja2 template
and calls the LLM with structured output to get ontology classes.
Args:
scenario: Scenario description text
domain: Optional domain hint
max_classes: Maximum number of classes to extract
llm_temperature: LLM temperature parameter
llm_max_tokens: LLM max tokens parameter
Returns:
OntologyExtractionResponse from LLM
Raises:
Exception: If LLM call fails
"""
try:
# Render prompt using template
prompt_content = await render_ontology_extraction_prompt(
scenario=scenario,
domain=domain,
max_classes=max_classes,
json_schema=OntologyExtractionResponse.model_json_schema(),
)
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
# Create messages for LLM
messages = [
{
"role": "system",
"content": (
"You are an expert ontology engineer specializing in knowledge "
"representation and OWL standards. Extract ontology classes from "
"scenario descriptions following the provided instructions. "
"Return valid JSON conforming to the schema."
),
},
{
"role": "user",
"content": prompt_content,
},
]
# Call LLM with structured output
logger.debug(
f"Calling LLM with temperature={llm_temperature}, "
f"max_tokens={llm_max_tokens}"
)
response = await self.llm_client.response_structured(
messages=messages,
response_model=OntologyExtractionResponse,
)
logger.info(
f"LLM extraction successful - extracted {len(response.classes)} classes"
)
return response
except Exception as e:
logger.error(
f"LLM extraction failed: {str(e)}",
exc_info=True
)
raise
def _validate_and_clean(
self,
response: OntologyExtractionResponse,
max_description_length: int,
) -> OntologyExtractionResponse:
"""Perform first-layer validation: string validation and cleaning.
This method validates and cleans the extracted ontology classes:
1. Validate class names (PascalCase, no reserved words)
2. Sanitize invalid class names
3. Truncate long descriptions
4. Remove duplicate classes
Args:
response: OntologyExtractionResponse from LLM
max_description_length: Maximum description length
Returns:
Cleaned OntologyExtractionResponse
"""
if not response.classes:
logger.debug("No classes to validate")
return response
logger.debug(f"Validating {len(response.classes)} classes")
validated_classes = []
for ontology_class in response.classes:
# Validate class name
is_valid, error_msg = self.validator.validate_class_name(
ontology_class.name
)
if not is_valid:
logger.warning(
f"Invalid class name '{ontology_class.name}': {error_msg}"
)
# Attempt to sanitize
sanitized_name = self.validator.sanitize_class_name(
ontology_class.name
)
logger.info(
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
)
# Update class name
ontology_class.name = sanitized_name
# Re-validate sanitized name
is_valid, error_msg = self.validator.validate_class_name(
sanitized_name
)
if not is_valid:
logger.error(
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
"Skipping this class."
)
continue
# Truncate description if too long
if ontology_class.description:
original_length = len(ontology_class.description)
ontology_class.description = self.validator.truncate_description(
ontology_class.description,
max_length=max_description_length,
)
if len(ontology_class.description) < original_length:
logger.debug(
f"Truncated description for '{ontology_class.name}': "
f"{original_length} -> {len(ontology_class.description)} chars"
)
validated_classes.append(ontology_class)
# Remove duplicates (case-insensitive)
original_count = len(validated_classes)
validated_classes = self.validator.remove_duplicates(validated_classes)
if len(validated_classes) < original_count:
logger.info(
f"Removed {original_count - len(validated_classes)} duplicate classes"
)
# Return cleaned response
return OntologyExtractionResponse(
classes=validated_classes,
domain=response.domain,
)
def _filter_invalid_classes(
self,
response: OntologyExtractionResponse,
errors: List[str],
) -> OntologyExtractionResponse:
"""Filter invalid classes based on OWL validation errors.
This method analyzes OWL validation errors and removes classes
that caused validation failures (e.g., circular inheritance,
inconsistencies).
Args:
response: OntologyExtractionResponse to filter
errors: List of error messages from OWL validation
Returns:
Filtered OntologyExtractionResponse
"""
if not errors:
return response
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
# Extract class names mentioned in errors
invalid_class_names = set()
for error in errors:
# Look for class names in error messages
for ontology_class in response.classes:
if ontology_class.name in error:
invalid_class_names.add(ontology_class.name)
logger.debug(
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
)
# Filter out invalid classes
if invalid_class_names:
original_count = len(response.classes)
filtered_classes = [
c for c in response.classes
if c.name not in invalid_class_names
]
logger.info(
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
f"{invalid_class_names}"
)
return OntologyExtractionResponse(
classes=filtered_classes,
domain=response.domain,
)
return response

View File

@@ -25,6 +25,15 @@ class TripletExtractor:
"""
self.llm_client = llm_client
def _get_language(self) -> str:
"""Get the configured language for entity descriptions
Returns:
Language code ("zh" or "en")
"""
from app.core.config import settings
return settings.DEFAULT_LANGUAGE
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
"""Process a single statement and return extracted triplets and entities"""
# Render the prompt using helper function
@@ -40,7 +49,8 @@ class TripletExtractor:
statement=statement.statement,
chunk_content=chunk_content,
json_schema=TripletExtractionResponse.model_json_schema(),
predicate_instructions=PREDICATE_DEFINITIONS
predicate_instructions=PREDICATE_DEFINITIONS,
language=self._get_language()
)
# Create messages for LLM

View File

@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
key=lambda eid: (
_strength_rank(eid),
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
0 # 临时占位
),
reverse=True
)

View File

@@ -177,7 +177,7 @@ def render_entity_dedup_prompt(
# Args:
# entity_a: Dict of entity A attributes
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
"""
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
@@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
chunk_content: The content of the chunk to process
json_schema: JSON schema for the expected output format
predicate_instructions: Optional predicate instructions
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
@@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
statement=statement,
chunk_content=chunk_content,
json_schema=json_schema,
predicate_instructions=predicate_instructions
predicate_instructions=predicate_instructions,
language=language
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt)
@@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
'statement': 'str',
'chunk_content': 'str',
'json_schema': 'TripletExtractionResponse.schema',
'predicate_instructions': 'PREDICATE_DEFINITIONS'
'predicate_instructions': 'PREDICATE_DEFINITIONS',
'language': language
})
return rendered_prompt
@@ -213,6 +216,7 @@ async def render_memory_summary_prompt(
chunk_texts: str,
json_schema: dict,
max_words: int = 200,
language: str = "zh",
) -> str:
"""
Renders the memory summary prompt using the memory_summary.jinja2 template.
@@ -221,6 +225,7 @@ async def render_memory_summary_prompt(
chunk_texts: Concatenated text of conversation chunks
json_schema: JSON schema for the expected output format
max_words: Maximum words for the summary
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string.
@@ -230,12 +235,14 @@ async def render_memory_summary_prompt(
chunk_texts=chunk_texts,
json_schema=json_schema,
max_words=max_words,
language=language,
)
log_prompt_rendering('memory summary', rendered_prompt)
log_template_rendering('memory_summary.jinja2', {
'chunk_texts_len': len(chunk_texts or ""),
'max_words': max_words,
'json_schema': 'MemorySummaryResponse.schema'
'json_schema': 'MemorySummaryResponse.schema',
'language': language
})
return rendered_prompt
@@ -388,24 +395,65 @@ async def render_memory_insight_prompt(
return rendered_prompt
async def render_episodic_title_and_type_prompt(content: str) -> str:
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
"""
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
Args:
content: The content of the episodic memory summary to analyze
language: The language to use for title generation ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("episodic_type_classification.jinja2")
rendered_prompt = template.render(content=content)
rendered_prompt = template.render(content=content, language=language)
# 记录渲染结果到提示日志
log_prompt_rendering('episodic title and type classification', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('episodic_type_classification.jinja2', {
'content_len': len(content) if content else 0
'content_len': len(content) if content else 0,
'language': language
})
return rendered_prompt
async def render_ontology_extraction_prompt(
scenario: str,
domain: str | None = None,
max_classes: int = 15,
json_schema: dict | None = None
) -> str:
"""
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
Args:
scenario: The scenario description text to extract ontology classes from
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
max_classes: Maximum number of classes to extract (default: 15)
json_schema: JSON schema for the expected output format
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_ontology.jinja2")
rendered_prompt = template.render(
scenario=scenario,
domain=domain,
max_classes=max_classes,
json_schema=json_schema
)
# 记录渲染结果到提示日志
log_prompt_rendering('ontology extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_ontology.jinja2', {
'scenario_len': len(scenario) if scenario else 0,
'domain': domain,
'max_classes': max_classes,
'json_schema': 'OntologyExtractionResponse.schema'
})
return rendered_prompt

View File

@@ -9,7 +9,8 @@
- 类型: "{{ entity_a.entity_type | default('') }}"
- 描述: "{{ entity_a.description | default('') }}"
- 别名: {{ entity_a.aliases | default([]) }}
- 摘要: "{{ entity_a.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
实体B:
@@ -17,7 +18,8 @@
- 类型: "{{ entity_b.entity_type | default('') }}"
- 描述: "{{ entity_b.description | default('') }}"
- 别名: {{ entity_b.aliases | default([]) }}
- 摘要: "{{ entity_b.fact_summary | default('') }}"
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
上下文:

View File

@@ -1,8 +1,19 @@
=== Task ===
Generate a concise title and classify the episodic memory into the most appropriate category.
{% if language == "zh" %}
**重要:请使用中文生成标题和分类。**
{% else %}
**Important: Please generate the title and classification in English.**
{% endif %}
=== Requirements ===
- Extract a clear, concise title (10-20 characters) that captures the core content
{% if language == "zh" %}
- 标题必须使用中文
{% else %}
- Title must be in English
{% endif %}
- Classify into exactly one category based on the primary theme
- Be specific and avoid ambiguity
- Output must be valid JSON conforming to the schema below

View File

@@ -0,0 +1,210 @@
===Task===
Extract ontology classes from the given scenario description following ontology engineering standards.
===Role===
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
===Scenario Description===
{{ scenario }}
{% if domain -%}
===Domain Hint===
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
{%- endif %}
===Extraction Rules===
**1. Abstract Classes, Not Instances:**
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
- Think in terms of "types of things" rather than "specific things"
**2. Naming Convention (PascalCase):**
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
- Use clear, descriptive names in English
- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA")
- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试")
**3. Domain Relevance:**
- Focus on classes that are central to the scenario's domain
- Prioritize classes that represent key concepts, entities, or relationships
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
**4. Class Quantity:**
- Extract between 5 and {{ max_classes }} classes
- Aim for a balanced set covering the main concepts in the scenario
- Quality over quantity: prefer well-defined classes over exhaustive lists
**5. Clear Descriptions:**
- Provide concise, informative descriptions in Chinese (max 500 characters)
- Describe what the class represents, not specific instances
- Use clear, natural Chinese language that explains the class's role in the domain
**6. Concrete Examples:**
- Provide 2-5 concrete instance examples in Chinese for each class
- Examples should be specific, realistic instances of the class
- Examples help clarify the class's scope and meaning
- Use natural Chinese language for examples
- Example format: ["示例1", "示例2", "示例3"]
**7. Class Hierarchy:**
- Identify parent-child relationships where applicable
- Use the parent_class field to specify inheritance
- Parent class must be one of the extracted classes or a standard OWL class
- Leave parent_class as null for top-level classes
**8. Entity Types:**
- Classify each class with an appropriate entity_type
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
- Choose the most specific type that applies
**9. OWL Reserved Words:**
- Do NOT use OWL reserved words as class names
- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal"
- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class")
**10. Language Consistency:**
- Extract all class names in English (PascalCase format) for the "name" field
- Provide Chinese translation for class names in the "name_chinese" field
- Descriptions MUST be in Chinese (中文)
- Examples MUST be in Chinese (中文)
- Use clear, natural Chinese language for descriptions and examples
===Examples===
**Example 1 (Healthcare Domain):**
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
Output:
{
"classes": [
{
"name": "Patient",
"name_chinese": "患者",
"description": "在医疗机构接受医疗护理或治疗的人",
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
"parent_class": null,
"entity_type": "Person",
"domain": "Healthcare"
},
{
"name": "MedicalProcedure",
"name_chinese": "医疗程序",
"description": "为医疗诊断或治疗而执行的系统性操作流程",
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
},
{
"name": "Diagnosis",
"name_chinese": "诊断",
"description": "基于症状和检查结果对疾病或状况的识别",
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Healthcare"
},
{
"name": "Doctor",
"name_chinese": "医生",
"description": "诊断和治疗患者的持证医疗专业人员",
"examples": ["全科医生", "外科医生", "心脏病专家"],
"parent_class": null,
"entity_type": "Role",
"domain": "Healthcare"
},
{
"name": "Treatment",
"name_chinese": "治疗",
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
}
],
"domain": "Healthcare",
"namespace": "http://example.org/healthcare#"
}
**Example 2 (Education Domain):**
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
Output:
{
"classes": [
{
"name": "Student",
"name_chinese": "学生",
"description": "在教育机构注册学习的人",
"examples": ["本科生", "研究生", "在职学生"],
"parent_class": null,
"entity_type": "Role",
"domain": "Education"
},
{
"name": "Course",
"name_chinese": "课程",
"description": "涵盖特定学科或主题的结构化教育课程",
"examples": ["计算机科学导论", "微积分I", "世界历史"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Education"
},
{
"name": "Professor",
"name_chinese": "教授",
"description": "教授课程并进行研究的学术教师",
"examples": ["助理教授", "副教授", "正教授"],
"parent_class": null,
"entity_type": "Role",
"domain": "Education"
},
{
"name": "AcademicProgram",
"name_chinese": "学术项目",
"description": "通向学位或证书的结构化课程体系",
"examples": ["理学学士", "文学硕士", "博士项目"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Education"
},
{
"name": "Assignment",
"name_chinese": "作业",
"description": "分配给学生以评估学习成果的任务或项目",
"examples": ["论文", "习题集", "研究报告", "实验报告"],
"parent_class": null,
"entity_type": "Object",
"domain": "Education"
},
{
"name": "Lecture",
"name_chinese": "讲座",
"description": "由教师进行的教育性演讲或讲座",
"examples": ["入门讲座", "客座讲座", "在线讲座"],
"parent_class": null,
"entity_type": "Event",
"domain": "Education"
}
],
"domain": "Education",
"namespace": "http://example.org/education#"
}
===Output Format===
**JSON Requirements:**
- Use only ASCII double quotes (") for JSON structure
- Never use Chinese quotation marks ("") or Unicode quotes
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
- All class names must be in PascalCase format
- All class names must be unique (case-insensitive)
- Extract between 5 and {{ max_classes }} classes
{{ json_schema }}

View File

@@ -5,6 +5,12 @@
===Task===
Extract entities and knowledge triplets from the given statement.
{% if language == "zh" %}
**重要请使用中文生成实体描述description和示例example。**
{% else %}
**Important: Please generate entity descriptions and examples in English.**
{% endif %}
===Inputs===
**Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}"
@@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement.
**Entity Extraction:**
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
{% if language == "zh" %}
- **实体描述description必须使用中文**
- **示例example必须使用中文**
{% else %}
- **Entity descriptions must be in English**
- **Examples must be in English**
{% endif %}
- **Semantic Memory Classification (is_explicit_memory):**
* Set to `true` if the entity represents **explicit/semantic memory**:
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
@@ -334,9 +347,11 @@ Output:
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
- The output language should ALWAYS match the input language
- If input is in English, extract statements in English
- If input is in Chinese, extract statements in Chinese
{% if language == "zh" %}
- **语言要求实体描述description和示例example必须使用中文**
{% else %}
- **Language Requirement: Entity descriptions and examples must be in English**
{% endif %}
- Preserve the original language and do not translate
{{ json_schema }}

View File

@@ -5,10 +5,21 @@
=== Task ===
Summarize the provided conversation chunks into a concise Memory summary.
{% if language == "zh" %}
**重要:请使用中文生成摘要内容。**
{% else %}
**Important: Please generate the summary content in English.**
{% endif %}
=== Requirements ===
- Focus on factual statements, user preferences, relationships, and salient temporal context.
- Avoid repetition and filler; be specific.
- Keep it under {{ max_words or 200 }} words.
{% if language == "zh" %}
- 摘要内容必须使用中文
{% else %}
- Summary content must be in English
{% endif %}
- Output must be valid JSON conforming to the schema below.
=== Input ===
@@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary.
4. Do not include line breaks within JSON string values
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
The output language should always be the same as the input language.
{% if language == "zh" %}
**语言要求:输出内容必须使用中文。**
{% else %}
**Language Requirement: The output content must be in English.**
{% endif %}
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
{{ json_schema }}

View File

@@ -0,0 +1,10 @@
"""Validation utilities for ontology extraction.
This module provides validation classes for ontology class names,
descriptions, and OWL compliance checking.
"""
from .ontology_validator import OntologyValidator
from .owl_validator import OWLValidator
__all__ = ['OntologyValidator', 'OWLValidator']

View File

@@ -0,0 +1,268 @@
"""String validation for ontology class names and descriptions.
This module provides the OntologyValidator class for validating and sanitizing
ontology class names according to OWL standards and naming conventions.
Classes:
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
"""
import logging
import re
from typing import List, Tuple
from app.core.memory.models.ontology_models import OntologyClass
logger = logging.getLogger(__name__)
class OntologyValidator:
"""Validator for ontology class names and descriptions.
This validator performs string-level validation including:
- PascalCase naming convention validation
- OWL reserved word checking
- Duplicate class name removal
- Description length truncation
Attributes:
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
"""
# OWL reserved words that cannot be used as class names
OWL_RESERVED_WORDS = {
'Thing', 'Nothing', 'Class', 'Property',
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
'Annotation', 'AnnotationProperty', 'Axiom',
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
'Datatype', 'DataRange', 'Literal',
'DeprecatedClass', 'DeprecatedProperty',
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
'BackwardCompatibleWith', 'OntologyProperty',
}
def validate_class_name(self, name: str) -> Tuple[bool, str]:
"""Validate that a class name follows OWL naming conventions.
Validation rules:
1. Must not be empty
2. Must start with an uppercase letter (PascalCase)
3. Cannot contain spaces
4. Can only contain alphanumeric characters and underscores
5. Cannot be an OWL reserved word
Args:
name: The class name to validate
Returns:
Tuple of (is_valid, error_message)
- is_valid: True if the name is valid, False otherwise
- error_message: Empty string if valid, error description if invalid
Examples:
>>> validator = OntologyValidator()
>>> validator.validate_class_name("MedicalProcedure")
(True, "")
>>> validator.validate_class_name("medical procedure")
(False, "Class name 'medical procedure' cannot contain spaces")
>>> validator.validate_class_name("Thing")
(False, "Class name 'Thing' is an OWL reserved word")
"""
logger.debug(f"Validating class name: '{name}'")
# Check if empty
if not name or not name.strip():
error_msg = "Class name cannot be empty"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
name = name.strip()
# Check if it's an OWL reserved word
if name in self.OWL_RESERVED_WORDS:
error_msg = f"Class name '{name}' is an OWL reserved word"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check if starts with uppercase letter
if not name[0].isupper():
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check for spaces
if ' ' in name:
error_msg = f"Class name '{name}' cannot contain spaces"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check for invalid characters (only alphanumeric and underscore allowed)
if not re.match(r'^[A-Za-z0-9_]+$', name):
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
logger.debug(f"Class name '{name}' is valid")
return True, ""
def sanitize_class_name(self, name: str) -> str:
"""Attempt to sanitize an invalid class name into a valid format.
Sanitization steps:
1. Strip whitespace
2. Remove invalid characters
3. Replace spaces with empty string (PascalCase)
4. Capitalize first letter of each word
5. If result is empty or starts with number, prefix with 'Class'
Args:
name: The class name to sanitize
Returns:
Sanitized class name that should pass validation
Examples:
>>> validator = OntologyValidator()
>>> validator.sanitize_class_name("medical procedure")
'MedicalProcedure'
>>> validator.sanitize_class_name("patient-record")
'PatientRecord'
>>> validator.sanitize_class_name("123invalid")
'Class123Invalid'
"""
logger.debug(f"Sanitizing class name: '{name}'")
if not name or not name.strip():
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
return "UnnamedClass"
# Strip whitespace
name = name.strip()
original_name = name
# Split on spaces, hyphens, and underscores, then capitalize each word
words = re.split(r'[\s\-_]+', name)
# Capitalize first letter of each word and keep rest as is
sanitized_words = []
for word in words:
if word:
# Remove non-alphanumeric characters except underscore
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
if clean_word:
# Capitalize first letter
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
# Join words
sanitized = ''.join(sanitized_words)
# If empty or starts with number, prefix with 'Class'
if not sanitized or sanitized[0].isdigit():
sanitized = 'Class' + sanitized
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
# If it's a reserved word, append 'Class' suffix
if sanitized in self.OWL_RESERVED_WORDS:
sanitized = sanitized + 'Class'
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
return sanitized
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
"""Remove duplicate ontology classes based on case-insensitive name comparison.
When duplicates are found, keeps the first occurrence and discards subsequent ones.
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
Args:
classes: List of OntologyClass objects
Returns:
List of OntologyClass objects with duplicates removed
Examples:
>>> validator = OntologyValidator()
>>> classes = [
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
... ]
>>> unique = validator.remove_duplicates(classes)
>>> len(unique)
2
>>> [c.name for c in unique]
['Patient', 'Doctor']
"""
if not classes:
logger.debug("No classes to check for duplicates")
return classes
logger.debug(f"Checking {len(classes)} classes for duplicates")
seen_names = set()
unique_classes = []
duplicates_found = []
for ontology_class in classes:
# Use lowercase for comparison
name_lower = ontology_class.name.lower()
if name_lower not in seen_names:
seen_names.add(name_lower)
unique_classes.append(ontology_class)
else:
duplicates_found.append(ontology_class.name)
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
if duplicates_found:
logger.info(
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
)
else:
logger.debug("No duplicate classes found")
return unique_classes
def truncate_description(self, description: str, max_length: int = 500) -> str:
"""Truncate a description to a maximum length.
If the description exceeds max_length, it will be truncated and
an ellipsis (...) will be appended to indicate truncation.
Args:
description: The description text to truncate
max_length: Maximum allowed length (default: 500)
Returns:
Truncated description string
Examples:
>>> validator = OntologyValidator()
>>> long_desc = "A" * 600
>>> truncated = validator.truncate_description(long_desc, max_length=500)
>>> len(truncated)
500
>>> truncated.endswith("...")
True
"""
if not description:
return ""
if len(description) <= max_length:
return description
# Truncate and add ellipsis
# Reserve 3 characters for "..."
truncate_at = max_length - 3
truncated = description[:truncate_at] + "..."
logger.debug(
f"Truncated description from {len(description)} to {len(truncated)} characters"
)
return truncated

View File

@@ -0,0 +1,585 @@
"""OWL semantic validation for ontology classes using Owlready2.
This module provides the OWLValidator class for validating ontology classes
against OWL standards using the Owlready2 library. It performs semantic
validation including consistency checking, circular inheritance detection,
and OWL file export.
Classes:
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
"""
import logging
from typing import List, Optional, Tuple
from owlready2 import (
World,
Thing,
get_ontology,
sync_reasoner_pellet,
OwlReadyInconsistentOntologyError,
)
from app.core.memory.models.ontology_models import OntologyClass
logger = logging.getLogger(__name__)
class OWLValidator:
"""Validator for OWL semantic validation of ontology classes.
This validator performs semantic-level validation using Owlready2 including:
- Creating OWL classes from ontology class definitions
- Running consistency checking with Pellet reasoner
- Detecting circular inheritance
- Validating Protégé compatibility
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
Attributes:
base_namespace: Base URI for the ontology namespace
"""
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
"""Initialize the OWL validator.
Args:
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
"""
self.base_namespace = base_namespace
def validate_ontology_classes(
self,
classes: List[OntologyClass],
) -> Tuple[bool, List[str], Optional[World]]:
"""Validate extracted ontology classes against OWL standards.
This method creates an OWL ontology from the provided classes using Owlready2,
runs consistency checking with the Pellet reasoner, and detects common issues
like circular inheritance.
Args:
classes: List of OntologyClass objects to validate
Returns:
Tuple of (is_valid, error_messages, world):
- is_valid: True if ontology is valid and consistent, False otherwise
- error_messages: List of error/warning messages
- world: Owlready2 World object containing the ontology (None if validation failed)
Examples:
>>> validator = OWLValidator()
>>> classes = [
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
... ]
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
>>> is_valid
True
>>> len(errors)
0
"""
if not classes:
return False, ["No classes provided for validation"], None
errors = []
try:
# Create a new world (isolated ontology environment)
world = World()
# Use a proper ontology IRI
# Owlready2 expects the IRI to end with .owl or similar
onto_iri = self.base_namespace.rstrip('#/')
if not onto_iri.endswith('.owl'):
onto_iri = onto_iri + '.owl'
# Create ontology
onto = world.get_ontology(onto_iri)
with onto:
# Dictionary to store created OWL classes for parent reference
owl_classes = {}
# First pass: Create all classes without parent relationships
for ontology_class in classes:
try:
# Create OWL class dynamically using type() with Thing as base
# The key is to NOT set namespace in the dict, let Owlready2 handle it
owl_class = type(
ontology_class.name, # Class name
(Thing,), # Base classes
{} # Class dict (empty, let Owlready2 manage)
)
# Add label (rdfs:label) - include both English and Chinese names
labels = [ontology_class.name]
if ontology_class.name_chinese:
labels.append(ontology_class.name_chinese)
owl_class.label = labels
# Add comment (rdfs:comment) with description
if ontology_class.description:
owl_class.comment = [ontology_class.description]
# Store for parent relationship setup
owl_classes[ontology_class.name] = owl_class
logger.debug(
f"Created OWL class: {ontology_class.name} "
f"(Chinese: {ontology_class.name_chinese}) "
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
)
except Exception as e:
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
errors.append(error_msg)
logger.error(error_msg, exc_info=True)
# Second pass: Set up parent relationships
for ontology_class in classes:
if ontology_class.parent_class and ontology_class.name in owl_classes:
parent_name = ontology_class.parent_class
# Check if parent exists
if parent_name in owl_classes:
try:
child_class = owl_classes[ontology_class.name]
parent_class = owl_classes[parent_name]
# Set parent by modifying is_a
child_class.is_a = [parent_class]
logger.debug(
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
)
except Exception as e:
error_msg = (
f"Failed to set parent relationship "
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
)
errors.append(error_msg)
logger.warning(error_msg)
else:
warning_msg = (
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
)
errors.append(warning_msg)
logger.warning(warning_msg)
# Check for circular inheritance
for class_name, owl_class in owl_classes.items():
if self._has_circular_inheritance(owl_class):
error_msg = f"Circular inheritance detected for class '{class_name}'"
errors.append(error_msg)
logger.error(error_msg)
# Run consistency checking with Pellet reasoner
try:
logger.info("Running Pellet reasoner for consistency checking...")
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
logger.info("Consistency check passed")
except OwlReadyInconsistentOntologyError as e:
error_msg = f"Ontology is inconsistent: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
return False, errors, world
except Exception as e:
# Reasoner errors are often due to Java not being installed or configured
# Log as warning but don't fail validation - ontology structure is still valid
warning_msg = f"Reasoner check skipped: {str(e)}"
if str(e).strip(): # Only log if there's an actual error message
logger.warning(warning_msg)
else:
logger.warning("Reasoner check skipped: Java may not be installed or configured")
# Continue - ontology structure is valid even without reasoner check
# If we have errors (excluding warnings), validation failed
is_valid = len(errors) == 0
return is_valid, errors, world
except Exception as e:
error_msg = f"OWL validation failed: {str(e)}"
errors.append(error_msg)
logger.error(error_msg, exc_info=True)
return False, errors, None
def _has_circular_inheritance(self, owl_class) -> bool:
"""Check if an OWL class has circular inheritance.
Circular inheritance occurs when a class inherits from itself through
a chain of parent relationships (e.g., A -> B -> C -> A).
Args:
owl_class: Owlready2 class object to check
Returns:
True if circular inheritance is detected, False otherwise
"""
visited = set()
current = owl_class
while current:
# Get class IRI or name as identifier
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
if class_id in visited:
# Found a cycle
return True
visited.add(class_id)
# Get parent classes (is_a relationship)
parents = getattr(current, 'is_a', [])
# Filter out Thing and other base classes
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
if not parent_classes:
# No more parents, no cycle
break
# Check first parent (in single inheritance)
current = parent_classes[0] if parent_classes else None
return False
def export_to_owl(
self,
world: World,
output_path: Optional[str] = None,
format: str = "rdfxml",
classes: Optional[List] = None
) -> str:
"""Export ontology to OWL file in specified format.
Supported formats:
- rdfxml: RDF/XML format (default, most compatible)
- turtle: Turtle format (more readable)
- ntriples: N-Triples format (simplest)
- json: JSON format (simplified, human-readable)
Args:
world: Owlready2 World object containing the ontology
output_path: Optional file path to save the ontology (if None, returns string)
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
classes: Optional list of OntologyClass objects (required for json format)
Returns:
String representation of the exported ontology
Raises:
ValueError: If format is not supported
RuntimeError: If export fails
Examples:
>>> validator = OWLValidator()
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
"""
# Validate format
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
if format not in valid_formats:
raise ValueError(
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
)
# JSON format doesn't need OWL processing
if format == "json":
if not classes:
raise ValueError("Classes list is required for JSON format export")
return self._export_to_json(classes)
# For OWL formats, world is required
if not world:
raise ValueError("World object is None. Cannot export ontology.")
# Note: Owlready2 has issues with turtle format export
# We'll handle it specially by converting from rdfxml
use_conversion = (format == "turtle")
try:
# Get all ontologies in the world
ontologies = list(world.ontologies.values())
if not ontologies:
raise RuntimeError("No ontologies found in world")
# Find the ontology with classes (skip anonymous/empty ontologies)
onto = None
for ont in ontologies:
classes_count = len(list(ont.classes()))
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
if classes_count > 0:
onto = ont
break
# If no ontology with classes found, use the last non-anonymous one
if onto is None:
for ont in reversed(ontologies):
if ont.base_iri != "http://anonymous/":
onto = ont
break
# If still no ontology, use the first one
if onto is None:
onto = ontologies[0]
# Log ontology contents for debugging
logger.info(f"Ontology IRI: {onto.base_iri}")
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
# List all classes in the ontology
all_classes = list(onto.classes())
for cls in all_classes:
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
if hasattr(cls, 'label'):
logger.debug(f" Labels: {cls.label}")
if hasattr(cls, 'comment'):
logger.debug(f" Comments: {cls.comment}")
if len(all_classes) == 0:
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
if output_path:
# Save to file
export_format = "rdfxml" if use_conversion else format
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
onto.save(file=output_path, format=export_format)
# Read back the file content to return
with open(output_path, 'r', encoding='utf-8') as f:
content = f.read()
# Convert to turtle if needed
if use_conversion:
content = self._convert_to_turtle(content)
logger.info(f"Successfully exported ontology to {output_path}")
# Format the content for better readability
content = self._format_owl_content(content, format)
return content
else:
# Export to string (save to temporary location and read)
import tempfile
import os
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
tmp_path = tmp.name
try:
export_format = "rdfxml" if use_conversion else format
onto.save(file=tmp_path, format=export_format)
with open(tmp_path, 'r', encoding='utf-8') as f:
content = f.read()
# Convert to turtle if needed
if use_conversion:
content = self._convert_to_turtle(content)
# Format the content for better readability
content = self._format_owl_content(content, format)
return content
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception as e:
error_msg = f"Failed to export ontology: {str(e)}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
def _export_to_json(self, classes: List) -> str:
"""Export ontology classes to simplified JSON format.
This format is more compact and easier to parse than OWL XML.
Args:
classes: List of OntologyClass objects
Returns:
JSON string representation (compact format)
"""
import json
result = {
"ontology": {
"namespace": self.base_namespace,
"classes": []
}
}
for cls in classes:
class_data = {
"name": cls.name,
"name_chinese": cls.name_chinese,
"description": cls.description,
"entity_type": cls.entity_type,
"domain": cls.domain,
"parent_class": cls.parent_class,
"examples": cls.examples if hasattr(cls, 'examples') else []
}
result["ontology"]["classes"].append(class_data)
# 使用紧凑格式:无缩进,使用分隔符减少空格
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
def _convert_to_turtle(self, rdfxml_content: str) -> str:
"""Convert RDF/XML content to Turtle format using rdflib.
Args:
rdfxml_content: RDF/XML format content
Returns:
Turtle format content
"""
try:
from rdflib import Graph
# Parse RDF/XML
g = Graph()
g.parse(data=rdfxml_content, format="xml")
# Serialize to Turtle
turtle_content = g.serialize(format="turtle")
# Handle bytes vs string
if isinstance(turtle_content, bytes):
turtle_content = turtle_content.decode('utf-8')
return turtle_content
except ImportError:
logger.warning(
"rdflib is not installed. Cannot convert to Turtle format. "
"Install with: pip install rdflib"
)
return rdfxml_content
except Exception as e:
logger.error(f"Failed to convert to Turtle format: {e}")
return rdfxml_content
def _format_owl_content(self, content: str, format: str) -> str:
"""Format OWL content for better readability.
Args:
content: Raw OWL content string
format: Format type (rdfxml, turtle, ntriples)
Returns:
Formatted OWL content string
"""
if format == "rdfxml":
# Format XML with proper indentation
try:
import xml.dom.minidom as minidom
dom = minidom.parseString(content)
# Pretty print with 2-space indentation
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
# Remove extra blank lines
lines = []
prev_blank = False
for line in formatted.split('\n'):
is_blank = not line.strip()
if not (is_blank and prev_blank): # Skip consecutive blank lines
lines.append(line)
prev_blank = is_blank
formatted = '\n'.join(lines)
return formatted
except Exception as e:
logger.warning(f"Failed to format XML content: {e}")
return content
elif format == "turtle":
# Turtle format is already relatively readable
# Just ensure consistent line endings and not empty
if not content or content.strip() == "":
logger.warning("Turtle content is empty, this may indicate an export issue")
return content.strip() + '\n' if content.strip() else content
elif format == "ntriples":
# N-Triples format is line-based, ensure proper line endings
return content.strip() + '\n' if content.strip() else content
return content
def validate_with_protege_compatibility(
self,
classes: List[OntologyClass]
) -> Tuple[bool, List[str]]:
"""Validate that ontology classes are compatible with Protégé editor.
Protégé compatibility checks:
- Class names are valid OWL identifiers
- No special characters that Protégé cannot handle
- Namespace is properly formatted
- Labels and comments are properly encoded
Args:
classes: List of OntologyClass objects to validate
Returns:
Tuple of (is_compatible, warnings):
- is_compatible: True if compatible with Protégé, False otherwise
- warnings: List of compatibility warning messages
Examples:
>>> validator = OWLValidator()
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
>>> is_compatible
True
"""
warnings = []
# Check namespace format
if not self.base_namespace.startswith(('http://', 'https://')):
warnings.append(
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
"for Protégé compatibility"
)
if not self.base_namespace.endswith(('#', '/')):
warnings.append(
f"Namespace '{self.base_namespace}' should end with # or / "
"for Protégé compatibility"
)
# Check each class
for ontology_class in classes:
# Check for special characters that might cause issues
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
warnings.append(
f"Class name '{ontology_class.name}' contains special characters "
"that may cause issues in Protégé"
)
# Check description length (Protégé can handle long descriptions but may display poorly)
if ontology_class.description and len(ontology_class.description) > 1000:
warnings.append(
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
"which may display poorly in Protégé"
)
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
if not ontology_class.name.isascii():
warnings.append(
f"Class name '{ontology_class.name}' contains non-ASCII characters "
"which may cause encoding issues in some Protégé versions"
)
# If no warnings, it's compatible
is_compatible = len(warnings) == 0
return is_compatible, warnings