feat(memory): add emotion batch extraction task and improve extraction prompts

- Add extract_emotion_batch_task for async emotion extraction
- Refine Chinese entity types and relation types in extraction prompts
- Add STATEMENT_EMOTION_UPDATE Cypher query for Neo4j backfill
- Refactor statement_step and triplet_step implementations
This commit is contained in:
lanceyq
2026-04-24 13:55:14 +08:00
parent a98011fc8a
commit b0ddd12cc6
14 changed files with 1321 additions and 833 deletions

View File

@@ -1370,6 +1370,160 @@ def write_message_task(
_shutdown_loop_gracefully(loop)
@celery_app.task(
bind=True,
name="app.tasks.extract_emotion_batch",
max_retries=2,
default_retry_delay=30,
)
def extract_emotion_batch_task(
self,
statements: List[Dict[str, str]],
llm_model_id: str,
language: str = "zh",
emotion_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Celery task: batch emotion extraction + Neo4j backfill.
Runs asynchronously after the main write pipeline completes.
Each statement is processed independently; individual failures
degrade gracefully without affecting other statements.
Args:
statements: List of dicts with keys: statement_id, statement_text, speaker.
llm_model_id: UUID string of the LLM model to use.
language: Language code ("zh" / "en").
emotion_config: Optional dict with emotion step config overrides
(emotion_extract_keywords, emotion_enable_subject).
"""
task_id = self.request.id
total = len(statements)
logger.info(
f"[Emotion] 开始批量情绪提取: "
f"statements={total}, llm_model_id={llm_model_id}, "
f"language={language}, task_id={task_id}"
)
start_time = time.time()
if not statements:
return {"status": "SUCCESS", "total": 0, "extracted": 0, "failed": 0, "task_id": task_id}
async def _run() -> Dict[str, Any]:
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.steps.base import StepContext
from app.core.memory.storage_services.extraction_engine.steps.emotion_step import EmotionExtractionStep
from app.core.memory.storage_services.extraction_engine.steps.schema import (
EmotionStepInput,
EmotionStepOutput,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.cypher_queries import STATEMENT_EMOTION_UPDATE
# Build LLM client
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_model_id)
# Build minimal pipeline config with emotion enabled
pipeline_config = ExtractionPipelineConfig(emotion_enabled=True)
# Apply optional config overrides
emo_cfg = emotion_config or {}
for key in ("emotion_extract_keywords", "emotion_enable_subject"):
if key in emo_cfg:
setattr(pipeline_config, key, emo_cfg[key])
context = StepContext(
llm_client=llm_client,
language=language,
config=pipeline_config,
)
step = EmotionExtractionStep(context)
# Concurrent extraction for all statements
extracted = 0
failed = 0
update_items = []
async def _extract_one(stmt_dict: Dict[str, str]):
nonlocal extracted, failed
inp = EmotionStepInput(
statement_id=stmt_dict["statement_id"],
statement_text=stmt_dict["statement_text"],
speaker=stmt_dict.get("speaker", "user"),
)
try:
result: EmotionStepOutput = await step.run(inp)
update_items.append({
"statement_id": stmt_dict["statement_id"],
"emotion_type": result.emotion_type,
"emotion_intensity": result.emotion_intensity,
"emotion_keywords": result.emotion_keywords,
})
extracted += 1
logger.debug(
f"[Emotion] 单条提取完成: stmt={stmt_dict['statement_id']}, "
f"type={result.emotion_type}, intensity={result.emotion_intensity}"
)
except Exception as e:
failed += 1
logger.warning(
f"[Emotion] 单条提取失败 stmt={stmt_dict['statement_id']}: {e}"
)
await asyncio.gather(*[_extract_one(s) for s in statements])
# Batch update Neo4j via write transaction
if update_items:
connector = Neo4jConnector()
try:
async def _write_emotions(tx):
result = await tx.run(STATEMENT_EMOTION_UPDATE, items=update_items)
records = [record async for record in result]
return records
records = await connector.execute_write_transaction(_write_emotions)
logger.info(
f"[Emotion] Neo4j 回写完成: "
f"更新 {len(records)}/{len(update_items)} 条 Statement 节点"
)
except Exception as e:
logger.error(f"[Emotion] Neo4j 回写失败: {e}")
raise
finally:
await connector.close()
return {"extracted": extracted, "failed": failed}
loop = None
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed = time.time() - start_time
logger.info(
f"[Emotion] 任务完成: 提取={result['extracted']}, "
f"失败={result['failed']}, 耗时={elapsed:.2f}s, task_id={task_id}"
)
return {
"status": "SUCCESS",
"total": total,
**result,
"elapsed_time": elapsed,
"task_id": task_id,
}
except Exception as e:
elapsed = time.time() - start_time
logger.error(
f"[Emotion] 任务失败: {e}, 耗时={elapsed:.2f}s",
exc_info=True,
)
raise self.retry(exc=e)
finally:
if loop:
_shutdown_loop_gracefully(loop)
# unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]: