Compare commits
7 Commits
pref/promp
...
feature/me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f9740412a | ||
|
|
6b68ee9fc8 | ||
|
|
e53be0765a | ||
|
|
3743188eec | ||
|
|
71e6bea2b8 | ||
|
|
461674c8d8 | ||
|
|
c59e179cc2 |
@@ -158,12 +158,19 @@ class RedisTaskScheduler:
|
|||||||
return {"status": status, "task_id": task_id, "result": result_content}
|
return {"status": status, "task_id": task_id, "result": result_content}
|
||||||
|
|
||||||
def _cleanup_finished(self):
|
def _cleanup_finished(self):
|
||||||
pending = self.redis.hgetall(PENDING_HASH)
|
cursor = 0
|
||||||
if not pending:
|
all_pending = {}
|
||||||
|
while True:
|
||||||
|
cursor, batch = self.redis.hscan(PENDING_HASH, cursor=cursor, count=100)
|
||||||
|
all_pending.update(batch)
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not all_pending:
|
||||||
return
|
return
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
task_ids = list(pending.keys())
|
task_ids = list(all_pending.keys())
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
pipe = self.redis.pipeline()
|
||||||
for task_id in task_ids:
|
for task_id in task_ids:
|
||||||
@@ -176,7 +183,7 @@ class RedisTaskScheduler:
|
|||||||
|
|
||||||
for task_id, raw_result in zip(task_ids, results):
|
for task_id, raw_result in zip(task_ids, results):
|
||||||
try:
|
try:
|
||||||
meta = json.loads(pending[task_id])
|
meta = json.loads(all_pending[task_id])
|
||||||
lock_key = meta["lock_key"]
|
lock_key = meta["lock_key"]
|
||||||
dispatched_at = meta.get("dispatched_at", 0)
|
dispatched_at = meta.get("dispatched_at", 0)
|
||||||
age = now - dispatched_at
|
age = now - dispatched_at
|
||||||
@@ -276,6 +283,22 @@ class RedisTaskScheduler:
|
|||||||
return True
|
return True
|
||||||
return stable_hash(user_id) % self._shard_count == self._shard_index
|
return stable_hash(user_id) % self._shard_count == self._shard_index
|
||||||
|
|
||||||
|
def _commit_post_dispatch(self, lock_key, task, msg_id, dispatch_lock):
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
pipe.set(lock_key, task.id, ex=3600)
|
||||||
|
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
||||||
|
"lock_key": lock_key,
|
||||||
|
"dispatched_at": time.time(),
|
||||||
|
"msg_id": msg_id,
|
||||||
|
}))
|
||||||
|
pipe.delete(dispatch_lock)
|
||||||
|
pipe.set(
|
||||||
|
f"task_tracker:{msg_id}",
|
||||||
|
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
||||||
|
ex=86400,
|
||||||
|
)
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
def _dispatch(self, msg_id, msg_data) -> bool:
|
def _dispatch(self, msg_id, msg_data) -> bool:
|
||||||
user_id = msg_data["user_id"]
|
user_id = msg_data["user_id"]
|
||||||
task_name = msg_data["task_name"]
|
task_name = msg_data["task_name"]
|
||||||
@@ -308,28 +331,17 @@ class RedisTaskScheduler:
|
|||||||
task_name, user_id, msg_id, e, exc_info=True,
|
task_name, user_id, msg_id, e, exc_info=True,
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
for attempt in range(2):
|
||||||
try:
|
try:
|
||||||
pipe = self.redis.pipeline()
|
self._commit_post_dispatch(lock_key, task, msg_id, dispatch_lock)
|
||||||
pipe.set(lock_key, task.id, ex=3600)
|
break
|
||||||
pipe.hset(PENDING_HASH, task.id, json.dumps({
|
except Exception as e:
|
||||||
"lock_key": lock_key,
|
logger.error(
|
||||||
"dispatched_at": time.time(),
|
"Post-dispatch state update failed for %s: %s",
|
||||||
"msg_id": msg_id,
|
task.id, e, exc_info=True,
|
||||||
}))
|
)
|
||||||
pipe.delete(dispatch_lock)
|
time.sleep(0.1)
|
||||||
pipe.set(
|
self.errors += 1
|
||||||
f"task_tracker:{msg_id}",
|
|
||||||
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
|
|
||||||
ex=86400,
|
|
||||||
)
|
|
||||||
pipe.execute()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
"Post-dispatch state update failed for %s: %s",
|
|
||||||
task.id, e, exc_info=True,
|
|
||||||
)
|
|
||||||
self.errors += 1
|
|
||||||
|
|
||||||
self.dispatched += 1
|
self.dispatched += 1
|
||||||
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
|
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
|
||||||
@@ -367,22 +379,21 @@ class RedisTaskScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for uid, msg in candidates:
|
for uid, msg in candidates:
|
||||||
|
queue_key = f"{USER_QUEUE_PREFIX}{uid}"
|
||||||
if self._dispatch(msg["msg_id"], msg):
|
if self._dispatch(msg["msg_id"], msg):
|
||||||
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
|
self.redis.lpop(queue_key)
|
||||||
|
if self.redis.llen(queue_key) > 0:
|
||||||
|
self.redis.sadd(READY_SET, uid)
|
||||||
|
|
||||||
def schedule_loop(self):
|
def schedule_loop(self):
|
||||||
self._heartbeat()
|
self._heartbeat()
|
||||||
self._cleanup_finished()
|
self._cleanup_finished()
|
||||||
|
|
||||||
pipe = self.redis.pipeline()
|
ready_users = self.redis.smembers(READY_SET) or set()
|
||||||
pipe.smembers(READY_SET)
|
|
||||||
pipe.delete(READY_SET)
|
|
||||||
results = pipe.execute()
|
|
||||||
ready_users = results[0] or set()
|
|
||||||
|
|
||||||
my_users = [uid for uid in ready_users if self._is_mine(uid)]
|
my_users = [uid for uid in ready_users if self._is_mine(uid)]
|
||||||
|
if my_users:
|
||||||
if not my_users:
|
self.redis.srem(READY_SET, *my_users)
|
||||||
|
else:
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -445,7 +456,7 @@ class RedisTaskScheduler:
|
|||||||
"Scheduler started: instance=%s", self.instance_id,
|
"Scheduler started: instance=%s", self.instance_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
self.schedule_loop()
|
self.schedule_loop()
|
||||||
|
|
||||||
@@ -480,9 +491,7 @@ class RedisTaskScheduler:
|
|||||||
logger.error("Shutdown cleanup error: %s", e)
|
logger.error("Shutdown cleanup error: %s", e)
|
||||||
|
|
||||||
|
|
||||||
scheduler: RedisTaskScheduler | None = None
|
scheduler = RedisTaskScheduler()
|
||||||
if scheduler is None:
|
|
||||||
scheduler = RedisTaskScheduler()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import signal
|
import signal
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from app.services import task_service, workspace_service
|
|||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config as get_config
|
from app.services.memory_agent_service import get_end_user_connected_config as get_config
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
from app.utils.tmp_session import ChatSessionCache
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -300,60 +301,39 @@ async def read_server(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
|
session_id = user_input.session_id.hex
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}, session_id={session_id}")
|
||||||
try:
|
try:
|
||||||
# result = await memory_agent_service.read_memory(
|
|
||||||
# user_input.end_user_id,
|
|
||||||
# user_input.message,
|
|
||||||
# user_input.history,
|
|
||||||
# user_input.search_switch,
|
|
||||||
# config_id,
|
|
||||||
# db,
|
|
||||||
# storage_type,
|
|
||||||
# user_rag_memory_id
|
|
||||||
# )
|
|
||||||
# if str(user_input.search_switch) == "2":
|
|
||||||
# retrieve_info = result['answer']
|
|
||||||
# history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
|
||||||
# user_input.end_user_id)
|
|
||||||
# query = user_input.message
|
|
||||||
#
|
|
||||||
# # 调用 memory_agent_service 的方法生成最终答案
|
|
||||||
# result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
|
||||||
# end_user_id=user_input.end_user_id,
|
|
||||||
# retrieve_info=retrieve_info,
|
|
||||||
# history=history,
|
|
||||||
# query=query,
|
|
||||||
# config_id=config_id,
|
|
||||||
# db=db
|
|
||||||
# )
|
|
||||||
# if "信息不足,无法回答" in result['answer']:
|
|
||||||
# result['answer'] = retrieve_info
|
|
||||||
memory_config = get_config(user_input.end_user_id, db)
|
memory_config = get_config(user_input.end_user_id, db)
|
||||||
service = MemoryService(
|
service = MemoryService(
|
||||||
db,
|
db,
|
||||||
memory_config["memory_config_id"],
|
memory_config["memory_config_id"],
|
||||||
end_user_id=user_input.end_user_id
|
end_user_id=user_input.end_user_id
|
||||||
)
|
)
|
||||||
|
session_cache = ChatSessionCache(session_id)
|
||||||
search_result = await service.read(
|
search_result = await service.read(
|
||||||
user_input.message,
|
user_input.message,
|
||||||
SearchStrategy(user_input.search_switch)
|
SearchStrategy(user_input.search_switch),
|
||||||
|
history=await session_cache.get_history(),
|
||||||
)
|
)
|
||||||
intermediate_outputs = []
|
intermediate_outputs = []
|
||||||
sub_queries = set()
|
sub_queries = set()
|
||||||
for memory in search_result.memories:
|
for memory in search_result.memories:
|
||||||
sub_queries.add(str(memory.query))
|
sub_queries.add(str(memory.query))
|
||||||
|
idx = 0
|
||||||
if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]:
|
if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]:
|
||||||
intermediate_outputs.append({
|
intermediate_outputs.append({
|
||||||
"type": "problem_split",
|
"type": "problem_split",
|
||||||
"title": "问题拆分",
|
"title": "问题拆分",
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"id": f"Q{idx+1}",
|
"id": f"Q{(idx := idx + 1)}",
|
||||||
"question": question
|
"question": question
|
||||||
}
|
}
|
||||||
for idx, question in enumerate(sub_queries)
|
for question in sub_queries
|
||||||
|
if question
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
perceptual_data = [
|
perceptual_data = [
|
||||||
@@ -375,16 +355,24 @@ async def read_server(
|
|||||||
"raw_result": search_result.memories,
|
"raw_result": search_result.memories,
|
||||||
"total": len(search_result.memories),
|
"total": len(search_result.memories),
|
||||||
})
|
})
|
||||||
|
answer = await memory_agent_service.generate_summary_from_retrieve(
|
||||||
|
end_user_id=user_input.end_user_id,
|
||||||
|
retrieve_info=search_result.content,
|
||||||
|
history=[],
|
||||||
|
query=user_input.message,
|
||||||
|
config_id=config_id,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
await session_cache.append_many(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": user_input.message},
|
||||||
|
{"role": "assistant", "content": answer}
|
||||||
|
]
|
||||||
|
)
|
||||||
result = {
|
result = {
|
||||||
'answer': await memory_agent_service.generate_summary_from_retrieve(
|
'answer': answer,
|
||||||
end_user_id=user_input.end_user_id,
|
"intermediate_outputs": intermediate_outputs,
|
||||||
retrieve_info=search_result.content,
|
"session_id": session_id,
|
||||||
history=[],
|
|
||||||
query=user_input.message,
|
|
||||||
config_id=config_id,
|
|
||||||
db=db
|
|
||||||
),
|
|
||||||
"intermediate_outputs": intermediate_outputs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
@@ -480,9 +468,11 @@ async def read_server_async(
|
|||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
try:
|
try:
|
||||||
|
session_id = user_input.session_id.hex
|
||||||
|
session_cache = ChatSessionCache(session_id)
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
"app.core.memory.agent.read_message",
|
"app.core.memory.agent.read_message",
|
||||||
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
args=[user_input.end_user_id, user_input.message, await session_cache.get_history(), user_input.search_switch,
|
||||||
config_id, storage_type, user_rag_memory_id]
|
config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
api_logger.info(f"Read task queued: {task.id}")
|
api_logger.info(f"Read task queued: {task.id}")
|
||||||
|
|||||||
@@ -43,10 +43,13 @@ class MemoryService:
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
search_switch: SearchStrategy,
|
search_switch: SearchStrategy,
|
||||||
|
history: list | None = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> MemorySearchResult:
|
) -> MemorySearchResult:
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
|
return await ReadPipeLine(self.ctx, db).run(query, search_switch, history, limit)
|
||||||
|
|
||||||
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
|
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -32,10 +32,12 @@ class Memory(BaseModel):
|
|||||||
|
|
||||||
class MemorySearchResult(BaseModel):
|
class MemorySearchResult(BaseModel):
|
||||||
memories: list[Memory]
|
memories: list[Memory]
|
||||||
|
content_str: str = Field(default="")
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
|
if self.content_str:
|
||||||
|
return self.content_str
|
||||||
return "\n".join([memory.content for memory in self.memories])
|
return "\n".join([memory.content for memory in self.memories])
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from app.core.memory.enums import SearchStrategy, StorageType
|
from app.core.memory.enums import SearchStrategy, StorageType
|
||||||
from app.core.memory.models.service_models import MemorySearchResult
|
from app.core.memory.models.service_models import MemorySearchResult
|
||||||
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
||||||
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
|
|
||||||
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
|
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
|
||||||
|
from app.core.memory.read_services.generate_engine.retrieval_summary import RetrievalSummaryProcessor
|
||||||
|
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
|
||||||
|
|
||||||
|
|
||||||
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||||
@@ -10,20 +11,30 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
search_switch: SearchStrategy,
|
search_switch: SearchStrategy,
|
||||||
|
history: list,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
includes=None
|
includes=None
|
||||||
) -> MemorySearchResult:
|
) -> MemorySearchResult:
|
||||||
|
memory_l0 = None
|
||||||
|
if self.ctx.storage_type == StorageType.NEO4J:
|
||||||
|
memory_l0 = await self._get_search_service(includes).memory_l0()
|
||||||
|
|
||||||
query = QueryPreprocessor.process(query)
|
query = QueryPreprocessor.process(query)
|
||||||
match search_switch:
|
match search_switch:
|
||||||
case SearchStrategy.DEEP:
|
case SearchStrategy.DEEP:
|
||||||
return await self._deep_read(query, limit, includes)
|
res = await self._deep_read(query, history, limit, includes)
|
||||||
case SearchStrategy.NORMAL:
|
case SearchStrategy.NORMAL:
|
||||||
return await self._normal_read(query, limit, includes)
|
res = await self._normal_read(query, history, limit, includes)
|
||||||
case SearchStrategy.QUICK:
|
case SearchStrategy.QUICK:
|
||||||
return await self._quick_read(query, limit, includes)
|
res = await self._quick_read(query, limit, includes)
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unsupported search strategy")
|
raise RuntimeError("Unsupported search strategy")
|
||||||
|
|
||||||
|
if memory_l0 is not None:
|
||||||
|
res.content_str = memory_l0.content + '\n' + res.content
|
||||||
|
res.memories.insert(0, memory_l0)
|
||||||
|
return res
|
||||||
|
|
||||||
def _get_search_service(self, includes=None):
|
def _get_search_service(self, includes=None):
|
||||||
if self.ctx.storage_type == StorageType.NEO4J:
|
if self.ctx.storage_type == StorageType.NEO4J:
|
||||||
return Neo4jSearchService(
|
return Neo4jSearchService(
|
||||||
@@ -37,10 +48,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
|||||||
self.db
|
self.db
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
async def _deep_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
|
||||||
search_service = self._get_search_service(includes)
|
search_service = self._get_search_service(includes)
|
||||||
questions = await QueryPreprocessor.split(
|
questions = await QueryPreprocessor.split(
|
||||||
query,
|
query,
|
||||||
|
history,
|
||||||
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||||
)
|
)
|
||||||
query_results = []
|
query_results = []
|
||||||
@@ -49,12 +61,18 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
|||||||
query_results.append(search_results)
|
query_results.append(search_results)
|
||||||
results = sum(query_results, start=MemorySearchResult(memories=[]))
|
results = sum(query_results, start=MemorySearchResult(memories=[]))
|
||||||
results.memories.sort(key=lambda x: x.score, reverse=True)
|
results.memories.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
results.content_str = await RetrievalSummaryProcessor.summary(
|
||||||
|
query,
|
||||||
|
results.content,
|
||||||
|
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
async def _normal_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
|
||||||
search_service = self._get_search_service(includes)
|
search_service = self._get_search_service(includes)
|
||||||
questions = await QueryPreprocessor.split(
|
questions = await QueryPreprocessor.split(
|
||||||
query,
|
query,
|
||||||
|
history,
|
||||||
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||||
)
|
)
|
||||||
query_results = []
|
query_results = []
|
||||||
@@ -63,6 +81,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
|||||||
query_results.append(search_results)
|
query_results.append(search_results)
|
||||||
results = sum(query_results, start=MemorySearchResult(memories=[]))
|
results = sum(query_results, start=MemorySearchResult(memories=[]))
|
||||||
results.memories.sort(key=lambda x: x.score, reverse=True)
|
results.memories.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
results.content_str = await RetrievalSummaryProcessor.summary(
|
||||||
|
query,
|
||||||
|
results.content,
|
||||||
|
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
||||||
|
|||||||
15
api/app/core/memory/prompt/retrieval_summary.jinja2
Normal file
15
api/app/core/memory/prompt/retrieval_summary.jinja2
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
You are a Content Condenser for a memory-augmented retrieval system.
|
||||||
|
|
||||||
|
Your task is to compress the retrieved content while preserving all information that is highly relevant to the user’s query.
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
|
||||||
|
Focus only on content related to the query; ignore irrelevant parts.
|
||||||
|
Remove redundancy, filler, or repeated information only for non-XML content.
|
||||||
|
Preserve all factual details: names, dates, decisions, code snippets, technical details.
|
||||||
|
If relevant information is inside XML tags, do not remove, merge, or compress the XML tags or their internal text; keep them fully intact.
|
||||||
|
Structure multiple relevant points as a compact bullet list or paragraph, depending on density.
|
||||||
|
If no content is relevant, return exactly: "No relevant information found."
|
||||||
|
Do not add any knowledge or facts not in the retrieved content.
|
||||||
|
# [IMPORTANT] OUTPUT ONLY THE CONDENSED CONTENT, DO NOT ATTEMPT TO ANSWER THE QUERY.
|
||||||
|
# [IMPORTANT] DO NOT REMOVE OR PARAPHRASE HIGHLY RELEVANT INFORMATION.
|
||||||
@@ -21,14 +21,14 @@ class QueryPreprocessor:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def split(query: str, llm_client: RedBearLLM):
|
async def split(query: str, history: list, llm_client: RedBearLLM):
|
||||||
system_prompt = prompt_manager.render(
|
system_prompt = prompt_manager.render(
|
||||||
name="problem_split",
|
name="problem_split",
|
||||||
datetime=datetime.now().strftime("%Y-%m-%d"),
|
datetime=datetime.now().strftime("%Y-%m-%d"),
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": query},
|
{"role": "user", "content": f"<history>{history}</history><query>{query}</query>"},
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')
|
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')
|
||||||
|
|||||||
@@ -1,11 +1,29 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from app.core.models import RedBearLLM
|
from app.core.models import RedBearLLM
|
||||||
|
from app.core.memory.prompt import prompt_manager
|
||||||
|
from app.core.memory.utils.llm.llm_utils import StructResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RetrievalSummaryProcessor:
|
class RetrievalSummaryProcessor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def summary(content: str, llm_client: RedBearLLM):
|
async def summary(query, content: str, llm_client: RedBearLLM):
|
||||||
return
|
system_prompt = prompt_manager.render(
|
||||||
|
name="retrieval_summary"
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": f"<query>{query}</query><content>{content}</content>"},
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
summary = await llm_client.ainvoke(messages) | StructResponse(mode='str')
|
||||||
|
return summary
|
||||||
|
except:
|
||||||
|
logger.error("Failed to generate reply summary, returning original content", exc_info=True)
|
||||||
|
return content
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify(content: str, llm_client: RedBearLLM):
|
async def verify(query, content: str, llm_client: RedBearLLM):
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
|||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.core.memory.read_services.search_engine.result_builder import MetadataBuilder
|
||||||
|
from app.repositories.neo4j.graph_search import search_user_metadata
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -177,6 +179,22 @@ class Neo4jSearchService:
|
|||||||
memories.sort(key=lambda x: x.score, reverse=True)
|
memories.sort(key=lambda x: x.score, reverse=True)
|
||||||
return MemorySearchResult(memories=memories[:limit])
|
return MemorySearchResult(memories=memories[:limit])
|
||||||
|
|
||||||
|
async def memory_l0(self) -> Memory:
|
||||||
|
async with Neo4jConnector() as connector:
|
||||||
|
end_user_id = self.ctx.end_user_id
|
||||||
|
user_meta = await search_user_metadata(connector, end_user_id)
|
||||||
|
metadata = MetadataBuilder(user_meta)
|
||||||
|
memory = Memory(
|
||||||
|
score=1,
|
||||||
|
source=Neo4jNodeType.EXTRACTEDENTITY,
|
||||||
|
query='',
|
||||||
|
id=end_user_id,
|
||||||
|
content=metadata.content,
|
||||||
|
data=metadata.data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return memory
|
||||||
|
|
||||||
|
|
||||||
class RAGSearchService:
|
class RAGSearchService:
|
||||||
def __init__(self, ctx: MemoryContext, db: Session):
|
def __init__(self, ctx: MemoryContext, db: Session):
|
||||||
|
|||||||
@@ -42,7 +42,15 @@ class ChunkBuilder(BaseBuilder):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
return self.record.get("content")
|
parts = ["<chunk>"]
|
||||||
|
fields = [
|
||||||
|
("content", self.record.get("content", "")),
|
||||||
|
]
|
||||||
|
for tag, value in fields:
|
||||||
|
if value:
|
||||||
|
parts.append(f"<{tag}>{value}</{tag}>")
|
||||||
|
parts.append("</chunk>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class StatementBuiler(BaseBuilder):
|
class StatementBuiler(BaseBuilder):
|
||||||
@@ -57,7 +65,15 @@ class StatementBuiler(BaseBuilder):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
return self.record.get("statement")
|
parts = ["<statement>"]
|
||||||
|
fields = [
|
||||||
|
("statement", self.record.get("statement", "")),
|
||||||
|
]
|
||||||
|
for tag, value in fields:
|
||||||
|
if value:
|
||||||
|
parts.append(f"<{tag}>{value}</{tag}>")
|
||||||
|
parts.append("</statement>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class EntityBuilder(BaseBuilder):
|
class EntityBuilder(BaseBuilder):
|
||||||
@@ -73,10 +89,16 @@ class EntityBuilder(BaseBuilder):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
return (f"<entity>"
|
parts = ["<entity>"]
|
||||||
f"<name>{self.record.get("name")}<name>"
|
fields = [
|
||||||
f"<description>{self.record.get("description")}</description>"
|
("name", self.record.get("name", "")),
|
||||||
f"</entity>")
|
("description", self.record.get("description", "")),
|
||||||
|
]
|
||||||
|
for tag, value in fields:
|
||||||
|
if value:
|
||||||
|
parts.append(f"<{tag}>{value}</{tag}>")
|
||||||
|
parts.append("</entity>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class SummaryBuilder(BaseBuilder):
|
class SummaryBuilder(BaseBuilder):
|
||||||
@@ -91,7 +113,15 @@ class SummaryBuilder(BaseBuilder):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
return self.record.get("content")
|
parts = ["<summary>"]
|
||||||
|
fields = [
|
||||||
|
("content", self.record.get("content", "")),
|
||||||
|
]
|
||||||
|
for tag, value in fields:
|
||||||
|
if value:
|
||||||
|
parts.append(f"<{tag}>{value}</{tag}>")
|
||||||
|
parts.append("</summary>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class PerceptualBuilder(BaseBuilder):
|
class PerceptualBuilder(BaseBuilder):
|
||||||
@@ -114,15 +144,21 @@ class PerceptualBuilder(BaseBuilder):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
return ("<history-file-info>"
|
parts = ["<history-file-info>"]
|
||||||
f"<file-name>{self.record.get('file_name')}</file-name>"
|
fields = [
|
||||||
f"<file-path>{self.record.get('file_path')}</file-path>"
|
("file-name", self.record.get("file_name", "")),
|
||||||
f"<summary>{self.record.get('summary')}</summary>"
|
("file-path", self.record.get("file_path", "")),
|
||||||
f"<topic>{self.record.get('topic')}</topic>"
|
("summary", self.record.get("summary", "")),
|
||||||
f"<domain>{self.record.get('domain')}</domain>"
|
("topic", self.record.get("topic", "")),
|
||||||
f"<keywords>{self.record.get('keywords')}</keywords>"
|
("domain", self.record.get("domain", "")),
|
||||||
f"<file-type>{self.record.get('file_type')}</file-type>"
|
("keywords", self.record.get("keywords", [])),
|
||||||
"</history-file-info>")
|
("file-type", self.record.get("file_type", "")),
|
||||||
|
]
|
||||||
|
for tag, value in fields:
|
||||||
|
if value:
|
||||||
|
parts.append(f"<{tag}>{value}</{tag}>")
|
||||||
|
parts.append("</history-file-info>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class CommunityBuilder(BaseBuilder):
|
class CommunityBuilder(BaseBuilder):
|
||||||
@@ -137,7 +173,54 @@ class CommunityBuilder(BaseBuilder):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
return self.record.get("content")
|
parts = ["<community>"]
|
||||||
|
fields = [
|
||||||
|
("content", self.record.get("content", "")),
|
||||||
|
]
|
||||||
|
for tag, value in fields:
|
||||||
|
if value:
|
||||||
|
parts.append(f"<{tag}>{value}</{tag}>")
|
||||||
|
parts.append("</community>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataBuilder(BaseBuilder):
|
||||||
|
@property
|
||||||
|
def data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.record.get("id", ""),
|
||||||
|
"aliases_name": self.record.get("aliases", []) or [],
|
||||||
|
"description": self.record.get("description", ""),
|
||||||
|
"anchors": self.record.get("anchors", []) or [],
|
||||||
|
"beliefs_or_stances": self.record.get("beliefs_or_stances", []) or [],
|
||||||
|
"core_facts": self.record.get("core_facts", []) or [],
|
||||||
|
"events": self.record.get("events", []) or [],
|
||||||
|
"goals": self.record.get("goals", []) or [],
|
||||||
|
"interests": self.record.get("interests", []) or [],
|
||||||
|
"relations": self.record.get("relations", []) or [],
|
||||||
|
"traits": self.record.get("traits", []) or [],
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
parts = ["<user-info>"]
|
||||||
|
fields = [
|
||||||
|
("description", self.record.get("description", "")),
|
||||||
|
("aliases", self.record.get("aliases", [])),
|
||||||
|
("anchors", self.record.get("anchors", [])),
|
||||||
|
("beliefs_or_stances", self.record.get("beliefs_or_stances", [])),
|
||||||
|
("core_facts", self.record.get("core_facts", [])),
|
||||||
|
("events", self.record.get("events", [])),
|
||||||
|
("goals", self.record.get("goals", [])),
|
||||||
|
("interests", self.record.get("interests", [])),
|
||||||
|
("relations", self.record.get("relations", [])),
|
||||||
|
("traits", self.record.get("traits", [])),
|
||||||
|
]
|
||||||
|
for tag, value in fields:
|
||||||
|
if value:
|
||||||
|
parts.append(f"<{tag}>{value}</{tag}>")
|
||||||
|
parts.append("</user-info>")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def data_builder_factory(node_type, data: dict) -> T:
|
def data_builder_factory(node_type, data: dict) -> T:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ async def handle_response(response: type[BaseModel]) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
class StructResponse:
|
class StructResponse:
|
||||||
def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None):
|
def __init__(self, mode: Literal["json", "pydantic", "str"], model: Type[BaseModel] = None):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
if mode == "pydantic" and model is None:
|
if mode == "pydantic" and model is None:
|
||||||
raise ValueError("Pydantic model is required")
|
raise ValueError("Pydantic model is required")
|
||||||
@@ -31,6 +31,8 @@ class StructResponse:
|
|||||||
for block in other.content_blocks:
|
for block in other.content_blocks:
|
||||||
if block.get("type") == "text":
|
if block.get("type") == "text":
|
||||||
text += block.get("text", "")
|
text += block.get("text", "")
|
||||||
|
if self.mode == "str":
|
||||||
|
return text
|
||||||
fixed_json = json_repair.repair_json(text, return_objects=True)
|
fixed_json = json_repair.repair_json(text, return_objects=True)
|
||||||
if self.mode == "json":
|
if self.mode == "json":
|
||||||
return fixed_json
|
return fixed_json
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 匹配模板变量 {{xxx}} 的正则
|
||||||
|
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||||
|
|
||||||
|
|
||||||
class NodeExecutionError(Exception):
|
class NodeExecutionError(Exception):
|
||||||
"""节点执行失败异常。
|
"""节点执行失败异常。
|
||||||
@@ -503,10 +507,29 @@ class BaseNode(ABC):
|
|||||||
variable_pool: The variable pool used for reading and writing variables.
|
variable_pool: The variable pool used for reading and writing variables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the node's input data.
|
A dictionary containing the node's input data with all template
|
||||||
|
variables resolved to their actual runtime values.
|
||||||
"""
|
"""
|
||||||
# Default implementation returns the node configuration
|
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||||
return {"config": self.config}
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
|
||||||
|
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 节点的原始配置(可能包含模板变量)。
|
||||||
|
variable_pool: 变量池,用于解析模板变量。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
|
||||||
|
"""
|
||||||
|
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
|
||||||
|
return BaseNode._render_template(config, variable_pool, strict=False)
|
||||||
|
elif isinstance(config, dict):
|
||||||
|
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
|
||||||
|
elif isinstance(config, list):
|
||||||
|
return [BaseNode._resolve_config(item, variable_pool) for item in config]
|
||||||
|
return config
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
"""Extracts the actual output from the business result.
|
"""Extracts the actual output from the business result.
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class CodeNode(BaseNode):
|
|||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{settings.SANDBOX_URL}:8194/v1/sandbox/run",
|
f"{settings.SANDBOX_URL}/v1/sandbox/run",
|
||||||
headers={
|
headers={
|
||||||
"x-api-key": 'redbear-sandbox'
|
"x-api-key": 'redbear-sandbox'
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
|
|||||||
return business_result
|
return business_result
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
return {"file_selector": self.config.get("file_selector")}
|
file_selector = self.config.get("file_selector", "")
|
||||||
|
# 将变量选择器(如 sys.files)解析为实际值
|
||||||
|
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
|
||||||
|
return {"file_selector": resolved}
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
config = DocExtractorNodeConfig(**self.config)
|
config = DocExtractorNodeConfig(**self.config)
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class MemoryReadNode(BaseNode):
|
|||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
user_rag_memory_id=state["user_rag_memory_id"],
|
user_rag_memory_id=state["user_rag_memory_id"],
|
||||||
)
|
)
|
||||||
|
# TODO: Historical Messages -> Used to refer to coreference resolution
|
||||||
search_result = await memory_service.read(
|
search_result = await memory_service.read(
|
||||||
self._render_template(self.typed_config.message, variable_pool),
|
self._render_template(self.typed_config.message, variable_pool),
|
||||||
search_switch=SearchStrategy(self.typed_config.search_switch)
|
search_switch=SearchStrategy(self.typed_config.search_switch)
|
||||||
|
|||||||
@@ -1296,6 +1296,7 @@ RETURN e.id AS id,
|
|||||||
e.name AS name,
|
e.name AS name,
|
||||||
e.end_user_id AS end_user_id,
|
e.end_user_id AS end_user_id,
|
||||||
e.entity_type AS entity_type,
|
e.entity_type AS entity_type,
|
||||||
|
e.description AS description,
|
||||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||||
e.last_access_time AS last_access_time,
|
e.last_access_time AS last_access_time,
|
||||||
@@ -1479,6 +1480,21 @@ ORDER BY score DESC
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SEARCH_USER_METADATA = """
|
||||||
|
MATCH (n:ExtractedEntity)
|
||||||
|
WHERE (n.end_user_id = $end_user_id AND n.entity_type ='用户')
|
||||||
|
RETURN n.description AS description,
|
||||||
|
n.aliases AS aliases,
|
||||||
|
n.anchors AS anchors,
|
||||||
|
n.beliefs_or_stances AS beliefs_or_stances,
|
||||||
|
n.core_facts AS core_facts,
|
||||||
|
n.events AS events,
|
||||||
|
n.goals AS goals,
|
||||||
|
n.interests AS interests,
|
||||||
|
n.relations AS relations,
|
||||||
|
n.traits AS traits
|
||||||
|
"""
|
||||||
|
|
||||||
FULLTEXT_QUERY_CYPHER_MAPPING = {
|
FULLTEXT_QUERY_CYPHER_MAPPING = {
|
||||||
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_KEYWORD,
|
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_KEYWORD,
|
||||||
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
SEARCH_PERCEPTUAL_BY_USER_ID,
|
SEARCH_PERCEPTUAL_BY_USER_ID,
|
||||||
FULLTEXT_QUERY_CYPHER_MAPPING,
|
FULLTEXT_QUERY_CYPHER_MAPPING,
|
||||||
USER_ID_QUERY_CYPHER_MAPPING,
|
USER_ID_QUERY_CYPHER_MAPPING,
|
||||||
NODE_ID_QUERY_CYPHER_MAPPING
|
NODE_ID_QUERY_CYPHER_MAPPING,
|
||||||
|
SEARCH_USER_METADATA
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -513,7 +513,7 @@ async def search_graph_by_embedding(
|
|||||||
task_keys = []
|
task_keys = []
|
||||||
|
|
||||||
for node_type in include:
|
for node_type in include:
|
||||||
tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit*2))
|
tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit * 2))
|
||||||
task_keys.append(node_type.value)
|
task_keys.append(node_type.value)
|
||||||
|
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
@@ -557,6 +557,17 @@ async def search_graph_by_embedding(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def search_user_metadata(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
end_user_id: str
|
||||||
|
) -> dict:
|
||||||
|
user_info = await connector.execute_query(
|
||||||
|
SEARCH_USER_METADATA,
|
||||||
|
end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
return user_info[0] if user_info else {}
|
||||||
|
|
||||||
|
|
||||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
|
import uuid
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class UserInput(BaseModel):
|
class UserInput(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
history: list[dict]
|
|
||||||
search_switch: str
|
search_switch: str
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
|
session_id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ class AppDslService:
|
|||||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||||
]
|
]
|
||||||
return enriched
|
return enriched
|
||||||
|
if app_type == AppType.WORKFLOW:
|
||||||
|
enriched = {**cfg}
|
||||||
|
if "nodes" in cfg:
|
||||||
|
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
|
||||||
|
return enriched
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||||
@@ -110,7 +115,7 @@ class AppDslService:
|
|||||||
config_data = {
|
config_data = {
|
||||||
"variables": config.variables if config else [],
|
"variables": config.variables if config else [],
|
||||||
"edges": config.edges if config else [],
|
"edges": config.edges if config else [],
|
||||||
"nodes": config.nodes if config else [],
|
"nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
|
||||||
"features": config.features if config else {},
|
"features": config.features if config else {},
|
||||||
"execution_config": config.execution_config if config else {},
|
"execution_config": config.execution_config if config else {},
|
||||||
"triggers": config.triggers if config else [],
|
"triggers": config.triggers if config else [],
|
||||||
@@ -190,6 +195,23 @@ class AppDslService:
|
|||||||
def _enrich_tools(self, tools: list) -> list:
|
def _enrich_tools(self, tools: list) -> list:
|
||||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||||
|
|
||||||
|
def _enrich_workflow_nodes(self, nodes: list) -> list:
|
||||||
|
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
enriched_nodes = []
|
||||||
|
for node in (nodes or []):
|
||||||
|
node_type = node.get("type")
|
||||||
|
config = dict(node.get("config") or {})
|
||||||
|
|
||||||
|
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||||
|
model_id = config.get("model_id")
|
||||||
|
if model_id:
|
||||||
|
config["model_ref"] = self._model_ref(model_id)
|
||||||
|
del config["model_id"]
|
||||||
|
|
||||||
|
enriched_nodes.append({**node, "config": config})
|
||||||
|
return enriched_nodes
|
||||||
|
|
||||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||||
if not skill_id:
|
if not skill_id:
|
||||||
return None
|
return None
|
||||||
@@ -620,16 +642,16 @@ class AppDslService:
|
|||||||
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
||||||
config["knowledge_bases"] = resolved_kbs
|
config["knowledge_bases"] = resolved_kbs
|
||||||
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||||
model_ref = config.get("model_id")
|
model_ref = config.get("model_ref") or config.get("model_id")
|
||||||
if model_ref:
|
if model_ref:
|
||||||
ref_dict = None
|
ref_dict = None
|
||||||
if isinstance(model_ref, dict):
|
if isinstance(model_ref, dict):
|
||||||
ref_id = model_ref.get("id")
|
ref_dict = {
|
||||||
ref_name = model_ref.get("name")
|
"id": model_ref.get("id"),
|
||||||
if ref_id:
|
"name": model_ref.get("name"),
|
||||||
ref_dict = {"id": ref_id}
|
"provider": model_ref.get("provider"),
|
||||||
elif ref_name is not None:
|
"type": model_ref.get("type")
|
||||||
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
|
}
|
||||||
elif isinstance(model_ref, str):
|
elif isinstance(model_ref, str):
|
||||||
try:
|
try:
|
||||||
uuid.UUID(model_ref)
|
uuid.UUID(model_ref)
|
||||||
@@ -640,12 +662,18 @@ class AppDslService:
|
|||||||
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
||||||
if resolved_model_id:
|
if resolved_model_id:
|
||||||
config["model_id"] = resolved_model_id
|
config["model_id"] = resolved_model_id
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
|
if "model_ref" in config:
|
||||||
|
del config["model_ref"]
|
||||||
resolved_nodes.append({**node, "config": config})
|
resolved_nodes.append({**node, "config": config})
|
||||||
return resolved_nodes
|
return resolved_nodes
|
||||||
|
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ def create_long_term_memory_tool(
|
|||||||
try:
|
try:
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
memory_service = MemoryService(db, config_id, end_user_id)
|
memory_service = MemoryService(db, config_id, end_user_id)
|
||||||
|
# TODO: Historical Messages -> Used to refer to coreference resolution
|
||||||
search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK))
|
search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK))
|
||||||
|
|
||||||
# memory_content = asyncio.run(
|
# memory_content = asyncio.run(
|
||||||
|
|||||||
0
api/app/utils/__init__.py
Normal file
0
api/app/utils/__init__.py
Normal file
77
api/app/utils/tmp_session.py
Normal file
77
api/app/utils/tmp_session.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
|
||||||
|
from app.aioRedis import get_redis_connection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_TTL = 3600
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSessionCache:
|
||||||
|
"""Cache user-AI conversation history in Redis with TTL-based expiry.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
cache = ChatSessionCache(session_id="user_123")
|
||||||
|
await cache.append("user", "Hello")
|
||||||
|
await cache.append("assistant", "Hi there!")
|
||||||
|
history = await cache.get_history()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_id: str, ttl: int = DEFAULT_TTL):
|
||||||
|
self.session_id = session_id
|
||||||
|
self.ttl = ttl
|
||||||
|
self._key = f"chat:session:{session_id}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _client() -> redis.StrictRedis:
|
||||||
|
return await get_redis_connection()
|
||||||
|
|
||||||
|
async def append(self, role: str, content: str) -> None:
|
||||||
|
r = await self._client()
|
||||||
|
entry = json.dumps({"role": role, "content": content}, ensure_ascii=False)
|
||||||
|
await r.rpush(self._key, entry)
|
||||||
|
await r.expire(self._key, self.ttl)
|
||||||
|
|
||||||
|
async def append_many(self, messages: list[dict[str, str]]) -> None:
|
||||||
|
"""Batch append messages. Each dict should have ``role`` and ``content`` keys."""
|
||||||
|
if not messages:
|
||||||
|
return
|
||||||
|
r = await self._client()
|
||||||
|
entries = [
|
||||||
|
json.dumps(m, ensure_ascii=False)
|
||||||
|
for m in messages
|
||||||
|
if "role" in m and "content" in m
|
||||||
|
]
|
||||||
|
if entries:
|
||||||
|
await r.rpush(self._key, *entries)
|
||||||
|
await r.expire(self._key, self.ttl)
|
||||||
|
|
||||||
|
async def get_history(self) -> list[dict[str, str]]:
|
||||||
|
r = await self._client()
|
||||||
|
raw = await r.lrange(self._key, 0, -1)
|
||||||
|
return [json.loads(item) for item in raw]
|
||||||
|
|
||||||
|
async def get_history_text(self, user_label: str = "User", ai_label: str = "Assistant") -> str:
|
||||||
|
"""Return conversation as a formatted text block."""
|
||||||
|
history = await self.get_history()
|
||||||
|
lines = []
|
||||||
|
for msg in history:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
label = user_label if role == "user" else ai_label if role == "assistant" else role
|
||||||
|
lines.append(f"{label}: {content}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Delete the session from Redis."""
|
||||||
|
r = await self._client()
|
||||||
|
await r.delete(self._key)
|
||||||
|
|
||||||
|
async def touch(self) -> None:
|
||||||
|
"""Refresh the TTL without modifying data."""
|
||||||
|
r = await self._client()
|
||||||
|
await r.expire(self._key, self.ttl)
|
||||||
@@ -355,14 +355,13 @@ const CaseList: FC<CaseListProps> = ({
|
|||||||
// Update node ports based on case count changes (add/remove cases)
|
// Update node ports based on case count changes (add/remove cases)
|
||||||
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
|
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
|
||||||
if (!selectedNode || !graphRef?.current) return;
|
if (!selectedNode || !graphRef?.current) return;
|
||||||
|
const graph = graphRef.current;
|
||||||
// Get current port count to determine if it's an add or remove operation
|
|
||||||
const currentPorts = selectedNode.getPorts().filter((port: any) => port.group === 'right');
|
const currentRightPorts = selectedNode.getPorts().filter((port: any) => port.group === 'right');
|
||||||
const currentCaseCount = currentPorts.length - 1; // Exclude ELSE port
|
const currentCaseCount = currentRightPorts.length - 1;
|
||||||
const isAddingCase = removedCaseIndex === undefined && caseCount > currentCaseCount;
|
const isAddingCase = removedCaseIndex === undefined && caseCount > currentCaseCount;
|
||||||
|
|
||||||
// Save existing edge connections (including left-side port connections)
|
const existingEdges = graph.getEdges().filter((edge: any) =>
|
||||||
const existingEdges = graphRef.current.getEdges().filter((edge: any) =>
|
|
||||||
edge.getSourceCellId() === selectedNode.id || edge.getTargetCellId() === selectedNode.id
|
edge.getSourceCellId() === selectedNode.id || edge.getTargetCellId() === selectedNode.id
|
||||||
);
|
);
|
||||||
const edgeConnections = existingEdges.map((edge: any) => ({
|
const edgeConnections = existingEdges.map((edge: any) => ({
|
||||||
@@ -371,113 +370,70 @@ const CaseList: FC<CaseListProps> = ({
|
|||||||
targetCellId: edge.getTargetCellId(),
|
targetCellId: edge.getTargetCellId(),
|
||||||
targetPortId: edge.getTargetPortId(),
|
targetPortId: edge.getTargetPortId(),
|
||||||
sourceCellId: edge.getSourceCellId(),
|
sourceCellId: edge.getSourceCellId(),
|
||||||
isIncoming: edge.getTargetCellId() === selectedNode.id
|
isIncoming: edge.getTargetCellId() === selectedNode.id,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Remove all existing right-side ports
|
const cases = form.getFieldValue(name) || [];
|
||||||
const existingPorts = selectedNode.getPorts();
|
const leftPorts = selectedNode.getPorts().filter((p: any) => p.group !== 'right');
|
||||||
existingPorts.forEach((port: any) => {
|
const newRightPorts = Array.from({ length: caseCount + 1 }, (_, i) => ({
|
||||||
if (port.group === 'right') {
|
id: `CASE${i + 1}`,
|
||||||
selectedNode.removePort(port.id);
|
group: 'right',
|
||||||
|
args: { x: nodeWidth, y: getConditionNodeCasePortY(cases, i) },
|
||||||
|
}));
|
||||||
|
|
||||||
|
graph.startBatch('update-ports');
|
||||||
|
|
||||||
|
existingEdges.forEach((edge: any) => graph.removeCell(edge));
|
||||||
|
// Replace all ports in one prop call — produces a single cell:change:ports command
|
||||||
|
selectedNode.prop('ports/items', [...leftPorts, ...newRightPorts], { rewrite: true });
|
||||||
|
selectedNode.prop('size', { width: nodeWidth, height: calcConditionNodeTotalHeight(cases) });
|
||||||
|
|
||||||
|
edgeConnections.forEach(({sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
|
||||||
|
if (isIncoming) {
|
||||||
|
const sourceCell = graph.getCellById(sourceCellId);
|
||||||
|
if (sourceCell) {
|
||||||
|
graph.addEdge({
|
||||||
|
source: { cell: sourceCellId, port: sourcePortId },
|
||||||
|
target: { cell: selectedNode.id, port: targetPortId },
|
||||||
|
...edgeAttrs
|
||||||
|
});
|
||||||
|
sourceCell.toFront();
|
||||||
|
bringLoopChildrenToFront(sourceCell);
|
||||||
|
selectedNode.toFront();
|
||||||
|
bringLoopChildrenToFront(selectedNode);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
|
||||||
|
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) return;
|
||||||
|
let newPortId = sourcePortId;
|
||||||
|
|
||||||
|
if (removedCaseIndex !== undefined) {
|
||||||
|
if (originalCaseNumber > removedCaseIndex + 1) {
|
||||||
|
newPortId = `CASE${originalCaseNumber - 1}`;
|
||||||
|
} else if (originalCaseNumber === currentCaseCount + 1) {
|
||||||
|
newPortId = `CASE${caseCount + 1}`;
|
||||||
|
}
|
||||||
|
} else if (isAddingCase && originalCaseNumber === currentCaseCount + 1) {
|
||||||
|
newPortId = `CASE${caseCount + 1}`;
|
||||||
|
}
|
||||||
|
if (newRightPorts.find((p) => p.id === newPortId)) {
|
||||||
|
const targetCell = graph.getCellById(targetCellId);
|
||||||
|
if (targetCell) {
|
||||||
|
graph.addEdge({
|
||||||
|
source: { cell: selectedNode.id, port: newPortId },
|
||||||
|
target: { cell: targetCellId, port: targetPortId },
|
||||||
|
...edgeAttrs
|
||||||
|
});
|
||||||
|
selectedNode.toFront();
|
||||||
|
bringLoopChildrenToFront(selectedNode);
|
||||||
|
targetCell.toFront();
|
||||||
|
bringLoopChildrenToFront(targetCell);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
const cases = form.getFieldValue(name) || [];
|
graph.stopBatch('update-ports');
|
||||||
selectedNode.prop('size', { width: nodeWidth, height: calcConditionNodeTotalHeight(cases) });
|
|
||||||
|
|
||||||
// Add ELIF ports
|
|
||||||
for (let i = 0; i < caseCount; i++) {
|
|
||||||
selectedNode.addPort({
|
|
||||||
id: `CASE${i + 1}`,
|
|
||||||
group: 'right',
|
|
||||||
args: {
|
|
||||||
x: nodeWidth,
|
|
||||||
y: getConditionNodeCasePortY(cases, i),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add ELSE port
|
|
||||||
selectedNode.addPort({
|
|
||||||
id: `CASE${caseCount + 1}`,
|
|
||||||
group: 'right',
|
|
||||||
args: {
|
|
||||||
x: nodeWidth,
|
|
||||||
y: getConditionNodeCasePortY(cases, caseCount),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Restore edge connections
|
|
||||||
setTimeout(() => {
|
|
||||||
edgeConnections.forEach(({ edge, sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
|
|
||||||
// If it's an incoming connection (left-side port), restore directly
|
|
||||||
if (isIncoming) {
|
|
||||||
const sourceCell = graphRef.current?.getCellById(sourceCellId);
|
|
||||||
if (sourceCell) {
|
|
||||||
graphRef.current?.addEdge({
|
|
||||||
source: { cell: sourceCellId, port: sourcePortId },
|
|
||||||
target: { cell: selectedNode.id, port: targetPortId },
|
|
||||||
...edgeAttrs,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
sourceCell.toFront()
|
|
||||||
selectedNode.toFront()
|
|
||||||
bringLoopChildrenToFront(sourceCell)
|
|
||||||
bringLoopChildrenToFront(selectedNode)
|
|
||||||
graphRef.current?.removeCell(edge);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle right-side port connections
|
|
||||||
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
|
|
||||||
|
|
||||||
// If it's a remove operation and the port is being removed, delete the connection
|
|
||||||
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) {
|
|
||||||
graphRef.current?.removeCell(edge);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let newPortId = sourcePortId;
|
|
||||||
|
|
||||||
// If it's a remove operation, remap port IDs
|
|
||||||
if (removedCaseIndex !== undefined) {
|
|
||||||
if (originalCaseNumber > removedCaseIndex + 1) {
|
|
||||||
// Ports after the removed port, shift numbering forward
|
|
||||||
newPortId = `CASE${originalCaseNumber - 1}`;
|
|
||||||
}
|
|
||||||
// ELSE port always maps to the new ELSE port position
|
|
||||||
else if (originalCaseNumber === currentCaseCount + 1) {
|
|
||||||
newPortId = `CASE${caseCount + 1}`;
|
|
||||||
}
|
|
||||||
} else if (isAddingCase) {
|
|
||||||
// If it's an add operation, ELSE port needs to be remapped
|
|
||||||
if (originalCaseNumber === currentCaseCount + 1) {
|
|
||||||
newPortId = `CASE${caseCount + 1}`; // New ELSE port
|
|
||||||
}
|
|
||||||
// Newly added ports don't restore any connections
|
|
||||||
}
|
|
||||||
|
|
||||||
const newPorts = selectedNode.getPorts();
|
|
||||||
const matchingPort = newPorts.find((port: any) => port.id === newPortId);
|
|
||||||
|
|
||||||
if (matchingPort) {
|
|
||||||
const targetCell = graphRef.current?.getCellById(targetCellId);
|
|
||||||
if (targetCell) {
|
|
||||||
graphRef.current?.addEdge({
|
|
||||||
source: { cell: selectedNode.id, port: newPortId },
|
|
||||||
target: { cell: targetCellId, port: targetPortId },
|
|
||||||
...edgeAttrs
|
|
||||||
});
|
|
||||||
selectedNode.toFront()
|
|
||||||
bringLoopChildrenToFront(selectedNode)
|
|
||||||
targetCell.toFront()
|
|
||||||
bringLoopChildrenToFront(targetCell)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
graphRef.current?.removeCell(edge);
|
|
||||||
});
|
|
||||||
}, 50);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleChangeLogicalOperator = (index: number) => {
|
const handleChangeLogicalOperator = (index: number) => {
|
||||||
|
|||||||
@@ -42,109 +42,73 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
|
|||||||
// Update node ports based on category count changes (add/remove categories)
|
// Update node ports based on category count changes (add/remove categories)
|
||||||
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
|
const updateNodePorts = (caseCount: number, removedCaseIndex?: number) => {
|
||||||
if (!selectedNode || !graphRef?.current) return;
|
if (!selectedNode || !graphRef?.current) return;
|
||||||
|
const graph = graphRef.current;
|
||||||
|
|
||||||
// Save existing edge connections (including left-side port connections)
|
const existingEdges = graph.getEdges().filter((edge: any) =>
|
||||||
const existingEdges = graphRef.current.getEdges().filter((edge: any) =>
|
|
||||||
edge.getSourceCellId() === selectedNode.id || edge.getTargetCellId() === selectedNode.id
|
edge.getSourceCellId() === selectedNode.id || edge.getTargetCellId() === selectedNode.id
|
||||||
);
|
);
|
||||||
const edgeConnections = existingEdges.map((edge: any) => ({
|
const edgeConnections = existingEdges.map((edge: any) => ({
|
||||||
edge,
|
|
||||||
sourcePortId: edge.getSourcePortId(),
|
sourcePortId: edge.getSourcePortId(),
|
||||||
targetCellId: edge.getTargetCellId(),
|
targetCellId: edge.getTargetCellId(),
|
||||||
targetPortId: edge.getTargetPortId(),
|
targetPortId: edge.getTargetPortId(),
|
||||||
sourceCellId: edge.getSourceCellId(),
|
sourceCellId: edge.getSourceCellId(),
|
||||||
isIncoming: edge.getTargetCellId() === selectedNode.id
|
isIncoming: edge.getTargetCellId() === selectedNode.id,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Remove all existing right-side ports
|
graph.startBatch('update-ports');
|
||||||
const existingPorts = selectedNode.getPorts();
|
|
||||||
existingPorts.forEach((port: any) => {
|
existingEdges.forEach((edge: any) => graph.removeCell(edge));
|
||||||
if (port.group === 'right') {
|
// Replace all ports in one prop call — produces a single cell:change:ports command
|
||||||
selectedNode.removePort(port.id);
|
const leftPorts = selectedNode.getPorts().filter((p: any) => p.group !== 'right');
|
||||||
}
|
const newRightPorts = Array.from({ length: caseCount }, (_, i) => ({
|
||||||
});
|
id: `CASE${i + 1}`,
|
||||||
|
group: 'right',
|
||||||
|
args: { x: nodeWidth, y: portItemArgsY * i + conditionNodePortItemArgsY },
|
||||||
|
}));
|
||||||
|
selectedNode.prop('ports/items', [...leftPorts, ...newRightPorts], { rewrite: true });
|
||||||
|
|
||||||
// Calculate new node height: base height 88px + 30px for each additional port
|
|
||||||
const newHeight = conditionNodeHeight + (caseCount - 2) * conditionNodeItemHeight;
|
const newHeight = conditionNodeHeight + (caseCount - 2) * conditionNodeItemHeight;
|
||||||
|
selectedNode.prop('size', { width: nodeWidth, height: newHeight < conditionNodeHeight ? conditionNodeHeight : newHeight });
|
||||||
|
|
||||||
selectedNode.prop('size', { width: nodeWidth, height: newHeight < conditionNodeHeight ? conditionNodeHeight : newHeight })
|
edgeConnections.forEach(({ sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
|
||||||
|
if (isIncoming) {
|
||||||
// Update right port x position
|
const sourceCell = graph.getCellById(sourceCellId);
|
||||||
const currentPorts = selectedNode.getPorts();
|
if (sourceCell) {
|
||||||
currentPorts.forEach(port => {
|
graph.addEdge({
|
||||||
if (port.group === 'right' && port.args) {
|
source: { cell: sourceCellId, port: sourcePortId },
|
||||||
selectedNode.portProp(port.id!, 'args/x', nodeWidth);
|
target: { cell: selectedNode.id, port: targetPortId },
|
||||||
|
...edgeAttrs
|
||||||
|
});
|
||||||
|
sourceCell.toFront();
|
||||||
|
bringLoopChildrenToFront(sourceCell);
|
||||||
|
selectedNode.toFront();
|
||||||
|
bringLoopChildrenToFront(selectedNode);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
|
||||||
|
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) return;
|
||||||
|
let newPortId = sourcePortId;
|
||||||
|
if (removedCaseIndex !== undefined && originalCaseNumber > removedCaseIndex + 1) {
|
||||||
|
newPortId = `CASE${originalCaseNumber - 1}`;
|
||||||
|
}
|
||||||
|
if (newRightPorts.find((p) => p.id === newPortId)) {
|
||||||
|
const targetCell = graph.getCellById(targetCellId);
|
||||||
|
if (targetCell) {
|
||||||
|
graph.addEdge({
|
||||||
|
source: { cell: selectedNode.id, port: newPortId },
|
||||||
|
target: { cell: targetCellId, port: targetPortId },
|
||||||
|
...edgeAttrs
|
||||||
|
});
|
||||||
|
selectedNode.toFront();
|
||||||
|
bringLoopChildrenToFront(selectedNode);
|
||||||
|
targetCell.toFront();
|
||||||
|
bringLoopChildrenToFront(targetCell);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add category ports
|
graph.stopBatch('update-ports');
|
||||||
for (let i = 0; i < caseCount; i++) {
|
|
||||||
selectedNode.addPort({
|
|
||||||
id: `CASE${i + 1}`,
|
|
||||||
group: 'right',
|
|
||||||
args: {
|
|
||||||
x: nodeWidth,
|
|
||||||
y: portItemArgsY * i + conditionNodePortItemArgsY,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// Restore edge connections
|
|
||||||
setTimeout(() => {
|
|
||||||
edgeConnections.forEach(({ edge, sourcePortId, targetCellId, targetPortId, sourceCellId, isIncoming }: any) => {
|
|
||||||
graphRef.current?.removeCell(edge);
|
|
||||||
|
|
||||||
// If it's an incoming connection (left-side port), restore directly
|
|
||||||
if (isIncoming) {
|
|
||||||
const sourceCell = graphRef.current?.getCellById(sourceCellId);
|
|
||||||
if (sourceCell) {
|
|
||||||
graphRef.current?.addEdge({
|
|
||||||
source: { cell: sourceCellId, port: sourcePortId },
|
|
||||||
target: { cell: selectedNode.id, port: targetPortId },
|
|
||||||
...edgeAttrs
|
|
||||||
});
|
|
||||||
sourceCell.toFront()
|
|
||||||
bringLoopChildrenToFront(sourceCell)
|
|
||||||
selectedNode.toFront()
|
|
||||||
bringLoopChildrenToFront(selectedNode)
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle right-side port connections
|
|
||||||
const originalCaseNumber = parseInt(sourcePortId.match(/CASE(\d+)/)?.[1] || '0');
|
|
||||||
|
|
||||||
// If it's a removed port, don't recreate the connection
|
|
||||||
if (removedCaseIndex !== undefined && originalCaseNumber === removedCaseIndex + 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let newPortId = sourcePortId;
|
|
||||||
|
|
||||||
// If a port was removed, remap subsequent port IDs
|
|
||||||
if (removedCaseIndex !== undefined && originalCaseNumber > removedCaseIndex + 1) {
|
|
||||||
newPortId = `CASE${originalCaseNumber - 1}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the new port exists
|
|
||||||
const newPorts = selectedNode.getPorts();
|
|
||||||
const matchingPort = newPorts.find((port: any) => port.id === newPortId);
|
|
||||||
|
|
||||||
if (matchingPort) {
|
|
||||||
const targetCell = graphRef.current?.getCellById(targetCellId);
|
|
||||||
if (targetCell) {
|
|
||||||
graphRef.current?.addEdge({
|
|
||||||
source: { cell: selectedNode.id, port: newPortId },
|
|
||||||
target: { cell: targetCellId, port: targetPortId },
|
|
||||||
...edgeAttrs
|
|
||||||
});
|
|
||||||
selectedNode.toFront()
|
|
||||||
bringLoopChildrenToFront(selectedNode)
|
|
||||||
targetCell.toFront()
|
|
||||||
bringLoopChildrenToFront(targetCell)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}, 50);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleAddCategory = (addFunc: Function) => {
|
const handleAddCategory = (addFunc: Function) => {
|
||||||
|
|||||||
@@ -124,9 +124,7 @@ export const useWorkflowGraph = ({
|
|||||||
const [canRedo, setCanRedo] = useState(false)
|
const [canRedo, setCanRedo] = useState(false)
|
||||||
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
|
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
|
||||||
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
|
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
|
||||||
const undoRef = useRef<() => void>(() => {})
|
const syncChildRelationshipsRef = useRef<() => void>(() => { })
|
||||||
const redoRef = useRef<() => void>(() => {})
|
|
||||||
const syncChildRelationshipsRef = useRef<() => void>(() => {})
|
|
||||||
const isSyncingRef = useRef(false)
|
const isSyncingRef = useRef(false)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!graphRef.current) return
|
if (!graphRef.current) return
|
||||||
@@ -532,24 +530,82 @@ export const useWorkflowGraph = ({
|
|||||||
const graph = graphRef.current
|
const graph = graphRef.current
|
||||||
graph.disableHistory()
|
graph.disableHistory()
|
||||||
graph.getNodes().forEach(node => {
|
graph.getNodes().forEach(node => {
|
||||||
const cycleId = node.getData()?.cycle
|
const nodeData = node.getData()
|
||||||
if (!cycleId) return
|
|
||||||
const parentNode = graph.getCellById(cycleId) as Node | null
|
|
||||||
if (!parentNode) return
|
|
||||||
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
|
|
||||||
parentNode.addChild(node, { silent: true })
|
|
||||||
}
|
|
||||||
})
|
|
||||||
graph.getNodes().forEach(node => {
|
|
||||||
const children = node.getChildren()
|
const children = node.getChildren()
|
||||||
if (!children?.length) return
|
|
||||||
children.forEach(child => {
|
const cycleId = nodeData?.cycle
|
||||||
if (!child.isNode()) return
|
|
||||||
const childCycleId = (child as Node).getData?.()?.cycle
|
if (cycleId) {
|
||||||
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
|
const parentNode = graph.getCellById(cycleId) as Node | null
|
||||||
node.removeChild(child, { silent: true })
|
if (!parentNode) return
|
||||||
|
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
|
||||||
|
parentNode.addChild(node, { silent: true })
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
|
if (nodeData.type === 'if-else') {
|
||||||
|
const rightPorts = node.getPorts().filter(p => p.group === 'right')
|
||||||
|
const caseCount = rightPorts.length - 1 // last port is ELSE
|
||||||
|
const currentCases: any[] = nodeData.config?.cases?.defaultValue ?? []
|
||||||
|
const newCases = caseCount !== currentCases.length
|
||||||
|
? Array.from({ length: caseCount }, (_, i) => currentCases[i] ?? { logical_operator: 'and', expressions: [] })
|
||||||
|
: currentCases
|
||||||
|
if (caseCount !== currentCases.length) {
|
||||||
|
node.setData({
|
||||||
|
...nodeData,
|
||||||
|
config: { ...nodeData.config, cases: { ...nodeData.config.cases, defaultValue: newCases } }
|
||||||
|
}, { deep: false, silent: true })
|
||||||
|
}
|
||||||
|
// Sync node height and port Y positions
|
||||||
|
node.prop('size', { width: nodeWidth, height: calcConditionNodeTotalHeight(newCases) })
|
||||||
|
newCases.forEach((_c: any, i: number) => {
|
||||||
|
node.portProp(`CASE${i + 1}`, 'args/y', getConditionNodeCasePortY(newCases, i))
|
||||||
|
})
|
||||||
|
node.portProp(`CASE${newCases.length + 1}`, 'args/y', getConditionNodeCasePortY(newCases, newCases.length))
|
||||||
|
node.toFront()
|
||||||
|
graph.getEdges().filter(e => e.getSourceCellId() === node.id).forEach(e => {
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId())
|
||||||
|
tgt?.toFront()
|
||||||
|
})
|
||||||
|
} else if (nodeData.type === 'question-classifier') {
|
||||||
|
const rightPorts = node.getPorts().filter(p => p.group === 'right')
|
||||||
|
const currentCategories: any[] = nodeData.config?.categories?.defaultValue ?? []
|
||||||
|
const categoryCount = rightPorts.length
|
||||||
|
const newCategories = categoryCount !== currentCategories.length
|
||||||
|
? rightPorts.map((port, i) => {
|
||||||
|
if (currentCategories[i]) return currentCategories[i]
|
||||||
|
const edge = graph.getEdges().find(e => e.getSourceCellId() === node.id && e.getSourcePortId() === port.id)
|
||||||
|
return edge ? { name: '' } : {}
|
||||||
|
})
|
||||||
|
: currentCategories
|
||||||
|
if (categoryCount !== currentCategories.length) {
|
||||||
|
node.setData({
|
||||||
|
...nodeData,
|
||||||
|
config: { ...nodeData.config, categories: { ...nodeData.config.categories, defaultValue: [...newCategories] } }
|
||||||
|
}, { deep: false, silent: true })
|
||||||
|
}
|
||||||
|
// Sync node height and port Y positions
|
||||||
|
const newHeight = conditionNodeHeight + (categoryCount - 2) * conditionNodeItemHeight
|
||||||
|
node.prop('size', { width: nodeWidth, height: Math.max(newHeight, conditionNodeHeight) })
|
||||||
|
rightPorts.forEach((_p, i) => {
|
||||||
|
node.portProp(`CASE${i + 1}`, 'args/y', portItemArgsY * i + conditionNodePortItemArgsY)
|
||||||
|
})
|
||||||
|
node.toFront()
|
||||||
|
graph.getEdges().filter(e => e.getSourceCellId() === node.id).forEach(e => {
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId())
|
||||||
|
tgt?.toFront()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (children?.length) {
|
||||||
|
children.forEach(child => {
|
||||||
|
if (!child.isNode()) return
|
||||||
|
const childCycleId = (child as Node).getData?.()?.cycle
|
||||||
|
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
|
||||||
|
node.removeChild(child, { silent: true })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
resizeGroupNodes(graph)
|
resizeGroupNodes(graph)
|
||||||
graph.getEdges().forEach(edge => {
|
graph.getEdges().forEach(edge => {
|
||||||
|
|||||||
Reference in New Issue
Block a user