From b9c705998b17365ab81b4039ca3c763b37078e7d Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 2 Dec 2025 20:29:44 +0800 Subject: [PATCH] chore: remove deprecated codebase and related files --- Dockerfile | 97 - LICENSE | 201 -- alembic.ini | 116 - app/aioRedis.py | 201 -- app/celery_app.py | 109 - app/celery_worker.py | 10 - app/controllers/__init__.py | 60 - app/controllers/api_key_controller.py | 151 -- app/controllers/app_controller.py | 716 ------- app/controllers/auth_controller.py | 195 -- app/controllers/chunk_controller.py | 447 ---- app/controllers/document_controller.py | 341 --- app/controllers/file_controller.py | 453 ---- app/controllers/knowledge_controller.py | 305 --- app/controllers/knowledgeshare_controller.py | 199 -- app/controllers/memory_agent_controller.py | 802 ------- .../memory_dashboard_controller.py | 516 ----- app/controllers/memory_storage_controller.py | 542 ----- app/controllers/model_controller.py | 332 --- app/controllers/multi_agent_controller.py | 404 ---- app/controllers/public_share_controller.py | 437 ---- app/controllers/release_share_controller.py | 170 -- app/controllers/service/__init__.py | 17 - app/controllers/service/app_api_controller.py | 16 - .../service/memory_api_controller.py | 16 - app/controllers/service/rag_api_controller.py | 16 - app/controllers/setup_controller.py | 23 - app/controllers/task_controller.py | 25 - app/controllers/test_controller.py | 126 -- app/controllers/upload_controller.py | 376 ---- app/controllers/user_controller.py | 183 -- app/controllers/workspace_controller.py | 342 --- app/core/agent/__init__.py | 0 app/core/agent/agent_api_text.py | 35 - app/core/agent/agent_chat.py | 109 - app/core/agent/langchain_agent.py | 347 --- app/core/api_key_utils.py | 56 - app/core/compensation.py | 47 - app/core/config.py | 237 -- app/core/error_codes.py | 130 -- app/core/exceptions.py | 86 - app/core/logging_config.py | 633 ------ app/core/memory/__init__.py | 0 app/core/memory/agent/__init__.py | 0 .../memory/agent/langgraph_graph/__init__.py | 16 - .../agent/langgraph_graph/nodes/__init__.py | 10 - .../agent/langgraph_graph/nodes/input_node.py | 144 -- .../agent/langgraph_graph/nodes/tool_node.py | 199 -- .../agent/langgraph_graph/read_graph.py | 508 ----- .../agent/langgraph_graph/routing/__init__.py | 13 - .../agent/langgraph_graph/routing/routers.py | 123 -- .../agent/langgraph_graph/state/__init__.py | 13 - .../agent/langgraph_graph/state/extractors.py | 164 -- .../agent/langgraph_graph/write_graph.py | 78 - .../memory/agent/logger_file/log_streamer.py | 285 --- .../memory/agent/logger_file/logger_data.py | 32 - app/core/memory/agent/mcp_server/__init__.py | 28 - .../memory/agent/mcp_server/mcp_instance.py | 11 - .../agent/mcp_server/models/__init__.py | 30 - .../agent/mcp_server/models/problem_models.py | 34 - .../mcp_server/models/retrieval_models.py | 17 - .../agent/mcp_server/models/summary_models.py | 31 - .../mcp_server/models/verification_models.py | 14 - app/core/memory/agent/mcp_server/server.py | 161 -- .../agent/mcp_server/services/__init__.py | 23 - .../mcp_server/services/parameter_builder.py | 157 -- .../mcp_server/services/search_service.py | 193 -- .../mcp_server/services/session_service.py | 169 -- .../mcp_server/services/template_service.py | 116 - .../memory/agent/mcp_server/tools/__init__.py | 27 - .../agent/mcp_server/tools/data_tools.py | 149 -- .../agent/mcp_server/tools/problem_tools.py | 293 --- .../agent/mcp_server/tools/retrieval_tools.py | 282 --- .../agent/mcp_server/tools/summary_tools.py | 647 ------ .../mcp_server/tools/verification_tools.py | 169 -- app/core/memory/agent/utils/__init__.py | 7 - app/core/memory/agent/utils/get_dialogs.py | 70 - app/core/memory/agent/utils/llm_tools.py | 204 -- app/core/memory/agent/utils/mcp_tools.py | 15 - app/core/memory/agent/utils/messages_tool.py | 239 --- app/core/memory/agent/utils/model_tool.py | 38 - app/core/memory/agent/utils/multimodal.py | 131 -- .../prompt/Problem_Extension_prompt.jinja2 | 81 - .../prompt/Retrieve_Summary_prompt.jinja2 | 37 - .../agent/utils/prompt/Retrieve_prompt.jinja2 | 29 - ...mplate_for_image_recognition_prompt.jinja2 | 10 - .../prompt/distinguish_types_prompt.jinja2 | 34 - .../prompt/problem_breakdown_prompt.jinja2 | 160 -- .../utils/prompt/split_verify_prompt.jinja2 | 60 - .../agent/utils/prompt/summary_prompt.jinja2 | 57 - app/core/memory/agent/utils/redis_tool.py | 203 -- .../memory/agent/utils/type_classifier.py | 59 - app/core/memory/agent/utils/verify_tool.py | 76 - .../memory/agent/utils/write_to_database.py | 49 - app/core/memory/agent/utils/write_tools.py | 183 -- app/core/memory/llm_tools/__init__.py | 19 - app/core/memory/llm_tools/chunker_client.py | 330 --- app/core/memory/llm_tools/embedder_client.py | 176 -- app/core/memory/llm_tools/llm_client.py | 187 -- app/core/memory/llm_tools/openai_client.py | 198 -- app/core/memory/llm_tools/openai_embedder.py | 87 - app/core/memory/main.py | 332 --- app/core/memory/models/__init__.py | 115 - app/core/memory/models/base_response.py | 59 - app/core/memory/models/config_models.py | 93 - app/core/memory/models/dedup_models.py | 52 - app/core/memory/models/graph_models.py | 304 --- app/core/memory/models/message_models.py | 247 --- app/core/memory/models/triplet_models.py | 85 - app/core/memory/models/variate_config.py | 151 -- app/core/memory/src/__init__.py | 0 app/core/memory/src/llm_tools/__init__.py | 0 .../memory/src/llm_tools/chunker_client.py | 330 --- .../memory/src/llm_tools/embedder_client.py | 22 - app/core/memory/src/llm_tools/llm_client.py | 37 - .../memory/src/llm_tools/openai_client.py | 224 -- .../memory/src/llm_tools/openai_embedder.py | 26 - app/core/memory/src/search.py | 980 --------- app/core/memory/storage_services/__init__.py | 8 - .../extraction_engine/__init__.py | 8 - .../data_preprocessing/__init__.py | 13 - .../data_preprocessing/data_chunker.py | 54 - .../data_preprocessing/data_preprocessor.py | 785 ------- .../data_preprocessing/data_pruning.py | 573 ----- .../deduplication/__init__.py | 41 - .../deduplication/deduped_and_disamb.py | 784 ------- .../deduplication/entity_dedup_llm.py | 689 ------ .../deduplication/second_layer_dedup.py | 149 -- .../deduplication/two_stage_dedup.py | 106 - .../extraction_orchestrator.py | 1306 ----------- .../knowledge_extraction/__init__.py | 11 - .../knowledge_extraction/chunk_extraction.py | 103 - .../embedding_generation.py | 307 --- .../knowledge_extraction/memory_summary.py | 117 - .../statement_extraction.py | 301 --- .../temporal_extraction.py | 222 -- .../triplet_extraction.py | 223 -- .../extraction_engine/pipeline_help.py | 528 ----- .../forgetting_engine/__init__.py | 8 - .../forgetting_engine/forgetting_engine.py | 271 --- .../forgetting_engine/memory_strength.py | 251 --- .../reflection_engine/__init__.py | 21 - .../reflection_engine/self_reflexion.py | 585 ----- .../storage_services/search/__init__.py | 131 -- .../storage_services/search/hybrid_chatbot.py | 447 ---- .../storage_services/search/hybrid_search.py | 408 ---- .../storage_services/search/keyword_search.py | 122 -- .../search/search_strategy.py | 125 -- .../search/semantic_search.py | 159 -- app/core/memory/utils/README.md | 445 ---- app/core/memory/utils/__init__.py | 65 - app/core/memory/utils/config/__init__.py | 82 - .../utils/config/config_optimization.py | 398 ---- app/core/memory/utils/config/config_utils.py | 267 --- app/core/memory/utils/config/definitions.py | 360 ---- app/core/memory/utils/config/get_data.py | 93 - .../memory/utils/config/get_example_data.py | 90 - .../memory/utils/config/litellm_config.py | 516 ----- app/core/memory/utils/config/overrides.py | 611 ------ app/core/memory/utils/data/__init__.py | 43 - app/core/memory/utils/data/ontology.py | 199 -- app/core/memory/utils/data/text_utils.py | 81 - app/core/memory/utils/data/time_utils.py | 127 -- app/core/memory/utils/llm/__init__.py | 18 - app/core/memory/utils/llm/llm_utils.py | 77 - app/core/memory/utils/log/__init__.py | 24 - app/core/memory/utils/log/audit_logger.py | 182 -- app/core/memory/utils/log/logging_utils.py | 38 - app/core/memory/utils/paths/__init__.py | 16 - app/core/memory/utils/paths/output_paths.py | 133 -- app/core/memory/utils/prompt/__init__.py | 34 - app/core/memory/utils/prompt/prompt_utils.py | 240 --- .../utils/prompt/prompts/entity_dedup.jinja2 | 60 - .../utils/prompt/prompts/evaluate.jinja2 | 19 - .../prompt/prompts/extracat_Pruning.jinja2 | 49 - .../prompt/prompts/extract_statement.jinja2 | 207 -- .../prompt/prompts/extract_temporal.jinja2 | 81 - .../prompt/prompts/extract_triplet.jinja2 | 248 --- .../prompt/prompts/memory_summary.jinja2 | 29 - .../utils/prompt/prompts/reflexion.jinja2 | 23 - .../memory/utils/prompt/prompts/system.jinja2 | 2 - .../memory/utils/prompt/prompts/user.jinja2 | 5 - .../memory/utils/prompt/template_render.py | 42 - .../utils/self_reflexion_utils/__init__.py | 16 - .../utils/self_reflexion_utils/evaluate.py | 49 - .../utils/self_reflexion_utils/reflexion.py | 51 - .../self_reflexion_utils/self_reflexion.py | 250 --- .../memory/utils/visualization/__init__.py | 26 - .../visualization/forgetting_visualizer.py | 386 ---- app/core/models/__init__.py | 13 - app/core/models/base.py | 167 -- app/core/models/embedding.py | 23 - app/core/models/factory.py | 16 - app/core/models/llm.py | 133 -- app/core/models/rerank copy.py | 35 - app/core/models/rerank.py | 80 - app/core/permissions/__init__.py | 17 - app/core/permissions/models.py | 133 -- app/core/permissions/policies.py | 151 -- app/core/permissions/service.py | 176 -- app/core/rag/__init__.py | 0 app/core/rag/app/__init__.py | 0 app/core/rag/app/audio.py | 42 - app/core/rag/app/book.py | 170 -- app/core/rag/app/laws.py | 219 -- app/core/rag/app/mail.py | 114 - app/core/rag/app/manual.py | 299 --- app/core/rag/app/naive.py | 849 -------- app/core/rag/app/one.py | 149 -- app/core/rag/app/paper.py | 284 --- app/core/rag/app/picture.py | 96 - app/core/rag/app/presentation.py | 164 -- app/core/rag/app/qa.py | 455 ---- app/core/rag/common/__init__.py | 0 app/core/rag/common/connection_utils.py | 106 - app/core/rag/common/constants.py | 180 -- app/core/rag/common/file_utils.py | 28 - app/core/rag/common/float_utils.py | 30 - app/core/rag/common/misc_utils.py | 92 - app/core/rag/common/settings.py | 2 - app/core/rag/common/string_utils.py | 57 - app/core/rag/common/token_utils.py | 59 - app/core/rag/deepdoc/README.md | 122 -- app/core/rag/deepdoc/README_zh.md | 116 - app/core/rag/deepdoc/__init__.py | 2 - app/core/rag/deepdoc/parser/__init__.py | 24 - app/core/rag/deepdoc/parser/docx_parser.py | 123 -- app/core/rag/deepdoc/parser/excel_parser.py | 210 -- app/core/rag/deepdoc/parser/figure_parser.py | 118 - app/core/rag/deepdoc/parser/html_parser.py | 197 -- app/core/rag/deepdoc/parser/json_parser.py | 159 -- .../rag/deepdoc/parser/markdown_parser.py | 277 --- app/core/rag/deepdoc/parser/mineru_parser.py | 524 ----- app/core/rag/deepdoc/parser/pdf_parser.py | 1387 ------------ app/core/rag/deepdoc/parser/ppt_parser.py | 83 - app/core/rag/deepdoc/parser/txt_parser.py | 48 - app/core/rag/deepdoc/parser/utils.py | 16 - app/core/rag/deepdoc/vision/__init__.py | 75 - .../rag/deepdoc/vision/layout_recognizer.py | 440 ---- app/core/rag/deepdoc/vision/ocr.py | 737 ------- app/core/rag/deepdoc/vision/operators.py | 709 ------ app/core/rag/deepdoc/vision/postprocess.py | 354 --- app/core/rag/deepdoc/vision/recognizer.py | 427 ---- app/core/rag/deepdoc/vision/seeit.py | 71 - app/core/rag/deepdoc/vision/t_ocr.py | 77 - app/core/rag/deepdoc/vision/t_recognizer.py | 170 -- .../vision/table_structure_recognizer.py | 597 ------ app/core/rag/graphrag/__init__.py | 0 app/core/rag/graphrag/utils.py | 19 - app/core/rag/llm/__init__.py | 0 app/core/rag/llm/chat_model.py | 670 ------ app/core/rag/llm/cv_model.py | 470 ---- app/core/rag/llm/sequence2txt_model.py | 179 -- app/core/rag/models/__init__.py | 0 app/core/rag/models/chunk.py | 72 - app/core/rag/nlp/__init__.py | 857 -------- app/core/rag/nlp/query.py | 261 --- app/core/rag/nlp/rag_tokenizer.py | 499 ----- app/core/rag/nlp/search.py | 192 -- app/core/rag/nlp/surname.py | 126 -- app/core/rag/nlp/synonym.py | 85 - app/core/rag/nlp/term_weight.py | 228 -- app/core/rag/prompts/__init__.py | 6 - app/core/rag/prompts/analyze_task_system.md | 48 - app/core/rag/prompts/analyze_task_user.md | 9 - app/core/rag/prompts/ask_summary.md | 14 - app/core/rag/prompts/assign_toc_levels.md | 53 - app/core/rag/prompts/citation_plus.md | 13 - app/core/rag/prompts/citation_prompt.md | 109 - .../rag/prompts/content_tagging_prompt.md | 32 - .../rag/prompts/cross_languages_sys_prompt.md | 35 - .../prompts/cross_languages_user_prompt.md | 7 - app/core/rag/prompts/full_question_prompt.md | 62 - app/core/rag/prompts/generator.py | 728 ------- app/core/rag/prompts/keyword_prompt.md | 16 - app/core/rag/prompts/meta_filter.md | 53 - app/core/rag/prompts/next_step.md | 92 - app/core/rag/prompts/question_prompt.md | 19 - app/core/rag/prompts/rank_memory.md | 30 - app/core/rag/prompts/reflect.md | 75 - app/core/rag/prompts/related_question.md | 55 - .../rag/prompts/structured_output_prompt.md | 16 - app/core/rag/prompts/summary4memory.md | 35 - app/core/rag/prompts/template.py | 20 - app/core/rag/prompts/toc_detection.md | 29 - app/core/rag/prompts/toc_extraction.md | 53 - .../rag/prompts/toc_extraction_continue.md | 60 - app/core/rag/prompts/toc_from_text_system.md | 119 -- app/core/rag/prompts/toc_from_text_user.md | 8 - app/core/rag/prompts/toc_index.md | 20 - app/core/rag/prompts/toc_relevance_system.md | 118 - app/core/rag/prompts/toc_relevance_user.md | 17 - app/core/rag/prompts/tool_call_summary.md | 19 - .../rag/prompts/vision_llm_describe_prompt.md | 23 - .../vision_llm_figure_describe_prompt.md | 24 - app/core/rag/utils/__init__.py | 0 app/core/rag/utils/doc_store_conn.py | 255 --- app/core/rag/utils/file_utils.py | 247 --- app/core/rag/vdb/__init__.py | 0 app/core/rag/vdb/elasticsearch/__init__.py | 0 .../vdb/elasticsearch/elasticsearch_vector.py | 779 ------- app/core/rag/vdb/field.py | 16 - app/core/rag/vdb/vector_base.py | 67 - app/core/rag_utils/README.md | 116 - app/core/rag_utils/__init__.py | 14 - app/core/rag_utils/chunk_insight.py | 205 -- app/core/rag_utils/chunk_summary.py | 99 - app/core/rag_utils/chunk_tags.py | 191 -- app/core/response_utils.py | 22 - app/core/security.py | 126 -- app/core/sensitive_filter.py | 210 -- app/core/share_utils.py | 94 - app/core/storage_strategy.py | 198 -- app/core/transaction_monitor.py | 230 -- app/core/uow.py | 265 --- app/core/upload_enums.py | 10 - app/core/upload_policies.py | 80 - app/core/validators/__init__.py | 6 - app/core/validators/file_validator.py | 357 ---- app/core/workflow/__init__.py | 0 app/db.py | 20 - app/dependencies.py | 459 ---- app/main.py | 382 ---- app/models/__init__.py | 52 - app/models/agent_app_config_model.py | 44 - app/models/api_key_model.py | 90 - app/models/app_model.py | 115 - app/models/app_release_model.py | 68 - app/models/appshare_model.py | 28 - app/models/conversation_model.py | 80 - app/models/data_config_model.py | 71 - app/models/document_model.py | 28 - app/models/end_user_model.py | 24 - app/models/file_model.py | 17 - app/models/generic_file_model.py | 52 - app/models/knowledge_model.py | 69 - app/models/knowledgeshare_model.py | 24 - app/models/memory_increment_model.py | 18 - app/models/models_model.py | 104 - app/models/multi_agent_model.py | 143 -- app/models/release_share_model.py | 47 - app/models/retrieval_info.py | 13 - app/models/tenant_model.py | 23 - app/models/user_model.py | 30 - app/models/workspace_model.py | 70 - app/repositories/__init__.py | 171 -- app/repositories/api_key_repository.py | 138 -- app/repositories/app_repository.py | 30 - app/repositories/base_repository.py | 108 - app/repositories/data_config_repository.py | 408 ---- app/repositories/document_repository.py | 153 -- app/repositories/end_user_repository.py | 105 - app/repositories/file_repository.py | 121 -- app/repositories/generic_file_repository.py | 243 --- app/repositories/knowledge_repository.py | 211 -- app/repositories/knowledgeshare_repository.py | 142 -- .../memory_increment_repository.py | 110 - app/repositories/model_repository.py | 386 ---- app/repositories/neo4j/__init__.py | 32 - app/repositories/neo4j/add_edges.py | 102 - app/repositories/neo4j/add_nodes.py | 215 -- .../neo4j/base_neo4j_repository.py | 175 -- app/repositories/neo4j/create_indexes.py | 332 --- app/repositories/neo4j/cypher_queries.py | 684 ------ app/repositories/neo4j/dialog_repository.py | 185 -- app/repositories/neo4j/entity_repository.py | 339 --- app/repositories/neo4j/graph_saver.py | 216 -- app/repositories/neo4j/graph_search.py | 584 ----- app/repositories/neo4j/neo4j_connector.py | 114 - .../neo4j/statement_repository.py | 319 --- app/repositories/release_share_repository.py | 59 - app/repositories/tenant_repository.py | 167 -- app/repositories/user_repository.py | 322 --- .../workspace_invite_repository.py | 134 -- app/repositories/workspace_repository.py | 383 ---- app/schemas/__init__.py | 108 - app/schemas/api_key_schema.py | 104 - app/schemas/app_schema.py | 425 ---- app/schemas/chunk_schema.py | 26 - app/schemas/conversation_schema.py | 86 - app/schemas/document_schema.py | 63 - app/schemas/end_user_schema.py | 17 - app/schemas/file_schema.py | 39 - app/schemas/generic_file_schema.py | 69 - app/schemas/item_schema.py | 5 - app/schemas/knowledge_schema.py | 69 - app/schemas/knowledgeshare_schema.py | 37 - app/schemas/memory_agent_schema.py | 17 - app/schemas/memory_increment_schema.py | 18 - app/schemas/memory_storage_schema.py | 343 --- app/schemas/model_schema.py | 162 -- app/schemas/multi_agent_schema.py | 167 -- app/schemas/prompt_schema.py | 61 - app/schemas/release_share_schema.py | 104 - app/schemas/response_schema.py | 22 - app/schemas/retrieval_info_schema.py | 13 - app/schemas/tenant_schema.py | 65 - app/schemas/token_schema.py | 30 - app/schemas/user_schema.py | 76 - app/schemas/workspace_schema.py | 172 -- app/services/__init__.py | 0 app/services/agent_config_converter.py | 116 - app/services/agent_config_helper.py | 38 - app/services/agent_invocation_service.py | 0 app/services/agent_registry.py | 191 -- app/services/agent_server.py | 130 -- app/services/agent_tools.py | 331 --- app/services/api_key_service.py | 173 -- app/services/app_service.py | 1903 ----------------- app/services/auth_service.py | 262 --- app/services/conversation_service.py | 229 -- app/services/conversation_state_manager.py | 261 --- app/services/document_service.py | 85 - app/services/draft_run_service.py | 1630 -------------- app/services/file_service.py | 87 - app/services/knowledge_service.py | 126 -- app/services/knowledgeshare_service.py | 108 - app/services/langchain_tool_server.py | 51 - app/services/llm_client.py | 340 --- app/services/llm_router.py | 685 ------ app/services/memory_agent_service.py | 1035 --------- app/services/memory_dashboard_service.py | 595 ------ app/services/memory_konwledges_server.py | 582 ----- app/services/memory_storage_service.py | 568 ----- app/services/model_parameter_merger.py | 160 -- app/services/model_service.py | 409 ---- app/services/multi_agent_config_converter.py | 191 -- app/services/multi_agent_orchestrator.py | 1116 ---------- app/services/multi_agent_service.py | 630 ------ app/services/release_share_service.py | 444 ---- app/services/session_service.py | 160 -- app/services/shared_chat_service.py | 759 ------- app/services/smart_router.py | 426 ---- app/services/task_service.py | 52 - app/services/tenant_service.py | 220 -- app/services/upload_service.py | 617 ------ app/services/user_service.py | 570 ----- app/services/workspace_service.py | 776 ------- app/tasks.py | 451 ---- app/utils/volc_asr.py | 112 - docker-compose.yml | 22 - env.example | 87 - main.py | 6 - migrations/README | 1 - migrations/env.py | 141 -- migrations/script.py.mako | 26 - pyproject.toml | 137 -- 447 files changed, 82854 deletions(-) delete mode 100644 Dockerfile delete mode 100644 LICENSE delete mode 100644 alembic.ini delete mode 100644 app/aioRedis.py delete mode 100644 app/celery_app.py delete mode 100644 app/celery_worker.py delete mode 100644 app/controllers/__init__.py delete mode 100644 app/controllers/api_key_controller.py delete mode 100644 app/controllers/app_controller.py delete mode 100644 app/controllers/auth_controller.py delete mode 100644 app/controllers/chunk_controller.py delete mode 100644 app/controllers/document_controller.py delete mode 100644 app/controllers/file_controller.py delete mode 100644 app/controllers/knowledge_controller.py delete mode 100644 app/controllers/knowledgeshare_controller.py delete mode 100644 app/controllers/memory_agent_controller.py delete mode 100644 app/controllers/memory_dashboard_controller.py delete mode 100644 app/controllers/memory_storage_controller.py delete mode 100644 app/controllers/model_controller.py delete mode 100644 app/controllers/multi_agent_controller.py delete mode 100644 app/controllers/public_share_controller.py delete mode 100644 app/controllers/release_share_controller.py delete mode 100644 app/controllers/service/__init__.py delete mode 100644 app/controllers/service/app_api_controller.py delete mode 100644 app/controllers/service/memory_api_controller.py delete mode 100644 app/controllers/service/rag_api_controller.py delete mode 100644 app/controllers/setup_controller.py delete mode 100644 app/controllers/task_controller.py delete mode 100644 app/controllers/test_controller.py delete mode 100644 app/controllers/upload_controller.py delete mode 100644 app/controllers/user_controller.py delete mode 100644 app/controllers/workspace_controller.py delete mode 100644 app/core/agent/__init__.py delete mode 100644 app/core/agent/agent_api_text.py delete mode 100644 app/core/agent/agent_chat.py delete mode 100644 app/core/agent/langchain_agent.py delete mode 100644 app/core/api_key_utils.py delete mode 100644 app/core/compensation.py delete mode 100644 app/core/config.py delete mode 100644 app/core/error_codes.py delete mode 100644 app/core/exceptions.py delete mode 100644 app/core/logging_config.py delete mode 100644 app/core/memory/__init__.py delete mode 100644 app/core/memory/agent/__init__.py delete mode 100644 app/core/memory/agent/langgraph_graph/__init__.py delete mode 100644 app/core/memory/agent/langgraph_graph/nodes/__init__.py delete mode 100644 app/core/memory/agent/langgraph_graph/nodes/input_node.py delete mode 100644 app/core/memory/agent/langgraph_graph/nodes/tool_node.py delete mode 100644 app/core/memory/agent/langgraph_graph/read_graph.py delete mode 100644 app/core/memory/agent/langgraph_graph/routing/__init__.py delete mode 100644 app/core/memory/agent/langgraph_graph/routing/routers.py delete mode 100644 app/core/memory/agent/langgraph_graph/state/__init__.py delete mode 100644 app/core/memory/agent/langgraph_graph/state/extractors.py delete mode 100644 app/core/memory/agent/langgraph_graph/write_graph.py delete mode 100644 app/core/memory/agent/logger_file/log_streamer.py delete mode 100644 app/core/memory/agent/logger_file/logger_data.py delete mode 100644 app/core/memory/agent/mcp_server/__init__.py delete mode 100644 app/core/memory/agent/mcp_server/mcp_instance.py delete mode 100644 app/core/memory/agent/mcp_server/models/__init__.py delete mode 100644 app/core/memory/agent/mcp_server/models/problem_models.py delete mode 100644 app/core/memory/agent/mcp_server/models/retrieval_models.py delete mode 100644 app/core/memory/agent/mcp_server/models/summary_models.py delete mode 100644 app/core/memory/agent/mcp_server/models/verification_models.py delete mode 100644 app/core/memory/agent/mcp_server/server.py delete mode 100644 app/core/memory/agent/mcp_server/services/__init__.py delete mode 100644 app/core/memory/agent/mcp_server/services/parameter_builder.py delete mode 100644 app/core/memory/agent/mcp_server/services/search_service.py delete mode 100644 app/core/memory/agent/mcp_server/services/session_service.py delete mode 100644 app/core/memory/agent/mcp_server/services/template_service.py delete mode 100644 app/core/memory/agent/mcp_server/tools/__init__.py delete mode 100644 app/core/memory/agent/mcp_server/tools/data_tools.py delete mode 100644 app/core/memory/agent/mcp_server/tools/problem_tools.py delete mode 100644 app/core/memory/agent/mcp_server/tools/retrieval_tools.py delete mode 100644 app/core/memory/agent/mcp_server/tools/summary_tools.py delete mode 100644 app/core/memory/agent/mcp_server/tools/verification_tools.py delete mode 100644 app/core/memory/agent/utils/__init__.py delete mode 100644 app/core/memory/agent/utils/get_dialogs.py delete mode 100644 app/core/memory/agent/utils/llm_tools.py delete mode 100644 app/core/memory/agent/utils/mcp_tools.py delete mode 100644 app/core/memory/agent/utils/messages_tool.py delete mode 100644 app/core/memory/agent/utils/model_tool.py delete mode 100644 app/core/memory/agent/utils/multimodal.py delete mode 100644 app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/prompt/distinguish_types_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/prompt/summary_prompt.jinja2 delete mode 100644 app/core/memory/agent/utils/redis_tool.py delete mode 100644 app/core/memory/agent/utils/type_classifier.py delete mode 100644 app/core/memory/agent/utils/verify_tool.py delete mode 100644 app/core/memory/agent/utils/write_to_database.py delete mode 100644 app/core/memory/agent/utils/write_tools.py delete mode 100644 app/core/memory/llm_tools/__init__.py delete mode 100644 app/core/memory/llm_tools/chunker_client.py delete mode 100644 app/core/memory/llm_tools/embedder_client.py delete mode 100644 app/core/memory/llm_tools/llm_client.py delete mode 100644 app/core/memory/llm_tools/openai_client.py delete mode 100644 app/core/memory/llm_tools/openai_embedder.py delete mode 100644 app/core/memory/main.py delete mode 100644 app/core/memory/models/__init__.py delete mode 100644 app/core/memory/models/base_response.py delete mode 100644 app/core/memory/models/config_models.py delete mode 100644 app/core/memory/models/dedup_models.py delete mode 100644 app/core/memory/models/graph_models.py delete mode 100644 app/core/memory/models/message_models.py delete mode 100644 app/core/memory/models/triplet_models.py delete mode 100644 app/core/memory/models/variate_config.py delete mode 100644 app/core/memory/src/__init__.py delete mode 100644 app/core/memory/src/llm_tools/__init__.py delete mode 100644 app/core/memory/src/llm_tools/chunker_client.py delete mode 100644 app/core/memory/src/llm_tools/embedder_client.py delete mode 100644 app/core/memory/src/llm_tools/llm_client.py delete mode 100644 app/core/memory/src/llm_tools/openai_client.py delete mode 100644 app/core/memory/src/llm_tools/openai_embedder.py delete mode 100644 app/core/memory/src/search.py delete mode 100644 app/core/memory/storage_services/__init__.py delete mode 100644 app/core/memory/storage_services/extraction_engine/__init__.py delete mode 100644 app/core/memory/storage_services/extraction_engine/data_preprocessing/__init__.py delete mode 100644 app/core/memory/storage_services/extraction_engine/data_preprocessing/data_chunker.py delete mode 100644 app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py delete mode 100644 app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py delete mode 100644 app/core/memory/storage_services/extraction_engine/deduplication/__init__.py delete mode 100644 app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py delete mode 100644 app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py delete mode 100644 app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py delete mode 100644 app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py delete mode 100644 app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py delete mode 100644 app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py delete mode 100644 app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py delete mode 100644 app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py delete mode 100644 app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py delete mode 100644 app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py delete mode 100644 app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py delete mode 100644 app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py delete mode 100644 app/core/memory/storage_services/extraction_engine/pipeline_help.py delete mode 100644 app/core/memory/storage_services/forgetting_engine/__init__.py delete mode 100644 app/core/memory/storage_services/forgetting_engine/forgetting_engine.py delete mode 100644 app/core/memory/storage_services/forgetting_engine/memory_strength.py delete mode 100644 app/core/memory/storage_services/reflection_engine/__init__.py delete mode 100644 app/core/memory/storage_services/reflection_engine/self_reflexion.py delete mode 100644 app/core/memory/storage_services/search/__init__.py delete mode 100644 app/core/memory/storage_services/search/hybrid_chatbot.py delete mode 100644 app/core/memory/storage_services/search/hybrid_search.py delete mode 100644 app/core/memory/storage_services/search/keyword_search.py delete mode 100644 app/core/memory/storage_services/search/search_strategy.py delete mode 100644 app/core/memory/storage_services/search/semantic_search.py delete mode 100644 app/core/memory/utils/README.md delete mode 100644 app/core/memory/utils/__init__.py delete mode 100644 app/core/memory/utils/config/__init__.py delete mode 100644 app/core/memory/utils/config/config_optimization.py delete mode 100644 app/core/memory/utils/config/config_utils.py delete mode 100644 app/core/memory/utils/config/definitions.py delete mode 100644 app/core/memory/utils/config/get_data.py delete mode 100644 app/core/memory/utils/config/get_example_data.py delete mode 100644 app/core/memory/utils/config/litellm_config.py delete mode 100644 app/core/memory/utils/config/overrides.py delete mode 100644 app/core/memory/utils/data/__init__.py delete mode 100644 app/core/memory/utils/data/ontology.py delete mode 100644 app/core/memory/utils/data/text_utils.py delete mode 100644 app/core/memory/utils/data/time_utils.py delete mode 100644 app/core/memory/utils/llm/__init__.py delete mode 100644 app/core/memory/utils/llm/llm_utils.py delete mode 100644 app/core/memory/utils/log/__init__.py delete mode 100644 app/core/memory/utils/log/audit_logger.py delete mode 100644 app/core/memory/utils/log/logging_utils.py delete mode 100644 app/core/memory/utils/paths/__init__.py delete mode 100644 app/core/memory/utils/paths/output_paths.py delete mode 100644 app/core/memory/utils/prompt/__init__.py delete mode 100644 app/core/memory/utils/prompt/prompt_utils.py delete mode 100644 app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/evaluate.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/extract_statement.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/memory_summary.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/reflexion.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/system.jinja2 delete mode 100644 app/core/memory/utils/prompt/prompts/user.jinja2 delete mode 100644 app/core/memory/utils/prompt/template_render.py delete mode 100644 app/core/memory/utils/self_reflexion_utils/__init__.py delete mode 100644 app/core/memory/utils/self_reflexion_utils/evaluate.py delete mode 100644 app/core/memory/utils/self_reflexion_utils/reflexion.py delete mode 100644 app/core/memory/utils/self_reflexion_utils/self_reflexion.py delete mode 100644 app/core/memory/utils/visualization/__init__.py delete mode 100644 app/core/memory/utils/visualization/forgetting_visualizer.py delete mode 100644 app/core/models/__init__.py delete mode 100644 app/core/models/base.py delete mode 100644 app/core/models/embedding.py delete mode 100644 app/core/models/factory.py delete mode 100644 app/core/models/llm.py delete mode 100644 app/core/models/rerank copy.py delete mode 100644 app/core/models/rerank.py delete mode 100644 app/core/permissions/__init__.py delete mode 100644 app/core/permissions/models.py delete mode 100644 app/core/permissions/policies.py delete mode 100644 app/core/permissions/service.py delete mode 100644 app/core/rag/__init__.py delete mode 100644 app/core/rag/app/__init__.py delete mode 100644 app/core/rag/app/audio.py delete mode 100644 app/core/rag/app/book.py delete mode 100644 app/core/rag/app/laws.py delete mode 100644 app/core/rag/app/mail.py delete mode 100644 app/core/rag/app/manual.py delete mode 100644 app/core/rag/app/naive.py delete mode 100644 app/core/rag/app/one.py delete mode 100644 app/core/rag/app/paper.py delete mode 100644 app/core/rag/app/picture.py delete mode 100644 app/core/rag/app/presentation.py delete mode 100644 app/core/rag/app/qa.py delete mode 100644 app/core/rag/common/__init__.py delete mode 100644 app/core/rag/common/connection_utils.py delete mode 100644 app/core/rag/common/constants.py delete mode 100644 app/core/rag/common/file_utils.py delete mode 100644 app/core/rag/common/float_utils.py delete mode 100644 app/core/rag/common/misc_utils.py delete mode 100644 app/core/rag/common/settings.py delete mode 100644 app/core/rag/common/string_utils.py delete mode 100644 app/core/rag/common/token_utils.py delete mode 100644 app/core/rag/deepdoc/README.md delete mode 100644 app/core/rag/deepdoc/README_zh.md delete mode 100644 app/core/rag/deepdoc/__init__.py delete mode 100644 app/core/rag/deepdoc/parser/__init__.py delete mode 100644 app/core/rag/deepdoc/parser/docx_parser.py delete mode 100644 app/core/rag/deepdoc/parser/excel_parser.py delete mode 100644 app/core/rag/deepdoc/parser/figure_parser.py delete mode 100644 app/core/rag/deepdoc/parser/html_parser.py delete mode 100644 app/core/rag/deepdoc/parser/json_parser.py delete mode 100644 app/core/rag/deepdoc/parser/markdown_parser.py delete mode 100644 app/core/rag/deepdoc/parser/mineru_parser.py delete mode 100644 app/core/rag/deepdoc/parser/pdf_parser.py delete mode 100644 app/core/rag/deepdoc/parser/ppt_parser.py delete mode 100644 app/core/rag/deepdoc/parser/txt_parser.py delete mode 100644 app/core/rag/deepdoc/parser/utils.py delete mode 100644 app/core/rag/deepdoc/vision/__init__.py delete mode 100644 app/core/rag/deepdoc/vision/layout_recognizer.py delete mode 100644 app/core/rag/deepdoc/vision/ocr.py delete mode 100644 app/core/rag/deepdoc/vision/operators.py delete mode 100644 app/core/rag/deepdoc/vision/postprocess.py delete mode 100644 app/core/rag/deepdoc/vision/recognizer.py delete mode 100644 app/core/rag/deepdoc/vision/seeit.py delete mode 100644 app/core/rag/deepdoc/vision/t_ocr.py delete mode 100644 app/core/rag/deepdoc/vision/t_recognizer.py delete mode 100644 app/core/rag/deepdoc/vision/table_structure_recognizer.py delete mode 100644 app/core/rag/graphrag/__init__.py delete mode 100644 app/core/rag/graphrag/utils.py delete mode 100644 app/core/rag/llm/__init__.py delete mode 100644 app/core/rag/llm/chat_model.py delete mode 100644 app/core/rag/llm/cv_model.py delete mode 100644 app/core/rag/llm/sequence2txt_model.py delete mode 100644 app/core/rag/models/__init__.py delete mode 100644 app/core/rag/models/chunk.py delete mode 100644 app/core/rag/nlp/__init__.py delete mode 100644 app/core/rag/nlp/query.py delete mode 100644 app/core/rag/nlp/rag_tokenizer.py delete mode 100644 app/core/rag/nlp/search.py delete mode 100644 app/core/rag/nlp/surname.py delete mode 100644 app/core/rag/nlp/synonym.py delete mode 100644 app/core/rag/nlp/term_weight.py delete mode 100644 app/core/rag/prompts/__init__.py delete mode 100644 app/core/rag/prompts/analyze_task_system.md delete mode 100644 app/core/rag/prompts/analyze_task_user.md delete mode 100644 app/core/rag/prompts/ask_summary.md delete mode 100644 app/core/rag/prompts/assign_toc_levels.md delete mode 100644 app/core/rag/prompts/citation_plus.md delete mode 100644 app/core/rag/prompts/citation_prompt.md delete mode 100644 app/core/rag/prompts/content_tagging_prompt.md delete mode 100644 app/core/rag/prompts/cross_languages_sys_prompt.md delete mode 100644 app/core/rag/prompts/cross_languages_user_prompt.md delete mode 100644 app/core/rag/prompts/full_question_prompt.md delete mode 100644 app/core/rag/prompts/generator.py delete mode 100644 app/core/rag/prompts/keyword_prompt.md delete mode 100644 app/core/rag/prompts/meta_filter.md delete mode 100644 app/core/rag/prompts/next_step.md delete mode 100644 app/core/rag/prompts/question_prompt.md delete mode 100644 app/core/rag/prompts/rank_memory.md delete mode 100644 app/core/rag/prompts/reflect.md delete mode 100644 app/core/rag/prompts/related_question.md delete mode 100644 app/core/rag/prompts/structured_output_prompt.md delete mode 100644 app/core/rag/prompts/summary4memory.md delete mode 100644 app/core/rag/prompts/template.py delete mode 100644 app/core/rag/prompts/toc_detection.md delete mode 100644 app/core/rag/prompts/toc_extraction.md delete mode 100644 app/core/rag/prompts/toc_extraction_continue.md delete mode 100644 app/core/rag/prompts/toc_from_text_system.md delete mode 100644 app/core/rag/prompts/toc_from_text_user.md delete mode 100644 app/core/rag/prompts/toc_index.md delete mode 100644 app/core/rag/prompts/toc_relevance_system.md delete mode 100644 app/core/rag/prompts/toc_relevance_user.md delete mode 100644 app/core/rag/prompts/tool_call_summary.md delete mode 100644 app/core/rag/prompts/vision_llm_describe_prompt.md delete mode 100644 app/core/rag/prompts/vision_llm_figure_describe_prompt.md delete mode 100644 app/core/rag/utils/__init__.py delete mode 100644 app/core/rag/utils/doc_store_conn.py delete mode 100644 app/core/rag/utils/file_utils.py delete mode 100644 app/core/rag/vdb/__init__.py delete mode 100644 app/core/rag/vdb/elasticsearch/__init__.py delete mode 100644 app/core/rag/vdb/elasticsearch/elasticsearch_vector.py delete mode 100644 app/core/rag/vdb/field.py delete mode 100644 app/core/rag/vdb/vector_base.py delete mode 100644 app/core/rag_utils/README.md delete mode 100644 app/core/rag_utils/__init__.py delete mode 100644 app/core/rag_utils/chunk_insight.py delete mode 100644 app/core/rag_utils/chunk_summary.py delete mode 100644 app/core/rag_utils/chunk_tags.py delete mode 100644 app/core/response_utils.py delete mode 100644 app/core/security.py delete mode 100644 app/core/sensitive_filter.py delete mode 100644 app/core/share_utils.py delete mode 100644 app/core/storage_strategy.py delete mode 100644 app/core/transaction_monitor.py delete mode 100644 app/core/uow.py delete mode 100644 app/core/upload_enums.py delete mode 100644 app/core/upload_policies.py delete mode 100644 app/core/validators/__init__.py delete mode 100644 app/core/validators/file_validator.py delete mode 100644 app/core/workflow/__init__.py delete mode 100644 app/db.py delete mode 100644 app/dependencies.py delete mode 100644 app/main.py delete mode 100644 app/models/__init__.py delete mode 100644 app/models/agent_app_config_model.py delete mode 100644 app/models/api_key_model.py delete mode 100644 app/models/app_model.py delete mode 100644 app/models/app_release_model.py delete mode 100644 app/models/appshare_model.py delete mode 100644 app/models/conversation_model.py delete mode 100644 app/models/data_config_model.py delete mode 100644 app/models/document_model.py delete mode 100644 app/models/end_user_model.py delete mode 100644 app/models/file_model.py delete mode 100644 app/models/generic_file_model.py delete mode 100644 app/models/knowledge_model.py delete mode 100644 app/models/knowledgeshare_model.py delete mode 100644 app/models/memory_increment_model.py delete mode 100644 app/models/models_model.py delete mode 100644 app/models/multi_agent_model.py delete mode 100644 app/models/release_share_model.py delete mode 100644 app/models/retrieval_info.py delete mode 100644 app/models/tenant_model.py delete mode 100644 app/models/user_model.py delete mode 100644 app/models/workspace_model.py delete mode 100644 app/repositories/__init__.py delete mode 100644 app/repositories/api_key_repository.py delete mode 100644 app/repositories/app_repository.py delete mode 100644 app/repositories/base_repository.py delete mode 100644 app/repositories/data_config_repository.py delete mode 100644 app/repositories/document_repository.py delete mode 100644 app/repositories/end_user_repository.py delete mode 100644 app/repositories/file_repository.py delete mode 100644 app/repositories/generic_file_repository.py delete mode 100644 app/repositories/knowledge_repository.py delete mode 100644 app/repositories/knowledgeshare_repository.py delete mode 100644 app/repositories/memory_increment_repository.py delete mode 100644 app/repositories/model_repository.py delete mode 100644 app/repositories/neo4j/__init__.py delete mode 100644 app/repositories/neo4j/add_edges.py delete mode 100644 app/repositories/neo4j/add_nodes.py delete mode 100644 app/repositories/neo4j/base_neo4j_repository.py delete mode 100644 app/repositories/neo4j/create_indexes.py delete mode 100644 app/repositories/neo4j/cypher_queries.py delete mode 100644 app/repositories/neo4j/dialog_repository.py delete mode 100644 app/repositories/neo4j/entity_repository.py delete mode 100644 app/repositories/neo4j/graph_saver.py delete mode 100644 app/repositories/neo4j/graph_search.py delete mode 100644 app/repositories/neo4j/neo4j_connector.py delete mode 100644 app/repositories/neo4j/statement_repository.py delete mode 100644 app/repositories/release_share_repository.py delete mode 100644 app/repositories/tenant_repository.py delete mode 100644 app/repositories/user_repository.py delete mode 100644 app/repositories/workspace_invite_repository.py delete mode 100644 app/repositories/workspace_repository.py delete mode 100644 app/schemas/__init__.py delete mode 100644 app/schemas/api_key_schema.py delete mode 100644 app/schemas/app_schema.py delete mode 100644 app/schemas/chunk_schema.py delete mode 100644 app/schemas/conversation_schema.py delete mode 100644 app/schemas/document_schema.py delete mode 100644 app/schemas/end_user_schema.py delete mode 100644 app/schemas/file_schema.py delete mode 100644 app/schemas/generic_file_schema.py delete mode 100644 app/schemas/item_schema.py delete mode 100644 app/schemas/knowledge_schema.py delete mode 100644 app/schemas/knowledgeshare_schema.py delete mode 100644 app/schemas/memory_agent_schema.py delete mode 100644 app/schemas/memory_increment_schema.py delete mode 100644 app/schemas/memory_storage_schema.py delete mode 100644 app/schemas/model_schema.py delete mode 100644 app/schemas/multi_agent_schema.py delete mode 100644 app/schemas/prompt_schema.py delete mode 100644 app/schemas/release_share_schema.py delete mode 100644 app/schemas/response_schema.py delete mode 100644 app/schemas/retrieval_info_schema.py delete mode 100644 app/schemas/tenant_schema.py delete mode 100644 app/schemas/token_schema.py delete mode 100644 app/schemas/user_schema.py delete mode 100644 app/schemas/workspace_schema.py delete mode 100644 app/services/__init__.py delete mode 100644 app/services/agent_config_converter.py delete mode 100644 app/services/agent_config_helper.py delete mode 100644 app/services/agent_invocation_service.py delete mode 100644 app/services/agent_registry.py delete mode 100644 app/services/agent_server.py delete mode 100644 app/services/agent_tools.py delete mode 100644 app/services/api_key_service.py delete mode 100644 app/services/app_service.py delete mode 100644 app/services/auth_service.py delete mode 100644 app/services/conversation_service.py delete mode 100644 app/services/conversation_state_manager.py delete mode 100644 app/services/document_service.py delete mode 100644 app/services/draft_run_service.py delete mode 100644 app/services/file_service.py delete mode 100644 app/services/knowledge_service.py delete mode 100644 app/services/knowledgeshare_service.py delete mode 100644 app/services/langchain_tool_server.py delete mode 100644 app/services/llm_client.py delete mode 100644 app/services/llm_router.py delete mode 100644 app/services/memory_agent_service.py delete mode 100644 app/services/memory_dashboard_service.py delete mode 100644 app/services/memory_konwledges_server.py delete mode 100644 app/services/memory_storage_service.py delete mode 100644 app/services/model_parameter_merger.py delete mode 100644 app/services/model_service.py delete mode 100644 app/services/multi_agent_config_converter.py delete mode 100644 app/services/multi_agent_orchestrator.py delete mode 100644 app/services/multi_agent_service.py delete mode 100644 app/services/release_share_service.py delete mode 100644 app/services/session_service.py delete mode 100644 app/services/shared_chat_service.py delete mode 100644 app/services/smart_router.py delete mode 100644 app/services/task_service.py delete mode 100644 app/services/tenant_service.py delete mode 100644 app/services/upload_service.py delete mode 100644 app/services/user_service.py delete mode 100644 app/services/workspace_service.py delete mode 100644 app/tasks.py delete mode 100644 app/utils/volc_asr.py delete mode 100644 docker-compose.yml delete mode 100644 env.example delete mode 100644 main.py delete mode 100644 migrations/README delete mode 100644 migrations/env.py delete mode 100644 migrations/script.py.mako delete mode 100644 pyproject.toml diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index a5c818ea..00000000 --- a/Dockerfile +++ /dev/null @@ -1,97 +0,0 @@ -FROM python:3.12-slim -USER root -SHELL ["/bin/bash", "-c"] - -ARG NEED_MIRROR=1 - -WORKDIR /code - -# 1. Download dependencies through download_deps.py: python download_deps.py --china-mirrors -# 2. Copy models -COPY huggingface.co/InfiniFlow/deepdoc/ /code/res/deepdoc/ -COPY huggingface.co/InfiniFlow/text_concat_xgb_v1.0/ /code/res/text_concat_xgb_v1.0/ -COPY huggingface.co/InfiniFlow/huqie/huqie.txt.trie /code/res/ - -# https://github.com/chrismattmann/tika-python -# 3. This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache. -COPY nltk_data/ /root/nltk_data/ -COPY tika-server-standard-3.1.0.jar /tmp/tika-server.jar -COPY tika-server-standard-3.1.0.jar.md5 /tmp/tika-server.jar.md5 -COPY cl100k_base.tiktoken /code/res/9b5ad71b2ce5302211f9c61530b329a4922fc6a4 - -ENV TIKA_SERVER_JAR="file:///tmp/tika-server.jar" -ENV DEBIAN_FRONTEND=noninteractive - -# 4. Setup apt -# Python package and implicit dependencies: -# opencv-python: libglib2.0-0 libglx-mesa0 libgl1 -# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb -# python-pptx: default-jdk tika-server-standard-3.0.0.jar -# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev -RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ - apt install -y libicu-dev && \ - if [ "$NEED_MIRROR" == "1" ]; then \ - rm -f /etc/apt/sources.list.d/debian.sources && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-backports main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \ - echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list; \ - fi; \ - rm -f /etc/apt/apt.conf.d/docker-clean && \ - echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \ - chmod 1777 /tmp && \ - apt update && \ - apt --no-install-recommends install -y ca-certificates && \ - apt update && \ - apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \ - apt install -y pkg-config libgdiplus && \ - apt install -y default-jdk && \ - apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ - apt install -y libjemalloc-dev && \ - apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ - apt install -y ghostscript - -RUN if [ "$NEED_MIRROR" == "1" ]; then \ - pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ - pip3 config set global.trusted-host pypi.tuna.tsinghua.edu.cn; \ - mkdir -p /etc/uv && \ - echo "[[index]]" > /etc/uv/uv.toml && \ - echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \ - echo "default = true" >> /etc/uv/uv.toml; \ - fi; \ - pipx install uv - -ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 -ENV PATH=/root/.local/bin:$PATH - -# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13 -# 5. aspose-slides on linux/arm64 is unavailable -COPY libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb /tmp/ -RUN if [ "$(uname -m)" = "x86_64" ]; then \ - dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \ - elif [ "$(uname -m)" = "aarch64" ]; then \ - dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_arm64.deb; \ - fi && \ - rm -f /tmp/libssl1.1_*.deb - - -# 6. install dependencies from uv.lock file -COPY ./pyproject.toml /code/pyproject.toml -COPY ./uv.lock /code/uv.lock -COPY ./app /code/app - -# https://github.com/astral-sh/uv/issues/10462 -# uv records index url into uv.lock but doesn't failover among multiple indexes -RUN --mount=type=cache,id=mem_uv,target=/root/.cache/uv,sharing=locked \ - if [ "$NEED_MIRROR" == "1" ]; then \ - sed -i 's|pypi.org|pypi.tuna.tsinghua.edu.cn|g' uv.lock; \ - else \ - sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \ - fi; \ - uv lock && \ - uv sync --locked --no-dev - -ENV PATH=/code/.venv/bin:$PATH - - - diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 0f032d49..00000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ -Apache License -Version 2.0, January 2004 -http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - -"License" shall mean the terms and conditions for use, reproduction, -and distribution as defined by Sections 1 through 9 of this document. - -"Licensor" shall mean the copyright owner or entity authorized by -the copyright owner that is granting the License. - -"Legal Entity" shall mean the union of the acting entity and all -other entities that control, are controlled by, or are under common -control with that entity. For the purposes of this definition, -"control" means (i) the power, direct or indirect, to cause the -direction or management of such entity, whether by contract or -otherwise, or (ii) ownership of fifty percent (50%) or more of the -outstanding shares, or (iii) beneficial ownership of such entity. - -"You" (or "Your") shall mean an individual or Legal Entity -exercising permissions granted by this License. - -"Source" form shall mean the preferred form for making modifications, -including but not limited to software source code, documentation -sources, and configuration files. - -"Object" form shall mean any form resulting from mechanical -transformation or translation of a Source form, including but not -limited to compiled object code, generated documentation, and -conversions to other media types. - -"Work" shall mean the work of authorship, whether in Source or -Object form, made available under the License, as indicated by a -copyright notice that is included in or attached to the work -(an example is provided in the Appendix below). - -"Derivative Works" shall mean any work, whether in Source or Object -form, that is based on (or derived from) the Work and for which the -editorial revisions, annotations, elaborations, or other modifications -represent, as a whole, an original work of authorship. For the purposes -of this License, Derivative Works shall not include works that remain -separable from, or merely link (or bind by name) to the interfaces of, -the Work and Derivative Works thereof. - -"Contribution" shall mean any work of authorship, including -the original version of the Work and any modifications or additions -to that Work or Derivative Works thereof, that is intentionally -submitted to Licensor for inclusion in the Work by the copyright owner -or by an individual or Legal Entity authorized to submit on behalf of -the copyright owner. For the purposes of this definition, "submitted" -means any form of electronic, verbal, or written communication sent -to the Licensor or its representatives, including but not limited to -communication on electronic mailing lists, source code control systems, -and issue tracking systems that are managed by, or on behalf of, the -Licensor for the purpose of discussing and improving the Work, but -excluding communication that is conspicuously marked or otherwise -designated in writing by the copyright owner as "Not a Contribution." - -"Contributor" shall mean Licensor and any individual or Legal Entity -on behalf of whom a Contribution has been received by Licensor and -subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of -this License, each Contributor hereby grants to You a perpetual, -worldwide, non-exclusive, no-charge, royalty-free, irrevocable -copyright license to reproduce, prepare Derivative Works of, -publicly display, publicly perform, sublicense, and distribute the -Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of -this License, each Contributor hereby grants to You a perpetual, -worldwide, non-exclusive, no-charge, royalty-free, irrevocable -(except as stated in this section) patent license to make, have made, -use, offer to sell, sell, import, and otherwise transfer the Work, -where such license applies only to those patent claims licensable -by such Contributor that are necessarily infringed by their -Contribution(s) alone or by combination of their Contribution(s) -with the Work to which such Contribution(s) was submitted. If You -institute patent litigation against any entity (including a -cross-claim or counterclaim in a lawsuit) alleging that the Work -or a Contribution incorporated within the Work constitutes direct -or contributory patent infringement, then any patent licenses -granted to You under this License for that Work shall terminate -as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the -Work or Derivative Works thereof in any medium, with or without -modifications, and in Source or Object form, provided that You -meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - -You may add Your own copyright statement to Your modifications and -may provide additional or different license terms and conditions -for use, reproduction, or distribution of Your modifications, or -for any such Derivative Works as a whole, provided Your use, -reproduction, and distribution of the Work otherwise complies with -the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, -any Contribution intentionally submitted for inclusion in the Work -by You to the Licensor shall be under the terms and conditions of -this License, without any additional terms or conditions. -Notwithstanding the above, nothing herein shall supersede or modify -the terms of any separate license agreement you may have executed -with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade -names, trademarks, service marks, or product names of the Licensor, -except as required for reasonable and customary use in describing the -origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or -agreed to in writing, Licensor provides the Work (and each -Contributor provides its Contributions) on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -implied, including, without limitation, any warranties or conditions -of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A -PARTICULAR PURPOSE. You are solely responsible for determining the -appropriateness of using or redistributing the Work and assume any -risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, -whether in tort (including negligence), contract, or otherwise, -unless required by applicable law (such as deliberate and grossly -negligent acts) or agreed to in writing, shall any Contributor be -liable to You for damages, including any direct, indirect, special, -incidental, or consequential damages of any character arising as a -result of this License or out of the use or inability to use the -Work (including but not limited to damages for loss of goodwill, -work stoppage, computer failure or malfunction, or any and all -other commercial damages or losses), even if such Contributor -has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing -the Work or Derivative Works thereof, You may choose to offer, -and charge a fee for, acceptance of support, warranty, indemnity, -or other liability obligations and/or rights consistent with this -License. However, in accepting such obligations, You may act only -on Your own behalf and on Your sole responsibility, not on behalf -of any other Contributor, and only if You agree to indemnify, -defend, and hold each Contributor harmless for any liability -incurred by, or claims asserted against, such Contributor by reason -of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - -To apply the Apache License to your work, attach the following -boilerplate notice, with the fields enclosed by brackets "[]" -replaced with your own identifying information. (Don't include -the brackets!) The text should be enclosed in the appropriate -comment syntax for the file format. We also recommend that a -file or class name and description of purpose be included on the -same "printed page" as the copyright notice for easier -identification within third-party archives. - -Copyright [2025] [SuanmoSuanyangTechnology] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file diff --git a/alembic.ini b/alembic.ini deleted file mode 100644 index 71a6553f..00000000 --- a/alembic.ini +++ /dev/null @@ -1,116 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = migrations - -# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s -# Uncomment the line below if you want the files to be prepended with date and time -# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file -# for all available tokens -# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s - -# sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. -prepend_sys_path = . - -# timezone to use when rendering the date within the migration file -# as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the -# "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; This defaults -# to migrations/versions. When using multiple version -# directories, initial revisions must be specified with --version-path. -# The path separator used here should be the separator specified by "version_path_separator" below. -# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions - -# version path separator; As mentioned above, this is the character used to split -# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. -# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. -# Valid values for version_path_separator are: -# -# version_path_separator = : -# version_path_separator = ; -# version_path_separator = space -version_path_separator = os # Use os.pathsep. Default configuration used for new projects. - -# set to 'true' to search source files recursively -# in each "version_locations" directory -# new in Alembic version 1.10 -# recursive_version_locations = false - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - -sqlalchemy.url = postgresql://user:password@localhost/dbname - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME - -# lint with attempts to fix using "ruff" - use the exec runner, execute a binary -# hooks = ruff -# ruff.type = exec -# ruff.executable = %(here)s/.venv/bin/ruff -# ruff.options = --fix REVISION_SCRIPT_FILENAME - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/app/aioRedis.py b/app/aioRedis.py deleted file mode 100644 index c729a3dc..00000000 --- a/app/aioRedis.py +++ /dev/null @@ -1,201 +0,0 @@ -import os -import asyncio -import json -import logging -from typing import Dict, Any, Optional -import redis.asyncio as redis -from redis.asyncio import ConnectionPool -from app.core.config import settings - -# 设置日志记录器 -logger = logging.getLogger(__name__) - - -# 创建连接池 -pool = ConnectionPool.from_url( - f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}", - db=settings.REDIS_DB, - password=settings.REDIS_PASSWORD, - decode_responses=True, - max_connections=30 -) -aio_redis = redis.StrictRedis(connection_pool=pool) - -async def get_redis_connection(): - """获取Redis连接""" - try: - return redis.StrictRedis(connection_pool=pool) - except Exception as e: - logger.error(f"Redis连接失败: {str(e)}") - return None - -async def aio_redis_set(key: str, val: str|dict, expire: int = None): - """设置Redis键值 - - Args: - key: Redis键 - val: 要存储的值(字符串或字典) - expire: 过期时间(秒),None表示永不过期 - """ - try: - if isinstance(val, dict): - val = json.dumps(val, ensure_ascii=False) - - if expire is not None: - # 设置带过期时间的键值 - await aio_redis.set(key, val, ex=expire) - else: - # 设置永久键值 - await aio_redis.set(key, val) - except Exception as e: - logger.error(f"Redis set错误: {str(e)}") - -async def aio_redis_get(key: str): - """获取Redis键值""" - try: - return await aio_redis.get(key) - except Exception as e: - logger.error(f"Redis get错误: {str(e)}") - return None - -async def aio_redis_delete(key: str): - """删除Redis键""" - try: - return await aio_redis.delete(key) - except Exception as e: - logger.error(f"Redis delete错误: {str(e)}") - return None - -async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool: - """发布消息到Redis频道""" - try: - conn = await get_redis_connection() - if not conn: - return False - await conn.publish(channel, json.dumps(message, ensure_ascii=False)) - return True - except Exception as e: - logger.error(f"Redis发布错误: {str(e)}") - return False - -class RedisSubscriber: - """Redis订阅器""" - - def __init__(self, channel: str): - self.channel = channel - self.conn = None - self.pubsub = None - self.is_closed = False - self._queue = asyncio.Queue() - self._task = None - - async def start(self): - """开始订阅""" - if self.is_closed or self._task: - return - - self._task = asyncio.create_task(self._receive_messages()) - logger.info(f"开始订阅: {self.channel}") - - async def _receive_messages(self): - """接收消息""" - try: - self.conn = await get_redis_connection() - if not self.conn: - return - - self.pubsub = self.conn.pubsub() - await self.pubsub.subscribe(self.channel) - - while not self.is_closed: - try: - message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01) - if message and isinstance(message.get("data"), str): - try: - await self._queue.put(json.loads(message["data"])) - except json.JSONDecodeError: - logger.warning(f"消息解析失败: {message['data']}") - await asyncio.sleep(0.01) - except Exception as e: - if "closed" in str(e).lower(): - break - logger.warning(f"接收消息错误: {str(e)}") - await asyncio.sleep(0.1) - except Exception as e: - logger.error(f"订阅错误: {str(e)}") - await self._queue.put({"type": "error", "data": {"message": str(e), "status": "error"}}) - finally: - await self._queue.put(None) - await self._cleanup() - - async def _cleanup(self): - """清理资源""" - if self.pubsub: - try: - await self.pubsub.unsubscribe(self.channel) - await self.pubsub.close() - except Exception: - pass - if self.conn: - try: - await self.conn.close() - except Exception: - pass - - async def get_message(self) -> Optional[Dict[str, Any]]: - """获取消息""" - if self.is_closed: - return None - if not self._task: - await self.start() - try: - return await self._queue.get() - except Exception as e: - logger.error(f"获取消息错误: {str(e)}") - return None - - async def close(self): - """关闭订阅器""" - if self.is_closed: - return - self.is_closed = True - if self._task: - self._task.cancel() - await self._cleanup() - -class RedisPubSubManager: - """Redis发布订阅管理器""" - - def __init__(self): - self.subscribers = {} - - async def publish(self, channel: str, message: Dict[str, Any]) -> bool: - return await aio_redis_publish(channel, message) - - def get_subscriber(self, channel: str) -> RedisSubscriber: - if channel in self.subscribers: - subscriber = self.subscribers[channel] - if not subscriber.is_closed: - return subscriber - - subscriber = RedisSubscriber(channel) - self.subscribers[channel] = subscriber - return subscriber - - def cancel_subscription(self, channel: str) -> bool: - if channel in self.subscribers: - asyncio.create_task(self.subscribers[channel].close()) - del self.subscribers[channel] - return True - return False - - def cancel_all_subscriptions(self) -> int: - count = len(self.subscribers) - for subscriber in self.subscribers.values(): - asyncio.create_task(subscriber.close()) - self.subscribers.clear() - return count - -# 全局实例 -pubsub_manager = RedisPubSubManager() - diff --git a/app/celery_app.py b/app/celery_app.py deleted file mode 100644 index 8e1b4d5d..00000000 --- a/app/celery_app.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -from datetime import timedelta -from urllib.parse import quote -from celery import Celery -from app.core.config import settings - -# 创建 Celery 应用实例 -# broker: 任务队列(使用 Redis DB 0) -# backend: 结果存储(使用 Redis DB 10) -celery_app = Celery( - "redbear_tasks", - broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}", - backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", -) - -# 配置使用本地队列,避免与远程 worker 冲突 -celery_app.conf.task_default_queue = 'localhost_test_wyl' -celery_app.conf.task_default_exchange = 'localhost_test_wyl' -celery_app.conf.task_default_routing_key = 'localhost_test_wyl' - -# macOS 兼容性配置 -import platform -if platform.system() == 'Darwin': # macOS - # 设置环境变量解决 fork 问题 - os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') - - # 使用 solo 池避免多进程问题 - celery_app.conf.worker_pool = 'solo' - - # 设置唯一的节点名称 - import socket - import time - hostname = socket.gethostname() - timestamp = int(time.time()) - celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}" - -# Celery 配置 -celery_app.conf.update( - # 序列化 - task_serializer='json', - accept_content=['json'], - result_serializer='json', - - # 时区 - timezone='Asia/Shanghai', - enable_utc=True, - - # 任务追踪 - task_track_started=True, - task_ignore_result=False, - - # 超时设置 - task_time_limit=30 * 60, # 30 分钟硬超时 - task_soft_time_limit=25 * 60, # 25 分钟软超时 - - # Worker 设置 - 针对 macOS 优化 - worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积 - worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏 - worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker - - # 结果过期时间 - result_expires=3600, # 结果保存 1 小时 - - # 任务确认设置 - task_acks_late=True, # 任务完成后才确认,避免任务丢失 - worker_disable_rate_limits=True, # 禁用速率限制 - - # 任务路由(可选,用于不同队列) - # task_routes={ - # 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'}, - # 'app.core.memory.agent.read_message': {'queue': 'memory_processing'}, - # 'app.core.memory.agent.write_message': {'queue': 'memory_processing'}, - # 'tasks.process_item': {'queue': 'default'}, - # }, -) - -# 自动发现任务模块 -celery_app.autodiscover_tasks(['app']) - -# Celery Beat schedule for periodic tasks -reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) -health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) -memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) - -# 构建定时任务配置 -beat_schedule_config = { - "run-reflection-engine": { - "task": "app.core.memory.agent.reflection.timer", - "schedule": reflection_schedule, - "args": (), - }, - "check-read-service": { - "task": "app.core.memory.agent.health.check_read_service", - "schedule": health_schedule, - "args": (), - }, -} - -# 如果配置了默认工作空间ID,则添加记忆总量统计任务 -if settings.DEFAULT_WORKSPACE_ID: - beat_schedule_config["write-total-memory"] = { - "task": "app.controllers.memory_storage_controller.search_all", - "schedule": memory_increment_schedule, - "kwargs": { - "workspace_id": settings.DEFAULT_WORKSPACE_ID, - }, - } - -celery_app.conf.beat_schedule = beat_schedule_config diff --git a/app/celery_worker.py b/app/celery_worker.py deleted file mode 100644 index baecdb3d..00000000 --- a/app/celery_worker.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -Celery Worker 入口点 -用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info -""" -from app.celery_app import celery_app - -# 导入任务模块以注册任务 -import app.tasks - -__all__ = ['celery_app'] \ No newline at end of file diff --git a/app/controllers/__init__.py b/app/controllers/__init__.py deleted file mode 100644 index 951f2d73..00000000 --- a/app/controllers/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -"""管理端接口 - 基于 JWT 认证 - -路由前缀: / -认证方式: JWT Token -""" -from fastapi import APIRouter -from . import ( - model_controller, - task_controller, - test_controller, - user_controller, - auth_controller, - workspace_controller, - setup_controller, - file_controller, - document_controller, - knowledge_controller, - chunk_controller, - knowledgeshare_controller, - app_controller, - upload_controller, - memory_agent_controller, - memory_dashboard_controller, - memory_storage_controller, - memory_dashboard_controller, - api_key_controller, - release_share_controller, - public_share_controller, - multi_agent_controller, -) - -# 创建管理端 API 路由器 -manager_router = APIRouter() - -# 注册所有管理端路由 -manager_router.include_router(task_controller.router) -manager_router.include_router(user_controller.router) -manager_router.include_router(auth_controller.router) -manager_router.include_router(workspace_controller.router) -manager_router.include_router(workspace_controller.public_router) # 公开路由(无需认证) -manager_router.include_router(setup_controller.router) -manager_router.include_router(model_controller.router) -manager_router.include_router(file_controller.router) -manager_router.include_router(document_controller.router) -manager_router.include_router(knowledge_controller.router) -manager_router.include_router(chunk_controller.router) -manager_router.include_router(test_controller.router) -manager_router.include_router(knowledgeshare_controller.router) -manager_router.include_router(app_controller.router) -manager_router.include_router(upload_controller.router) -manager_router.include_router(memory_agent_controller.router) -manager_router.include_router(memory_dashboard_controller.router) -manager_router.include_router(memory_storage_controller.router) -manager_router.include_router(api_key_controller.router) -manager_router.include_router(release_share_controller.router) -manager_router.include_router(public_share_controller.router) # 公开路由(无需认证) -manager_router.include_router(memory_dashboard_controller.router) -manager_router.include_router(multi_agent_controller.router) - -__all__ = ["manager_router"] diff --git a/app/controllers/api_key_controller.py b/app/controllers/api_key_controller.py deleted file mode 100644 index 3948115e..00000000 --- a/app/controllers/api_key_controller.py +++ /dev/null @@ -1,151 +0,0 @@ -"""API Key 管理接口 - 基于 JWT 认证""" -from fastapi import APIRouter, Depends, Query -from sqlalchemy.orm import Session -import uuid - -from app.db import get_db -from app.dependencies import get_current_user, cur_workspace_access_guard -from app.models.user_model import User -from app.core.response_utils import success -from app.schemas import api_key_schema -from app.schemas.response_schema import ApiResponse -from app.services.api_key_service import ApiKeyService -from app.core.logging_config import get_business_logger - -router = APIRouter(prefix="/apikeys", tags=["API Keys"]) -logger = get_business_logger() - - -@router.post("", response_model=ApiResponse) -@cur_workspace_access_guard() -def create_api_key( - data: api_key_schema.ApiKeyCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """创建 API Key - - - 创建后返回明文 API Key(仅此一次) - - 支持设置权限范围、速率限制、配额等 - """ - workspace_id = current_user.current_workspace_id - api_key_obj, api_key = ApiKeyService.create_api_key( - db, - workspace_id=workspace_id, - user_id=current_user.id, - data=data - ) - - # 返回包含明文 Key 的响应 - response_data = api_key_schema.ApiKeyResponse( - **api_key_obj.__dict__, - api_key=api_key - ) - - return success(data=response_data, msg="API Key 创建成功") - - -@router.get("", response_model=ApiResponse) -@cur_workspace_access_guard() -def list_api_keys( - type: api_key_schema.ApiKeyType = Query(None), - is_active: bool = Query(None), - resource_id: uuid.UUID = Query(None), - page: int = Query(1, ge=1), - pagesize: int = Query(10, ge=1, le=100), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """列出 API Keys""" - workspace_id = current_user.current_workspace_id - query = api_key_schema.ApiKeyQuery( - type=type, - is_active=is_active, - resource_id=resource_id, - page=page, - pagesize=pagesize - ) - - result = ApiKeyService.list_api_keys(db, workspace_id, query) - return success(data=result) - - -@router.get("/{api_key_id}", response_model=ApiResponse) -@cur_workspace_access_guard() -def get_api_key( - api_key_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """获取 API Key 详情""" - workspace_id = current_user.current_workspace_id - api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - - return success(data=api_key_schema.ApiKey.model_validate(api_key)) - - -@router.put("/{api_key_id}", response_model=ApiResponse) -@cur_workspace_access_guard() -def update_api_key( - api_key_id: uuid.UUID, - data: api_key_schema.ApiKeyUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """更新 API Key""" - workspace_id = current_user.current_workspace_id - api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data) - - return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功") - - -@router.delete("/{api_key_id}", response_model=ApiResponse) -@cur_workspace_access_guard() -def delete_api_key( - api_key_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """删除 API Key""" - workspace_id = current_user.current_workspace_id - ApiKeyService.delete_api_key(db, api_key_id, workspace_id) - - return success(msg="API Key 删除成功") - - -@router.post("/{api_key_id}/regenerate", response_model=ApiResponse) -@cur_workspace_access_guard() -def regenerate_api_key( - api_key_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """重新生成 API Key - - - 生成新的 API Key 并返回明文(仅此一次) - - 旧的 API Key 立即失效 - """ - workspace_id = current_user.current_workspace_id - api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id) - - # 返回包含明文 Key 的响应 - response_data = api_key_schema.ApiKeyResponse( - **api_key_obj.__dict__, - api_key=api_key - ) - - return success(data=response_data, msg="API Key 重新生成成功") - - -@router.get("/{api_key_id}/stats", response_model=ApiResponse) -@cur_workspace_access_guard() -def get_api_key_stats( - api_key_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """获取 API Key 使用统计""" - workspace_id = current_user.current_workspace_id - stats = ApiKeyService.get_stats(db, api_key_id, workspace_id) - - return success(data=stats) diff --git a/app/controllers/app_controller.py b/app/controllers/app_controller.py deleted file mode 100644 index 90783647..00000000 --- a/app/controllers/app_controller.py +++ /dev/null @@ -1,716 +0,0 @@ -import uuid -from typing import Optional -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -from app.db import get_db -from app.core.response_utils import success -from app.core.logging_config import get_business_logger -from app.models import User -from app.repositories import knowledge_repository -from app.schemas import app_schema -from app.schemas.response_schema import PageData, PageMeta -from app.services import app_service, workspace_service -from app.services.app_service import AppService -from app.services.agent_config_helper import enrich_agent_config -from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard -from fastapi.responses import StreamingResponse -from app.models.app_model import AppType -from app.core.error_codes import BizCode - -router = APIRouter(prefix="/apps", tags=["Apps"]) -logger = get_business_logger() - - -@router.post("", summary="创建应用(可选创建 Agent 配置)") -@cur_workspace_access_guard() -def create_app( - payload: app_schema.AppCreate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload) - return success(data=app_schema.App.model_validate(app)) - - -@router.get("", summary="应用列表(分页)") -@cur_workspace_access_guard() -def list_apps( - type: str | None = None, - visibility: str | None = None, - status: str | None = None, - search: str | None = None, - include_shared: bool = True, - page: int = 1, - pagesize: int = 10, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """列出应用 - - - 默认包含本工作空间的应用和分享给本工作空间的应用 - - 设置 include_shared=false 可以只查看本工作空间的应用 - """ - workspace_id = current_user.current_workspace_id - items_orm, total = app_service.list_apps( - db, - workspace_id=workspace_id, - type=type, - visibility=visibility, - status=status, - search=search, - include_shared=include_shared, - page=page, - pagesize=pagesize, - ) - - # 使用 AppService 的转换方法来设置 is_shared 字段 - service = app_service.AppService(db) - items = [service._convert_to_schema(app, workspace_id) for app in items_orm] - meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) - return success(data=PageData(page=meta, items=items)) - -@router.get("/{app_id}", summary="获取应用详情") -@cur_workspace_access_guard() -def get_app( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """获取应用详细信息 - - - 支持获取本工作空间的应用 - - 支持获取分享给本工作空间的应用 - """ - workspace_id = current_user.current_workspace_id - service = app_service.AppService(db) - app = service.get_app(app_id, workspace_id) - - # 转换为 Schema 并设置 is_shared 字段 - app_schema_obj = service._convert_to_schema(app, workspace_id) - return success(data=app_schema_obj) - - -@router.put("/{app_id}", summary="更新应用基本信息") -@cur_workspace_access_guard() -def update_app( - app_id: uuid.UUID, - payload: app_schema.AppUpdate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id) - return success(data=app_schema.App.model_validate(app)) - - -@router.delete("/{app_id}", summary="删除应用") -@cur_workspace_access_guard() -def delete_app( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """删除应用 - - 会级联删除: - - Agent 配置 - - 发布版本 - - 会话和消息 - """ - workspace_id = current_user.current_workspace_id - logger.info( - f"用户请求删除应用", - extra={ - "app_id": str(app_id), - "user_id": str(current_user.id), - "workspace_id": str(workspace_id) - } - ) - - app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id) - - return success(msg="应用删除成功") - - -@router.post("/{app_id}/copy", summary="复制应用") -@cur_workspace_access_guard() -def copy_app( - app_id: uuid.UUID, - new_name: Optional[str] = None, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """复制应用(包括基础信息和配置) - - - 复制应用的基础信息(名称、描述、图标等) - - 复制 Agent 配置(如果是 agent 类型) - - 新应用默认为草稿状态 - - 不影响原应用 - """ - workspace_id = current_user.current_workspace_id - logger.info( - f"用户请求复制应用", - extra={ - "source_app_id": str(app_id), - "user_id": str(current_user.id), - "workspace_id": str(workspace_id), - "new_name": new_name - } - ) - - service = AppService(db) - new_app = service.copy_app( - app_id=app_id, - user_id=current_user.id, - workspace_id=workspace_id, - new_name=new_name - ) - - return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功") - - -@router.put("/{app_id}/config", summary="更新 Agent 配置") -@cur_workspace_access_guard() -def update_agent_config( - app_id: uuid.UUID, - payload: app_schema.AgentConfigUpdate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) - cfg = enrich_agent_config(cfg) - return success(data=app_schema.AgentConfig.model_validate(cfg)) - - -@router.get("/{app_id}/config", summary="获取 Agent 配置") -@cur_workspace_access_guard() -def get_agent_config( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) - # 配置总是存在(不存在时返回默认模板) - cfg = enrich_agent_config(cfg) - return success(data=app_schema.AgentConfig.model_validate(cfg)) - - -@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)") -@cur_workspace_access_guard() -def publish_app( - app_id: uuid.UUID, - payload: app_schema.PublishRequest, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - release = app_service.publish( - db, - app_id=app_id, - publisher_id=current_user.id, - workspace_id=workspace_id, - version_name = payload.version_name, - release_notes=payload.release_notes - ) - return success(data=app_schema.AppRelease.model_validate(release)) - - -@router.get("/{app_id}/release", summary="获取当前发布版本") -@cur_workspace_access_guard() -def get_current_release( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id) - if not release: - return success(data=None) - return success(data=app_schema.AppRelease.model_validate(release)) - - -@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)") -@cur_workspace_access_guard() -def list_releases( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id) - data = [app_schema.AppRelease.model_validate(r) for r in releases] - return success(data=data) - - -@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本") -@cur_workspace_access_guard() -def rollback( - app_id: uuid.UUID, - version: int, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id) - return success(data=app_schema.AppRelease.model_validate(release)) - - -@router.post("/{app_id}/share", summary="分享应用到其他工作空间") -@cur_workspace_access_guard() -def share_app( - app_id: uuid.UUID, - payload: app_schema.AppShareCreate, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """分享应用到其他工作空间 - - - 只能分享自己工作空间的应用 - - 不能分享到自己的工作空间 - - 同一个应用不能重复分享到同一个工作空间 - """ - workspace_id = current_user.current_workspace_id - - service = app_service.AppService(db) - shares = service.share_app( - app_id=app_id, - target_workspace_ids=payload.target_workspace_ids, - user_id=current_user.id, - workspace_id=workspace_id - ) - - data = [app_schema.AppShare.model_validate(s) for s in shares] - return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间") - - -@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享") -@cur_workspace_access_guard() -def unshare_app( - app_id: uuid.UUID, - target_workspace_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """取消应用分享 - - - 只能取消自己工作空间应用的分享 - """ - workspace_id = current_user.current_workspace_id - - service = app_service.AppService(db) - service.unshare_app( - app_id=app_id, - target_workspace_id=target_workspace_id, - workspace_id=workspace_id - ) - - return success(msg="应用分享已取消") - - -@router.get("/{app_id}/shares", summary="列出应用的分享记录") -@cur_workspace_access_guard() -def list_app_shares( - app_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """列出应用的所有分享记录 - - - 只能查看自己工作空间应用的分享记录 - """ - workspace_id = current_user.current_workspace_id - - service = app_service.AppService(db) - shares = service.list_app_shares( - app_id=app_id, - workspace_id=workspace_id - ) - - data = [app_schema.AppShare.model_validate(s) for s in shares] - return success(data=data) - -@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)") -@cur_workspace_access_guard() -async def draft_run( - app_id: uuid.UUID, - payload: app_schema.DraftRunRequest, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """ - 试运行 Agent,使用当前的草稿配置(未发布的配置) - - - 不需要发布应用即可测试 - - 使用当前的 AgentConfig 配置 - - 支持流式和非流式返回 - """ - workspace_id = current_user.current_workspace_id - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - - - # 提前验证和准备(在流式响应开始前完成) - from app.services.app_service import AppService - from app.services.multi_agent_service import MultiAgentService - from app.models import AgentConfig, ModelConfig - from sqlalchemy import select - from app.core.exceptions import BusinessException - - - service = AppService(db) - - # 1. 验证应用 - app = service._get_app_or_404(app_id) - if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT: - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - service._validate_app_accessible(app, workspace_id) - if app.type == AppType.AGENT: - service._check_agent_config(app_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - model_config = db.get(ModelConfig, agent_cfg.default_model_config_id) - if not model_config: - from app.core.exceptions import ResourceNotFoundException - raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id)) - - # 流式返回 - if payload.stream: - async def event_generator(): - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) - async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=payload.message, - workspace_id=workspace_id, - conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), - variables=payload.variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ): - yield event - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" - } - ) - - # 非流式返回 - logger.debug( - f"开始非流式试运行", - extra={ - "app_id": str(app_id), - "message_length": len(payload.message), - "has_conversation_id": bool(payload.conversation_id), - "has_variables": bool(payload.variables) - } - ) - - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) - result = await draft_service.run( - agent_config=agent_cfg, - model_config=model_config, - message=payload.message, - workspace_id=workspace_id, - conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), - variables=payload.variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - logger.debug( - f"试运行返回结果", - extra={ - "result_type": str(type(result)), - "result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict" - } - ) - - # 验证结果 - try: - validated_result = app_schema.DraftRunResponse.model_validate(result) - logger.debug(f"结果验证成功") - return success(data=validated_result) - except Exception as e: - logger.error( - f"结果验证失败", - extra={ - "error": str(e), - "error_type": str(type(e)), - "result": str(result)[:200] - } - ) - raise - elif app.type == AppType.MULTI_AGENT: - # 1. 检查多智能体配置完整性 - service._check_multi_agent_config(app_id) - - # 2. 构建多智能体运行请求 - from app.schemas.multi_agent_schema import MultiAgentRunRequest - - multi_agent_request = MultiAgentRunRequest( - message=payload.message, - conversation_id=payload.conversation_id, - user_id=payload.user_id, - variables=payload.variables or {}, - use_llm_routing=True # 默认启用 LLM 路由 - ) - - # 3. 流式返回 - if payload.stream: - logger.debug( - f"开始多智能体流式试运行", - extra={ - "app_id": str(app_id), - "message_length": len(payload.message), - "has_conversation_id": bool(payload.conversation_id) - } - ) - - async def event_generator(): - """多智能体流式事件生成器""" - multiservice = MultiAgentService(db) - - # 调用多智能体服务的流式方法 - async for event in multiservice.run_stream( - app_id=app_id, - request=multi_agent_request, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - - ): - yield event - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" - } - ) - - # 4. 非流式返回 - logger.debug( - f"开始多智能体非流式试运行", - extra={ - "app_id": str(app_id), - "message_length": len(payload.message), - "has_conversation_id": bool(payload.conversation_id) - } - ) - - multiservice = MultiAgentService(db) - result = await multiservice.run(app_id, multi_agent_request) - - logger.debug( - f"多智能体试运行返回结果", - extra={ - "result_type": str(type(result)), - "has_response": "response" in result if isinstance(result, dict) else False - } - ) - - return success( - data=result, - msg="多 Agent 任务执行成功" - ) - - - - -@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行") -@cur_workspace_access_guard() -async def draft_run_compare( - app_id: uuid.UUID, - payload: app_schema.DraftRunCompareRequest, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """ - 多模型对比试运行 - - - 支持对比 1-5 个模型 - - 可以是不同的模型,也可以是同一模型的不同参数配置 - - 通过 model_parameters 覆盖默认参数 - - 支持并行或串行执行(非流式) - - 支持流式返回(串行执行) - - 返回每个模型的运行结果和性能对比 - - 使用场景: - 1. 对比不同模型的效果(GPT-4 vs Claude vs Gemini) - 2. 调优模型参数(不同 temperature 的效果对比) - 3. 性能和成本分析 - """ - workspace_id = current_user.current_workspace_id - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - - logger.info( - f"多模型对比试运行", - extra={ - "app_id": str(app_id), - "model_count": len(payload.models), - "parallel": payload.parallel, - "stream": payload.stream - } - ) - - # 提前验证和准备(在流式响应开始前完成) - from app.services.app_service import AppService - from app.models import ModelConfig - - service = AppService(db) - - # 1. 验证应用和权限 - app = service._get_app_or_404(app_id) - if app.type != "agent": - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - service._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - from sqlalchemy import select - from app.models import AgentConfig - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = db.scalars(stmt).first() - if not agent_cfg: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 验证所有模型配置 - model_configs = [] - for model_item in payload.models: - model_config = db.get(ModelConfig, model_item.model_config_id) - if not model_config: - from app.core.exceptions import ResourceNotFoundException - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id, - "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id - }) - - # 流式返回 - if payload.stream: - async def event_generator(): - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) - async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=payload.message, - workspace_id=workspace_id, - conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), - variables=payload.variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - web_search=True, - memory=True, - parallel=payload.parallel, - timeout=payload.timeout or 60 - ): - yield event - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" - } - ) - - # 非流式返回 - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) - result = await draft_service.run_compare( - agent_config=agent_cfg, - models=model_configs, - message=payload.message, - workspace_id=workspace_id, - conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), - variables=payload.variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - web_search=True, - memory=True, - parallel=payload.parallel, - timeout=payload.timeout or 60 - ) - - logger.info( - f"多模型对比完成", - extra={ - "app_id": str(app_id), - "successful": result["successful_count"], - "failed": result["failed_count"] - } - ) - - return success(data=app_schema.DraftRunCompareResponse(**result)) diff --git a/app/controllers/auth_controller.py b/app/controllers/auth_controller.py deleted file mode 100644 index a6960096..00000000 --- a/app/controllers/auth_controller.py +++ /dev/null @@ -1,195 +0,0 @@ -from datetime import datetime, timedelta, timezone -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -from app.core.response_utils import success -from app.db import get_db -from app.schemas.response_schema import ApiResponse -from app.schemas.token_schema import Token, RefreshTokenRequest, TokenRequest -from app.schemas.workspace_schema import InviteAcceptRequest -from app.services import auth_service, user_service, workspace_service -from app.core import security -from app.core.config import settings -from app.services.session_service import SessionService -from app.core.logging_config import get_auth_logger, get_security_logger -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode -from app.dependencies import get_current_user, oauth2_scheme -from app.models.user_model import User - -# 获取专用日志器 -auth_logger = get_auth_logger() -security_logger = get_security_logger() - -router = APIRouter(tags=["Authentication"]) - -@router.post("/token", response_model=ApiResponse) -async def login_for_access_token( - form_data: TokenRequest, - db: Session = Depends(get_db) -): - """用户登录获取token""" - auth_logger.info(f"用户登录请求: {form_data.email}") - - # 验证邀请码(如果提供) - invite_info = None - # 验证用户凭据或注册新用户 - user = None - if form_data.invite: - auth_logger.info(f"检测到邀请码: {form_data.invite[:8]}...") - invite_info = workspace_service.validate_invite_token(db, form_data.invite) - - if not invite_info.is_valid: - raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST) - - if invite_info.email != form_data.email: - raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST) - auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}") - try: - # 尝试认证用户 - user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password) - auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})") - if form_data.invite: - auth_service.bind_workspace_with_invite(db=db, - user=user, - invite_token=form_data.invite, - workspace_id=invite_info.workspace_id) - except BusinessException as e: - # 用户不存在且有邀请码,尝试注册 - if e.code == BizCode.USER_NOT_FOUND: - auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}") - user = auth_service.register_user_with_invite( - db=db, - email=form_data.email, - password=form_data.password, - invite_token=form_data.invite, - workspace_id=invite_info.workspace_id - ) - elif e.code == BizCode.PASSWORD_ERROR: - # 用户存在但密码错误 - auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}") - raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED) - else: - # 其他认证失败情况,直接抛出 - raise - else: - try: - # 尝试认证用户 - user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password) - auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})") - - except BusinessException as e: - - # 其他认证失败情况,直接抛出 - raise BusinessException(e.message,BizCode.LOGIN_FAILED) - - # 创建 tokens - access_token, access_token_id = security.create_access_token(subject=user.id) - refresh_token, refresh_token_id = security.create_refresh_token(subject=user.id) - - # 计算过期时间 - access_expires_at = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) - - # 单点登录会话管理 - if settings.ENABLE_SINGLE_SESSION: - await SessionService.invalidate_old_session(user.id, access_token_id) - await SessionService.set_user_active_session(user.id, access_token_id, access_expires_at) - - # 更新最后登录时间 - user_service.update_last_login_time(db, user.id) - - auth_logger.info(f"用户 {user.username} 登录成功") - - return success( - data=Token( - access_token=access_token, - refresh_token=refresh_token, - token_type="bearer", - expires_at=access_expires_at, - refresh_expires_at=refresh_expires_at - ), - msg="登录成功" - ) - - -@router.post("/refresh", response_model=ApiResponse) -async def refresh_token( - refresh_request: RefreshTokenRequest, - db: Session = Depends(get_db) -): - """刷新token""" - auth_logger.info("收到token刷新请求") - - # 验证 refresh token - userId = security.verify_token(refresh_request.refresh_token, "refresh") - if not userId: - raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID) - - # 检查用户是否存在 - user = auth_service.get_user_by_id(db, userId) - if not user: - raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) - - # 检查 refresh token 黑名单 - if settings.ENABLE_SINGLE_SESSION: - refresh_token_id = security.get_token_id(refresh_request.refresh_token) - if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id): - raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED) - - # 生成新 tokens - new_access_token, new_access_token_id = security.create_access_token(subject=user.id) - new_refresh_token, new_refresh_token_id = security.create_refresh_token(subject=user.id) - - # 计算过期时间 - access_expires_at = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - refresh_expires_at = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) - - # 单点登录会话管理 - if settings.ENABLE_SINGLE_SESSION: - # 将旧 refresh token 加入黑名单 - old_refresh_token_id = security.get_token_id(refresh_request.refresh_token) - if old_refresh_token_id: - await SessionService.blacklist_token(old_refresh_token_id) - - # 更新会话 - await SessionService.invalidate_old_session(user.id, new_access_token_id) - await SessionService.set_user_active_session(user.id, new_access_token_id, access_expires_at) - - auth_logger.info(f"用户 {user.id} token刷新成功") - - return success( - data=Token( - access_token=new_access_token, - refresh_token=new_refresh_token, - token_type="bearer", - expires_at=access_expires_at, - refresh_expires_at=refresh_expires_at - ), - msg="token刷新成功" - ) - - -@router.post("/logout", response_model=ApiResponse) -async def logout( - token: str = Depends(oauth2_scheme), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -): - """登出当前用户:加入token黑名单并清理会话""" - auth_logger.info(f"用户 {current_user.username} 请求登出") - - token_id = security.get_token_id(token) - if not token_id: - raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID) - - # 加入黑名单 - await SessionService.blacklist_token(token_id) - - # 清理会话 - if settings.ENABLE_SINGLE_SESSION: - await SessionService.clear_user_session(current_user.username) - - auth_logger.info(f"用户 {current_user.username} 登出成功") - return success(msg="登出成功") - diff --git a/app/controllers/chunk_controller.py b/app/controllers/chunk_controller.py deleted file mode 100644 index 9942eed0..00000000 --- a/app/controllers/chunk_controller.py +++ /dev/null @@ -1,447 +0,0 @@ -import os -from typing import Any, Optional -import uuid -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session -from sqlalchemy import func - -from app.core.config import settings -from app.db import get_db -from app.core.rag.llm.cv_model import QWenCV -from app.dependencies import get_current_user -from app.models.user_model import User -from app.models.document_model import Document -from app.models import knowledge_model, knowledgeshare_model -from app.core.rag.models.chunk import DocumentChunk -from app.schemas import chunk_schema -from app.schemas.response_schema import ApiResponse -from app.core.response_utils import success -from app.services import knowledge_service, document_service, file_service, knowledgeshare_service -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.logging_config import get_api_logger - -# Obtain a dedicated API logger -api_logger = get_api_logger() - -router = APIRouter( - prefix="/chunks", - tags=["chunks"], - dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller -) - - -@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse) -async def get_preview_chunks( - kb_id: uuid.UUID, - document_id: uuid.UUID, - page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 - pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items - keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Paged query document block preview list - - Support filtering by document_id - - Support keyword search for segmented content - - Return paging metadata + file list - """ - api_logger.info(f"Paged query document block preview list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}") - # 1. parameter validation - if page < 1 or pagesize < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The paging parameter must be greater than 0" - ) - - # 2. Obtain knowledge base information - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - # 3. Check if the document exists - db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user) - if not db_document: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document does not exist or you do not have permission to access it" - ) - - # 4. Check if the file exists - db_file = file_service.get_file_by_id(db, file_id=db_document.file_id) - if not db_file: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The file does not exist or you do not have permission to access it" - ) - - # 5. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} - file_path = os.path.join( - settings.FILE_PATH, - str(db_file.kb_id), - str(db_file.parent_id), - f"{db_file.id}{db_file.file_ext}" - ) - - # 6. Check if the file exists - if not os.path.exists(file_path): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="File not found (possibly deleted)" - ) - - # 7. Document parsing & segmentation - def progress_callback(prog=None, msg=None): - print(f"prog: {prog} msg: {msg}\n") - # Prepare to configure vision_model information - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", # Default to Chinese - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - from app.core.rag.app.naive import chunk - res = chunk(filename=file_path, - from_page=0, - to_page=5, - callback=progress_callback, - vision_model=vision_model, - parser_config=db_document.parser_config, - is_root=False) - - start_index = (page - 1) * pagesize - end_index = start_index + pagesize - # Use slicing to obtain the data of the current page - paginated_chunk_str_list = res[start_index:end_index] - chunks = [] - for idx, item in enumerate(paginated_chunk_str_list): - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": idx, - "status": 1, - } - chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) - - # 8. Return structured response - total = len(res) - result = { - "items": chunks, - "page": { - "page": page, - "pagesize": pagesize, - "total": total, - "has_next": True if page * pagesize < total else False - } - } - api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records") - return success(data=result, msg="Querying the document block preview list succeeded") - - -@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse) -async def get_chunks( - kb_id: uuid.UUID, - document_id: uuid.UUID, - page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 - pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items - keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Paged query document chunk list - - Support filtering by document_id - - Support keyword search for segmented content - - Return paging metadata + file list - """ - api_logger.info(f"Paged query document chunk list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}") - # 1. parameter validation - if page < 1 or pagesize < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The paging parameter must be greater than 0" - ) - - # 2. Obtain knowledge base information - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - - # 3. Execute paged query - try: - api_logger.debug(f"Start executing document chunk query") - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True) - api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records") - except Exception as e: - api_logger.error(f"Document chunk query failed: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Query failed: {str(e)}" - ) - - # 4. Return structured response - result = { - "items": items, - "page": { - "page": page, - "pagesize": pagesize, - "total": total, - "has_next": True if page * pagesize < total else False - } - } - return success(data=result, msg="Query of document chunk list succeeded") - - -@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse) -async def create_chunk( - kb_id: uuid.UUID, - document_id: uuid.UUID, - create_data: chunk_schema.ChunkCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - create chunk - """ - api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}") - - # 1. Obtain knowledge base information - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - # 1. Obtain document information - db_document = db.query(Document).filter(Document.id == document_id).first() - if not db_document: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document does not exist or you do not have permission to access it" - ) - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - - # 2. Get the sort ID - sort_id = 0 - total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False) - if items: - sort_id = items[0].metadata["sort_id"] - sort_id = sort_id + 1 - - doc_id = uuid.uuid4().hex - metadata = { - "doc_id": doc_id, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(document_id), - "knowledge_id": str(kb_id), - "sort_id": sort_id, - "status": 1, - } - chunk = DocumentChunk(page_content=create_data.content, metadata=metadata) - # 3. Segmented vector storage - vector_service.add_chunks([chunk]) - - # 4.update chunk_num - db_document.chunk_num += 1 - db.commit() - - return success(data=chunk, msg="Document chunk creation successful") - - -@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) -async def get_chunk( - kb_id: uuid.UUID, - document_id: uuid.UUID, - doc_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Retrieve document chunk information based on doc_id - """ - api_logger.info(f"Obtain document chunk information: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}") - - # 1. Obtain knowledge base information - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - total, items = vector_service.get_by_segment(doc_id=doc_id) - if total: - return success(data=items[0], msg="Document chunk query successful") - else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document chunk does not exist or you do not have access" - ) - - -@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) -async def update_chunk( - kb_id: uuid.UUID, - document_id: uuid.UUID, - doc_id: str, - update_data: chunk_schema.ChunkUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Update document chunk content - """ - api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={update_data.content}, username: {current_user.username}") - - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - total, items = vector_service.get_by_segment(doc_id=doc_id) - if total: - chunk = items[0] - chunk.page_content = update_data.content - vector_service.update_by_segment(chunk) - return success(data=chunk, msg="The document chunk has been successfully updated") - else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document chunk does not exist or you do not have access to it" - ) - - -@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse) -async def delete_chunk( - kb_id: uuid.UUID, - document_id: uuid.UUID, - doc_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - delete document chunk - """ - api_logger.info(f"Request to delete document chunk: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}") - - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - if vector_service.text_exists(doc_id): - vector_service.delete_by_ids([doc_id]) - # 更新 chunk_num - db_document = db.query(Document).filter(Document.id == document_id).first() - db_document.chunk_num -= 1 - db.commit() - return success(msg="The document chunk has been successfully deleted") - else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document chunk does not exist or you do not have access to it" - ) - - -@router.get("/retrieve_type", response_model=ApiResponse) -def get_retrieve_types(): - return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType)) - - -@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK) -async def retrieve_chunks( - retrieve_data: chunk_schema.ChunkRetrieve, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - retrieve chunk - """ - api_logger.info(f"retrieve chunk: query={retrieve_data.query}, username: {current_user.username}") - - filters = [ - knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 - ] - existing_ids = knowledge_service.get_chunded_knowledgeids( - db=db, - filters=filters, - current_user=current_user - ) - filters = [ - knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 - ] - share_ids = knowledge_service.get_chunded_knowledgeids( - db=db, - filters=filters, - current_user=current_user - ) - if share_ids: - filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids) - ] - items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters, - current_user=current_user - ) - existing_ids.extend(items) - if not existing_ids: - return success(data=[], msg="retrieval successful") - kb_id = existing_ids[0] - uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] - indices = ",".join(uuid_strs) - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - - # 1 participle search, 2 semantic search, 3 hybrid search - match retrieve_data.retrieve_type: - case chunk_schema.RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) - return success(data=rs, msg="retrieval successful") - case chunk_schema.RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) - return success(data=rs, msg="retrieval successful") - case _: - rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) - # Efficient deduplication - seen_ids = set() - unique_rs = [] - for doc in rs1 + rs2: - if doc.metadata["doc_id"] not in seen_ids: - seen_ids.add(doc.metadata["doc_id"]) - unique_rs.append(doc) - rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) - return success(data=rs, msg="retrieval successful") \ No newline at end of file diff --git a/app/controllers/document_controller.py b/app/controllers/document_controller.py deleted file mode 100644 index 651e700d..00000000 --- a/app/controllers/document_controller.py +++ /dev/null @@ -1,341 +0,0 @@ -import os -from typing import Optional -import datetime -import uuid -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session - -from app.core.config import settings -from app.db import get_db -from app.dependencies import get_current_user -from app.models.user_model import User -from app.models import document_model -from app.schemas import document_schema -from app.schemas.response_schema import ApiResponse -from app.core.response_utils import success -from app.services import document_service, file_service, knowledge_service -from app.controllers import file_controller -from app.celery_app import celery_app -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.logging_config import get_api_logger - -# Obtain a dedicated API logger -api_logger = get_api_logger() - -router = APIRouter( - prefix="/documents", - tags=["documents"], - dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller -) - - -@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse) -async def get_documents( - kb_id: uuid.UUID, - parent_id: uuid.UUID, - page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 - pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items - orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"), - desc: Optional[bool] = Query(False, description="Is it descending order"), - keywords: Optional[str] = Query(None, description="Search keywords (file name)"), - document_ids: Optional[str] = Query(None, description="document ids, separated by commas"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Paged query document list - - Support filtering by kb_id and parent_id - - Support keyword search for file names - - Support dynamic sorting - - Return paging metadata + file list - """ - api_logger.info(f"Query document list: kb_id={kb_id}, page={page}, pagesize={pagesize}, keywords={keywords}, document_ids={document_ids}, username: {current_user.username}") - # 1. parameter validation - if page < 1 or pagesize < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The paging parameter must be greater than 0" - ) - - # 2. Construct query conditions - filters = [ - document_model.Document.kb_id == kb_id, - document_model.Document.status == 1 - ] - - if parent_id: - files = file_service.get_files_by_parent_id(db=db, parent_id=parent_id, current_user=current_user) - files_ids = [item.id for item in files] - filters.append(document_model.Document.file_id.in_(files_ids)) - - # Keyword search (fuzzy matching of file name) - if keywords: - api_logger.debug(f"Add keyword search criteria: {keywords}") - filters.append(document_model.Document.file_name.ilike(f"%{keywords}%")) - # document ids - if document_ids: - filters.append(document_model.Document.id.in_(document_ids.split(','))) - - # 3. Execute paged query - try: - api_logger.debug(f"Start executing document paging query") - total, items = document_service.get_documents_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc, - current_user=current_user - ) - api_logger.info(f"Document query successful: total={total}, returned={len(items)} records") - except Exception as e: - api_logger.error(f"Document query failed: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Query failed: {str(e)}" - ) - - # 4. Return structured response - result = { - "items": items, - "page": { - "page": page, - "pagesize": pagesize, - "total": total, - "has_next": True if page * pagesize < total else False - } - } - return success(data=result, msg="Query of document list succeeded") - - -@router.post("/document", response_model=ApiResponse) -async def create_document( - create_data: document_schema.DocumentCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - create document - """ - api_logger.info(f"Create document request: file_name={create_data.file_name}, kb_id={create_data.kb_id}, username: {current_user.username}") - - try: - api_logger.debug(f"Start creating a document: {create_data.file_name}") - db_document = document_service.create_document(db=db, document=create_data, current_user=current_user) - api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})") - return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful") - except Exception as e: - api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}") - raise - - -@router.get("/{document_id}", response_model=ApiResponse) -async def get_document( - document_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Retrieve document information based on document_id - """ - api_logger.info(f"Obtain document information: document_id={document_id}, username: {current_user.username}") - - try: - # 1. Query document information from the database - api_logger.debug(f"query documentation: {document_id}") - db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user) - if not db_document: - api_logger.warning(f"The document does not exist or you do not have access: document_id={document_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document does not exist or you do not have access" - ) - - api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})") - return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information") - except HTTPException: - raise - except Exception as e: - api_logger.error(f"Document query failed: document_id={document_id} - {str(e)}") - raise - - -@router.put("/{document_id}", response_model=ApiResponse) -async def update_document( - document_id: uuid.UUID, - update_data: document_schema.DocumentUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Update document information - """ - # 1. Check if the document exists - api_logger.debug(f"Query the document to be updated: {document_id}") - db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user) - - if not db_document: - api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document does not exist or you do not have permission to access it" - ) - - # 2. If updating the status, synchronize the document status switch to whether it can be retrieved from the vector database - update_dict = update_data.dict(exclude_unset=True) - if "status" in update_dict: - new_status = update_dict["status"] - if new_status != db_document.status: - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user) - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - vector_service.change_status_by_document_id(document_id=str(document_id), status=new_status) - - # 3. Update fields (only update non-null fields) - api_logger.debug(f"Start updating the document fields: {document_id}") - updated_fields = [] - for field, value in update_dict.items(): - if hasattr(db_document, field): - old_value = getattr(db_document, field) - if old_value != value: - # update value - setattr(db_document, field, value) - updated_fields.append(f"{field}: {old_value} -> {value}") - - if updated_fields: - api_logger.debug(f"updated fields: {', '.join(updated_fields)}") - - db_document.updated_at = datetime.datetime.now() - - # 4. Save to database - try: - db.commit() - db.refresh(db_document) - api_logger.info(f"The document has been successfully updated: {db_document.file_name} (ID: {db_document.id})") - except Exception as e: - db.rollback() - api_logger.error(f"Document update failed: document_id={document_id} - {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Document update failed: {str(e)}" - ) - - # 5. Return the updated document - return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully") - - -@router.delete("/{document_id}", response_model=ApiResponse) -async def delete_document( - document_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Delete document - """ - api_logger.info(f"Request to delete document: document_id={document_id}, username: {current_user.username}") - - try: - # 1. Check if the document exists - api_logger.debug(f"Check whether the document exists: {document_id}") - db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user) - - if not db_document: - api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document does not exist or you do not have permission to access it" - ) - file_id = db_document.file_id - - # 2. Delete document - api_logger.debug(f"Perform document delete: {db_document.file_name} (ID: {document_id})") - db.delete(db_document) - db.commit() - - # 3. Delete file - await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user) - - # 4. Delete vector index - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user) - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) - - api_logger.info(f"The document has been successfully deleted: {db_document.file_name} (ID: {document_id})") - return success(msg="The document has been successfully deleted") - except Exception as e: - api_logger.error(f"Failed to delete from the document: document_id={document_id} - {str(e)}") - raise - - -@router.post("/{document_id}/chunks", response_model=ApiResponse) -async def parse_documents( - document_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - parse document - """ - api_logger.info(f"Request to parse document: document_id={document_id}, username: {current_user.username}") - - try: - # 1. Check if the document exists - api_logger.debug(f"Check whether the document exists: {document_id}") - db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user) - - if not db_document: - api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The document does not exist or you do not have permission to access it" - ) - - # 2. Check if the file exists - api_logger.debug(f"Check whether the file exists: {db_document.file_id}") - db_file = file_service.get_file_by_id(db, file_id=db_document.file_id) - - if not db_file: - api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={db_document.file_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The file does not exist or you do not have permission to access it" - ) - - # 3. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} - file_path = os.path.join( - settings.FILE_PATH, - str(db_file.kb_id), - str(db_file.parent_id), - f"{db_file.id}{db_file.file_ext}" - ) - - # 4. Check if the file exists - if not os.path.exists(file_path): - api_logger.warning(f"File not found (possibly deleted): file_path={file_path}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="File not found (possibly deleted)" - ) - - # 5. Obtain knowledge base information - api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}") - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user) - if not db_knowledge: - api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - - # 6. Task: Document parsing, vectorization, and storage - # from app.tasks import parse_document - # parse_document(file_path, document_id) - task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id]) - result = { - "task_id": task.id - } - return success(data=result, msg="Task accepted. The document is being processed in the background.") - except Exception as e: - api_logger.error(f"Failed to parse document: document_id={document_id} - {str(e)}") - raise diff --git a/app/controllers/file_controller.py b/app/controllers/file_controller.py deleted file mode 100644 index 57a99118..00000000 --- a/app/controllers/file_controller.py +++ /dev/null @@ -1,453 +0,0 @@ -import os -from typing import Any, Optional -from pathlib import Path -import shutil -import uuid -from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query -from fastapi.responses import FileResponse -from sqlalchemy.orm import Session - -from app.core.config import settings -from app.db import get_db -from app.dependencies import get_current_user -from app.models.user_model import User -from app.models import file_model -from app.schemas import file_schema, document_schema -from app.schemas.response_schema import ApiResponse -from app.core.response_utils import success -from app.services import file_service, document_service -from app.core.logging_config import get_api_logger - -# Obtain a dedicated API logger -api_logger = get_api_logger() - -router = APIRouter( - prefix="/files", - tags=["files"] -) - - -@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse) -async def get_files( - kb_id: uuid.UUID, - parent_id: uuid.UUID, - page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 - pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items - orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"), - desc: Optional[bool] = Query(False, description="Is it descending order"), - keywords: Optional[str] = Query(None, description="Search keywords (file name)"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Paged query file list - - Support filtering by kb_id and parent_id - - Support keyword search for file names - - Support dynamic sorting - - Return paging metadata + file list - """ - api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}") - # 1. parameter validation - if page < 1 or pagesize < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The paging parameter must be greater than 0" - ) - - # 2. Construct query conditions - filters = [ - file_model.File.kb_id == kb_id - ] - if parent_id: - filters.append(file_model.File.parent_id == parent_id) - # Keyword search (fuzzy matching of file name) - if keywords: - filters.append(file_model.File.file_name.ilike(f"%{keywords}%")) - - # 3. Execute paged query - try: - api_logger.debug(f"Start executing file paging query") - total, items = file_service.get_files_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc, - current_user=current_user - ) - api_logger.info(f"File query successful: total={total}, returned={len(items)} records") - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Query failed: {str(e)}" - ) - - # 4. Return structured response - result = { - "items": items, - "page": { - "page": page, - "pagesize": pagesize, - "total": total, - "has_next": True if page * pagesize < total else False - } - } - return success(data=result, msg="Query of file list succeeded") - - -@router.post("/folder", response_model=ApiResponse) -def create_folder( - kb_id: uuid.UUID, - parent_id: uuid.UUID, - folder_name: str = '/', - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - Create a new folder - """ - api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}") - - try: - api_logger.debug(f"Start creating a folder: {folder_name}") - create_folder = file_schema.FileCreate( - kb_id=kb_id, - created_by=current_user.id, - parent_id=parent_id, - file_name=folder_name, - file_ext='folder', - file_size=0, - ) - db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user) - api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})") - return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful") - except Exception as e: - api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}") - raise - - -@router.post("/file", response_model=ApiResponse) -async def upload_file( - kb_id: uuid.UUID, - parent_id: uuid.UUID, - file: UploadFile = File(...), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - upload file - """ - api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}") - - # Read the contents of the file - contents = await file.read() - # Check file size - file_size = len(contents) - print(f"file size: {file_size} byte") - if file_size == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The file is empty." - ) - # If the file size exceeds 50MB (50 * 1024 * 1024 bytes) - if file_size > settings.MAX_FILE_SIZE: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit" - ) - - # Extract the extension using `os.path.splitext` - _, file_extension = os.path.splitext(file.filename) - upload_file = file_schema.FileCreate( - kb_id=kb_id, - created_by=current_user.id, - parent_id=parent_id, - file_name=file.filename, - file_ext=file_extension.lower(), - file_size=file_size, - ) - db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user) - - # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} - save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id)) - Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists - save_path = os.path.join(save_dir, f"{db_file.id}{file_extension}") - - # Save file - with open(save_path, "wb") as f: - f.write(contents) - - # Verify whether the file has been saved successfully - if not os.path.exists(save_path): - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="File save failed" - ) - - # Create a document - create_data = document_schema.DocumentCreate( - kb_id=kb_id, - created_by=current_user.id, - file_id=db_file.id, - file_name=db_file.file_name, - file_ext=db_file.file_ext, - file_size=db_file.file_size, - file_meta={}, - parser_id="naive", - parser_config={ - "layout_recognize": "DeepDOC", - "chunk_token_num": 128, - "delimiter": "\n", - "auto_keywords": 0, - "auto_questions": 0, - "html4excel": "false" - } - ) - db_document = document_service.create_document(db=db, document=create_data, current_user=current_user) - - api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})") - return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful") - - -@router.post("/customtext", response_model=ApiResponse) -async def custom_text( - kb_id: uuid.UUID, - parent_id: uuid.UUID, - create_data: file_schema.CustomTextFileCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - custom text - """ - api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}") - - # Check file content size - # 将内容编码为字节(UTF-8) - content_bytes = create_data.content.encode('utf-8') - file_size = len(content_bytes) - print(f"file size: {file_size} byte") - if file_size == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The content is empty." - ) - # If the file size exceeds 50MB (50 * 1024 * 1024 bytes) - if file_size > settings.MAX_FILE_SIZE: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit" - ) - - upload_file = file_schema.FileCreate( - kb_id=kb_id, - created_by=current_user.id, - parent_id=parent_id, - file_name=f"{create_data.title}.txt", - file_ext=".txt", - file_size=file_size, - ) - db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user) - - # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} - save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id)) - Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists - save_path = os.path.join(save_dir, f"{db_file.id}.txt") - - # Save file - with open(save_path, "wb") as f: - f.write(content_bytes) - - # Verify whether the file has been saved successfully - if not os.path.exists(save_path): - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="File save failed" - ) - - # Create a document - create_document_data = document_schema.DocumentCreate( - kb_id=kb_id, - created_by=current_user.id, - file_id=db_file.id, - file_name=db_file.file_name, - file_ext=db_file.file_ext, - file_size=db_file.file_size, - file_meta={}, - parser_id="naive", - parser_config={ - "layout_recognize": "DeepDOC", - "chunk_token_num": 128, - "delimiter": "\n", - "auto_keywords": 0, - "auto_questions": 0, - "html4excel": "false" - } - ) - db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) - - api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})") - return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") - - -@router.get("/{file_id}", response_model=Any) -async def get_file( - file_id: uuid.UUID, - db: Session = Depends(get_db) -) -> Any: - """ - Download the file based on the file_id - - Query file information from the database - - Construct the file path and check if it exists - - Return a FileResponse to download the file - """ - api_logger.info(f"Download the file based on the file_id: file_id={file_id}") - - # 1. Query file information from the database - db_file = file_service.get_file_by_id(db, file_id=file_id) - if not db_file: - api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The file does not exist or you do not have permission to access it" - ) - - # 2. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} - file_path = os.path.join( - settings.FILE_PATH, - str(db_file.kb_id), - str(db_file.parent_id), - f"{db_file.id}{db_file.file_ext}" - ) - - # 3. Check if the file exists - if not os.path.exists(file_path): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="File not found (possibly deleted)" - ) - - # 4.Return FileResponse (automatically handle download) - return FileResponse( - path=file_path, - filename=db_file.file_name, # Use original file name - media_type="application/octet-stream" # Universal binary stream type - ) - - -@router.put("/{file_id}", response_model=ApiResponse) -async def update_file( - file_id: uuid.UUID, - update_data: file_schema.FileUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Update file information (such as file name) - - Only specified fields such as file_name are allowed to be modified - """ - api_logger.debug(f"Query the file to be updated: {file_id}") - - # 1. Check if the file exists - db_file = file_service.get_file_by_id(db, file_id=file_id) - - if not db_file: - api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The file does not exist or you do not have permission to access it" - ) - - # 2. Update fields (only update non-null fields) - api_logger.debug(f"Start updating the file fields: {file_id}") - updated_fields = [] - for field, value in update_data.items(): - if hasattr(db_file, field): - old_value = getattr(db_file, field) - if old_value != value: - # update value - setattr(db_file, field, value) - updated_fields.append(f"{field}: {old_value} -> {value}") - - if updated_fields: - api_logger.debug(f"updated fields: {', '.join(updated_fields)}") - - # 3. Save to database - try: - db.commit() - db.refresh(db_file) - api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})") - except Exception as e: - db.rollback() - api_logger.error(f"File update failed: file_id={file_id} - {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"File update failed: {str(e)}" - ) - - # 4. Return the updated file - return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully") - - -@router.delete("/{file_id}", response_model=ApiResponse) -async def delete_file( - file_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Delete a file or folder - """ - api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}") - await _delete_file(db=db, file_id=file_id, current_user=current_user) - return success(msg="File deleted successfully") - -async def _delete_file( - file_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -) -> None: - """ - Delete a file or folder - """ - # 1. Check if the file exists - db_file = file_service.get_file_by_id(db, file_id=file_id) - - if not db_file: - api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The file does not exist or you do not have permission to access it" - ) - - # 2. Construct physical path - file_path = Path( - settings.FILE_PATH, - str(db_file.kb_id), - str(db_file.id) - ) if db_file.file_ext == 'folder' else Path( - settings.FILE_PATH, - str(db_file.kb_id), - str(db_file.parent_id), - f"{db_file.id}{db_file.file_ext}" - ) - - # 3. Delete physical files/folders - try: - if file_path.exists(): - if db_file.file_ext == 'folder': - shutil.rmtree(file_path) # Recursively delete folders - else: - file_path.unlink() # Delete a single file - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to delete physical file/folder: {str(e)}" - ) - - # 4.Delete db_file - if db_file.file_ext == 'folder': - db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete() - db.delete(db_file) - db.commit() diff --git a/app/controllers/knowledge_controller.py b/app/controllers/knowledge_controller.py deleted file mode 100644 index 892dcc39..00000000 --- a/app/controllers/knowledge_controller.py +++ /dev/null @@ -1,305 +0,0 @@ -from typing import Optional -import datetime -import uuid -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy import or_ -from sqlalchemy.orm import Session - -from app.db import get_db -from app.dependencies import get_current_user -from app.models.user_model import User -from app.models import knowledge_model, document_model, file_model -from app.schemas import knowledge_schema -from app.schemas.response_schema import ApiResponse -from app.core.response_utils import success -from app.services import knowledge_service, document_service -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.logging_config import get_api_logger - -# Obtain a dedicated API logger -api_logger = get_api_logger() - -router = APIRouter( - prefix="/knowledges", - tags=["knowledges"], - dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller -) - - -@router.get("/knowledgetype", response_model=ApiResponse) -def get_knowledge_types(): - return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType)) - - -@router.get("/permissiontype", response_model=ApiResponse) -def get_permission_types(): - return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType)) - - -@router.get("/parsertype", response_model=ApiResponse) -def get_parser_types(): - return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType)) - - -@router.get("/knowledges", response_model=ApiResponse) -async def get_knowledges( - parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"), - page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 - pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items - orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"), - desc: Optional[bool] = Query(False, description="Is it descending order"), - keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"), - kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Query the knowledge base list in pages - - Support filtering by parent_id - - Support keyword search for knowledge base names - - Support dynamic sorting - - Return paging metadata + file list - """ - api_logger.info(f"Query knowledge base list: workspace_id={current_user.current_workspace_id}, page={page}, pagesize={pagesize}, keywords={keywords}, kb_ids={kb_ids}, username: {current_user.username}") - - # 1. parameter validation - if page < 1 or pagesize < 1: - api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The paging parameter must be greater than 0" - ) - - # 2. Construct query conditions - filters = [ - knowledge_model.Knowledge.workspace_id == current_user.current_workspace_id - ] - if parent_id: - filters.append(knowledge_model.Knowledge.parent_id == parent_id) - - # Keyword search (fuzzy matching of knowledge base name) - if keywords: - api_logger.debug(f"Add keyword search criteria: {keywords}") - filters.append( - or_( - knowledge_model.Knowledge.name.ilike(f"%{keywords}%"), - knowledge_model.Knowledge.description.ilike(f"%{keywords}%") - ) - ) - # Knowledge base ids - if kb_ids: - filters.append(knowledge_model.Knowledge.id.in_(kb_ids.split(','))) - else: - filters.append(knowledge_model.Knowledge.status != 2) - # 3. Execute paged query - try: - api_logger.debug(f"Start executing knowledge base paging query") - total, items = knowledge_service.get_knowledges_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc, - current_user=current_user - ) - api_logger.info(f"Knowledge base query successful: total={total}, returned={len(items)} records") - except Exception as e: - api_logger.error(f"Knowledge base query failed: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Query failed: {str(e)}" - ) - - # 4. Return structured response - result = { - "items": items, - "page": { - "page": page, - "pagesize": pagesize, - "total": total, - "has_next": True if page*pagesize < total else False - } - } - return success(data=result, msg="Query of knowledge base list successful") - - -@router.post("/knowledge", response_model=ApiResponse) -async def create_knowledge( - create_data: knowledge_schema.KnowledgeCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - create knowledge - """ - api_logger.info(f"Request to create a knowledge base: name={create_data.name}, workspace_id={current_user.current_workspace_id}, username: {current_user.username}") - - try: - api_logger.debug(f"Start creating the knowledge base: {create_data.name}") - # 1. Check if the knowledge base name already exists - db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=create_data.name, current_user=current_user) - if db_knowledge_exist: - api_logger.warning(f"The knowledge base name already exists: {create_data.name}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"The knowledge base name already exists: {create_data.name}" - ) - db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user) - api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})") - return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created") - except Exception as e: - api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}") - raise - - -@router.get("/{knowledge_id}", response_model=ApiResponse) -async def get_knowledge( - knowledge_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Retrieve knowledge base information based on knowledge_id - """ - api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}") - - try: - # 1. Query knowledge base information from the database - api_logger.debug(f"Query knowledge base: {knowledge_id}") - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user) - if not db_knowledge: - api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or access is denied" - ) - - api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})") - return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information") - except HTTPException: - raise - except Exception as e: - api_logger.error(f"Knowledge base query failed: knowledge_id={knowledge_id} - {str(e)}") - raise - - -@router.put("/{knowledge_id}", response_model=ApiResponse) -async def update_knowledge( - knowledge_id: uuid.UUID, - update_data: knowledge_schema.KnowledgeUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}") - db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user) - return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated") - - -async def _update_knowledge( - knowledge_id: uuid.UUID, - update_data: knowledge_schema.KnowledgeUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -) -> knowledge_schema.Knowledge: - """ - Update knowledge base information - """ - try: - # 1. Check whether the knowledge base exists - api_logger.debug(f"Query the knowledge base to be updated: {knowledge_id}") - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user) - - if not db_knowledge: - api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or you do not have permission to access it" - ) - - # 2. If updating the embedding_id, delete the knowledge base vector index, reset all document parsing progress to 0, and set chunk_num to 0 - update_dict = update_data.dict(exclude_unset=True) - if "name" in update_dict: - name = update_dict["name"] - if name != db_knowledge.name: - # Check if the knowledge base name already exists - db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=name, current_user=current_user) - if db_knowledge_exist: - api_logger.warning(f"The knowledge base name already exists: {name}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"The knowledge base name already exists: {name}" - ) - if "embedding_id" in update_dict: - embedding_id = update_dict["embedding_id"] - if embedding_id != db_knowledge.embedding_id: - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - vector_service.delete() - document_service.reset_documents_progress_by_kb_id(db, kb_id=db_knowledge.id, current_user=current_user) - - # 2. Update fields (only update non-null fields) - api_logger.debug(f"Start updating the knowledge base fields: {knowledge_id}") - updated_fields = [] - for field, value in update_data.dict(exclude_unset=True).items(): - if hasattr(db_knowledge, field): - old_value = getattr(db_knowledge, field) - if old_value != value: - # update value - setattr(db_knowledge, field, value) - updated_fields.append(f"{field}: {old_value} -> {value}") - - if updated_fields: - api_logger.debug(f"updated fields: {', '.join(updated_fields)}") - - db_knowledge.updated_at = datetime.datetime.now() - - # 3. Save to database - db.commit() - db.refresh(db_knowledge) - api_logger.info(f"The knowledge base has been successfully updated: {db_knowledge.name} (ID: {db_knowledge.id})") - - # 4. Return the updated knowledge base - return db_knowledge - except HTTPException: - raise - except Exception as e: - db.rollback() - api_logger.error(f"Knowledge base update failed: knowledge_id={knowledge_id} - {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Knowledge base update failed: {str(e)}" - ) - - -@router.delete("/{knowledge_id}", response_model=ApiResponse) -async def delete_knowledge( - knowledge_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Soft-delete knowledge base - """ - api_logger.info(f"Request to delete knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}") - - try: - # 1. Check whether the knowledge base exists - api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}") - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user) - - if not db_knowledge: - api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base does not exist or you do not have permission to access it" - ) - - # 2. Soft-delete knowledge base - api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})") - db_knowledge.status = 2 - db.commit() - api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})") - return success(msg="The knowledge base has been successfully deleted") - except Exception as e: - api_logger.error(f"Failed to delete from the knowledge base: knowledge_id={knowledge_id} - {str(e)}") - raise diff --git a/app/controllers/knowledgeshare_controller.py b/app/controllers/knowledgeshare_controller.py deleted file mode 100644 index 8a1b5bb7..00000000 --- a/app/controllers/knowledgeshare_controller.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import Optional -import uuid -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session - -from app.db import get_db -from app.dependencies import get_current_user -from app.models.user_model import User -from app.models import knowledgeshare_model, knowledge_model -from app.schemas import knowledgeshare_schema, knowledge_schema -from app.schemas.response_schema import ApiResponse -from app.core.response_utils import success -from app.services import knowledgeshare_service, knowledge_service -from app.core.logging_config import get_api_logger - -# Obtain a dedicated API logger -api_logger = get_api_logger() - -router = APIRouter( - prefix="/knowledgeshares", - tags=["knowledgeshares"], - dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller -) - - -@router.get("/{kb_id}/knowledgeshares", response_model=ApiResponse) -async def get_knowledgeshares( - kb_id: uuid.UUID, - page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 - pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items - orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"), - desc: Optional[bool] = Query(False, description="Is it descending order"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Paged query knowledge base sharing list - - Support filtering by kb_id - - Support dynamic sorting - - Return paging metadata + share list - """ - api_logger.info( - f"Query knowledge base sharing list: workspace_id={current_user.current_workspace_id}, kb_id={kb_id}, page={page}, pagesize={pagesize}, username: {current_user.username}") - - # 1. parameter validation - if page < 1 or pagesize < 1: - api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The paging parameter must be greater than 0" - ) - - # 2. Construct query conditions - filters = [ - knowledgeshare_model.KnowledgeShare.source_workspace_id == current_user.current_workspace_id, - knowledgeshare_model.KnowledgeShare.source_kb_id == kb_id - ] - - # 3. Execute paged query - try: - api_logger.debug(f"Start executing knowledge base sharing and paging query") - total, items = knowledgeshare_service.get_knowledgeshares_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc, - current_user=current_user - ) - api_logger.info(f"Knowledge base sharing query successful: total={total}, returned={len(items)} records") - except Exception as e: - api_logger.error(f"Knowledge base sharing query failed: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Query failed: {str(e)}" - ) - - # 4. Return structured response - result = { - "items": items, - "page": { - "page": page, - "pagesize": pagesize, - "total": total, - "has_next": True if page * pagesize < total else False - } - } - return success(data=result, msg="Query of knowledge base sharing list successful") - - -@router.post("/knowledgeshare", response_model=ApiResponse) -async def create_knowledgeshare( - create_data: knowledgeshare_schema.KnowledgeShareCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - create knowledgeshare - """ - api_logger.info( - f"Create a knowledge base sharing request: source_kb_id={create_data.source_kb_id}, source_workspace_id={current_user.current_workspace_id}, username: {current_user.username}") - - try: - # 1.Create a knowledge base with permission_id=knowledge_model.PermissionType.Share - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=create_data.source_kb_id, current_user=current_user) - knowledge = knowledge_schema.KnowledgeCreate( - workspace_id=create_data.target_workspace_id, - created_by=current_user.id, - parent_id=create_data.target_workspace_id, - name=db_knowledge.name, - description=db_knowledge.description, - avatar=db_knowledge.avatar, - type=db_knowledge.type, - permission_id=knowledge_model.PermissionType.Share, - embedding_id=db_knowledge.embedding_id, - reranker_id=db_knowledge.reranker_id, - llm_id=db_knowledge.llm_id, - image2text_id=db_knowledge.image2text_id, - doc_num=db_knowledge.doc_num, - chunk_num=db_knowledge.chunk_num, - parser_id=db_knowledge.parser_id, - parser_config=db_knowledge.parser_config - ) - db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=knowledge, current_user=current_user) - # 2. Create a knowledge base for sharing - api_logger.debug(f"Start creating the knowledge base sharing: {db_knowledge.name}") - create_data.target_kb_id = db_knowledge.id - db_knowledgeshare = knowledgeshare_service.create_knowledgeshare(db=db, knowledgeshare=create_data, current_user=current_user) - api_logger.info(f"The knowledge base sharing has been successfully created: (ID: {db_knowledgeshare.id})") - return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="The knowledge base sharing has been successfully created") - except Exception as e: - api_logger.error(f"The creation of the knowledge base sharing failed: {str(e)}") - raise - - -@router.get("/{knowledgeshare_id}", response_model=ApiResponse) -async def get_knowledgeshare( - knowledgeshare_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Retrieve knowledge base sharing information based on knowledgeshare_id - """ - api_logger.info(f"Obtain details of the knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}") - - try: - # 1. Query knowledge base sharing information from the database - api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}") - db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user) - if not db_knowledgeshare: - api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base sharing does not exist or access is denied" - ) - - api_logger.info(f"Knowledge base sharing query successful: (ID: {db_knowledgeshare.id})") - return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="Successfully obtained knowledge base sharing information") - except HTTPException: - raise - except Exception as e: - api_logger.error(f"Knowledge base sharing query failed: knowledgeshare_id={knowledgeshare_id} - {str(e)}") - raise - - -@router.delete("/{knowledgeshare_id}", response_model=ApiResponse) -async def delete_knowledgeshare( - knowledgeshare_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Delete knowledge base sharing - """ - api_logger.info(f"Delete knowledge base sharing request: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}") - - try: - # 1. Query knowledge base sharing information from the database - api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}") - db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user) - if not db_knowledgeshare: - api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="The knowledge base sharing does not exist or access is denied" - ) - # 2. Deleting shared knowledge base - knowledge_service.delete_knowledge_by_id(db, knowledge_id=db_knowledgeshare.target_kb_id ,current_user=current_user) - # 3. Delete knowledge base sharing - api_logger.debug(f"perform knowledge base sharing delete: (ID: {knowledgeshare_id})") - - knowledgeshare_service.delete_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user) - api_logger.info(f"The knowledge base sharing has been successfully deleted: (ID: {knowledgeshare_id})") - return success(msg="The knowledge base sharing has been successfully deleted") - except Exception as e: - api_logger.error(f"Failed to delete from the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}") - raise diff --git a/app/controllers/memory_agent_controller.py b/app/controllers/memory_agent_controller.py deleted file mode 100644 index 419de257..00000000 --- a/app/controllers/memory_agent_controller.py +++ /dev/null @@ -1,802 +0,0 @@ -import json -import time -from typing import Optional, List -from fastapi import APIRouter, Depends, Query, UploadFile -from sqlalchemy.orm import Session -from starlette.responses import StreamingResponse -from app.db import get_db -from app.core.memory.utils.config.config_utils import get_model_config -from app.core.rag.llm.cv_model import QWenCV -from app.models import ModelApiKey, Knowledge -from app.services.memory_agent_service import MemoryAgentService -from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard -from app.celery_app import celery_app -from app.core.logging_config import get_api_logger -from app.core.response_utils import success, fail -from app.core.error_codes import BizCode -from app.services import task_service, workspace_service -from app.schemas.memory_agent_schema import UserInput, Write_UserInput -from app.schemas.response_schema import ApiResponse -from app.dependencies import get_current_user -from app.models.user_model import User -from fastapi import APIRouter, Depends, File, UploadFile, Form -from app.repositories import knowledge_repository -from app.services.model_service import ModelConfigService -from dotenv import load_dotenv -import os - -# 加载.env文件 -load_dotenv() -# Get API logger -api_logger = get_api_logger() - -# Initialize service -memory_agent_service = MemoryAgentService() - -router = APIRouter( - prefix="/memory", - tags=["Memory"], -) - - -def validate_config_id(config_id: int, db: Session) -> int: - """ - Validate and ensure config_id is available, valid, and exists in database. - - Args: - config_id: Configuration ID to validate - db: Database session for checking existence - - Returns: - int: Validated config_id - - Raises: - ValueError: If config_id is None, invalid, or doesn't exist in database - """ - if config_id is None: - api_logger.info(f"config_id is required but was not provided") - config_id = os.getenv('config_id') - if config_id is None: - raise ValueError("config_id is required but was not provided") - - - # Check if config exists in database - try: - from app.models.data_config_model import DataConfig - from app.models.models_model import ModelConfig - - config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() - if config is None: - error_msg = f"Configuration with config_id={config_id} does not exist in database" - api_logger.error(error_msg) - raise ValueError(error_msg) - - # Validate llm_id exists and is usable - if config.llm_id: - try: - llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first() - if llm_config is None: - error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist" - api_logger.error(error_msg) - raise ValueError(error_msg) - if not llm_config.is_active: - error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active" - api_logger.error(error_msg) - raise ValueError(error_msg) - api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}") - except ValueError: - raise - except Exception as e: - error_msg = f"Error validating LLM model: {str(e)}" - api_logger.error(error_msg, exc_info=True) - raise ValueError(error_msg) - else: - api_logger.error(f"Config {config_id} has no llm_id set") - raise ValueError(f"Config {config_id} has no llm_id set") - - # Validate embedding_id exists and is usable - if config.embedding_id: - try: - embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first() - if embedding_config is None: - error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist" - api_logger.error(error_msg) - raise ValueError(error_msg) - if not embedding_config.is_active: - error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active" - api_logger.error(error_msg) - raise ValueError(error_msg) - api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}") - except ValueError: - raise - except Exception as e: - error_msg = f"Error validating embedding model: {str(e)}" - api_logger.error(error_msg, exc_info=True) - raise ValueError(error_msg) - else: - api_logger.error(f"Config {config_id} has no embedding_id set") - raise ValueError(f"Config {config_id} has no embedding_id set") - - api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}") - return config_id - except ValueError: - # Re-raise ValueError from above - raise - except Exception as e: - error_msg = f"Database error while validating config_id={config_id}: {str(e)}" - api_logger.error(error_msg, exc_info=True) - raise ValueError(error_msg) - - -@router.get("/health/status", response_model=ApiResponse) -async def get_health_status( - current_user: User = Depends(get_current_user) -): - """ - Get latest health status written by Celery periodic task - - Returns health status information from Redis cache - """ - api_logger.info("Health status check requested") - try: - result = await memory_agent_service.get_health_status() - return success(data=result["status"]) - except Exception as e: - api_logger.error(f"Health status check failed: {str(e)}") - return fail(BizCode.SERVICE_UNAVAILABLE, "健康状态查询失败", str(e)) - - -@router.get("/download_log") -async def download_log( - log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"), - current_user: User = Depends(get_current_user) -): - """ - Download or stream agent service log file - - log_type: str = Query("file", regex="^(file|transmission)$", - description="日志类型: file=完整文件, transmission=实时流式传输"), - current_user: User = Depends(get_current_user) - - Args: - log_type: Log retrieval mode - - "file": Returns complete log file content in single response (default) - - "transmission": Real-time streaming of log content using Server-Sent Events - - Returns: - - file mode: ApiResponse with log content - - transmission mode: StreamingResponse with SSE - """ - api_logger.info(f"Log download requested with log_type={log_type}") - - # Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity) - if log_type not in ["file", "transmission"]: - api_logger.warning(f"Invalid log_type parameter: {log_type}") - return fail( - BizCode.BAD_REQUEST, - "无效的log_type参数", - "log_type必须是'file'或'transmission'" - ) - - # Route to appropriate mode - if log_type == "file": - # File mode: Return complete log file content - try: - log_content = memory_agent_service.get_log_content() - return success(data=log_content) - except ValueError as e: - api_logger.warning(f"Log content issue: {str(e)}") - return fail(BizCode.FILE_NOT_FOUND, str(e)) - except Exception as e: - api_logger.error(f"Log reading failed: {str(e)}") - return fail(BizCode.FILE_READ_ERROR, "日志读取失败", str(e)) - - else: # log_type == "transmission" - # Transmission mode: Stream log content using SSE - try: - api_logger.info("Starting SSE log streaming") - return StreamingResponse( - memory_agent_service.stream_log_content(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" # Disable nginx buffering - } - ) - except Exception as e: - api_logger.error(f"Failed to start log streaming: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) - - -@router.post("/writer_service", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server( - user_input: Write_UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Write service endpoint - processes write operations synchronously - - Args: - user_input: Write request containing message and group_id - - Returns: - Response with write operation status - """ - # Validate config_id - try: - config_id = validate_config_id(user_input.config_id, db) - except ValueError as e: - return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e)) - - workspace_id = current_user.current_workspace_id - api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - - # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id - if storage_type == 'rag': - if workspace_id: - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: - user_rag_memory_id = str(knowledge.id) - else: - api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") - storage_type = 'neo4j' - else: - api_logger.warning(f"workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") - storage_type = 'neo4j' - - api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") - try: - result = await memory_agent_service.write_memory( - user_input.group_id, - user_input.message, - config_id, - storage_type, - user_rag_memory_id - ) - return success(data=result, msg="写入成功") - except Exception as e: - api_logger.error(f"Write operation error: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) - - -@router.post("/writer_service_async", response_model=ApiResponse) -@cur_workspace_access_guard() -async def write_server_async( - user_input: Write_UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Async write service endpoint - enqueues write processing to Celery - - Args: - user_input: Write request containing message and group_id - - Returns: - Task ID for tracking async operation - Use GET /memory/write_result/{task_id} to check task status and get result - """ - # Validate config_id - try: - config_id = validate_config_id(user_input.config_id, db) - except ValueError as e: - return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e)) - - workspace_id = current_user.current_workspace_id - api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - try: - task = celery_app.send_task( - "app.core.memory.agent.write_message", - args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id] - ) - api_logger.info(f"Write task queued: {task.id}") - - return success(data={"task_id": task.id}, msg="写入任务已提交") - except Exception as e: - api_logger.error(f"Async write operation failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) - - -@router.post("/read_service", response_model=ApiResponse) -@cur_workspace_access_guard() -async def read_server( - user_input: UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - Read service endpoint - processes read operations synchronously - - search_switch values: - - "0": Requires verification - - "1": No verification, direct split - - "2": Direct answer based on context - - Args: - user_input: Read request with message, history, search_switch, and group_id - - Returns: - Response with query answer - """ - # Validate config_id - try: - config_id = validate_config_id(user_input.config_id, db) - except ValueError as e: - return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e)) - - workspace_id = current_user.current_workspace_id - api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - - api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") - try: - result = await memory_agent_service.read_memory( - user_input.group_id, - user_input.message, - user_input.history, - user_input.search_switch, - config_id, - storage_type, - user_rag_memory_id - ) - return success(data=result, msg="回复对话消息成功") - except Exception as e: - api_logger.error(f"Read operation error: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e)) - - -@router.post("/file", response_model=ApiResponse) -async def file_update( - files: List[UploadFile] = File(..., description="要上传的文件"), - model_id:str = Form(..., description="模型ID"), - metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), - current_user: User = Depends(get_current_user) -): - """ - 文件上传接口 - 支持图片识别 - - Args: - files: 上传的文件列表 - metadata: 文件元数据(可选) - current_user: 当前用户 - - Returns: - 文件处理结果 - """ - - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) - api_logger.info(f"File upload requested, file count: {len(files)}") - config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - apiConfig: ModelApiKey = config.api_keys[0] - file_content = [] - try: - for file in files: - api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}") - content = await file.read() - - if file.content_type and file.content_type.startswith("image/"): - vision_model = QWenCV( - key=apiConfig.api_key, - model_name=apiConfig.model_name, - lang="Chinese", - base_url=apiConfig.api_base - ) - description, token_count = vision_model.describe(content) - file_content.append(description) - api_logger.info(f"Image processed: {file.filename}, tokens: {token_count}") - else: - api_logger.warning(f"Unsupported file type: {file.content_type}") - file_content.append(f"[不支持的文件类型: {file.content_type}]") - - result_text = ';'.join(file_content) - api_logger.info(f"File processing completed, result length: {len(result_text)}") - - return success(data=result_text, msg="转换文本成功") - - except Exception as e: - api_logger.error(f"File processing failed: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e)) - - -@router.post("/read_service_async", response_model=ApiResponse) -@cur_workspace_access_guard() -async def read_server_async( - user_input: UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - # Validate config_id - try: - config_id = validate_config_id(user_input.config_id, db) - except ValueError as e: - return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e)) - - workspace_id = current_user.current_workspace_id - api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: storage_type = 'neo4j' - user_rag_memory_id = '' - if workspace_id: - - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - try: - task = celery_app.send_task( - "app.core.memory.agent.read_message", - args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch, - config_id, storage_type, user_rag_memory_id] - ) - api_logger.info(f"Read task queued: {task.id}") - - return success(data={"task_id": task.id}, msg="查询任务已提交") - except Exception as e: - api_logger.error(f"Async read operation failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e)) - - -@router.get("/read_result/", response_model=ApiResponse) -async def get_read_task_result( - task_id: str, - current_user: User = Depends(get_current_user) -): - """ - Get the status and result of an async read task - - Args: - task_id: Celery task ID returned from /read_service_async - - Returns: - Task status and result if completed - - Response format: - - PENDING: Task is waiting to be executed - - STARTED: Task has started - - SUCCESS: Task completed successfully, returns result - - FAILURE: Task failed, returns error message - """ - api_logger.info(f"Read task status check requested for task {task_id}") - try: - result = task_service.get_task_memory_read_result(task_id) - status = result.get("status") - - if status == "SUCCESS": - # 任务成功完成 - task_result = result.get("result", {}) - if isinstance(task_result, dict): - # 新格式:包含详细信息 - return success( - data={ - "result": task_result.get("result"), - "group_id": task_result.get("group_id"), - "elapsed_time": task_result.get("elapsed_time"), - "task_id": task_id - }, - msg="查询任务已完成" - ) - else: - # 旧格式:直接返回结果 - return success(data=task_result, msg="查询任务已完成") - - elif status == "FAILURE": - # 任务失败 - error_info = result.get("result", "Unknown error") - if isinstance(error_info, dict): - error_msg = error_info.get("error", str(error_info)) - else: - error_msg = str(error_info) - return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg) - - elif status in ["PENDING", "STARTED"]: - # 任务进行中 - return success( - data={ - "status": status, - "task_id": task_id, - "message": "任务处理中,请稍后查询" - }, - msg="查询任务处理中" - ) - else: - # 未知状态 - return success( - data={ - "status": status, - "task_id": task_id - }, - msg=f"任务状态: {status}" - ) - - except Exception as e: - api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) - - -@router.get("/write_result/", response_model=ApiResponse) -async def get_write_task_result( - task_id: str, - current_user: User = Depends(get_current_user) -): - """ - Get the status and result of an async write task - - Args: - task_id: Celery task ID returned from /writer_service_async - - Returns: - Task status and result if completed - - Response format: - - PENDING: Task is waiting to be executed - - STARTED: Task has started - - SUCCESS: Task completed successfully, returns result - - FAILURE: Task failed, returns error message - """ - api_logger.info(f"Write task status check requested for task {task_id}") - try: - result = task_service.get_task_memory_write_result(task_id) - status = result.get("status") - - if status == "SUCCESS": - # 任务成功完成 - task_result = result.get("result", {}) - if isinstance(task_result, dict): - # 新格式:包含详细信息 - return success( - data={ - "result": task_result.get("result"), - "group_id": task_result.get("group_id"), - "elapsed_time": task_result.get("elapsed_time"), - "task_id": task_id - }, - msg="写入任务已完成" - ) - else: - # 旧格式:直接返回结果 - return success(data=task_result, msg="写入任务已完成") - - elif status == "FAILURE": - # 任务失败 - error_info = result.get("result", "Unknown error") - if isinstance(error_info, dict): - error_msg = error_info.get("error", str(error_info)) - else: - error_msg = str(error_info) - return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg) - - elif status in ["PENDING", "STARTED"]: - # 任务进行中 - return success( - data={ - "status": status, - "task_id": task_id, - "message": "任务处理中,请稍后查询" - }, - msg="写入任务处理中" - ) - else: - # 未知状态 - return success( - data={ - "status": status, - "task_id": task_id - }, - msg=f"任务状态: {status}" - ) - - except Exception as e: - api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) - - -@router.post("/status_type", response_model=ApiResponse) -async def status_type( - user_input: Write_UserInput, - current_user: User = Depends(get_current_user) -): - """ - Determine the type of user message (read or write) - - Args: - user_input: Request containing user message and group_id - - Returns: - Type classification result - """ - api_logger.info(f"Status type check requested for group {user_input.group_id}") - try: - result = await memory_agent_service.classify_message_type(user_input.message) - return success(data=result) - except Exception as e: - api_logger.error(f"Message type classification failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "类型判断失败", str(e)) - - -# ==================== 新增的三个接口路由 ==================== - -@router.get("/stats/types", response_model=ApiResponse) -async def get_knowledge_type_stats_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - only_active: bool = Query(True, description="仅统计有效记录(status=1)"), - current_user: User = Depends(get_current_user) -): - """ - 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。 - 会对缺失类型补 0,返回字典形式。 - 可选按状态过滤。 - - 知识库类型根据当前用户的 current_workspace_id 过滤 - - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤 - - 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0 - """ - api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") - try: - from app.db import get_db - - # 获取数据库会话 - db_gen = get_db() - db = next(db_gen) - - # 调用service层函数 - result = await memory_agent_service.get_knowledge_type_stats( - end_user_id=end_user_id, - only_active=only_active, - current_workspace_id=current_user.current_workspace_id, - db=db - ) - - return success(data=result, msg="获取知识库类型统计成功") - except Exception as e: - api_logger.error(f"Knowledge type stats failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e)) - - -@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse) -async def get_hot_memory_tags_by_user_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - limit: int = Query(20, description="返回标签数量限制"), - current_user: User = Depends(get_current_user) -): - """ - 获取指定用户的热门记忆标签 - - 返回格式: - [ - {"name": "标签名", "frequency": 频次}, - ... - ] - """ - api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}") - try: - result = await memory_agent_service.get_hot_memory_tags_by_user( - end_user_id=end_user_id, - limit=limit - ) - return success(data=result, msg="获取热门记忆标签成功") - except Exception as e: - api_logger.error(f"Hot memory tags by user failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e)) - - -@router.get("/analytics/user_profile", response_model=ApiResponse) -async def get_user_profile_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - current_user: User = Depends(get_current_user) -): - """ - 获取用户详情,包含: - - name: 用户名字(直接使用 end_user_id) - - tags: 3个用户特征标签(从语句和实体中LLM总结) - - hot_tags: 4个热门记忆标签 - - 返回格式: - { - "name": "用户名", - "tags": ["产品设计师", "旅行爱好者", "摄影发烧友"], - "hot_tags": [ - {"name": "标签1", "frequency": 10}, - {"name": "标签2", "frequency": 8}, - ... - ] - } - """ - api_logger.info(f"User profile requested: end_user_id={end_user_id}, current_user={current_user.id}") - try: - result = await memory_agent_service.get_user_profile( - end_user_id=end_user_id, - current_user_id=str(current_user.id) - ) - return success(data=result, msg="获取用户详情成功") - except Exception as e: - api_logger.error(f"User profile failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "获取用户详情失败", str(e)) - - -# @router.get("/docs/api", response_model=ApiResponse) -# async def get_api_docs_api( -# file_path: Optional[str] = Query(None, description="API文档文件路径,不传则使用默认路径") -# ): -# """ -# Get parsed API documentation (Public endpoint - no authentication required) - -# Args: -# file_path: Optional path to API docs file. If None, uses default path. - -# Returns: -# Parsed API documentation including title, meta info, and sections -# """ -# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}") -# try: -# result = await memory_agent_service.get_api_docs(file_path) - -# if result.get("success"): -# return success(msg=result["msg"], data=result["data"]) -# else: -# return fail( -# code=BizCode.BAD_REQUEST, -# msg=result["msg"], -# error=result.get("data", {}).get("error", result.get("error_code", "")) -# ) -# except Exception as e: -# api_logger.error(f"API docs retrieval failed: {str(e)}") -# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e)) \ No newline at end of file diff --git a/app/controllers/memory_dashboard_controller.py b/app/controllers/memory_dashboard_controller.py deleted file mode 100644 index e915a7a8..00000000 --- a/app/controllers/memory_dashboard_controller.py +++ /dev/null @@ -1,516 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid - -from app.core.response_utils import success -from app.db import get_db -from app.dependencies import get_current_user -from app.models.user_model import User -from app.schemas.response_schema import ApiResponse -from app.schemas.app_schema import App as AppSchema - -from app.services import memory_dashboard_service, memory_storage_service, workspace_service -from app.core.logging_config import get_api_logger - -# 获取API专用日志器 -api_logger = get_api_logger() - -router = APIRouter( - prefix="/dashboard", - tags=["Dashboard"], - dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller -) - - -@router.get("/total_end_users", response_model=ApiResponse) -def get_workspace_total_end_users( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 获取用户列表的总用户数 - """ - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") - total_end_users = memory_dashboard_service.get_workspace_total_end_users( - db=db, - workspace_id=workspace_id, - current_user=current_user - ) - api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}") - return success(data=total_end_users, msg="用户数量获取成功") - - -@router.get("/end_users", response_model=ApiResponse) -async def get_workspace_end_users( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 获取工作空间的宿主列表 - - 返回格式与原 memory_list 接口中的 end_users 字段相同 - """ - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") - end_users = memory_dashboard_service.get_workspace_end_users( - db=db, - workspace_id=workspace_id, - current_user=current_user - ) - result = [] - for end_user in end_users: - # EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get() - memory_num = await memory_storage_service.search_all(str(end_user.id)) - result.append( - { - 'end_user':end_user, - 'memory_num':memory_num - } - ) - api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") - return success(data=result, msg="宿主列表获取成功") - - -@router.get("/memory_increment", response_model=ApiResponse) -def get_workspace_memory_increment( - limit: int = Query(7, description="返回记录数"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """获取工作空间的记忆增量""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆增量") - memory_increment = memory_dashboard_service.get_workspace_memory_increment( - db=db, - workspace_id=workspace_id, - current_user=current_user, - limit=limit - ) - api_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录") - return success(data=memory_increment, msg="记忆增量获取成功") - - -@router.get("/api_increment", response_model=ApiResponse) -def get_workspace_api_increment( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """获取API调用趋势""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的API调用增量") - api_increment = memory_dashboard_service.get_workspace_api_increment( - db=db, - workspace_id=workspace_id, - current_user=current_user - ) - api_logger.info(f"成功获取 {api_increment} API调用增量") - return success(data=api_increment, msg="API调用增量获取成功") - - -@router.post("/total_memory", response_model=ApiResponse) -def write_workspace_total_memory( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """工作空间记忆总量的写入(异步任务)""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求写入工作空间 {workspace_id} 的记忆总量") - - # 触发 Celery 异步任务 - from app.celery_app import celery_app - task = celery_app.send_task( - "app.controllers.memory_storage_controller.search_all", - kwargs={"workspace_id": str(workspace_id)} - ) - - api_logger.info(f"已触发记忆总量统计任务,task_id: {task.id}") - return success( - data={"task_id": task.id, "workspace_id": str(workspace_id)}, - msg="记忆总量统计任务已启动" - ) - - -@router.get("/task_status/{task_id}", response_model=ApiResponse) -def get_task_status( - task_id: str, - current_user: User = Depends(get_current_user), -): - """查询异步任务的执行状态和结果""" - api_logger.info(f"用户 {current_user.username} 查询任务状态: task_id={task_id}") - - from app.celery_app import celery_app - from celery.result import AsyncResult - - # 获取任务结果 - task_result = AsyncResult(task_id, app=celery_app) - - response_data = { - "task_id": task_id, - "status": task_result.state, # PENDING, STARTED, SUCCESS, FAILURE, RETRY, REVOKED - } - - # 如果任务完成,返回结果 - if task_result.ready(): - if task_result.successful(): - response_data["result"] = task_result.result - api_logger.info(f"任务 {task_id} 执行成功") - return success(data=response_data, msg="任务执行成功") - else: - # 任务失败 - response_data["error"] = str(task_result.result) - api_logger.error(f"任务 {task_id} 执行失败: {task_result.result}") - return success(data=response_data, msg="任务执行失败") - else: - # 任务还在执行中 - api_logger.info(f"任务 {task_id} 状态: {task_result.state}") - return success(data=response_data, msg=f"任务状态: {task_result.state}") - - -@router.get("/memory_list", response_model=ApiResponse) -def get_workspace_memory_list( - limit: int = Query(7, description="记忆增量返回记录数"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 用户记忆列表整合接口 - - 整合以下三个接口的数据: - 1. total_memory - 工作空间记忆总量 - 2. memory_increment - 工作空间记忆增量 - 3. hosts - 工作空间宿主列表 - - 返回格式: - { - "total_memory": float, - "memory_increment": [ - {"date": "2024-01-01", "count": 100}, - ... - ], - "hosts": [ - {"id": "uuid", "name": "宿主名", ...}, - ... - ] - } - """ - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆列表") - memory_list = memory_dashboard_service.get_workspace_memory_list( - db=db, - workspace_id=workspace_id, - current_user=current_user, - limit=limit - ) - api_logger.info(f"成功获取记忆列表") - return success(data=memory_list, msg="记忆列表获取成功") - - -@router.get("/total_memory_count", response_model=ApiResponse) -async def get_workspace_total_memory_count( - end_user_id: Optional[str] = Query(None, description="可选的用户ID"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 获取工作空间的记忆总量(通过聚合所有host的记忆数) - - 逻辑: - 1. 从 memory_list 获取所有 host_id - 2. 对每个 host_id 调用 search_all 获取 total - 3. 将所有 total 求和返回 - - 返回格式: - { - "total_memory_count": int, - "host_count": int, - "details": [ - {"host_id": "uuid", "count": 100}, - ... - ] - } - """ - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆总量") - total_memory_count = await memory_dashboard_service.get_workspace_total_memory_count( - db=db, - workspace_id=workspace_id, - current_user=current_user, - end_user_id=end_user_id - ) - api_logger.info(f"成功获取记忆总量: {total_memory_count.get('total_memory_count', 0)}") - return success(data=total_memory_count, msg="记忆总量获取成功") - - -# ======== RAG 数据统计 ======== -@router.get("/total_rag_count", response_model=ApiResponse) -def get_workspace_total_rag_count( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 获取 rag 的总文档数、总chunk数、总知识库数量、总api调用数量 - """ - total_documents = memory_dashboard_service.get_rag_total_doc(db, current_user) - total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user) - total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) - data = { - 'total_documents':total_documents, - 'total_chunk':total_chunk, - 'total_kb':total_kb, - 'total_api':1024 - } - return success(data=data, msg="RAG相关数据获取成功") - -@router.get("/current_user_rag_total_num", response_model=ApiResponse) -def get_current_user_rag_total_num( - end_user_id: str = Query(..., description="宿主ID"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 获取当前宿主的 RAG 的总chunk数量 - """ - total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user) - return success(data=total_chunk, msg="宿主RAG知识数据获取成功") - -@router.get("/rag_content", response_model=ApiResponse) -def get_rag_content( - end_user_id: str = Query(..., description="宿主ID"), - limit: int = Query(15, description="返回记录数"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 获取当前宿主知识库中的chunk内容 - """ - data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user) - return success(data=data, msg="宿主RAGchunk数据获取成功") - - -@router.get("/chunk_summary_tag", response_model=ApiResponse) -async def get_chunk_summary_tag( - end_user_id: str = Query(..., description="宿主ID"), - limit: int = Query(15, description="返回记录数"), - max_tags: int = Query(10, description="最大标签数量"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 获取chunk总结、提取的标签和人物形象 - - 返回格式: - { - "summary": "chunk内容的总结", - "tags": [ - {"tag": "标签1", "frequency": 5}, - {"tag": "标签2", "frequency": 3}, - ... - ], - "personas": [ - "产品设计师", - "旅行爱好者", - "摄影发烧友", - ... - ] - } - """ - api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象") - - data = await memory_dashboard_service.get_chunk_summary_and_tags( - end_user_id=end_user_id, - limit=limit, - max_tags=max_tags, - db=db, - current_user=current_user - ) - - api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象") - return success(data=data, msg="chunk摘要、标签和人物形象获取成功") - - -@router.get("/chunk_insight", response_model=ApiResponse) -async def get_chunk_insight( - end_user_id: str = Query(..., description="宿主ID"), - limit: int = Query(15, description="返回记录数"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 获取chunk的洞察内容 - - 返回格式: - { - "insight": "对chunk内容的深度洞察分析" - } - """ - api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察") - - data = await memory_dashboard_service.get_chunk_insight( - end_user_id=end_user_id, - limit=limit, - db=db, - current_user=current_user - ) - - api_logger.info(f"成功获取chunk洞察") - return success(data=data, msg="chunk洞察获取成功") - - -@router.get("/dashboard_data", response_model=ApiResponse) -async def dashboard_data( - end_user_id: Optional[str] = Query(None, description="可选的用户ID"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 整合dashboard数据接口 - - 整合以下接口的数据: - 1. /dashboard/total_memory_count - 记忆总量 - 2. /dashboard/api_increment - API调用增量 - 3. /memory/stats/types - 知识库类型统计(只要total数据) - 4. /dashboard/total_rag_count - RAG相关数据 - - 根据 storage_type 判断调用不同的接口 - - 返回格式: - { - "storage_type": str, - "neo4j_data": { - "total_memory": int, - "total_app": int, - "total_knowledge": int, - "total_api_call": int - } | null, - "rag_data": { - "total_memory": int, - "total_app": int, - "total_knowledge": int, - "total_api_call": int - } | null - } - """ - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据") - - # 获取 storage_type,如果为 None 则使用默认值 - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - if storage_type is None: - storage_type = 'neo4j' - - user_rag_memory_id = None - - # 根据 storage_type 决定返回哪个数据对象 - # 如果是 'rag',neo4j_data 为 null;否则 rag_data 为 null - result = { - "storage_type": storage_type, - "neo4j_data": None, - "rag_data": None - } - - try: - # 如果 storage_type 为 'neo4j' 或空,获取 neo4j_data - if storage_type == 'neo4j': - neo4j_data = { - "total_memory": None, - "total_app": None, - "total_knowledge": None, - "total_api_call": None - } - - # 1. 获取记忆总量(total_memory) - try: - total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count( - db=db, - workspace_id=workspace_id, - current_user=current_user, - end_user_id=end_user_id - ) - neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0) - # total_app: 统计当前空间下的所有app数量 - from app.repositories import app_repository - apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - neo4j_data["total_app"] = len(apps_orm) - api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}") - except Exception as e: - api_logger.warning(f"获取记忆总量失败: {str(e)}") - - # 2. 获取知识库类型统计(total_knowledge) - try: - from app.services.memory_agent_service import MemoryAgentService - memory_agent_service = MemoryAgentService() - knowledge_stats = await memory_agent_service.get_knowledge_type_stats( - end_user_id=end_user_id, - only_active=True, - current_workspace_id=workspace_id, - db=db - ) - neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0) - api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}") - except Exception as e: - api_logger.warning(f"获取知识库类型统计失败: {str(e)}") - - # 3. 获取API调用增量(total_api_call,转换为整数) - try: - api_increment = memory_dashboard_service.get_workspace_api_increment( - db=db, - workspace_id=workspace_id, - current_user=current_user - ) - neo4j_data["total_api_call"] = api_increment - api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}") - except Exception as e: - api_logger.warning(f"获取API调用增量失败: {str(e)}") - - result["neo4j_data"] = neo4j_data - api_logger.info(f"成功获取neo4j_data") - - # 如果 storage_type 为 'rag',获取 rag_data - elif storage_type == 'rag': - rag_data = { - "total_memory": None, - "total_app": None, - "total_knowledge": None, - "total_api_call": None - } - - # 获取RAG相关数据 - try: - # total_memory: 使用 total_chunk(总chunk数) - total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user) - rag_data["total_memory"] = total_chunk - - # total_app: 统计当前空间下的所有app数量 - from app.repositories import app_repository - apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - rag_data["total_app"] = len(apps_orm) - - # total_knowledge: 使用 total_kb(总知识库数) - total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) - rag_data["total_knowledge"] = total_kb - - # total_api_call: 固定值 - rag_data["total_api_call"] = 1024 - - api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}") - except Exception as e: - api_logger.warning(f"获取RAG相关数据失败: {str(e)}") - - result["rag_data"] = rag_data - api_logger.info(f"成功获取rag_data") - - api_logger.info(f"成功获取dashboard整合数据") - return success(data=result, msg="Dashboard数据获取成功") - - except Exception as e: - api_logger.error(f"获取dashboard整合数据失败: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"获取dashboard整合数据失败: {str(e)}" - ) \ No newline at end of file diff --git a/app/controllers/memory_storage_controller.py b/app/controllers/memory_storage_controller.py deleted file mode 100644 index 6d1a901b..00000000 --- a/app/controllers/memory_storage_controller.py +++ /dev/null @@ -1,542 +0,0 @@ -from typing import Optional -import os -import uuid -from fastapi import APIRouter, Depends - -from app.core.logging_config import get_api_logger -from app.core.response_utils import success, fail -from app.core.error_codes import BizCode -from app.services.memory_storage_service import ( - MemoryStorageService, - DataConfigService, - kb_type_distribution, - search_dialogue, - search_chunk, - search_statement, - search_entity, - search_all, - search_detials, - search_edges, - search_entity_graph, - analytics_hot_memory_tags, - analytics_memory_insight_report, - analytics_recent_activity_stats, - analytics_user_summary, -) -from app.schemas.response_schema import ApiResponse -from app.schemas.memory_storage_schema import ( - ConfigParamsCreate, - ConfigParamsDelete, - ConfigUpdate, - ConfigUpdateExtracted, - ConfigUpdateForget, - ConfigKey, - ConfigPilotRun, -) -from app.core.memory.utils.config.definitions import reload_configuration_from_database -from app.dependencies import get_current_user -from app.models.user_model import User -# Get API logger -api_logger = get_api_logger() - -# Initialize service -memory_storage_service = MemoryStorageService() - -router = APIRouter( - prefix="/memory-storage", - tags=["Memory Storage"], -) - - -@router.get("/info", response_model=ApiResponse) -async def get_storage_info( - storage_id: str, - current_user: User = Depends(get_current_user) -): - """ - Example wrapper endpoint - retrieves storage information - - Args: - storage_id: Storage identifier - - Returns: - Storage information - """ - api_logger.info(f"Storage info requested ") - try: - result = await memory_storage_service.get_storage_info() - return success(data=result) - except Exception as e: - api_logger.error(f"Storage info retrieval failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) - - -# --- DB connection dependency --- -_CONN: Optional[object] = None - - -"""PostgreSQL 连接生成与管理(使用 psycopg2)。""" -# 这个可以转移,可能是已经有的 -# PostgreSQL 数据库连接 -def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接 - host = os.getenv("DB_HOST") - user = os.getenv("DB_USER") - password = os.getenv("DB_PASSWORD") - database = os.getenv("DB_NAME") - port_str = os.getenv("DB_PORT") - try: - import psycopg2 # type: ignore - port = int(port_str) if port_str else 5432 - conn = psycopg2.connect( - host=host or "localhost", - port=port, - user=user, - password=password, - dbname=database, - ) - # 设置自动提交,避免显式事务管理 - conn.autocommit = True - # 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示 - try: - cur = conn.cursor() - cur.execute("SET TIME ZONE 'Asia/Shanghai'") - cur.close() - except Exception: - # 时区设置失败不影响连接,仅记录但不抛出 - pass - return conn - except Exception as e: - try: - print(f"[PostgreSQL] 连接失败: {e}") - except Exception: - pass - return None - -def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接 - global _CONN - if _CONN is None: - _CONN = _make_pgsql_conn() - return _CONN - - -def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接 - """Close and recreate the global DB connection.""" - global _CONN - try: - if _CONN: - try: - _CONN.close() - except Exception: - pass - _CONN = _make_pgsql_conn() - return _CONN is not None - except Exception: - _CONN = None - return False - - -@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 -def create_config( - payload: ConfigParamsCreate, - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}") - try: - # 将 workspace_id 注入到 payload 中(保持为 UUID 类型) - payload.workspace_id = workspace_id - svc = DataConfigService(get_db_conn()) - result = svc.create(payload) - return success(data=result, msg="创建成功") - except Exception as e: - api_logger.error(f"Create config failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) - - -@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) -def delete_config( - config_id: str, - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}") - try: - svc = DataConfigService(get_db_conn()) - result = svc.delete(ConfigParamsDelete(config_id=config_id)) - return success(data=result, msg="删除成功") - except Exception as e: - api_logger.error(f"Delete config failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) - -@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc -def update_config( - payload: ConfigUpdate, - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") - try: - svc = DataConfigService(get_db_conn()) - result = svc.update(payload) - return success(data=result, msg="更新成功") - except Exception as e: - api_logger.error(f"Update config failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e)) - - -@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 -def update_config_extracted( - payload: ConfigUpdateExtracted, - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}") - try: - svc = DataConfigService(get_db_conn()) - result = svc.update_extracted(payload) - return success(data=result, msg="更新成功") - except Exception as e: - api_logger.error(f"Update config extracted failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e)) - - -# --- Forget config params --- -@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径) -def update_config_forget( - payload: ConfigUpdateForget, - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}") - try: - svc = DataConfigService(get_db_conn()) - result = svc.update_forget(payload) - return success(data=result, msg="更新成功") - except Exception as e: - api_logger.error(f"Update config forget failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e)) - - -@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 -def read_config_extracted( - config_id: str, - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}") - try: - svc = DataConfigService(get_db_conn()) - result = svc.get_extracted(ConfigKey(config_id=config_id)) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Read config extracted failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) - -@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 -def read_config_forget( - config_id: str, - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}") - try: - svc = DataConfigService(get_db_conn()) - result = svc.get_forget(ConfigKey(config_id=config_id)) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Read config forget failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e)) - -@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 -def read_all_config( - current_user: User = Depends(get_current_user), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置") - try: - svc = DataConfigService(get_db_conn()) - # 传递 workspace_id 进行过滤(保持为 UUID 类型) - result = svc.get_all(workspace_id=workspace_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Read all config failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e)) - - -@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理 -async def pilot_run( - payload: ConfigPilotRun, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}") - - # 先尝试从数据库加载配置 - try: - config_loaded = reload_configuration_from_database(str(payload.config_id)) - if not config_loaded: - api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}") - return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置") - api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}") - except Exception as e: - api_logger.error(f"Exception while loading configuration: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e)) - - try: - svc = DataConfigService(get_db_conn()) - result = await svc.pilot_run(payload) - return success(data=result, msg="试运行完成") - except ValueError as e: - # 捕获参数验证错误 - api_logger.error(f"Pilot run parameter validation failed: {str(e)}") - return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e)) - except Exception as e: - api_logger.error(f"Pilot run failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e)) - -""" -以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。 -""" - -@router.get("/search/kb_type_distribution", response_model=ApiResponse) -async def get_kb_type_distribution( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") - try: - result = await kb_type_distribution(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"KB type distribution failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e)) - - -@router.get("/search/dialogue", response_model=ApiResponse) -async def search_dialogues_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") - try: - result = await search_dialogue(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search dialogue failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "对话查询失败", str(e)) - - -@router.get("/search/chunk", response_model=ApiResponse) -async def search_chunks_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") - try: - result = await search_chunk(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search chunk failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "分块查询失败", str(e)) - - -@router.get("/search/statement", response_model=ApiResponse) -async def search_statements_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") - try: - result = await search_statement(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search statement failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "语句查询失败", str(e)) - - -@router.get("/search/entity", response_model=ApiResponse) -async def search_entities_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") - try: - result = await search_entity(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search entity failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "实体查询失败", str(e)) - - -@router.get("/search", response_model=ApiResponse) -async def search_all_num( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Search all requested for end_user_id: {end_user_id}") - try: - result = await search_all(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search all failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "全部查询失败", str(e)) - - -@router.get("/search/detials", response_model=ApiResponse) -async def search_entities_detials( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Search details requested for end_user_id: {end_user_id}") - try: - result = await search_detials(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search details failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "详情查询失败", str(e)) - - -@router.get("/search/edges", response_model=ApiResponse) -async def search_entity_edges( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") - try: - result = await search_edges(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search edges failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) - -@router.get("/search/entity_graph", response_model=ApiResponse) -async def search_for_entity_graph( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - """ - 搜索所有实体之间的关系网络 - """ - api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}") - try: - result = await search_entity_graph(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search entity graph failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e)) - - -@router.get("/analytics/hot_memory_tags", response_model=ApiResponse) -async def get_hot_memory_tags_api( - end_user_id: Optional[str] = None, - limit: int = 10, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Hot memory tags requested for end_user_id: {end_user_id}") - try: - result = await analytics_hot_memory_tags(end_user_id, limit) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Hot memory tags failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) - - -@router.get("/analytics/memory_insight/report", response_model=ApiResponse) -async def get_memory_insight_report_api( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}") - try: - result = await analytics_memory_insight_report(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Memory insight report failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e)) - - -@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) -async def get_recent_activity_stats_api( - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info("Recent activity stats requested") - try: - result = await analytics_recent_activity_stats() - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Recent activity stats failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - - -@router.get("/analytics/user_summary", response_model=ApiResponse) -async def get_user_summary_api( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"User summary requested for end_user_id: {end_user_id}") - try: - result = await analytics_user_summary(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"User summary failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e)) - -from app.core.memory.utils.self_reflexion_utils import self_reflexion -@router.get("/self_reflexion") -async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: - """ - 自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。 - - Args: - None - Returns: - 自我反思结果。 - """ - return await self_reflexion(host_id) diff --git a/app/controllers/model_controller.py b/app/controllers/model_controller.py deleted file mode 100644 index 0c32c225..00000000 --- a/app/controllers/model_controller.py +++ /dev/null @@ -1,332 +0,0 @@ -from fastapi import APIRouter, Depends, status, Query -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.prompts import ChatPromptTemplate -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid - - -from app.core.models import RedBearLLM -from app.core.models.base import RedBearModelConfig -from app.db import get_db -from app.dependencies import get_current_user -from app.models.models_model import ModelProvider, ModelType -from app.models.user_model import User -from app.schemas import model_schema -from app.core.response_utils import success -from app.schemas.response_schema import ApiResponse, PageData -from app.services.model_service import ModelConfigService, ModelApiKeyService -from app.core.logging_config import get_api_logger - -# 获取API专用日志器 -api_logger = get_api_logger() - -router = APIRouter( - prefix="/models", - tags=["Models"], -) - -@router.get("/type", response_model=ApiResponse) -def get_model_types(): - - return success(msg="获取模型类型成功", data=list(ModelType)) - - -@router.get("/provider", response_model=ApiResponse) -def get_model_providers(): - return success(msg="获取模型提供商成功", data=list(ModelProvider)) - - -@router.get("", response_model=ApiResponse) -def get_model_list( - type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING)"), - provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), - is_active: Optional[bool] = Query(None, description="激活状态筛选"), - is_public: Optional[bool] = Query(None, description="公开状态筛选"), - search: Optional[str] = Query(None, description="搜索关键词"), - page: int = Query(1, ge=1, description="页码"), - pagesize: int = Query(10, ge=1, le=100, description="每页数量"), - db: Session = Depends(get_db) -): - """ - 获取模型配置列表 - - 支持多个 type 参数: - - 单个:?type=LLM - - 多个:?type=LLM&type=EMBEDDING - """ - api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}") - - try: - query = model_schema.ModelConfigQuery( - type=type, - provider=provider, - is_active=is_active, - is_public=is_public, - search=search, - page=page, - pagesize=pagesize - ) - - api_logger.debug(f"开始获取模型配置列表: {query.dict()}") - result_orm = ModelConfigService.get_model_list(db=db, query=query) - result = PageData.model_validate(result_orm) - api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}") - return success(data=result, msg="模型配置列表获取成功") - except Exception as e: - api_logger.error(f"获取模型配置列表失败: {str(e)}") - raise - - -@router.get("/{model_id}", response_model=ApiResponse) -def get_model_by_id( - model_id: uuid.UUID, - db: Session = Depends(get_db) -): - """ - 根据ID获取模型配置 - """ - api_logger.info(f"获取模型配置请求: model_id={model_id}") - - try: - api_logger.debug(f"开始获取模型配置: model_id={model_id}") - result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - api_logger.info(f"模型配置获取成功: {result_orm.name}") - - # 将ORM对象转换为Pydantic模型 - result_pydantic = model_schema.ModelConfig.model_validate(result_orm) - - return success(data=result_pydantic, msg="模型配置获取成功") - except Exception as e: - api_logger.error(f"获取模型配置失败: model_id={model_id} - {str(e)}") - raise - - -@router.post("", response_model=ApiResponse) -async def create_model( - model_data: model_schema.ModelConfigCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 创建模型配置 - - - 创建模型配置基础信息 - - 如果包含 API Key,会先验证配置有效性,然后创建 - - 验证失败时会抛出异常,不会创建配置 - - 可通过 skip_validation=true 跳过验证 - """ - api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}") - - try: - api_logger.debug(f"开始创建模型配置: {model_data.name}") - result_orm = await ModelConfigService.create_model(db=db, model_data=model_data) - api_logger.info(f"模型配置创建成功: {result_orm.name} (ID: {result_orm.id})") - - # 将ORM对象转换为Pydantic模型 - result = model_schema.ModelConfig.model_validate(result_orm) - - return success(data=result, msg="模型配置创建成功") - except Exception as e: - api_logger.error(f"创建模型配置失败: {model_data.name} - {str(e)}") - raise - - -@router.put("/{model_id}", response_model=ApiResponse) -def update_model( - model_id: uuid.UUID, - model_data: model_schema.ModelConfigUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 更新模型配置 - """ - api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}") - - try: - api_logger.debug(f"开始更新模型配置: model_id={model_id}") - result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data) - api_logger.info(f"模型配置更新成功: {result_orm.name} (ID: {model_id})") - - # 将ORM对象转换为Pydantic模型 - result_pydantic = model_schema.ModelConfig.model_validate(result_orm) - - return success(data=result_pydantic, msg="模型配置更新成功") - except Exception as e: - api_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}") - raise - - -@router.delete("/{model_id}", response_model=ApiResponse) -def delete_model( - model_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 删除模型配置 - """ - api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}") - - try: - api_logger.debug(f"开始删除模型配置: model_id={model_id}") - ModelConfigService.delete_model(db=db, model_id=model_id) - api_logger.info(f"模型配置删除成功: model_id={model_id}") - return success(msg="模型配置删除成功") - except Exception as e: - api_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}") - raise - - -# API Key 相关接口 -@router.get("/{model_id}/apikeys", response_model=ApiResponse) -def get_model_api_keys( - model_id: uuid.UUID, - is_active: bool = Query(True, description="是否只获取活跃的API Key"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 获取模型的API Key列表 - """ - api_logger.info(f"获取模型API Key列表请求: model_id={model_id}, 用户: {current_user.username}") - - try: - api_logger.debug(f"开始获取模型API Key列表: model_id={model_id}") - result_orm = ModelApiKeyService.get_api_keys_by_model( - db=db, model_config_id=model_id, is_active=is_active - ) - - # 将ORM对象列表转换为Pydantic模型列表 - result_pydantic = [model_schema.ModelApiKey.model_validate(item) for item in result_orm] - - api_logger.info(f"模型API Key列表获取成功: 数量={len(result_pydantic)}") - return success(data=result_pydantic, msg="模型API Key列表获取成功") - except Exception as e: - api_logger.error(f"获取模型API Key列表失败: model_id={model_id} - {str(e)}") - raise - - -@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED) -async def create_model_api_key( - model_id: uuid.UUID, - api_key_data: model_schema.ModelApiKeyCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 为模型创建API Key - """ - api_logger.info(f"创建模型API Key请求: model_id={model_id}, model_name={api_key_data.model_name}, 用户: {current_user.username}") - - try: - # 设置模型配置ID - api_key_data.model_config_id = model_id - - api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}") - result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data) - api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})") - return success(data=result, msg="模型API Key创建成功") - except Exception as e: - api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}") - raise - - -@router.get("/apikeys/{api_key_id}", response_model=ApiResponse) -def get_api_key_by_id( - api_key_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 根据ID获取API Key - """ - api_logger.info(f"获取API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}") - - try: - api_logger.debug(f"开始获取API Key: api_key_id={api_key_id}") - result = ModelApiKeyService.get_api_key_by_id(db=db, api_key_id=api_key_id) - api_logger.info(f"API Key获取成功: {result.model_name}") - return success(data=result, msg="API Key获取成功") - except Exception as e: - api_logger.error(f"获取API Key失败: api_key_id={api_key_id} - {str(e)}") - raise - - -@router.put("/apikeys/{api_key_id}", response_model=ApiResponse) -async def update_api_key( - api_key_id: uuid.UUID, - api_key_data: model_schema.ModelApiKeyUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 更新API Key - """ - api_logger.info(f"更新API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}") - - try: - api_logger.debug(f"开始更新API Key: api_key_id={api_key_id}") - result = await ModelApiKeyService.update_api_key(db=db, api_key_id=api_key_id, api_key_data=api_key_data) - api_logger.info(f"API Key更新成功: {result.model_name} (ID: {api_key_id})") - result_pydantic = model_schema.ModelApiKey.model_validate(result) - return success(data=result_pydantic, msg="API Key更新成功") - except Exception as e: - api_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}") - raise - - -@router.delete("/apikeys/{api_key_id}", response_model=ApiResponse) -def delete_api_key( - api_key_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 删除API Key - """ - api_logger.info(f"删除API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}") - - try: - api_logger.debug(f"开始删除API Key: api_key_id={api_key_id}") - ModelApiKeyService.delete_api_key(db=db, api_key_id=api_key_id) - api_logger.info(f"API Key删除成功: api_key_id={api_key_id}") - return success(msg="API Key删除成功") - except Exception as e: - api_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}") - raise - - -@router.post("/validate", response_model=ApiResponse) -async def validate_model_config( - validate_data: model_schema.ModelValidateRequest, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -): - """ - 验证模型配置是否有效 - - 支持验证不同类型的模型: - - llm: 大语言模型 - - chat: 对话模型 - - embedding: 向量模型 - - rerank: 重排序模型 - """ - api_logger.info(f"验证模型配置请求: {validate_data.model_name} ({validate_data.model_type}), 用户: {current_user.username}") - - result = await ModelConfigService.validate_model_config( - db=db, - model_name=validate_data.model_name, - provider=validate_data.provider, - api_key=validate_data.api_key, - api_base=validate_data.api_base, - model_type=validate_data.model_type, - test_message=validate_data.test_message - ) - - return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成") - - - - diff --git a/app/controllers/multi_agent_controller.py b/app/controllers/multi_agent_controller.py deleted file mode 100644 index f832ac89..00000000 --- a/app/controllers/multi_agent_controller.py +++ /dev/null @@ -1,404 +0,0 @@ -"""多 Agent 控制器""" -import uuid -from fastapi import APIRouter, Depends, Query, Path -from sqlalchemy.orm import Session - -from app.db import get_db -from app.dependencies import get_current_user -from app.core.response_utils import success -from app.core.logging_config import get_business_logger -from app.schemas import multi_agent_schema -from app.schemas.response_schema import PageData, PageMeta -from app.services.multi_agent_service import MultiAgentService -from app.models import User - -router = APIRouter(prefix="/apps", tags=["Multi-Agent"]) -logger = get_business_logger() - - -# ==================== 多 Agent 配置管理 ==================== - -@router.post( - "/{app_id}/multi-agent", - summary="创建多 Agent 配置" -) -def create_multi_agent_config( - app_id: uuid.UUID = Path(..., description="应用 ID"), - data: multi_agent_schema.MultiAgentConfigCreate = ..., - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -): - """创建多 Agent 配置 - - 支持四种编排模式: - - sequential: 顺序执行 - - parallel: 并行执行 - - conditional: 条件路由 - - loop: 循环执行 - """ - service = MultiAgentService(db) - config = service.create_config( - app_id=app_id, - data=data, - created_by=current_user.id - ) - - return success( - data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config), - msg="多 Agent 配置创建成功" - ) - - - -@router.get( - "/{app_id}/multi-agent", - summary="获取当前应用的最新有效多 Agent 配置" -) -def get_multi_agent_configs( - app_id: uuid.UUID = Path(..., description="应用 ID"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -): - """获取指定应用的最新有效多 Agent 配置,如果不存在则返回默认模板""" - service = MultiAgentService(db) - - # 通过 app_id 获取最新有效配置(已转换 agent_id 为 app_id) - config = service.get_multi_agent_configs(app_id) - - if not config: - # 返回默认模板 - default_template = { - "app_id": str(app_id), - "master_agent_id": None, - "master_agent_name": None, - "orchestration_mode": "conditional", - "sub_agents": [], - "routing_rules": [], - "execution_config": { - "max_iterations": 10, - "timeout": 300, - "enable_parallel": False, - "error_handling": "stop" - }, - "aggregation_strategy": "merge", - } - return success( - data=default_template, - msg="该应用暂无配置,返回默认模板" - ) - - # config 已经是字典格式,直接返回 - return success(data=config) - -@router.put( - "/{app_id}/multi-agent", - summary="更新多 Agent 配置" -) -def update_multi_agent_config( - app_id: uuid.UUID = Path(..., description="应用 ID"), - data: multi_agent_schema.MultiAgentConfigUpdate = ..., - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -): - """更新多 Agent 配置""" - service = MultiAgentService(db) - config = service.update_config(app_id, data) - - return success( - data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config), - msg="多 Agent 配置更新成功" - ) - - -@router.delete( - "/{app_id}/multi-agent", - summary="删除多 Agent 配置" -) -def delete_multi_agent_config( - app_id: uuid.UUID = Path(..., description="应用 ID"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -): - """删除多 Agent 配置""" - service = MultiAgentService(db) - service.delete_config(app_id) - - return success(msg="多 Agent 配置删除成功") - -# ==================== 多 Agent 运行 ==================== - -@router.post( - "/{app_id}/multi-agent/run", - summary="运行多 Agent 任务" -) -async def run_multi_agent( - app_id: uuid.UUID = Path(..., description="应用 ID"), - request: multi_agent_schema.MultiAgentRunRequest = ..., - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -): - """运行多 Agent 任务 - - 根据配置的编排模式执行多个 Agent: - - sequential: 按优先级顺序执行 - - parallel: 并行执行所有 Agent - - conditional: 根据条件选择 Agent - - loop: 循环执行直到满足条件 - """ - service = MultiAgentService(db) - result = await service.run(app_id, request) - - return success( - data=multi_agent_schema.MultiAgentRunResponse(**result), - msg="多 Agent 任务执行成功" - ) - - -# ==================== 智能路由测试 ==================== - -@router.post( - "/{app_id}/multi-agent/test-routing", - summary="测试智能路由" -) -async def test_routing( - app_id: uuid.UUID = Path(..., description="应用 ID"), - request: multi_agent_schema.RoutingTestRequest = ..., - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -): - """测试智能路由功能 - - 支持三种路由模式: - - keyword: 仅使用关键词路由 - - llm: 使用 LLM 路由(需要提供 routing_model_id) - - hybrid: 混合路由(关键词 + LLM) - - 参数: - - message: 测试消息 - - conversation_id: 会话 ID(可选) - - routing_model_id: 路由模型 ID(可选,用于 LLM 路由) - - use_llm: 是否启用 LLM(默认 False) - - keyword_threshold: 关键词置信度阈值(默认 0.8) - """ - from app.services.conversation_state_manager import ConversationStateManager - from app.services.llm_router import LLMRouter - from app.models import ModelConfig - - # 1. 获取多 Agent 配置 - service = MultiAgentService(db) - config = service.get_config(app_id) - - if not config: - return success( - data=None, - msg="应用未配置多 Agent,无法测试路由" - ) - - # 2. 准备子 Agent 信息 - sub_agents = {} - for sub_agent_info in config.sub_agents: - agent_id = sub_agent_info["agent_id"] - sub_agents[agent_id] = { - "name": sub_agent_info.get("name", agent_id), - "role": sub_agent_info.get("role", "") - } - - # 3. 获取路由模型(如果指定) - routing_model = None - if request.routing_model_id: - routing_model = db.get(ModelConfig, request.routing_model_id) - if not routing_model: - return success( - data=None, - msg=f"路由模型不存在: {request.routing_model_id}" - ) - - # 4. 初始化路由器 - state_manager = ConversationStateManager() - router = LLMRouter( - db=db, - state_manager=state_manager, - routing_rules=config.routing_rules or [], - sub_agents=sub_agents, - routing_model_config=routing_model, - use_llm=request.use_llm and routing_model is not None - ) - - # 5. 设置阈值 - if request.keyword_threshold: - router.keyword_high_confidence_threshold = request.keyword_threshold - - # 6. 执行路由 - try: - routing_result = await router.route( - message=request.message, - conversation_id=str(request.conversation_id) if request.conversation_id else None, - force_new=request.force_new - ) - - # 7. 获取 Agent 信息 - agent_id = routing_result["agent_id"] - agent_info = sub_agents.get(agent_id, {}) - - # 8. 构建响应 - response_data = { - "message": request.message, - "routing_result": { - "agent_id": agent_id, - "agent_name": agent_info.get("name", agent_id), - "agent_role": agent_info.get("role", ""), - "confidence": routing_result["confidence"], - "strategy": routing_result["strategy"], - "topic": routing_result["topic"], - "topic_changed": routing_result["topic_changed"], - "reason": routing_result["reason"], - "routing_method": routing_result["routing_method"] - }, - "cmulti-agent/batch-test-routingonfig_info": { - "use_llm": request.use_llm and routing_model is not None, - "routing_model": routing_model.name if routing_model else None, - "keyword_threshold": router.keyword_high_confidence_threshold, - "total_sub_agents": len(sub_agents) - } - } - - return success( - data=response_data, - msg="路由测试成功" - ) - - except Exception as e: - logger.error(f"路由测试失败: {str(e)}") - return success( - data=None, - msg=f"路由测试失败: {str(e)}" - ) - - -@router.post( - "/{app_id}/", - summary="批量测试智能路由" -) -async def batch_test_routing( - app_id: uuid.UUID = Path(..., description="应用 ID"), - request: multi_agent_schema.BatchRoutingTestRequest = ..., - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), -): - """批量测试智能路由功能 - - 用于测试多条消息的路由效果,并统计准确率 - - 参数: - - test_cases: 测试用例列表 - - routing_model_id: 路由模型 ID(可选) - - use_llm: 是否启用 LLM - - keyword_threshold: 关键词置信度阈值 - """ - from app.services.conversation_state_manager import ConversationStateManager - from app.services.llm_router import LLMRouter - from app.models import ModelConfig - - # 1. 获取多 Agent 配置 - service = MultiAgentService(db) - config = service.get_config(app_id) - - if not config: - return success( - data=None, - msg="应用未配置多 Agent,无法测试路由" - ) - - # 2. 准备子 Agent 信息 - sub_agents = {} - for sub_agent_info in config.sub_agents: - agent_id = sub_agent_info["agent_id"] - sub_agents[agent_id] = { - "name": sub_agent_info.get("name", agent_id), - "role": sub_agent_info.get("role", "") - } - - # 3. 获取路由模型 - routing_model = None - if request.routing_model_id: - routing_model = db.get(ModelConfig, request.routing_model_id) - - # 4. 初始化路由器 - state_manager = ConversationStateManager() - router = LLMRouter( - db=db, - state_manager=state_manager, - routing_rules=config.routing_rules or [], - sub_agents=sub_agents, - routing_model_config=routing_model, - use_llm=request.use_llm and routing_model is not None - ) - - if request.keyword_threshold: - router.keyword_high_confidence_threshold = request.keyword_threshold - - # 5. 批量测试 - results = [] - correct_count = 0 - total_count = len(request.test_cases) - - for test_case in request.test_cases: - try: - routing_result = await router.route( - message=test_case.message, - conversation_id=str(uuid.uuid4()) # 每个测试用例使用独立会话 - ) - - agent_id = routing_result["agent_id"] - agent_info = sub_agents.get(agent_id, {}) - - # 判断是否正确 - is_correct = None - if test_case.expected_agent_id: - is_correct = (agent_id == str(test_case.expected_agent_id)) - if is_correct: - correct_count += 1 - - results.append({ - "message": test_case.message, - "description": test_case.description, - "routed_agent_id": agent_id, - "routed_agent_name": agent_info.get("name"), - "expected_agent_id": str(test_case.expected_agent_id) if test_case.expected_agent_id else None, - "is_correct": is_correct, - "confidence": routing_result["confidence"], - "routing_method": routing_result["routing_method"], - "strategy": routing_result["strategy"] - }) - - except Exception as e: - logger.error(f"测试用例失败: {test_case.message}, 错误: {str(e)}") - results.append({ - "message": test_case.message, - "description": test_case.description, - "error": str(e) - }) - - # 6. 统计 - accuracy = None - if correct_count > 0: - total_with_expected = sum(1 for r in results if r.get("expected_agent_id")) - if total_with_expected > 0: - accuracy = correct_count / total_with_expected * 100 - - response_data = { - "total_count": total_count, - "correct_count": correct_count, - "accuracy": accuracy, - "results": results, - "config_info": { - "use_llm": request.use_llm and routing_model is not None, - "routing_model": routing_model.name if routing_model else None, - "keyword_threshold": router.keyword_high_confidence_threshold - } - } - - return success( - data=response_data, - msg=f"批量测试完成,准确率: {accuracy:.1f}%" if accuracy else "批量测试完成" - ) diff --git a/app/controllers/public_share_controller.py b/app/controllers/public_share_controller.py deleted file mode 100644 index 236deda4..00000000 --- a/app/controllers/public_share_controller.py +++ /dev/null @@ -1,437 +0,0 @@ -from fastapi import APIRouter, Depends, Query, Request, Header -from fastapi.responses import StreamingResponse -from sqlalchemy.orm import Session -import uuid -import hashlib -import time -import jwt -from typing import Optional, Dict -from functools import wraps - -from app.db import get_db -from app.core.response_utils import success -from app.core.logging_config import get_business_logger -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode -from app.core.config import settings -from app.schemas import release_share_schema, conversation_schema -from app.schemas.response_schema import PageData, PageMeta -from app.services.release_share_service import ReleaseShareService -from app.services.shared_chat_service import SharedChatService -from app.services.conversation_service import ConversationService -from app.services.auth_service import create_access_token -from app.dependencies import get_share_user_id, ShareTokenData - - -router = APIRouter(prefix="/public/share", tags=["Public Share"]) -logger = get_business_logger() - - -def get_base_url(request: Request) -> str: - """从请求中获取基础 URL""" - return f"{request.url.scheme}://{request.url.netloc}" - - -def get_or_generate_user_id(payload_user_id: str, request: Request) -> str: - """获取或生成用户 ID - - 优先级: - 1. 使用前端传递的 user_id - 2. 基于 IP + User-Agent 生成唯一 ID - - Args: - payload_user_id: 前端传递的 user_id - request: FastAPI Request 对象 - - Returns: - 用户 ID - """ - if payload_user_id: - return payload_user_id - - # 获取客户端 IP - client_ip = request.client.host if request.client else "unknown" - - # 获取 User-Agent - user_agent = request.headers.get("user-agent", "unknown") - - # 生成唯一 ID:基于 IP + User-Agent 的哈希 - unique_string = f"{client_ip}_{user_agent}" - hash_value = hashlib.md5(unique_string.encode()).hexdigest()[:16] - - return f"guest_{hash_value}" - - -@router.post( - "/{share_token}/token", - summary="获取访问 token" -) -def get_access_token( - share_token: str, - payload: release_share_schema.TokenRequest, - request: Request, - db: Session = Depends(get_db), -): - """获取访问 token - - - 用户通过 user_id + share_token 换取访问 token - - 后续请求需要携带此 token - """ - # 获取或生成 user_id - user_id = get_or_generate_user_id(payload.user_id, request) - - # 验证分享链接(可选:验证密码) - service = ReleaseShareService(db) - try: - service.get_shared_release_info( - share_token=share_token, - password=payload.password - ) - except Exception as e: - logger.error(f"获取分享信息失败: {str(e)}") - raise - - # 生成 token - access_token = create_access_token(user_id, share_token) - - logger.info( - f"生成访问 token", - extra={ - "share_token": share_token, - "user_id": user_id - } - ) - - return success(data={ - "access_token": access_token, - "token_type": "Bearer", - "user_id": user_id - }) - - -@router.get( - "", - summary="获取公开分享的应用信息", - response_model=None -) -def get_shared_release( - password: str = Query(None, description="访问密码(如果需要)"), - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), -): - """获取公开分享的发布版本信息 - - - 无需认证即可访问 - - 如果设置了密码保护,需要提供正确的密码 - - 如果密码错误或未提供密码,返回基本信息(不含配置详情) - """ - service = ReleaseShareService(db) - info = service.get_shared_release_info( - share_token=share_data.share_token, - password=password - ) - - return success(data=info) - - -@router.post( - "/verify", - summary="验证访问密码" -) -def verify_password( - payload: release_share_schema.PasswordVerifyRequest, - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), -): - """验证分享的访问密码 - - - 用于前端先验证密码,再获取完整信息 - """ - service = ReleaseShareService(db) - is_valid = service.verify_password( - share_token=share_data.share_token, - password=payload.password - ) - - return success(data={"valid": is_valid}) - - -@router.get( - "/embed", - summary="获取嵌入代码" -) -def get_embed_code( - width: str = Query("100%", description="iframe 宽度"), - height: str = Query("600px", description="iframe 高度"), - request: Request = None, - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), -): - """获取嵌入代码 - - - 返回 iframe 嵌入代码 - - 可以自定义宽度和高度 - """ - base_url = get_base_url(request) if request else None - - service = ReleaseShareService(db) - embed_code = service.get_embed_code( - share_token=share_data.share_token, - width=width, - height=height, - base_url=base_url - ) - - return success(data=embed_code) - - - -# ---------- 会话管理接口 ---------- - -@router.get( - "/conversations", - summary="获取会话列表" -) -def list_conversations( - password: str = Query(None, description="访问密码"), - page: int = Query(1, ge=1), - pagesize: int = Query(20, ge=1, le=100), - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), -): - """获取分享应用的会话列表 - - - 可以按 user_id 筛选 - - 支持分页 - """ - logger.debug(f"share_data:{share_data.user_id}") - other_id = share_data.user_id - service = SharedChatService(db) - share, release = service._get_release_by_share_token(share_data.share_token, password) - from app.repositories.end_user_repository import EndUserRepository - end_user_repo = EndUserRepository(db) - new_end_user = end_user_repo.get_or_create_end_user( - app_id=share.app_id, - other_id=other_id - ) - logger.debug(new_end_user.id) - service = SharedChatService(db) - conversations, total = service.list_conversations( - share_token=share_data.share_token, - user_id=str(new_end_user.id), - password=password, - page=page, - pagesize=pagesize - ) - - items = [conversation_schema.Conversation.model_validate(c) for c in conversations] - meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) - - return success(data=PageData(page=meta, items=items)) - - -@router.get( - "/conversations/{conversation_id}", - summary="获取会话详情(含消息)" -) -def get_conversation( - conversation_id: uuid.UUID, - password: str = Query(None, description="访问密码"), - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), -): - """获取会话详情和消息历史""" - chat_service = SharedChatService(db) - conversation = chat_service.get_conversation_messages( - share_token=share_data.share_token, - conversation_id=conversation_id, - password=password - ) - - # 获取消息 - conv_service = ConversationService(db) - messages = conv_service.get_messages(conversation_id) - - # 构建响应 - conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump() - conv_dict["messages"] = [ - conversation_schema.Message.model_validate(m) for m in messages - ] - - return success(data=conv_dict) - - -# ---------- 聊天接口 ---------- - -@router.post( - "/chat", - summary="发送消息(支持流式和非流式)" -) -async def chat( - payload: conversation_schema.ChatRequest, - share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db), -): - """发送消息并获取回复 - - 使用 Bearer token 认证: - - Header: Authorization: Bearer {token} - - user_id 和 share_token 从 token 中解码 - - - 支持多轮对话(提供 conversation_id) - - 支持流式返回(设置 stream=true) - - 如果不提供 conversation_id,会自动创建新会话 - """ - service = SharedChatService(db) - - # 从依赖中获取 user_id 和 share_token - user_id = share_data.user_id - share_token = share_data.share_token - password = None # Token 认证不需要密码 - # end_user_id = user_id - other_id = user_id - - # 提前验证和准备(在流式响应开始前完成) - # 这样可以确保错误能正确返回,而不是在流式响应中间出错 - from app.models.app_model import AppType - try: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode - from app.services.app_service import AppService - # 验证分享链接和密码 - share, release = service._get_release_by_share_token(share_token, password) - - # # Create end_user_id by concatenating app_id with user_id - # end_user_id = f"{share.app_id}_{user_id}" - - # Store end_user_id in database with original user_id - from app.repositories.end_user_repository import EndUserRepository - end_user_repo = EndUserRepository(db) - new_end_user = end_user_repo.get_or_create_end_user( - app_id=share.app_id, - other_id=other_id, - original_user_id=user_id # Save original user_id to other_id - ) - - # 获取应用类型 - app_type = release.app.type if release.app else None - - # 根据应用类型验证配置 - if app_type == "agent": - # Agent 类型:验证模型配置 - model_config_id = release.default_model_config_id - if not model_config_id: - raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) - elif app_type == "multi_agent": - # Multi-Agent 类型:验证多 Agent 配置 - config = release.config or {} - if not config.get("sub_agents"): - raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING) - else: - raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 获取或创建会话(提前验证) - conversation = service.create_or_get_conversation( - share_token=share_data.share_token, - conversation_id=payload.conversation_id, - user_id=str(new_end_user.id), # 转换为字符串 - password=password - ) - - logger.debug( - f"参数验证完成", - extra={ - "share_token": share_token, - "app_type": app_type, - "conversation_id": str(conversation.id), - "stream": payload.stream - } - ) - - except Exception as e: - # 验证失败,直接抛出异常(会被 FastAPI 的异常处理器捕获) - logger.error(f"参数验证失败: {str(e)}") - raise - - if app_type == AppType.AGENT: - # 流式返回 - if payload.stream: - async def event_generator(): - async for event in service.chat_stream( - share_token=share_token, - message=payload.message, - conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=str(new_end_user.id), # 转换为字符串 - variables=payload.variables, - password=password, - web_search=payload.web_search, - memory=payload.memory - ): - yield event - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" - } - ) - - # 非流式返回 - result = await service.chat( - share_token=share_token, - message=payload.message, - conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=str(new_end_user.id), # 转换为字符串 - variables=payload.variables, - password=password, - web_search=payload.web_search, - memory=payload.memory - ) - return success(data=conversation_schema.ChatResponse(**result)) - elif app_type == AppType.MULTI_AGENT: - # 多 Agent 流式返回 - if payload.stream: - async def event_generator(): - async for event in service.multi_agent_chat_stream( - share_token=share_token, - message=payload.message, - conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=str(new_end_user.id), # 转换为字符串 - variables=payload.variables, - password=password, - web_search=payload.web_search, - memory=payload.memory - ): - yield event - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" - } - ) - - # 多 Agent 非流式返回 - result = await service.multi_agent_chat( - share_token=share_token, - message=payload.message, - conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=str(new_end_user.id), # 转换为字符串 - variables=payload.variables, - password=password, - web_search=payload.web_search, - memory=payload.memory - ) - - return success(data=conversation_schema.ChatResponse(**result)) - else: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode - raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) - pass diff --git a/app/controllers/release_share_controller.py b/app/controllers/release_share_controller.py deleted file mode 100644 index 033c0209..00000000 --- a/app/controllers/release_share_controller.py +++ /dev/null @@ -1,170 +0,0 @@ -import uuid -from fastapi import APIRouter, Depends, Request -from sqlalchemy.orm import Session - -from app.db import get_db -from app.core.response_utils import success -from app.core.logging_config import get_business_logger -from app.schemas import release_share_schema -from app.services.release_share_service import ReleaseShareService -from app.dependencies import get_current_user, cur_workspace_access_guard - -router = APIRouter(tags=["Release Share"]) -logger = get_business_logger() - - -def get_base_url(request: Request) -> str: - """从请求中获取基础 URL""" - return f"{request.url.scheme}://{request.url.netloc}" - - -@router.post( - "/apps/{app_id}/releases/{release_id}/share", - summary="创建/启用分享配置" -) -@cur_workspace_access_guard() -def create_share( - app_id: uuid.UUID, - release_id: uuid.UUID, - payload: release_share_schema.ReleaseShareCreate, - request: Request, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """创建或更新发布版本的分享配置 - - - 如果已存在分享配置,则更新 - - 自动生成唯一的分享 token - - 返回完整的分享 URL - """ - workspace_id = current_user.current_workspace_id - base_url = get_base_url(request) - - service = ReleaseShareService(db) - share = service.create_or_update_share( - release_id=release_id, - user_id=current_user.id, - workspace_id=workspace_id, - data=payload, - base_url=base_url - ) - - share_schema = service._convert_to_schema(share, base_url) - return success(data=share_schema, msg="分享配置已创建") - - -@router.put( - "/apps/{app_id}/releases/{release_id}/share", - summary="更新分享配置" -) -@cur_workspace_access_guard() -def update_share( - app_id: uuid.UUID, - release_id: uuid.UUID, - payload: release_share_schema.ReleaseShareUpdate, - request: Request, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """更新分享配置 - - - 可以更新启用状态、密码、嵌入设置等 - - 不会改变 share_token - """ - workspace_id = current_user.current_workspace_id - base_url = get_base_url(request) - - service = ReleaseShareService(db) - share = service.update_share( - release_id=release_id, - workspace_id=workspace_id, - data=payload - ) - - share_schema = service._convert_to_schema(share, base_url) - return success(data=share_schema, msg="分享配置已更新") - - -@router.get( - "/apps/{app_id}/releases/{release_id}/share", - summary="获取分享配置" -) -@cur_workspace_access_guard() -def get_share( - app_id: uuid.UUID, - release_id: uuid.UUID, - request: Request, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """获取发布版本的分享配置 - - - 如果不存在分享配置,返回 null - """ - workspace_id = current_user.current_workspace_id - base_url = get_base_url(request) - - service = ReleaseShareService(db) - share = service.get_share( - release_id=release_id, - workspace_id=workspace_id, - base_url=base_url - ) - - return success(data=share) - - -@router.delete( - "/apps/{app_id}/releases/{release_id}/share", - summary="删除分享配置" -) -@cur_workspace_access_guard() -def delete_share( - app_id: uuid.UUID, - release_id: uuid.UUID, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """删除分享配置 - - - 删除后,公开访问链接将失效 - """ - workspace_id = current_user.current_workspace_id - - service = ReleaseShareService(db) - service.delete_share( - release_id=release_id, - workspace_id=workspace_id - ) - - return success(msg="分享配置已删除") - - -@router.post( - "/apps/{app_id}/releases/{release_id}/share/regenerate-token", - summary="重新生成分享链接" -) -@cur_workspace_access_guard() -def regenerate_token( - app_id: uuid.UUID, - release_id: uuid.UUID, - request: Request, - db: Session = Depends(get_db), - current_user=Depends(get_current_user), -): - """重新生成分享 token - - - 旧的分享链接将失效 - - 生成新的唯一 token - """ - workspace_id = current_user.current_workspace_id - base_url = get_base_url(request) - - service = ReleaseShareService(db) - share = service.regenerate_token( - release_id=release_id, - workspace_id=workspace_id - ) - - share_schema = service._convert_to_schema(share, base_url) - return success(data=share_schema, msg="分享链接已重新生成") diff --git a/app/controllers/service/__init__.py b/app/controllers/service/__init__.py deleted file mode 100644 index 00f056e7..00000000 --- a/app/controllers/service/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Service API Controllers - 基于 API Key 认证的服务接口 - -路由前缀: /v1 -认证方式: API Key -""" -from fastapi import APIRouter -from . import app_api_controller, rag_api_controller, memory_api_controller - -# 创建 V1 API 路由器 -service_router = APIRouter() - -# 注册子路由 -service_router.include_router(app_api_controller.router) -service_router.include_router(rag_api_controller.router) -service_router.include_router(memory_api_controller.router) - -__all__ = ["service_router"] diff --git a/app/controllers/service/app_api_controller.py b/app/controllers/service/app_api_controller.py deleted file mode 100644 index f2a322cd..00000000 --- a/app/controllers/service/app_api_controller.py +++ /dev/null @@ -1,16 +0,0 @@ -"""App 服务接口 - 基于 API Key 认证""" -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -from app.db import get_db -from app.core.response_utils import success -from app.core.logging_config import get_business_logger - -router = APIRouter(prefix="/v1/apps", tags=["V1 - App API"]) -logger = get_business_logger() - - -@router.get("") -async def list_apps(): - """列出可访问的应用(占位)""" - return success(data=[], msg="App API - Coming Soon") diff --git a/app/controllers/service/memory_api_controller.py b/app/controllers/service/memory_api_controller.py deleted file mode 100644 index 22dcb87b..00000000 --- a/app/controllers/service/memory_api_controller.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Memory 服务接口 - 基于 API Key 认证""" -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -from app.db import get_db -from app.core.response_utils import success -from app.core.logging_config import get_business_logger - -router = APIRouter(prefix="/memory", tags=["V1 - Memory API"]) -logger = get_business_logger() - - -@router.get("") -async def get_memory_info(): - """获取记忆服务信息(占位)""" - return success(data={}, msg="Memory API - Coming Soon") diff --git a/app/controllers/service/rag_api_controller.py b/app/controllers/service/rag_api_controller.py deleted file mode 100644 index ecd1dd23..00000000 --- a/app/controllers/service/rag_api_controller.py +++ /dev/null @@ -1,16 +0,0 @@ -"""RAG 服务接口 - 基于 API Key 认证""" -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -from app.db import get_db -from app.core.response_utils import success -from app.core.logging_config import get_business_logger - -router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"]) -logger = get_business_logger() - - -@router.get("") -async def list_knowledge(): - """列出可访问的知识库(占位)""" - return success(data=[], msg="RAG API - Coming Soon") diff --git a/app/controllers/setup_controller.py b/app/controllers/setup_controller.py deleted file mode 100644 index 50c24f3d..00000000 --- a/app/controllers/setup_controller.py +++ /dev/null @@ -1,23 +0,0 @@ -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -from app.core.response_utils import success -from app.db import get_db -from app.schemas.response_schema import ApiResponse -from app.services import user_service - -router = APIRouter( - prefix="/setup", - tags=["Setup"], -) - -@router.post("", summary="Create the first superuser", response_model=ApiResponse) -def setup_initial_user(db: Session = Depends(get_db)): - """ - Create the initial superuser. This can only be run once. - Reads credentials from environment variables. - """ - user = user_service.create_initial_superuser(db) - if not user: - return success(msg="Superuser already exists.") - return success(msg="Superuser created successfully.") diff --git a/app/controllers/task_controller.py b/app/controllers/task_controller.py deleted file mode 100644 index 22be4fe2..00000000 --- a/app/controllers/task_controller.py +++ /dev/null @@ -1,25 +0,0 @@ -from fastapi import APIRouter, status -from app.schemas.item_schema import Item -from app.services import task_service - -router = APIRouter( - prefix="/tasks", - tags=["Tasks"], -) - -@router.post("/process_item", status_code=status.HTTP_202_ACCEPTED) -def process_item_task(item: Item): - """ - This endpoint receives an item, and instead of processing it directly, - it sends a task to the Celery queue via the task service. - """ - task_id = task_service.create_processing_task(item.dict()) - return {"message": "Task accepted. The item is being processed in the background.", "task_id": task_id} - -@router.get("/result/{task_id}") -def get_task_result_controller(task_id: str): - """ - This endpoint allows clients to check the status and result of a - previously submitted task using its ID, by calling the task service. - """ - return task_service.get_task_result(task_id) diff --git a/app/controllers/test_controller.py b/app/controllers/test_controller.py deleted file mode 100644 index ad46386c..00000000 --- a/app/controllers/test_controller.py +++ /dev/null @@ -1,126 +0,0 @@ -from fastapi import APIRouter, Depends, status, Query, HTTPException -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.prompts import ChatPromptTemplate -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid - - -from app.core.models import RedBearLLM, RedBearRerank -from app.core.models.base import RedBearModelConfig -from app.core.models.embedding import RedBearEmbeddings -from app.db import get_db -from app.dependencies import get_current_user -from app.models.models_model import ModelApiKey, ModelProvider, ModelType -from app.models.user_model import User -from app.schemas import model_schema -from app.core.response_utils import success -from app.schemas.response_schema import ApiResponse, PageData -from app.services.model_service import ModelConfigService, ModelApiKeyService -from app.core.logging_config import get_api_logger - -# 获取API专用日志器 -api_logger = get_api_logger() - -router = APIRouter( - prefix="/test", - tags=["test"], -) - - -@router.get(f"/llm/{{model_id}}", response_model=ApiResponse) -def test_llm( - model_id: uuid.UUID, - db: Session = Depends(get_db) -): - config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - if not config: - api_logger.error(f"模型ID {model_id} 不存在") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - try: - apiConfig: ModelApiKey = config.api_keys[0] - llm = RedBearLLM(RedBearModelConfig( - model_name=apiConfig.model_name, - provider=apiConfig.provider, - api_key=apiConfig.api_key, - base_url=apiConfig.api_base - ), type=config.type) - print(llm.dict()) - - template = """Question: {question} - -Answer: Let's think step by step.""" - # ChatPromptTemplate - prompt = ChatPromptTemplate.from_template(template) - chain = prompt | llm - answer = chain.invoke({"question": "What is LangChain?"}) - print("Answer:", answer) - return success(msg="测试LLM成功", data={"question": "What is LangChain?", "answer": answer}) - - except Exception as e: - api_logger.error(f"测试LLM失败: {str(e)}") - raise - - -@router.get(f"/embedding/{{model_id}}", response_model=ApiResponse) -def test_embedding( - model_id: uuid.UUID, - db: Session = Depends(get_db) -): - config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - if not config: - api_logger.error(f"模型ID {model_id} 不存在") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - - apiConfig: ModelApiKey = config.api_keys[0] - model = RedBearEmbeddings(RedBearModelConfig( - model_name=apiConfig.model_name, - provider=apiConfig.provider, - api_key=apiConfig.api_key, - base_url=apiConfig.api_base - )) - - data = [ - "最近哪家咖啡店评价最好?", - "附近有没有推荐的咖啡厅?", - "明天天气预报说会下雨。", - "北京是中国的首都。", - "我想找一个适合学习的地方。" - ] - embeddings = model.embed_documents(data) - print(embeddings) - query = "我想找一个适合学习的地方。" - query_embedding = model.embed_query(query) - print(query_embedding) - - return success(msg="测试LLM成功") - - -@router.get(f"/rerank/{{model_id}}", response_model=ApiResponse) -def test_rerank( - model_id: uuid.UUID, - db: Session = Depends(get_db) -): - config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - if not config: - api_logger.error(f"模型ID {model_id} 不存在") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - - apiConfig: ModelApiKey = config.api_keys[0] - model = RedBearRerank(RedBearModelConfig( - model_name=apiConfig.model_name, - provider=apiConfig.provider, - api_key=apiConfig.api_key, - base_url=apiConfig.api_base - )) - query = "最近哪家咖啡店评价最好?" - data = [ - "最近哪家咖啡店评价最好?", - "附近有没有推荐的咖啡厅?", - "明天天气预报说会下雨。", - "北京是中国的首都。", - "我想找一个适合学习的地方。" - ] - scores = model.rerank(query=query, documents=data, top_n=3) - print(scores) - return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores}) diff --git a/app/controllers/upload_controller.py b/app/controllers/upload_controller.py deleted file mode 100644 index 28e2b950..00000000 --- a/app/controllers/upload_controller.py +++ /dev/null @@ -1,376 +0,0 @@ -""" -Upload Controller for Generic File Upload System -Handles HTTP requests for file upload, download, deletion, and metadata updates. -""" -import os -import json -from typing import List, Optional, Any -from pathlib import Path -from fastapi import APIRouter, Depends, File, UploadFile, Form -from fastapi.responses import FileResponse -from sqlalchemy.orm import Session - -from app.db import get_db -from app.dependencies import get_current_user -from app.models.user_model import User -from app.schemas.response_schema import ApiResponse -from app.schemas.generic_file_schema import ( - GenericFileResponse, - FileMetadataUpdate, - UploadResultSchema, - BatchUploadResponse -) -from app.core.response_utils import success, fail -from app.core.upload_enums import UploadContext -from app.services.upload_service import UploadService -from app.core.logging_config import get_logger -from app.core.exceptions import ( - ValidationException, - ResourceNotFoundException, - FileUploadException, - BusinessException -) - -# Get logger -logger = get_logger(__name__) - -# Create router -router = APIRouter( - prefix="/api", - tags=["upload"], - dependencies=[Depends(get_current_user)] -) - -# Initialize upload service -upload_service = UploadService() - - -@router.post("/upload", response_model=ApiResponse) -async def upload_file( - file: UploadFile = File(..., description="要上传的文件"), - context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"), - metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -) -> ApiResponse: - """ - 单文件上传接口 - - - **file**: 要上传的文件 - - **context**: 上传上下文,决定文件存储位置和验证规则 - - **metadata**: 可选的文件元数据,JSON格式字符串 - - 返回上传成功的文件信息 - """ - logger.info(f"Upload request: filename={file.filename}, context={context}, user={current_user.id}") - - try: - # Validate and parse context - try: - upload_context = UploadContext(context) - except ValueError: - logger.warning(f"Invalid upload context: {context}") - raise ValidationException( - f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}", - field="context" - ) - - # Parse metadata if provided - file_metadata = {} - if metadata: - try: - file_metadata = json.loads(metadata) - except json.JSONDecodeError: - logger.warning(f"Invalid metadata JSON: {metadata}") - raise ValidationException( - "元数据必须是有效的JSON格式", - field="metadata" - ) - - # Upload file - db_file = upload_service.upload_file( - file=file, - context=upload_context, - metadata=file_metadata, - current_user=current_user, - db=db - ) - - # Convert to response schema - file_response = GenericFileResponse.model_validate(db_file) - - logger.info(f"Upload successful: {file.filename} (ID: {db_file.id})") - return success(data=file_response.dict(), msg="文件上传成功") - - except BusinessException: - # Business exceptions are handled by global exception handlers - raise - except Exception as e: - logger.error(f"Upload failed: {str(e)}") - # Wrap unknown exceptions as FileUploadException - raise FileUploadException( - f"文件上传失败: {str(e)}", - cause=e - ) - - -@router.post("/upload/batch", response_model=ApiResponse) -async def upload_files_batch( - files: List[UploadFile] = File(..., description="要上传的文件列表"), - context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"), - metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -) -> ApiResponse: - """ - 批量文件上传接口 - - - **files**: 要上传的文件列表(最多20个) - - **context**: 上传上下文,决定文件存储位置和验证规则 - - **metadata**: 可选的文件元数据,JSON格式字符串,应用于所有文件 - - 返回每个文件的上传结果 - """ - logger.info(f"Batch upload request: {len(files)} files, context={context}, user={current_user.id}") - - try: - # Validate and parse context - try: - upload_context = UploadContext(context) - except ValueError: - logger.warning(f"Invalid upload context: {context}") - raise ValidationException( - f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}", - field="context" - ) - - # Parse metadata if provided - file_metadata = {} - if metadata: - try: - file_metadata = json.loads(metadata) - except json.JSONDecodeError: - logger.warning(f"Invalid metadata JSON: {metadata}") - raise ValidationException( - "元数据必须是有效的JSON格式", - field="metadata" - ) - - # Upload files in batch - upload_results = upload_service.upload_files_batch( - files=files, - context=upload_context, - metadata=file_metadata, - current_user=current_user, - db=db - ) - - # Convert results to response schemas - result_schemas = [] - for result in upload_results: - result_schema = UploadResultSchema( - success=result.success, - file_id=result.file_id, - file_name=result.file_name, - error=result.error, - file_info=None - ) - - # If upload was successful, get file info - if result.success and result.file_id: - try: - db_file = upload_service.get_file(result.file_id, current_user, db) - result_schema.file_info = GenericFileResponse.model_validate(db_file) - except Exception as e: - logger.warning(f"Failed to get file info for {result.file_id}: {str(e)}") - - result_schemas.append(result_schema) - - # Create batch response - batch_response = BatchUploadResponse( - total=len(files), - success_count=sum(1 for r in upload_results if r.success), - failed_count=sum(1 for r in upload_results if not r.success), - results=result_schemas - ) - - logger.info(f"Batch upload completed: {batch_response.success_count}/{batch_response.total} successful") - return success(data=batch_response.dict(), msg="批量上传完成") - - except BusinessException: - # Business exceptions are handled by global exception handlers - raise - except Exception as e: - logger.error(f"Batch upload failed: {str(e)}") - # Wrap unknown exceptions as FileUploadException - raise FileUploadException( - f"批量上传失败: {str(e)}", - cause=e - ) - - -@router.get("/files/{file_id}", response_model=Any) -async def download_file( - file_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -) -> Any: - """ - 文件下载接口 - - - **file_id**: 文件ID - - 返回文件内容供下载 - """ - logger.info(f"Download request: file_id={file_id}, user={current_user.id}") - - try: - # Parse file_id - import uuid - try: - file_uuid = uuid.UUID(file_id) - except ValueError: - logger.warning(f"Invalid file ID format: {file_id}") - raise ValidationException( - "无效的文件ID格式", - field="file_id" - ) - - # Get file from database - db_file = upload_service.get_file(file_uuid, current_user, db) - - # Check if physical file exists - storage_path = Path(db_file.storage_path) - if not storage_path.exists(): - logger.error(f"Physical file not found: {storage_path}") - raise ResourceNotFoundException( - "文件", - str(file_uuid), - context={"detail": "文件未找到(可能已被删除)"} - ) - - # Return file response - logger.info(f"Download successful: {db_file.file_name} (ID: {file_id})") - return FileResponse( - path=str(storage_path), - filename=db_file.file_name, - media_type=db_file.mime_type or "application/octet-stream" - ) - - except BusinessException: - # Business exceptions are handled by global exception handlers - raise - except Exception as e: - logger.error(f"Download failed: {str(e)}") - # Wrap unknown exceptions - raise FileUploadException( - f"文件下载失败: {str(e)}", - cause=e - ) - - -@router.delete("/files/{file_id}", response_model=ApiResponse) -async def delete_file( - file_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -) -> ApiResponse: - """ - 文件删除接口 - - - **file_id**: 文件ID - - 删除文件(包括物理文件和数据库记录) - """ - logger.info(f"Delete request: file_id={file_id}, user={current_user.id}") - - try: - # Parse file_id - import uuid - try: - file_uuid = uuid.UUID(file_id) - except ValueError: - logger.warning(f"Invalid file ID format: {file_id}") - raise ValidationException( - "无效的文件ID格式", - field="file_id" - ) - - # Delete file - upload_service.delete_file(file_uuid, current_user, db) - - logger.info(f"Delete successful: file_id={file_id}") - return success(msg="文件删除成功") - - except BusinessException: - # Business exceptions are handled by global exception handlers - raise - except Exception as e: - logger.error(f"Delete failed: {str(e)}") - # Wrap unknown exceptions - raise FileUploadException( - f"文件删除失败: {str(e)}", - cause=e - ) - - -@router.put("/files/{file_id}", response_model=ApiResponse) -async def update_file_metadata( - file_id: str, - update_data: FileMetadataUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) -) -> ApiResponse: - """ - 文件元数据更新接口 - - - **file_id**: 文件ID - - **update_data**: 要更新的元数据 - - 更新文件的元数据(文件名、自定义元数据、公开状态) - """ - logger.info(f"Update metadata request: file_id={file_id}, user={current_user.id}") - - try: - # Parse file_id - import uuid - try: - file_uuid = uuid.UUID(file_id) - except ValueError: - logger.warning(f"Invalid file ID format: {file_id}") - raise ValidationException( - "无效的文件ID格式", - field="file_id" - ) - - # Convert update data to dict, excluding unset fields - update_dict = update_data.dict(exclude_unset=True) - - if not update_dict: - logger.warning(f"No fields to update for file: {file_id}") - raise ValidationException( - "没有提供要更新的字段", - field="update_data" - ) - - # Update file metadata - updated_file = upload_service.update_file_metadata( - file_uuid, update_dict, current_user, db - ) - - # Convert to response schema - file_response = GenericFileResponse.model_validate(updated_file) - - logger.info(f"Update metadata successful: file_id={file_id}") - return success(data=file_response.dict(), msg="文件元数据更新成功") - - except BusinessException: - # Business exceptions are handled by global exception handlers - raise - except Exception as e: - logger.error(f"Update metadata failed: {str(e)}") - # Wrap unknown exceptions - raise FileUploadException( - f"文件元数据更新失败: {str(e)}", - cause=e - ) diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py deleted file mode 100644 index b4d1c123..00000000 --- a/app/controllers/user_controller.py +++ /dev/null @@ -1,183 +0,0 @@ -from fastapi import APIRouter, Depends, status -from sqlalchemy.orm import Session -import uuid - -from app.db import get_db -from app.dependencies import get_current_user, get_current_superuser -from app.models.user_model import User -from app.schemas import user_schema -from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest -from app.schemas.response_schema import ApiResponse -from app.services import user_service -from app.core.logging_config import get_api_logger -from app.core.response_utils import success - -# 获取API专用日志器 -api_logger = get_api_logger() - -router = APIRouter( - prefix="/users", - tags=["Users"], -) - - -@router.post("/superuser", response_model=ApiResponse) -def create_superuser( - user: user_schema.UserCreate, - db: Session = Depends(get_db), - current_superuser: User = Depends(get_current_superuser) -): - """创建超级管理员(仅超级管理员可访问)""" - api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}") - - result = user_service.create_superuser(db=db, user=user, current_user=current_superuser) - api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})") - - result_schema = user_schema.User.model_validate(result) - return success(data=result_schema, msg="超级管理员创建成功") - - -@router.delete("/{user_id}", response_model=ApiResponse) -def delete_user( - user_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """停用用户(软删除)""" - api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}") - result = user_service.deactivate_user( - db=db, user_id_to_deactivate=user_id, current_user=current_user - ) - api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})") - return success(msg="用户停用成功") - -@router.post("/{user_id}/activate", response_model=ApiResponse) -def activate_user( - user_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """激活用户""" - api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}") - - result = user_service.activate_user( - db=db, user_id_to_activate=user_id, current_user=current_user - ) - api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})") - - result_schema = user_schema.User.model_validate(result) - return success(data=result_schema, msg="用户激活成功") - - -@router.get("", response_model=ApiResponse) -def get_current_user_info( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """获取当前用户信息""" - api_logger.info(f"当前用户信息请求: {current_user.username}") - - result = user_service.get_user( - db=db, user_id=current_user.id, current_user=current_user - ) - - result_schema = user_schema.User.model_validate(result) - - # 设置当前工作空间的角色和名称 - if current_user.current_workspace_id: - from app.repositories.workspace_repository import WorkspaceRepository - workspace_repo = WorkspaceRepository(db) - current_workspace = workspace_repo.get_workspace_by_id(current_user.current_workspace_id) - if current_workspace: - result_schema.current_workspace_name = current_workspace.name - - for ws in result.workspaces: - if ws.workspace_id == current_user.current_workspace_id: - result_schema.role = ws.role - break - - api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") - return success(data=result_schema, msg="用户信息获取成功") - - -@router.get("/superusers", response_model=ApiResponse) -def get_tenant_superusers( - include_inactive: bool = False, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_superuser), -): - """获取当前租户下的超管账号列表(仅超级管理员可访问)""" - api_logger.info(f"获取租户超管列表请求: {current_user.username}") - - superusers = user_service.get_tenant_superusers( - db=db, - current_user=current_user, - include_inactive=include_inactive - ) - api_logger.info(f"租户超管列表获取成功: count={len(superusers)}") - - superusers_schema = [user_schema.User.model_validate(u) for u in superusers] - return success(data=superusers_schema, msg="租户超管列表获取成功") - - -@router.get("/{user_id}", response_model=ApiResponse) -def get_user_info_by_id( - user_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """根据用户ID获取用户信息""" - api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}") - - result = user_service.get_user( - db=db, user_id=user_id, current_user=current_user - ) - api_logger.info(f"用户信息获取成功: {result.username}") - - result_schema = user_schema.User.model_validate(result) - return success(data=result_schema, msg="用户信息获取成功") - - -@router.put("/change-password", response_model=ApiResponse) -async def change_password( - request: ChangePasswordRequest, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """修改当前用户密码""" - api_logger.info(f"用户密码修改请求: {current_user.username}") - - await user_service.change_password( - db=db, - user_id=current_user.id, - old_password=request.old_password, - new_password=request.new_password, - current_user=current_user - ) - api_logger.info(f"用户密码修改成功: {current_user.username}") - return success(msg="密码修改成功") - - -@router.put("/admin/change-password", response_model=ApiResponse) -async def admin_change_password( - request: AdminChangePasswordRequest, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_superuser), -): - """超级管理员修改指定用户的密码""" - api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}") - - user, generated_password = await user_service.admin_change_password( - db=db, - target_user_id=request.user_id, - new_password=request.new_password, - current_user=current_user - ) - - # 根据是否生成了随机密码来构造响应 - if request.new_password: - api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}") - return success(msg="密码修改成功") - else: - api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成") - return success(data=generated_password, msg="密码重置成功") \ No newline at end of file diff --git a/app/controllers/workspace_controller.py b/app/controllers/workspace_controller.py deleted file mode 100644 index fc9dca8f..00000000 --- a/app/controllers/workspace_controller.py +++ /dev/null @@ -1,342 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid - -from app.core.response_utils import success -from app.db import get_db -from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard -from app.models.user_model import User -from app.models.tenant_model import Tenants -from app.models.workspace_model import Workspace, InviteStatus -from app.schemas.response_schema import ApiResponse -from app.schemas.workspace_schema import ( - WorkspaceCreate, WorkspaceUpdate, WorkspaceResponse, - WorkspaceInviteCreate, WorkspaceInviteResponse, - InviteValidateResponse, InviteAcceptRequest, - WorkspaceMemberUpdate, WorkspaceMemberItem -) -from app.schemas import knowledge_schema -from app.services import workspace_service -from app.core.logging_config import get_api_logger -from app.services import knowledge_service, document_service -# 获取API专用日志器 -api_logger = get_api_logger() -# 需要认证的路由器 -router = APIRouter( - prefix="/workspaces", - tags=["Workspaces"], - dependencies=[Depends(get_current_user)] -) - -# 公开路由器(不需要认证) -public_router = APIRouter( - prefix="/workspaces", - tags=["Workspaces"] -) - - -def _convert_members_to_table_items(members): - """将工作空间成员列表转换为表格项""" - return [ - WorkspaceMemberItem( - id=m.id, - username=m.user.username, - account=m.user.email, - role=m.role, - last_login_at=m.user.last_login_at - ) - for m in members - ] - - -@router.get("", response_model=ApiResponse) -def get_workspaces( - include_current: bool = Query(True, description="是否包含当前工作空间"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), - current_tenant: Tenants = Depends(get_current_tenant) -): - """获取当前租户下用户参与的所有工作空间 - - Args: - include_current: 是否包含当前工作空间(默认 True) - """ - api_logger.info( - f"用户 {current_user.username} 在租户 {current_tenant.name} 中请求获取工作空间列表", - extra={"include_current": include_current} - ) - - workspaces = workspace_service.get_user_workspaces(db, current_user) - - # 如果不包含当前工作空间,则过滤掉 - if not include_current and current_user.current_workspace_id: - workspaces = [w for w in workspaces if w.id != current_user.current_workspace_id] - api_logger.debug( - f"过滤掉当前工作空间", - extra={"current_workspace_id": str(current_user.current_workspace_id)} - ) - - api_logger.info(f"成功获取 {len(workspaces)} 个工作空间") - workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces] - return success(data=workspaces_schema, msg="工作空间列表获取成功") - - -@router.post("", response_model=ApiResponse) -def create_workspace( - workspace: WorkspaceCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_superuser), -): - """创建新的工作空间""" - api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}") - - result = workspace_service.create_workspace( - db=db, workspace=workspace, user=current_user) - - api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}") - result_schema = WorkspaceResponse.model_validate(result) - return success(data=result_schema, msg="工作空间创建成功") - -@router.put("", response_model=ApiResponse) -@cur_workspace_access_guard() -def update_workspace( - workspace: WorkspaceUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """更新工作空间""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求更新工作空间 ID: {workspace_id}") - - result = workspace_service.update_workspace( - db=db, - workspace_id=workspace_id, - workspace_in=workspace, - user=current_user, - ) - api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}") - result_schema = WorkspaceResponse.model_validate(result) - return success(data=result_schema, msg="工作空间更新成功") - -@router.get("/members", response_model=ApiResponse) -@cur_workspace_access_guard() -def get_cur_workspace_members( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """获取工作空间成员列表(关系序列化)""" - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表") - - members = workspace_service.get_workspace_members( - db=db, - workspace_id=current_user.current_workspace_id, - user=current_user, - ) - api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}") - table_items = _convert_members_to_table_items(members) - return success(data=table_items, msg="工作空间成员列表获取成功") - - -@router.put("/members", response_model=ApiResponse) -@cur_workspace_access_guard() -def update_workspace_members( - - updates: List[WorkspaceMemberUpdate], - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色") - members = workspace_service.update_workspace_member_roles( - db=db, - workspace_id=workspace_id, - updates=updates, - user=current_user, - ) - api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}") - return success(msg="成员角色更新成功") - - -@router.delete("/members/{member_id}", response_model=ApiResponse) -@cur_workspace_access_guard() -def delete_workspace_member( - member_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") - - workspace_service.delete_workspace_member( - db=db, - workspace_id=workspace_id, - member_id=member_id, - user=current_user, - ) - api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}") - return success(msg="成员删除成功") - - -# 创建空间协作邀请 -@router.post("/invites", response_model=ApiResponse) -@cur_workspace_access_guard() -def create_workspace_invite( - invite_data: WorkspaceInviteCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """创建工作空间邀请""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求为工作空间 {workspace_id} 创建邀请: {invite_data.email}") - - result = workspace_service.create_workspace_invite( - db=db, - workspace_id=workspace_id, - invite_data=invite_data, - user=current_user - ) - api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}") - return success(data=result, msg="邀请创建成功") - - -@router.get("/invites", response_model=ApiResponse) -@cur_workspace_access_guard() -def get_workspace_invites( - - status_filter: Optional[InviteStatus] = Query(None, alias="status"), - limit: int = Query(50, ge=1, le=100), - offset: int = Query(0, ge=0), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """获取工作空间邀请列表""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的邀请列表") - - invites = workspace_service.get_workspace_invites( - db=db, - workspace_id=workspace_id, - user=current_user, - status=status_filter, - limit=limit, - offset=offset - ) - api_logger.info(f"成功获取 {len(invites)} 个邀请记录") - return success(data=invites, msg="邀请列表获取成功") - - -@public_router.get("/invites/validate/{token}", response_model=ApiResponse) -def get_workspace_invite_info( - token: str, - db: Session = Depends(get_db), -): - """获取工作空间邀请用户信息(无需认证)""" - result = workspace_service.validate_invite_token(db=db, token=token) - api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}") - return success(data=result, msg="邀请验证成功") - - -@router.delete("/invites/{invite_id}", response_model=ApiResponse) -@cur_workspace_access_guard() -def revoke_workspace_invite( - - invite_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """撤销工作空间邀请""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求撤销工作空间 {workspace_id} 的邀请 {invite_id}") - - result = workspace_service.revoke_workspace_invite( - db=db, - workspace_id=workspace_id, - invite_id=invite_id, - user=current_user - ) - api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}") - return success(data=result, msg="邀请撤销成功") - -# ==================== 公开邀请接口(无需认证) ==================== - -# # 创建一个新的路由器用于公开接口 -# public_router = APIRouter( -# prefix="/invites", -# tags=["Public Invites"] -# ) - -# @public_router.get("/validate", response_model=ApiResponse) -# def validate_invite_token( -# token: str = Query(..., description="邀请令牌"), -# db: Session = Depends(get_db), -# ): -# """验证邀请令牌(公开接口)""" -# api_logger.info(f"验证邀请令牌请求") -@router.put("/{workspace_id}/switch", response_model=ApiResponse) -@workspace_access_guard() -def switch_workspace( - workspace_id: uuid.UUID, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """切换工作空间""" - api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}") - - workspace_service.switch_workspace( - db=db, - workspace_id=workspace_id, - user=current_user, - ) - api_logger.info(f"成功切换工作空间为 {workspace_id}") - return success(msg="工作空间切换成功") - - -@router.get("/storage", response_model=ApiResponse) -@cur_workspace_access_guard() -def get_workspace_storage_type( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """获取当前工作空间的存储类型""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的存储类型") - - storage_type = workspace_service.get_workspace_storage_type( - db=db, - workspace_id=workspace_id, - user=current_user - ) - api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}") - return success(data={"storage_type": storage_type}, msg="存储类型获取成功") - - -@router.get("/workspace_models", response_model=ApiResponse) -@cur_workspace_access_guard() -def workspace_models_configs( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """获取当前工作空间的模型配置(llm, embedding, rerank)""" - workspace_id = current_user.current_workspace_id - api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的模型配置") - - configs = workspace_service.get_workspace_models_configs( - db=db, - workspace_id=workspace_id, - user=current_user - ) - - if configs is None: - api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="工作空间不存在或无权访问" - ) - - api_logger.info( - f"成功获取工作空间 {workspace_id} 的模型配置: " - f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}" - ) - return success(data=configs, msg="模型配置获取成功") - diff --git a/app/core/agent/__init__.py b/app/core/agent/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/agent/agent_api_text.py b/app/core/agent/agent_api_text.py deleted file mode 100644 index 74b9e46e..00000000 --- a/app/core/agent/agent_api_text.py +++ /dev/null @@ -1,35 +0,0 @@ -from pydantic import BaseModel - -from app.core.agent.agent_chat import Agent_chat -from app.core.logging_config import get_business_logger -from fastapi import APIRouter, Depends, HTTPException - -from app.dependencies import workspace_access_guard -from app.services.agent_server import config,ChatRequest -router = APIRouter(prefix="/Test", tags=["Apps"]) -logger = get_business_logger() -class CombinedRequest(BaseModel): - config_base: config - agent_config: ChatRequest - -@router.post("", summary="uuid") -async def agent_chat( - config_base: CombinedRequest -): - chat_config=config_base.agent_config - chat_base=config_base.config_base - request = ChatRequest( - end_user_id=chat_config.end_user_id, - message=chat_config.message, - search_switch=chat_config.search_switch, - kb_ids=chat_config.kb_ids, - similarity_threshold=chat_config.similarity_threshold, - vector_similarity_weight=chat_config.vector_similarity_weight, - top_k=chat_config.top_k, - hybrid=chat_config.hybrid, - token=chat_config.token - ) - - chat_result=await Agent_chat(chat_base).chat(request) - - return chat_result diff --git a/app/core/agent/agent_chat.py b/app/core/agent/agent_chat.py deleted file mode 100644 index 23a821c0..00000000 --- a/app/core/agent/agent_chat.py +++ /dev/null @@ -1,109 +0,0 @@ -import asyncio -import os -import time - -from typing import Dict, Any, List - -from app.core.logging_config import get_business_logger -from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole -from app.services.api_resquests_server import messages_type, write_messages -from app.services.agent_server import ChatRequest, tool_memory, create_dynamic_agent, tool_Retrieval - -logger = get_business_logger() -class Agent_chat: - def __init__(self,config_data: dict): - self.prompt_message = render_prompt_message( - config_data.template_str, - PromptMessageRole.USER, - config_data.params - ) - self.prompt = self.prompt_message.get_text_content() - self.model_configs = config_data.model_configs - self.history_memory = config_data.history_memory - self.knowledge_base = config_data.knowledge_base - logger.info(f"渲染结果:{self.prompt_message.get_text_content()}" ) - - async def run_agent(self,agent, end_user_id:str, user_prompt:str, model_name:str): - response = agent.invoke( - { - "messages": [ - { - "role": "user", - "content": user_prompt - } - ] - }, - {"configurable": {"thread_id": f'{model_name}_{end_user_id}'}}, - ) - outputs = [] - for msg in response["messages"]: - if hasattr(msg, "tool_calls") and msg.tool_calls: - outputs.append({ - "role": "assistant", - "tool_calls": [ - {"name": t["name"], "arguments": t["args"]} - for t in msg.tool_calls - ] - }) - elif hasattr(msg, "content") and msg.content: - outputs.append({ - "role": msg.__class__.__name__.lower().replace("message", ""), - "content": msg.content - }) - ai_messages=[msg['content'] for msg in outputs if msg["role"] == "ai"] - return {"model_name": model_name, "end_user_id": end_user_id, "response": ai_messages} - - async def chat(self,req: ChatRequest) -> Dict[str, Any]: - - end_user_id = req.end_user_id # 用 user_id 作为对话线程标识 - start=time.time() - user_prompt = req.message - - '''判断是都写入redis数据库''' - messags_type = await messages_type(req.message,end_user_id) - messags_type=messags_type['data'] - if messags_type=='question': - writer_result=await write_messages(f'{end_user_id}', req.message) - logger.info(f'判断类型写入耗时:{time.time() - start},{writer_result}') - - - - '''history_memory''' - - if self.history_memory==True: - tool_result =await tool_memory(req) - if tool_result!='' :tool_result=tool_result['data'] - if tool_result!='' :self.prompt=self.prompt+f''',历史消息:{tool_result},结合历史消息''' - logger.info(f"记忆科学消耗时间:{time.time()-start},工具调用结果:{tool_result}") - - '''baidu''' - - - '''knowledge_base''' - if self.knowledge_base == True: - retrieval_result=await tool_Retrieval(req) - retrieval_knowledge = [i['page_content'] for i in retrieval_result['data']] - retrieval_knowledge=','.join(retrieval_knowledge) - logger.info(f"检索消耗时间:{time.time()-start},{retrieval_knowledge}") - if retrieval_knowledge!='' :self.prompt=self.prompt+f",知识库检索内容:{retrieval_knowledge},结合检索结果" - self.prompt=self.prompt+f'给出最合适的答案,确保答案的完整性,只保留用户的问题的回答,不额外输出提示语' - logger.info(f"用户输入:{user_prompt}") - logger.info(f"系统prompt:{self.prompt}") - - AGENTS = { - cfg["name"]: await create_dynamic_agent(cfg["name"], cfg["moder_id"], self.prompt, req.token) - for cfg in self.model_configs - } - tasks=[ - self.run_agent(agent, end_user_id, user_prompt, model_name) - for model_name, agent in AGENTS.items() - ] - # 并行运行 - results = await asyncio.gather(*tasks) - - result=[] - - for i in results: - result.append(i) - chat_result=(f"最终耗时:{time.time()-start},{result}") - return chat_result \ No newline at end of file diff --git a/app/core/agent/langchain_agent.py b/app/core/agent/langchain_agent.py deleted file mode 100644 index a32c8c96..00000000 --- a/app/core/agent/langchain_agent.py +++ /dev/null @@ -1,347 +0,0 @@ -""" -LangChain Agent 封装 - -使用 LangChain 1.x 标准方式 -- 使用 create_agent 创建 agent graph -- 支持工具调用循环 -- 支持流式输出 -- 使用 RedBearLLM 支持多提供商 -""" -import os -import time -import asyncio -from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence -from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage -from langchain_core.tools import BaseTool -from langchain.agents import create_agent - -from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType -from app.core.logging_config import get_business_logger -from app.services.memory_agent_service import MemoryAgentService -from app.services.memory_konwledges_server import write_rag -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task -logger = get_business_logger() - - -class LangChainAgent: - - def __init__( - self, - model_name: str, - api_key: str, - provider: str = "openai", - api_base: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2000, - system_prompt: Optional[str] = None, - tools: Optional[Sequence[BaseTool]] = None, - streaming: bool = False - ): - """初始化 LangChain Agent - - Args: - model_name: 模型名称 - api_key: API Key - provider: 提供商(openai, xinference, gpustack, ollama, dashscope) - api_base: API 基础 URL - temperature: 温度参数 - max_tokens: 最大 token 数 - system_prompt: 系统提示词 - tools: 工具列表(可选,框架自动走 ReAct 循环) - streaming: 是否启用流式输出(默认 True) - """ - self.model_name = model_name - self.provider = provider - self.system_prompt = system_prompt or "你是一个专业的AI助手" - self.tools = tools or [] - self.streaming = streaming - - # 创建 RedBearLLM(支持多提供商) - model_config = RedBearModelConfig( - model_name=model_name, - provider=provider, - api_key=api_key, - base_url=api_base, - extra_params={ - "temperature": temperature, - "max_tokens": max_tokens, - "streaming": streaming # 使用参数控制流式 - } - ) - - self.llm = RedBearLLM(model_config, type=ModelType.CHAT) - - # 获取底层模型用于真正的流式调用 - self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm - - # 确保底层模型也启用流式 - if streaming and hasattr(self._underlying_llm, 'streaming'): - self._underlying_llm.streaming = True - - # 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式) - # 无论是否有工具,都使用 agent 统一处理 - self.agent = create_agent( - model=self.llm, - tools=self.tools if self.tools else None, - system_prompt=self.system_prompt - ) - - logger.info( - f"LangChain Agent 初始化完成", - extra={ - "model": model_name, - "provider": provider, - "has_api_base": bool(api_base), - "temperature": temperature, - "streaming": streaming, - "tool_count": len(self.tools), - "tool_names": [tool.name for tool in self.tools] if self.tools else [], - "tool_count": len(self.tools) - } - ) - - def _prepare_messages( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None - ) -> List[BaseMessage]: - """准备消息列表 - - Args: - message: 用户消息 - history: 历史消息列表 - context: 上下文信息 - - Returns: - List[BaseMessage]: 消息列表 - """ - messages = [] - - # 添加系统提示词 - messages.append(SystemMessage(content=self.system_prompt)) - - # 添加历史消息 - if history: - for msg in history: - if msg["role"] == "user": - messages.append(HumanMessage(content=msg["content"])) - elif msg["role"] == "assistant": - messages.append(AIMessage(content=msg["content"])) - - # 添加当前用户消息 - user_content = message - if context: - user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}" - - messages.append(HumanMessage(content=user_content)) - - return messages - - async def chat( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, # 添加这个参数 - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - ) -> Dict[str, Any]: - """执行对话 - - Args: - message: 用户消息 - history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] - context: 上下文信息(如知识库检索结果) - - Returns: - Dict: 包含 content 和元数据的字典 - """ - start_time = time.time() - - logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') - if storage_type == "rag": - await write_rag(end_user_id, message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - if config_id==None: - actual_config_id = os.getenv("config_id") - else:actual_config_id=config_id - actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id) - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'Agent:{actual_end_user_id};{write_status}') - - - try: - # 准备消息列表 - messages = self._prepare_messages(message, history, context) - - logger.debug( - f"准备调用 LangChain Agent", - extra={ - "has_context": bool(context), - "has_history": bool(history), - "has_tools": bool(self.tools), - "message_count": len(messages) - } - ) - - # 统一使用 agent.invoke 调用 - result = await self.agent.ainvoke({"messages": messages}) - - # 获取最后的 AI 消息 - output_messages = result.get("messages", []) - content = "" - for msg in reversed(output_messages): - if isinstance(msg, AIMessage): - content = msg.content - break - - elapsed_time = time.time() - start_time - - if storage_type == "rag": - await write_rag(end_user_id, message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, user_rag_memory_id) - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'Agent:{actual_end_user_id};{write_status}') - - response = { - "content": content, - "model": self.model_name, - "elapsed_time": elapsed_time, - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - } - } - - logger.debug( - f"Agent 调用完成", - extra={ - "elapsed_time": elapsed_time, - "content_length": len(response["content"]) - } - ) - - return response - - except Exception as e: - logger.error(f"Agent 调用失败", extra={"error": str(e)}) - raise - - async def chat_stream( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - end_user_id:Optional[str] = None, - config_id: Optional[str] = None, - storage_type:Optional[str] = None, - user_rag_memory_id:Optional[str] = None, - - ) -> AsyncGenerator[str, None]: - """执行流式对话 - - Args: - message: 用户消息 - history: 历史消息列表 - context: 上下文信息 - - Yields: - str: 消息内容块 - """ - logger.info("=" * 80) - logger.info(f" chat_stream 方法开始执行") - logger.info(f" Message: {message[:100]}") - logger.info(f" Has tools: {bool(self.tools)}") - logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") - logger.info("=" * 80) - - start_time = time.time() - if storage_type == "rag": - await write_rag(end_user_id, message, user_rag_memory_id) - else: - if config_id==None: - actual_config_id = os.getenv("config_id") - else:actual_config_id=config_id - actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id) - - try: - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'Agent:{actual_end_user_id};{write_status}') - except Exception as e: - logger.error(f"Agent 记忆用户输入出错", extra={"error": str(e)}) - - try: - # 准备消息列表 - messages = self._prepare_messages(message, history, context) - - logger.debug( - f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}" - ) - - chunk_count = 0 - yielded_content = False - - # 统一使用 agent 的 astream_events 实现流式输出 - logger.debug("使用 Agent astream_events 实现流式输出") - - try: - async for event in self.agent.astream_events( - {"messages": messages}, - version="v2" - ): - chunk_count += 1 - kind = event.get("event") - - # 处理所有可能的流式事件 - if kind == "on_chat_model_stream": - # LLM 流式输出 - chunk = event.get("data", {}).get("chunk") - if chunk and hasattr(chunk, "content") and chunk.content: - yield chunk.content - yielded_content = True - - elif kind == "on_llm_stream": - # 另一种 LLM 流式事件 - chunk = event.get("data", {}).get("chunk") - if chunk: - if hasattr(chunk, "content") and chunk.content: - yield chunk.content - yielded_content = True - elif isinstance(chunk, str): - yield chunk - yielded_content = True - - # 记录工具调用(可选) - elif kind == "on_tool_start": - logger.debug(f"工具调用开始: {event.get('name')}") - elif kind == "on_tool_end": - logger.debug(f"工具调用结束: {event.get('name')}") - - logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") - - except Exception as e: - logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) - raise - - except Exception as e: - logger.error("=" * 80) - logger.error(f"chat_stream 异常: {str(e)}") - logger.error("=" * 80, exc_info=True) - raise - finally: - logger.info("=" * 80) - logger.info(f"chat_stream 方法执行结束") - logger.info("=" * 80) - - diff --git a/app/core/api_key_utils.py b/app/core/api_key_utils.py deleted file mode 100644 index e8dc5d98..00000000 --- a/app/core/api_key_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -"""API Key 工具函数""" -import secrets -import hashlib -from app.models.api_key_model import ApiKeyType - - -def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]: - """生成 API Key - - Args: - key_type: API Key 类型 - - Returns: - tuple: (api_key, key_hash, key_prefix) - """ - # 前缀映射 - prefix_map = { - ApiKeyType.APP: "sk-app-", - ApiKeyType.RAG: "sk-rag-", - ApiKeyType.MEMORY: "sk-mem-", - ApiKeyType.GENERAL: "sk-gen-", - } - - prefix = prefix_map[key_type] - random_string = secrets.token_urlsafe(32)[:32] # 32 字符 - api_key = f"{prefix}{random_string}" - - # 生成哈希值存储 - key_hash = hash_api_key(api_key) - - return api_key, key_hash, prefix - - -def hash_api_key(api_key: str) -> str: - """对 API Key 进行哈希 - - Args: - api_key: API Key 明文 - - Returns: - str: 哈希值 - """ - return hashlib.sha256(api_key.encode()).hexdigest() - - -def verify_api_key(api_key: str, key_hash: str) -> bool: - """验证 API Key - - Args: - api_key: API Key 明文 - key_hash: 存储的哈希值 - - Returns: - bool: 是否匹配 - """ - return hash_api_key(api_key) == key_hash diff --git a/app/core/compensation.py b/app/core/compensation.py deleted file mode 100644 index 916f10ba..00000000 --- a/app/core/compensation.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Compensation Transaction Handler -Handles operations that cannot be rolled back (like file system operations). -""" -from typing import List, Callable -from app.core.logging_config import get_logger - -logger = get_logger(__name__) - - -class CompensationHandler: - """补偿事务处理器,用于处理无法回滚的操作""" - - def __init__(self): - self._compensations: List[Callable] = [] - - def register(self, compensation: Callable): - """ - 注册补偿操作 - - Args: - compensation: 补偿操作的可调用对象 - """ - self._compensations.append(compensation) - logger.debug(f"Registered compensation operation: {compensation.__name__ if hasattr(compensation, '__name__') else 'lambda'}") - - def execute(self): - """执行所有补偿操作(按注册的逆序执行)""" - if not self._compensations: - logger.debug("No compensation operations to execute") - return - - logger.info(f"Executing {len(self._compensations)} compensation operations") - - for compensation in reversed(self._compensations): - try: - compensation() - logger.debug(f"Compensation operation executed successfully") - except Exception as e: - logger.error(f"补偿操作失败: {e}", exc_info=True) - - def clear(self): - """清空补偿操作""" - count = len(self._compensations) - self._compensations.clear() - if count > 0: - logger.debug(f"Cleared {count} compensation operations") diff --git a/app/core/config.py b/app/core/config.py deleted file mode 100644 index 5fe205f3..00000000 --- a/app/core/config.py +++ /dev/null @@ -1,237 +0,0 @@ -import os -import json -from pathlib import Path -from typing import Dict, Any, Optional -from dotenv import load_dotenv - -load_dotenv() - -class Settings: - ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" - # API Keys Configuration - OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") - DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "") - - # Neo4j Configuration (记忆系统数据库) - NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687") - NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j") - NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "") - - # Database configuration (Postgres) - DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1") - DB_PORT: int = int(os.getenv("DB_PORT", "5432")) - DB_USER: str = os.getenv("DB_USER", "postgres") - DB_PASSWORD: str = os.getenv("DB_PASSWORD", "password") - DB_NAME: str = os.getenv("DB_NAME", "redbear-mem") - - DB_AUTO_UPGRADE = os.getenv("DB_AUTO_UPGRADE", "false").lower() == "true" - - # Redis configuration - REDIS_HOST: str = os.getenv("REDIS_HOST", "127.0.0.1") - REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) - REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) - REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") - - # ElasticSearch configuration - ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") - ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200")) - ELASTICSEARCH_USERNAME: str = os.getenv("ELASTICSEARCH_USERNAME", "elastic") - ELASTICSEARCH_PASSWORD: str = os.getenv("ELASTICSEARCH_PASSWORD", "") - ELASTICSEARCH_VERIFY_CERTS: bool = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "False").lower() == "true" - ELASTICSEARCH_CA_CERTS: str = os.getenv("ELASTICSEARCH_CA_CERTS", "") - ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000")) - ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true" - ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10")) - - # Xinference configuration - XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1") - - # LangSmith configuration - LANGCHAIN_TRACING_V2: bool = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true" - LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true" - LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "") - LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "") - - # LLM Request Configuration - LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0")) - LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2")) - - # JWT Token Configuration - SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random") - ALGORITHM: str = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30")) - REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7")) - - # Single Sign-On configuration - ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true" - - # File Upload - MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800")) - FILE_PATH: str = os.getenv("FILE_PATH", "/files") - - # VOLC ASR settings - VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "") - VOLC_ACCESS_KEY: str = os.getenv("VOLC_ACCESS_KEY", "") - VOLC_SUBMIT_URL: str = os.getenv("VOLC_SUBMIT_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/submit") - VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query") - - # Langfuse configuration - LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "") - LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "") - LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "") - - # Server Configuration - SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1") - - # ======================================================================== - # Internal Configuration (not in .env, used by application code) - # ======================================================================== - - # Superuser settings (internal defaults) - FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com") - FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin") - FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password") - - # Generic File Upload (internal) - GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads") - ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true" - ENABLE_VIRUS_SCAN: bool = os.getenv("ENABLE_VIRUS_SCAN", "false").lower() == "true" - FILE_ACCESS_URL_PREFIX: str = os.getenv("FILE_ACCESS_URL_PREFIX", "http://localhost:8000/api/files") - - # Frontend URL for workspace invitations (internal) - WEB_URL: str = os.getenv("WEB_URL", "http://localhost:3000") - - # CORS configuration (internal) - CORS_ORIGINS: list[str] = [ - origin.strip() - for origin in os.getenv("CORS_ORIGINS", "").split(",") - if origin.strip() - ] - - # Logging settings - LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") - LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") - LOG_FILE_PATH: str = os.getenv("LOG_FILE_PATH", "logs/app.log") - LOG_MAX_SIZE: int = int(os.getenv("LOG_MAX_SIZE", "10485760")) # 10MB - LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5")) - LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true" - LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true" - - # Sensitive Data Filtering - ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true" - - # Memory Module Logging - PROMPT_LOG_LEVEL: str = os.getenv("PROMPT_LOG_LEVEL", "INFO") - ENABLE_TEMPLATE_LOGGING: bool = os.getenv("ENABLE_TEMPLATE_LOGGING", "false").lower() == "true" - TIMING_LOG_FILE: str = os.getenv("TIMING_LOG_FILE", "logs/time.log") - TIMING_LOG_TO_CONSOLE: bool = os.getenv("TIMING_LOG_TO_CONSOLE", "true").lower() == "true" - AGENT_LOG_FILE: str = os.getenv("AGENT_LOG_FILE", "logs/agent_service.log") - AGENT_LOG_MAX_SIZE: int = int(os.getenv("AGENT_LOG_MAX_SIZE", "5242880")) # 5MB - AGENT_LOG_BACKUP_COUNT: int = int(os.getenv("AGENT_LOG_BACKUP_COUNT", "20")) - - # Log Streaming Configuration - LOG_STREAM_KEEPALIVE_INTERVAL: int = int(os.getenv("LOG_STREAM_KEEPALIVE_INTERVAL", "300")) # 5 minutes - LOG_STREAM_MAX_CONNECTIONS: int = int(os.getenv("LOG_STREAM_MAX_CONNECTIONS", "10")) - LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB - LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB - - - # Celery configuration (internal) - CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) - CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) - REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) - HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) - MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) - DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) - - # Memory Module Configuration (internal) - MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") - MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") - MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json") - MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json") - MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json") - - def get_memory_output_path(self, filename: str = "") -> str: - """ - Get the full path for memory module output files. - - Args: - filename: Optional filename to append to the output directory - - Returns: - Full path to the output file or directory - """ - base_path = Path(self.MEMORY_OUTPUT_DIR) - if filename: - return str(base_path / filename) - return str(base_path) - - def get_memory_config_path(self, config_file: str = "") -> str: - """ - Get the full path for memory module configuration files. - - Args: - config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE) - - Returns: - Full path to the config file - """ - if not config_file: - config_file = self.MEMORY_CONFIG_FILE - return str(Path(self.MEMORY_CONFIG_DIR) / config_file) - - def load_memory_config(self) -> Dict[str, Any]: - """ - Load memory module configuration from config.json. - - Returns: - Dictionary containing memory configuration - """ - config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE) - try: - with open(config_path, "r", encoding="utf-8") as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: - print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}") - return {} - - def load_memory_runtime_config(self) -> Dict[str, Any]: - """ - Load memory module runtime configuration from runtime.json. - - Returns: - Dictionary containing runtime configuration - """ - runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE) - try: - with open(runtime_path, "r", encoding="utf-8") as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: - print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}") - return {"selections": {}} - - def load_memory_dbrun_config(self) -> Dict[str, Any]: - """ - Load memory module database run configuration from dbrun.json. - - Returns: - Dictionary containing dbrun configuration - """ - dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE) - try: - with open(dbrun_path, "r", encoding="utf-8") as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: - print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}") - return {"selections": {}} - - def ensure_memory_output_dir(self) -> None: - """ - Ensure the memory output directory exists. - Creates the directory if it doesn't exist. - """ - output_dir = Path(self.MEMORY_OUTPUT_DIR) - output_dir.mkdir(parents=True, exist_ok=True) - - -settings = Settings() diff --git a/app/core/error_codes.py b/app/core/error_codes.py deleted file mode 100644 index b175c10c..00000000 --- a/app/core/error_codes.py +++ /dev/null @@ -1,130 +0,0 @@ -from enum import IntEnum - - -class BizCode(IntEnum): - # 通用(1xxx) - OK = 0 - BAD_REQUEST = 1000 - VALIDATION_FAILED = 1001 - MISSING_PARAMETER = 1002 - INVALID_PARAMETER = 1003 - # 认证/鉴权(2xxx/3xxx) - UNAUTHORIZED = 2001 - TOKEN_INVALID = 2002 - TOKEN_EXPIRED = 2003 - TOKEN_BLACKLISTED = 2004 - PASSWORD_ERROR = 2005 - LOGIN_FAILED = 2006 - FORBIDDEN = 3001 - TENANT_NOT_FOUND = 3002 - WORKSPACE_NO_ACCESS = 3003 - WORKSPACE_INVITE_NOT_FOUND = 3004 - # 资源(4xxx) - NOT_FOUND = 4000 - USER_NOT_FOUND = 4001 - WORKSPACE_NOT_FOUND = 4002 - MODEL_NOT_FOUND = 4003 - KNOWLEDGE_NOT_FOUND = 4004 - DOCUMENT_NOT_FOUND = 4005 - FILE_NOT_FOUND = 4006 - APP_NOT_FOUND = 4007 - RELEASE_NOT_FOUND = 4008 - - # 冲突/状态(5xxx) - DUPLICATE_NAME = 5001 - RESOURCE_ALREADY_EXISTS = 5002 - VERSION_ALREADY_EXISTS = 5003 - STATE_CONFLICT = 5004 - - # 应用发布(6xxx) - PUBLISH_FAILED = 6001 - NO_DRAFT_TO_PUBLISH = 6002 - ROLLBACK_TARGET_NOT_FOUND = 6003 - APP_TYPE_NOT_SUPPORTED = 6004 - AGENT_CONFIG_MISSING = 6005 - SHARE_DISABLED = 6006 - INVALID_PASSWORD = 6007 - PASSWORD_REQUIRED = 6008 - EMBED_NOT_ALLOWED = 6009 - PERMISSION_DENIED = 6010 - INVALID_CONVERSATION = 6011 - - # 模型(7xxx) - MODEL_CONFIG_INVALID = 7001 - API_KEY_MISSING = 7002 - PROVIDER_NOT_SUPPORTED = 7003 - LLM_ERROR = 7004 - EMBEDDING_ERROR = 7005 - - # 文件/解析(8xxx) - FILE_READ_ERROR = 8001 - PARSER_NOT_SUPPORTED = 8002 - CHUNKING_FAILED = 8003 - - # RAG/知识(9xxx) - INDEX_BUILD_FAILED = 9001 - EMBEDDING_FAILED = 9002 - SEARCH_FAILED = 9003 - - # 系统(100xx) - INTERNAL_ERROR = 10001 - DB_ERROR = 10002 - SERVICE_UNAVAILABLE = 10003 - RATE_LIMITED = 10004 - - -# 建议的HTTP状态映射(如需在异常处理器中使用) -HTTP_MAPPING = { - BizCode.OK: 200, - BizCode.LOGIN_FAILED: 200, - BizCode.BAD_REQUEST: 400, - BizCode.VALIDATION_FAILED: 400, - BizCode.MISSING_PARAMETER: 400, - BizCode.INVALID_PARAMETER: 400, - BizCode.UNAUTHORIZED: 401, - BizCode.TOKEN_INVALID: 401, - BizCode.TOKEN_EXPIRED: 401, - BizCode.TOKEN_BLACKLISTED: 401, - BizCode.FORBIDDEN: 403, - BizCode.TENANT_NOT_FOUND: 404, - BizCode.WORKSPACE_NO_ACCESS: 403, - BizCode.NOT_FOUND: 404, - BizCode.USER_NOT_FOUND: 200, - BizCode.WORKSPACE_NOT_FOUND: 404, - BizCode.MODEL_NOT_FOUND: 404, - BizCode.KNOWLEDGE_NOT_FOUND: 404, - BizCode.DOCUMENT_NOT_FOUND: 404, - BizCode.FILE_NOT_FOUND: 404, - BizCode.APP_NOT_FOUND: 404, - BizCode.RELEASE_NOT_FOUND: 404, - BizCode.DUPLICATE_NAME: 409, - BizCode.RESOURCE_ALREADY_EXISTS: 409, - BizCode.VERSION_ALREADY_EXISTS: 409, - BizCode.STATE_CONFLICT: 409, - BizCode.PUBLISH_FAILED: 500, - BizCode.NO_DRAFT_TO_PUBLISH: 400, - BizCode.ROLLBACK_TARGET_NOT_FOUND: 404, - BizCode.APP_TYPE_NOT_SUPPORTED: 400, - BizCode.AGENT_CONFIG_MISSING: 400, - BizCode.SHARE_DISABLED: 403, - BizCode.INVALID_PASSWORD: 401, - BizCode.PASSWORD_REQUIRED: 401, - BizCode.EMBED_NOT_ALLOWED: 403, - BizCode.PERMISSION_DENIED: 403, - BizCode.INVALID_CONVERSATION: 400, - BizCode.MODEL_CONFIG_INVALID: 400, - BizCode.API_KEY_MISSING: 400, - BizCode.PROVIDER_NOT_SUPPORTED: 400, - BizCode.LLM_ERROR: 500, - BizCode.EMBEDDING_ERROR: 500, - BizCode.FILE_READ_ERROR: 500, - BizCode.PARSER_NOT_SUPPORTED: 400, - BizCode.CHUNKING_FAILED: 500, - BizCode.INDEX_BUILD_FAILED: 500, - BizCode.EMBEDDING_FAILED: 500, - BizCode.SEARCH_FAILED: 500, - BizCode.INTERNAL_ERROR: 500, - BizCode.DB_ERROR: 500, - BizCode.SERVICE_UNAVAILABLE: 503, - BizCode.RATE_LIMITED: 429, -} \ No newline at end of file diff --git a/app/core/exceptions.py b/app/core/exceptions.py deleted file mode 100644 index ef5dd2cd..00000000 --- a/app/core/exceptions.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -业务异常定义 -""" -from typing import Any, Dict, Optional -from app.core.error_codes import BizCode - - -class BusinessException(Exception): - """业务逻辑异常基类""" - - def __init__( - self, - message: str, - code: BizCode | int | None = None, - context: Optional[Dict[str, Any]] = None, - cause: Optional[Exception] = None - ): - self.message = message - self.code = code if code is not None else BizCode.BAD_REQUEST - # Make a copy of context to avoid modifying the original dict - self.context = dict(context) if context else {} - self.cause = cause - super().__init__(self.message) - - def __str__(self) -> str: - ctx = f", context={self.context}" if self.context else "" - code_name = self.code.name if isinstance(self.code, BizCode) else str(self.code) - return f"{code_name}: {self.message}{ctx}" - - -class ValidationException(BusinessException): - """数据验证异常""" - - def __init__(self, message: str, field: str = None, **kwargs): - context = {"field": field} if field else {} - if "context" in kwargs: - context.update(kwargs.pop("context")) - super().__init__(message, BizCode.VALIDATION_FAILED, context, **kwargs) - - -class AuthenticationException(BusinessException): - """认证异常""" - - def __init__(self, message: str = "认证失败", **kwargs): - super().__init__(message, BizCode.UNAUTHORIZED, **kwargs) - - -class AuthorizationException(BusinessException): - """授权异常""" - - def __init__(self, message: str = "权限不足", **kwargs): - super().__init__(message, BizCode.FORBIDDEN, **kwargs) - - -class ResourceNotFoundException(BusinessException): - """资源未找到异常""" - - def __init__(self, resource_type: str, resource_id: str = None, **kwargs): - message = f"{resource_type} 不存在" - context = {"resource_type": resource_type} - if resource_id: - context["resource_id"] = resource_id - if "context" in kwargs: - context.update(kwargs.pop("context")) - super().__init__(message, BizCode.FILE_NOT_FOUND, context, **kwargs) - - -class DuplicateResourceException(BusinessException): - """资源重复异常""" - - def __init__(self, message: str = "资源已存在", **kwargs): - super().__init__(message, BizCode.DUPLICATE_NAME, **kwargs) - - -class FileUploadException(BusinessException): - """文件上传异常""" - - def __init__(self, message: str, **kwargs): - super().__init__(message, BizCode.FILE_READ_ERROR, **kwargs) - - -class PermissionDeniedException(BusinessException): - """权限拒绝异常""" - - def __init__(self, message: str = "权限不足", **kwargs): - super().__init__(message, BizCode.FORBIDDEN, **kwargs) \ No newline at end of file diff --git a/app/core/logging_config.py b/app/core/logging_config.py deleted file mode 100644 index 02747cfb..00000000 --- a/app/core/logging_config.py +++ /dev/null @@ -1,633 +0,0 @@ -import logging -import logging.handlers -import os -from pathlib import Path -from typing import Optional - -from app.core.config import settings -from app.core.sensitive_filter import SensitiveDataFilter - - -class SensitiveDataLoggingFilter(logging.Filter): - """日志过滤器:自动过滤敏感信息""" - - def filter(self, record: logging.LogRecord) -> bool: - """ - 过滤日志记录中的敏感信息 - - Args: - record: 日志记录 - - Returns: - True表示允许记录,False表示拒绝 - """ - # 过滤消息中的敏感信息 - if hasattr(record, 'msg') and isinstance(record.msg, str): - record.msg = SensitiveDataFilter.filter_string(record.msg) - - # 过滤参数中的敏感信息 - if hasattr(record, 'args') and record.args: - if isinstance(record.args, dict): - record.args = SensitiveDataFilter.filter_dict(record.args) - elif isinstance(record.args, (list, tuple)): - record.args = tuple( - SensitiveDataFilter.filter_string(str(arg)) if isinstance(arg, str) else arg - for arg in record.args - ) - - return True - - -class LoggingConfig: - """全局日志配置类""" - - _initialized = False - _memory_loggers_initialized = False - _prompt_logger = None - _template_logger = None - _timing_logger = None - _agent_loggers = {} - - @classmethod - def setup_logging(cls) -> None: - """初始化全局日志配置""" - if cls._initialized: - return - - # 创建日志目录 - log_dir = Path(settings.LOG_FILE_PATH).parent - log_dir.mkdir(parents=True, exist_ok=True) - - # 配置根日志器 - root_logger = logging.getLogger() - root_logger.setLevel(getattr(logging, settings.LOG_LEVEL.upper())) - - # 清除现有处理器 - root_logger.handlers.clear() - - # 创建格式化器 - formatter = logging.Formatter( - fmt=settings.LOG_FORMAT, - datefmt='%Y-%m-%d %H:%M:%S' - ) - - # 创建敏感信息过滤器 - sensitive_filter = SensitiveDataLoggingFilter() - - # 控制台处理器 - if settings.LOG_TO_CONSOLE: - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper())) - console_handler.addFilter(sensitive_filter) - root_logger.addHandler(console_handler) - - # 文件处理器(带轮转) - if settings.LOG_TO_FILE: - file_handler = logging.handlers.RotatingFileHandler( - filename=settings.LOG_FILE_PATH, - maxBytes=settings.LOG_MAX_SIZE, - backupCount=5, - encoding='utf-8' - ) - file_handler.setFormatter(formatter) - file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper())) - file_handler.addFilter(sensitive_filter) - root_logger.addHandler(file_handler) - - cls._initialized = True - - # Initialize memory module logging - cls.setup_memory_logging() - - # 记录初始化完成 - logger = logging.getLogger(__name__) - logger.info("全局日志系统初始化完成") - - @classmethod - def setup_memory_logging(cls) -> None: - """Initialize memory module specific loggers. - - Called automatically by setup_logging() or can be called independently. - Sets up: - - Prompt logger with timestamped files - - Template logger with conditional file output - - Timing logger with dual output (file + console) - - Agent logger factory with concurrent handlers - """ - if cls._memory_loggers_initialized: - return - - # Create logs directory if it doesn't exist - log_dir = Path("logs") - try: - log_dir.mkdir(parents=True, exist_ok=True) - except OSError as e: - print(f"Warning: Could not create log directory: {e}") - # Continue with console-only logging - - # Initialize memory-specific loggers - # These will be created lazily when first requested via factory functions - # This method just marks the system as ready for memory logging - - cls._memory_loggers_initialized = True - - -def get_logger(name: Optional[str] = None) -> logging.Logger: - """获取日志器实例 - - Args: - name: 日志器名称,默认为调用模块名 - - Returns: - 配置好的日志器实例 - """ - return logging.getLogger(name) - - -def get_auth_logger() -> logging.Logger: - """获取认证专用日志器""" - return logging.getLogger("auth") - - -def get_security_logger() -> logging.Logger: - """获取安全专用日志器""" - return logging.getLogger("security") - - -def get_api_logger() -> logging.Logger: - """获取API专用日志器""" - return logging.getLogger("api") - - -def get_db_logger() -> logging.Logger: - """获取数据库专用日志器""" - return logging.getLogger("database") - - -def get_business_logger() -> logging.Logger: - """获取业务逻辑专用日志器""" - return logging.getLogger("business") - - -def get_prompt_logger() -> logging.Logger: - """Get the prompt logger for memory module. - - Returns a logger configured for prompt rendering output with: - - Logger name: memory.prompts - - Output: logs/prompt_logs-{timestamp}.log - - Level: Configurable via PROMPT_LOG_LEVEL setting (default: INFO) - - Handler: FileHandler (no console output) - - The logger is cached after first creation for performance. - - Returns: - Logger configured for prompt rendering output - - Example: - >>> logger = get_prompt_logger() - >>> logger.info("=== RENDERED EXTRACTION PROMPT ===\\n%s", prompt_content) - """ - # Return cached logger if already initialized - if LoggingConfig._prompt_logger is not None: - return LoggingConfig._prompt_logger - - # Ensure memory logging is initialized - if not LoggingConfig._memory_loggers_initialized: - LoggingConfig.setup_memory_logging() - - # Create prompt logger - logger = logging.getLogger("memory.prompts") - logger.setLevel(getattr(logging, settings.PROMPT_LOG_LEVEL.upper())) - logger.propagate = False # Don't propagate to root logger (no console output) - - # Create timestamped log file - from datetime import datetime - timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - log_file = Path("logs/prompts/") / f"prompt_logs-{timestamp}.log" - - # Ensure log directory exists - log_file.parent.mkdir(parents=True, exist_ok=True) - - # Create file handler - file_handler = logging.FileHandler( - filename=str(log_file), - encoding='utf-8' - ) - - # Create formatter - formatter = logging.Formatter( - fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - file_handler.setFormatter(formatter) - - # Add handler to logger - logger.addHandler(file_handler) - - # Cache the logger - LoggingConfig._prompt_logger = logger - - return logger - - -def get_template_logger() -> logging.Logger: - """Get the template logger for memory module. - - Returns a logger configured for template rendering information with: - - Logger name: memory.templates - - Output: logs/prompt_templates.log (only when ENABLE_TEMPLATE_LOGGING is True) - - Level: INFO - - Handler: FileHandler when enabled, NullHandler when disabled - - The logger is cached after first creation for performance. - - Returns: - Logger configured for template rendering info - - Example: - >>> logger = get_template_logger() - >>> logger.info("Rendering template: %s with context keys: %s", - ... template_name, list(context.keys())) - """ - # Return cached logger if already initialized - if LoggingConfig._template_logger is not None: - return LoggingConfig._template_logger - - # Ensure memory logging is initialized - if not LoggingConfig._memory_loggers_initialized: - LoggingConfig.setup_memory_logging() - - # Create template logger - logger = logging.getLogger("memory.templates") - logger.setLevel(logging.INFO) - logger.propagate = False # Don't propagate to root logger - - # Add appropriate handler based on configuration - if settings.ENABLE_TEMPLATE_LOGGING: - # Create log file path - log_file = Path("logs") / "prompt_templates.log" - - # Ensure log directory exists - log_file.parent.mkdir(parents=True, exist_ok=True) - - # Create file handler - file_handler = logging.FileHandler( - filename=str(log_file), - encoding='utf-8' - ) - - # Create formatter - formatter = logging.Formatter( - fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - file_handler.setFormatter(formatter) - - # Add handler to logger - logger.addHandler(file_handler) - else: - # Use NullHandler when template logging is disabled - null_handler = logging.NullHandler() - logger.addHandler(null_handler) - - # Cache the logger - LoggingConfig._template_logger = logger - - return logger - - -def log_prompt_rendering(prompt_type: str, content: str) -> None: - """Log rendered prompt content. - - Logs the rendered prompt with a formatted header and separator for easy - identification in log files. This is useful for debugging LLM interactions - and understanding what prompts are being sent. - - Args: - prompt_type: Type of prompt (e.g., 'statement_extraction', 'triplet_extraction') - content: The rendered prompt text - - Example: - >>> log_prompt_rendering("extraction", "Extract entities from: Hello world") - # Logs: - # === RENDERED EXTRACTION PROMPT === - # Extract entities from: Hello world - # ===================================== - """ - logger = get_prompt_logger() - - # Format the log entry with header and separator - separator = "=" * 50 - header = f"=== RENDERED {prompt_type.upper()} PROMPT ===" - - log_message = f"\n{header}\n{content}\n{separator}\n" - - logger.info(log_message) - - -def log_template_rendering(template_name: str, context: dict | None = None) -> None: - """Log template rendering information. - - Logs the template name and context keys for debugging template rendering. - This function is wrapped in try-except to ensure it never breaks application - flow, even if logging fails. - - Args: - template_name: Name of the Jinja2 template being rendered - context: Optional context dictionary with template variables - - Example: - >>> log_template_rendering("extract_triplet.jinja2", {"text": "...", "ontology": "..."}) - # Logs: Rendering template: extract_triplet.jinja2 with context keys: ['text', 'ontology'] - - >>> log_template_rendering("system.jinja2") - # Logs: Rendering template: system.jinja2 with no context - """ - try: - logger = get_template_logger() - - if context is not None: - context_keys = list(context.keys()) - logger.info(f"Rendering template: {template_name} with context keys: {context_keys}") - else: - logger.info(f"Rendering template: {template_name} with no context") - except Exception: - # Never break application flow due to logging issues - # Silently ignore any logging errors - pass - - - -def get_timing_logger() -> logging.Logger: - """Get the timing logger for memory module. - - Returns a logger configured for performance timing with: - - Logger name: memory.timing - - Output: Configurable via TIMING_LOG_FILE setting (default: logs/time.log) - - Level: INFO - - Handlers: FileHandler + optional StreamHandler for console output - - Console output: Controlled by TIMING_LOG_TO_CONSOLE setting (default: True) - - The logger is cached after first creation for performance. - - Returns: - Logger configured for performance timing - - Example: - >>> logger = get_timing_logger() - >>> logger.info("[2025-11-18 10:30:45] Extraction: 2.34 seconds") - """ - # Return cached logger if already initialized - if LoggingConfig._timing_logger is not None: - return LoggingConfig._timing_logger - - # Ensure memory logging is initialized - if not LoggingConfig._memory_loggers_initialized: - LoggingConfig.setup_memory_logging() - - # Create timing logger - logger = logging.getLogger("memory.timing") - logger.setLevel(logging.INFO) - logger.propagate = False # Don't propagate to root logger - - # Create formatter - formatter = logging.Formatter( - fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - - # Add file handler - log_file = Path(settings.TIMING_LOG_FILE) - - # Ensure log directory exists - log_file.parent.mkdir(parents=True, exist_ok=True) - - file_handler = logging.FileHandler( - filename=str(log_file), - encoding='utf-8' - ) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - # Add console handler if enabled - if settings.TIMING_LOG_TO_CONSOLE: - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - # Cache the logger - LoggingConfig._timing_logger = logger - - return logger - - -def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -> None: - """Log timing information for performance tracking. - - Logs timing information to both file and console (console output is always shown - for backward compatibility). The file output includes a timestamp and full details, - while console output shows a concise checkmark format. - - Args: - step_name: Name of the operation being timed - duration: Duration in seconds - log_file: Optional custom log file path (default: logs/time.log) - - Example: - >>> log_time("Knowledge Extraction", 2.34) - # File logs: [2025-11-18 10:30:45] Knowledge Extraction: 2.34 seconds - # Console prints: ✓ Knowledge Extraction: 2.34s - - >>> log_time("Database Query", 0.15, "logs/custom_time.log") - # Logs to custom file and console - """ - from datetime import datetime - - # Format timestamp - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # Format timing entry for file - log_entry = f"[{timestamp}] {step_name}: {duration:.2f} seconds\n" - - # Write to file with error handling - try: - log_path = Path(log_file) - log_path.parent.mkdir(parents=True, exist_ok=True) - - with open(log_path, "a", encoding="utf-8") as f: - f.write(log_entry) - except IOError as e: - # Fallback to console only if file write fails - print(f"Warning: Could not write to timing log: {e}") - - # Always print to console (backward compatible behavior) - print(f"✓ {step_name}: {duration:.2f}s") - - -def get_agent_logger(name: str = "agent_service", - console_level: str = "INFO", - file_level: str = "DEBUG") -> logging.Logger: - """Get an agent logger with concurrent file handling. - - Returns a logger configured for agent operations with: - - Logger name: memory.agent.{name} - - Output: Configurable via AGENT_LOG_FILE setting (default: logs/agent_service.log) - - Console level: Configurable (default: INFO) - - File level: Configurable (default: DEBUG) - - Handler: ConcurrentRotatingFileHandler for multi-process support - - Rotation: Configurable via AGENT_LOG_MAX_SIZE (default: 5MB) and - AGENT_LOG_BACKUP_COUNT (default: 20) - - The logger is cached by name after first creation for performance. - Supports concurrent writes from multiple processes. - - Args: - name: Logger name for namespacing (default: "agent_service") - console_level: Log level for console output (default: "INFO") - file_level: Log level for file output (default: "DEBUG") - - Returns: - Logger configured for agent operations - - Example: - >>> logger = get_agent_logger("my_agent") - >>> logger.info("Agent operation started") - >>> logger.debug("Detailed agent state information") - - >>> logger = get_agent_logger("custom_agent", console_level="WARNING", file_level="INFO") - >>> logger.warning("This appears in console and file") - >>> logger.info("This only appears in file") - """ - # Return cached logger if already initialized - if name in LoggingConfig._agent_loggers: - return LoggingConfig._agent_loggers[name] - - # Ensure memory logging is initialized - if not LoggingConfig._memory_loggers_initialized: - LoggingConfig.setup_memory_logging() - - # Create agent logger with namespaced name - logger_name = f"memory.agent.{name}" - logger = logging.getLogger(logger_name) - logger.setLevel(logging.DEBUG) # Set to DEBUG to allow both handlers to filter - logger.propagate = False # Don't propagate to root logger - - # Create formatter - formatter = logging.Formatter( - fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - - # Add console handler - console_handler = logging.StreamHandler() - console_handler.setLevel(getattr(logging, console_level.upper())) - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - # Add concurrent rotating file handler - try: - from concurrent_log_handler import ConcurrentRotatingFileHandler - except ImportError: - # Fall back to standard RotatingFileHandler if concurrent handler not available - from logging.handlers import RotatingFileHandler as ConcurrentRotatingFileHandler - print("Warning: concurrent-log-handler not available, using standard RotatingFileHandler") - - # Create log file path - log_file = Path(settings.AGENT_LOG_FILE) - - # Ensure log directory exists - log_file.parent.mkdir(parents=True, exist_ok=True) - - # Create file handler with rotation - file_handler = ConcurrentRotatingFileHandler( - filename=str(log_file), - maxBytes=settings.AGENT_LOG_MAX_SIZE, - backupCount=settings.AGENT_LOG_BACKUP_COUNT, - encoding='utf-8' - ) - file_handler.setLevel(getattr(logging, file_level.upper())) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - # Cache the logger - LoggingConfig._agent_loggers[name] = logger - - return logger - - -def get_named_logger(name: str) -> logging.Logger: - """Backward compatible alias for get_agent_logger. - - This function maintains backward compatibility with existing code that uses - the get_named_logger pattern from the agent logger module. - - Args: - name: Logger name for namespacing - - Returns: - Logger configured for agent operations - - Example: - >>> logger = get_named_logger("my_agent") - >>> logger.info("Agent operation started") - """ - return get_agent_logger(name) - - -def get_memory_logger(name: Optional[str] = None) -> logging.Logger: - """Get a standard logger for memory module components. - - Returns a logger configured for memory module components that inherits - the root logger's configuration (handlers, formatters, and level). This - provides consistent logging behavior across the memory module while - maintaining the ability to filter and identify memory-specific logs. - - The logger uses the 'memory' namespace: - - If name is provided: logger name is 'memory.{module_name}' - - If name is None: logger name is 'memory' - - The logger inherits all handlers and formatters from the root logger, - ensuring consistent output format and destinations (console, file, etc.). - - Args: - name: Optional logger name, typically __name__ from the calling module. - If provided, creates a namespaced logger under 'memory.{name}'. - If None, returns the base 'memory' logger. - - Returns: - Logger configured for memory module operations with root logger inheritance - - Example: - >>> # In app/core/memory/src/search.py - >>> logger = get_memory_logger(__name__) - >>> logger.info("Starting search operation") - # Logs: [timestamp] - memory.app.core.memory.src.search - INFO - Starting search operation - - >>> # Get base memory logger - >>> logger = get_memory_logger() - >>> logger.debug("Memory module initialized") - # Logs: [timestamp] - memory - DEBUG - Memory module initialized - - >>> # In app/core/memory/src/knowledge_extraction/triplet_extraction.py - >>> logger = get_memory_logger(__name__) - >>> logger.error("Extraction failed", exc_info=True) - # Logs error with full traceback - """ - # Ensure memory logging is initialized - if not LoggingConfig._memory_loggers_initialized: - LoggingConfig.setup_memory_logging() - - # Construct logger name with memory namespace - if name is not None: - logger_name = f"memory.{name}" - else: - logger_name = "memory" - - # Get logger - it will inherit from root logger configuration - logger = logging.getLogger(logger_name) - - # The logger automatically inherits handlers, formatters, and level from root logger - # through Python's logging hierarchy, so no additional configuration is needed - - return logger diff --git a/app/core/memory/__init__.py b/app/core/memory/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/memory/agent/__init__.py b/app/core/memory/agent/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/memory/agent/langgraph_graph/__init__.py b/app/core/memory/agent/langgraph_graph/__init__.py deleted file mode 100644 index a0596e38..00000000 --- a/app/core/memory/agent/langgraph_graph/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -LangGraph Graph package for memory agent. - -This package provides the LangGraph workflow orchestrator with modular -node implementations, routing logic, and state management. - -Package structure: -- read_graph: Main graph factory for read operations -- write_graph: Main graph factory for write operations -- nodes: LangGraph node implementations -- routing: State routing logic -- state: State management utilities -""" -from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph - -__all__ = ['make_read_graph'] \ No newline at end of file diff --git a/app/core/memory/agent/langgraph_graph/nodes/__init__.py b/app/core/memory/agent/langgraph_graph/nodes/__init__.py deleted file mode 100644 index 4e808919..00000000 --- a/app/core/memory/agent/langgraph_graph/nodes/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -LangGraph node implementations. - -This module contains custom node implementations for the LangGraph workflow. -""" - -from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode -from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message - -__all__ = ["ToolExecutionNode", "create_input_message"] diff --git a/app/core/memory/agent/langgraph_graph/nodes/input_node.py b/app/core/memory/agent/langgraph_graph/nodes/input_node.py deleted file mode 100644 index 350043fa..00000000 --- a/app/core/memory/agent/langgraph_graph/nodes/input_node.py +++ /dev/null @@ -1,144 +0,0 @@ -""" -Input node for LangGraph workflow entry point. - -This module provides the create_input_message function which processes initial -user input with multimodal support and creates the first tool call message. -""" - -import logging -import re -import uuid -from datetime import datetime -from typing import Dict, Any - -from langchain_core.messages import AIMessage - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor - -logger = logging.getLogger(__name__) - - -async def create_input_message( - state: Dict[str, Any], - tool_name: str, - session_id: str, - search_switch: str, - apply_id: str, - group_id: str, - multimodal_processor: MultimodalProcessor -) -> Dict[str, Any]: - """ - Create initial tool call message from user input. - - This function: - 1. Extracts the last message content from state - 2. Processes multimodal inputs (images/audio) using the multimodal processor - 3. Generates a unique message ID - 4. Extracts namespace from session_id - 5. Handles verified_data extraction for backward compatibility - 6. Returns AIMessage with complete tool_calls structure - - Args: - state: LangGraph state dictionary containing messages - tool_name: Name of the tool to invoke (typically "Split_The_Problem") - session_id: Session identifier (format: "call_id_{namespace}") - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - multimodal_processor: Processor for handling image/audio inputs - - Returns: - State update with AIMessage containing tool_call - - Examples: - >>> state = {"messages": [HumanMessage(content="What is AI?")]} - >>> result = await create_input_message( - ... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor - ... ) - >>> result["messages"][0].tool_calls[0]["name"] - 'Split_The_Problem' - """ - messages = state.get("messages", []) - - # Extract last message content - if messages: - last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1]) - else: - logger.warning("[create_input_message] No messages in state, using empty string") - last_message = "" - - logger.debug(f"[create_input_message] Original input: {last_message[:100]}...") - - # Process multimodal input (images/audio) - try: - processed_content = await multimodal_processor.process_input(last_message) - if processed_content != last_message: - logger.info( - f"[create_input_message] Multimodal processing converted input " - f"from {len(last_message)} to {len(processed_content)} chars" - ) - last_message = processed_content - except Exception as e: - logger.error( - f"[create_input_message] Multimodal processing failed: {e}", - exc_info=True - ) - # Continue with original content - - # Generate unique message ID - uuid_str = uuid.uuid4() - time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # Extract namespace from session_id - # Expected format: "call_id_{namespace}" or similar - try: - namespace = str(session_id).split('_id_')[1] - except (IndexError, AttributeError): - logger.warning( - f"[create_input_message] Could not extract namespace from session_id: {session_id}" - ) - namespace = "unknown" - - # Handle verified_data extraction (backward compatibility) - # This regex-based extraction is kept for compatibility with existing data formats - if 'verified_data' in str(last_message): - try: - messages_last = str(last_message).replace('\\n', '').replace('\\', '') - query_match = re.findall(r'"query": "(.*?)",', messages_last) - if query_match: - last_message = query_match[0] - logger.debug( - f"[create_input_message] Extracted query from verified_data: {last_message}" - ) - except Exception as e: - logger.warning( - f"[create_input_message] Failed to extract query from verified_data: {e}" - ) - - # Construct tool call message - tool_call_id = f"{session_id}_{uuid_str}" - - logger.info( - f"[create_input_message] Creating tool call for '{tool_name}' " - f"with ID: {tool_call_id}" - ) - - return { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": tool_name, - "args": { - "sentence": last_message, - "sessionid": session_id, - "messages_id": str(uuid_str), - "search_switch": search_switch, - "apply_id": apply_id, - "group_id": group_id - }, - "id": tool_call_id - }] - ) - ] - } diff --git a/app/core/memory/agent/langgraph_graph/nodes/tool_node.py b/app/core/memory/agent/langgraph_graph/nodes/tool_node.py deleted file mode 100644 index 9ea2cad6..00000000 --- a/app/core/memory/agent/langgraph_graph/nodes/tool_node.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Tool execution node for LangGraph workflow. - -This module provides the ToolExecutionNode class which wraps tool execution -with parameter transformation logic using the ParameterBuilder service. -""" - -import logging -import time -from typing import Any, Callable, Dict - -from langchain_core.messages import AIMessage -from langgraph.prebuilt import ToolNode - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_tool_call_id, - extract_content_payload -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder - -logger = logging.getLogger(__name__) - - -class ToolExecutionNode: - """ - Custom LangGraph node that wraps tool execution with parameter transformation. - - This node extracts content from previous tool results, transforms parameters - based on tool type using ParameterBuilder, and invokes the tool with the - correct argument structure. - - Attributes: - tool_node: LangGraph ToolNode wrapping the actual tool - id: Node identifier for message IDs - tool_name: Name of the tool being executed - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - """ - - def __init__( - self, - tool: Callable, - node_id: str, - namespace: str, - search_switch: str, - apply_id: str, - group_id: str, - parameter_builder: ParameterBuilder, - storage_type:str, - user_rag_memory_id:str - ): - """ - Initialize the tool execution node. - - Args: - tool: The tool function to execute - node_id: Identifier for this node (used in message IDs) - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - """ - self.tool_node = ToolNode([tool]) - self.id = node_id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.parameter_builder = parameter_builder - self.storage_type=storage_type - self.user_rag_memory_id=user_rag_memory_id - - logger.info( - f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'" - ) - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """ - Execute the tool with transformed parameters. - - This method: - 1. Extracts the last message from state - 2. Extracts tool call ID using state extractors - 3. Extracts content payload using state extractors - 4. Builds tool arguments using parameter builder - 5. Constructs AIMessage with tool_calls - 6. Invokes the tool and returns the result - - Args: - state: LangGraph state dictionary - - Returns: - Updated state with tool result in messages - """ - messages = state.get("messages", []) - logger.debug( self.tool_name) - - if not messages: - logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state") - return {"messages": [AIMessage(content="Error: No messages in state")]} - - last_message = messages[-1] - logger.debug( - f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}" - ) - - try: - # Extract tool call ID using state extractors - tool_call_id = extract_tool_call_id(last_message) - logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}") - - except ValueError as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}" - ) - return {"messages": [AIMessage(content=f"Error: {str(e)}")]} - - try: - # Extract content payload using state extractors - content = extract_content_payload(last_message) - logger.debug( - f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}" - ) - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}", - exc_info=True - ) - content = {} - - try: - # Build tool arguments using parameter builder - tool_args = self.parameter_builder.build_tool_args( - tool_name=self.tool_name, - content=content, - tool_call_id=tool_call_id, - search_switch=self.search_switch, - apply_id=self.apply_id, - group_id=self.group_id, - storage_type=self.storage_type, - user_rag_memory_id=self.user_rag_memory_id - ) - logger.debug( - f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}" - ) - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}", - exc_info=True - ) - return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]} - - # Construct tool input message - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": f"{self.id}_{tool_call_id}", - }] - ) - ] - } - - try: - # Invoke the tool - result = await self.tool_node.ainvoke(tool_input) - - logger.debug( - f"[ToolExecutionNode] {self.id} - Tool execution completed" - ) - - # Return the result directly - it already contains the messages list - return result - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}", - exc_info=True - ) - # Return error as ToolMessage to maintain message chain consistency - from langchain_core.messages import ToolMessage - return { - "messages": [ - ToolMessage( - content=f"Error executing tool: {str(e)}", - tool_call_id=f"{self.id}_{tool_call_id}" - ) - ] - } diff --git a/app/core/memory/agent/langgraph_graph/read_graph.py b/app/core/memory/agent/langgraph_graph/read_graph.py deleted file mode 100644 index 51127f3e..00000000 --- a/app/core/memory/agent/langgraph_graph/read_graph.py +++ /dev/null @@ -1,508 +0,0 @@ -import asyncio -import io -import json -import logging -import os -import re -import time -import uuid -import warnings -from contextlib import asynccontextmanager -from datetime import datetime -from typing import Literal - -from dotenv import load_dotenv -from langchain_core.messages import AIMessage -from langgraph.constants import START, END -from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode -from functools import partial - -from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState -from langgraph.checkpoint.memory import InMemorySaver - -from app.core.memory.agent.utils.redis_tool import store -from app.core.logging_config import get_agent_logger - -# Import new modular components -from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message -from app.core.memory.agent.langgraph_graph.routing.routers import ( - Verify_continue, - Retrieve_continue, - Split_continue -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder -from app.core.memory.agent.utils.multimodal import MultimodalProcessor - -logger = get_agent_logger(__name__) - -warnings.filterwarnings("ignore", category=RuntimeWarning) -load_dotenv() -redishost=os.getenv("REDISHOST") -redisport=os.getenv('REDISPORT') -redisdb=os.getenv('REDISDB') -redispassword=os.getenv('REDISPASSWORD') -counter = COUNTState(limit=3) - -# 在工作流中添加循环计数更新 -async def update_loop_count(state): - """更新循环计数器""" - current_count = state.get("loop_count", 0) - return {"loop_count": current_count + 1} - - -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: - messages = state["messages"] - - # 添加边界检查 - if not messages: - return END - counter.add(1) # 累加 1 - - loop_count = counter.get_total() - logger.debug(f"[should_continue] 当前循环次数: {loop_count}") - - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"Status tools: {status_tools}") - - if "success" in status_tools: - counter.reset() - return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # 最大循环次数 3 - return "content_input" - else: - counter.reset() - return "Summary_fails" - else: - # 添加默认返回值,避免返回 None - counter.reset() - return "Summary" # 或根据业务需求选择合适的默认值 - - -def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # 添加默认返回值,避免返回 None - return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值 - - -def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - logger.debug(f"Split_continue state: {state}") - - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '2': - return 'Input_Summary' - return 'Split_The_Problem' # 默认情况 - -# 在 input_sentence 函数中修改参数名称 -async def input_sentence(state, name, id, search_switch,apply_id,group_id): - messages = state["messages"] - last_message = messages[-1].content if messages else "" - - if last_message.endswith('.jpg') or last_message.endswith('.png'): - last_message=await picture_model_requests(last_message) - if any(last_message.endswith(ext) for ext in audio_extensions): - last_message=await Vico_recognition([last_message]).run() - logger.debug(f"Audio recognition result: {last_message}") - - - uuid_str = uuid.uuid4() - time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - namespace = str(id).split('_id_')[1] - if 'verified_data' in str(last_message): - messages_last = str(last_message).replace('\\n', '').replace('\\', '') - last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0] - - return { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": name, - "args": { - "sentence": last_message, - 'sessionid': id, - 'messages_id': str(uuid_str), - "search_switch": search_switch, # 正确地将 search_switch 放入 args 中 - "apply_id":apply_id, - "group_id":group_id - }, - "id": id + f'_{uuid_str}' - }] - ) - ] - } - - -class ProblemExtensionNode: - def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""): - self.tool_node = ToolNode([tool]) - self.id = id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.storage_type = storage_type - self.user_rag_memory_id = user_rag_memory_id - - async def __call__(self, state): - messages = state["messages"] - last_message = messages[-1] if messages else "" - logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}") - if self.tool_name=='Input_Summary': - tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0] - else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1] - # try: - # content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message - # except: - # content = last_message.content if hasattr(last_message, 'content') else str(last_message) - # 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示) - raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message) - extracted_payload = None - # 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配 - m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S) - if m: - extracted_payload = m.group(1) - else: - # 回退:直接尝试使用原始字符串 - extracted_payload = raw_msg - - # 优先尝试将内容解析为 JSON - try: - content = json.loads(extracted_payload) - except Exception: - # 尝试从文本中提取 JSON 片段再解析 - parsed = None - candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S) - for cand in candidates: - try: - parsed = json.loads(cand) - break - except Exception: - continue - # 如果仍然失败,则以原始字符串作为内容 - content = parsed if parsed is not None else extracted_payload - - # 根据工具名称构建正确的参数 - tool_args = {} - - if self.tool_name == "Verify": - # Verify工具需要context和usermessages参数 - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Retrieve": - # Retrieve工具需要context和usermessages参数 - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary": - # Summary工具需要字符串类型的context参数 - if isinstance(content, dict): - # 将字典转换为JSON字符串 - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary_fails": - # Summary工具需要字符串类型的context参数 - if isinstance(content, dict): - # 将字典转换为JSON字符串 - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name=='Input_Summary': - tool_args["context"] =str(last_message) - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - tool_args["storage_type"] = getattr(self, 'storage_type', "") - tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "") - elif self.tool_name=='Retrieve_Summary' : - # Retrieve_Summary expects dict directly, not JSON string - # content might be a JSON string, try to parse it - if isinstance(content, str): - try: - parsed_content = json.loads(content) - # Check if it has a "context" key - if isinstance(parsed_content, dict) and "context" in parsed_content: - tool_args["context"] = parsed_content["context"] - else: - tool_args["context"] = parsed_content - except json.JSONDecodeError: - # If parsing fails, wrap the string - tool_args["context"] = {"content": content} - elif isinstance(content, dict): - # Check if content has a "context" key that needs unwrapping - if "context" in content: - tool_args["context"] = content["context"] - else: - tool_args["context"] = content - else: - tool_args["context"] = {"content": str(content)} - - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - else: - # 其他工具使用context参数 - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - - - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": self.id + f"{tool_call}", - }] - ) - ] - } - result = await self.tool_node.ainvoke(tool_input) - result_text = str(result) - - return {"messages": [AIMessage(content=result_text)]} - - -@asynccontextmanager -async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None): - memory = InMemorySaver() - tool=[i.name for i in tools ] - logger.info(f"Initializing read graph with tools: {tool}") - if config_id: - logger.info(f"使用配置 ID: {config_id}") - - # Extract tool functions - Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None) - Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None) - Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None) - Verify_ = next((t for t in tools if t.name == "Verify"), None) - Summary_ = next((t for t in tools if t.name == "Summary"), None) - Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None) - Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None) - Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None) - - # Instantiate services - parameter_builder = ParameterBuilder() - multimodal_processor = MultimodalProcessor() - - # Create nodes using new modular components - Split_The_Problem_node = ToolNode([Split_The_Problem_]) - - Problem_Extension_node = ToolExecutionNode( - tool=Problem_Extension_, - node_id="Problem_Extension_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - Retrieve_node = ToolExecutionNode( - tool=Retrieve_, - node_id="Retrieve_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - Verify_node = ToolExecutionNode( - tool=Verify_, - node_id="Verify_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - Summary_node = ToolExecutionNode( - tool=Summary_, - node_id="Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - Summary_fails_node = ToolExecutionNode( - tool=Summary_fails_, - node_id="Summary_fails_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - Retrieve_Summary_node = ToolExecutionNode( - tool=Retrieve_Summary_, - node_id="Retrieve_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - Input_Summary_node = ToolExecutionNode( - tool=Input_Summary_, - node_id="Input_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - - async def content_input_node(state): - state_search_switch = state.get("search_switch", search_switch) - - tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem" - session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id" - - return await create_input_message( - state=state, - tool_name=tool_name, - session_id=f"{session_prefix}_{namespace}", - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - multimodal_processor=multimodal_processor - ) - - - # Build workflow graph - workflow = StateGraph(ReadState) - workflow.add_node("content_input", content_input_node) - workflow.add_node("Split_The_Problem", Split_The_Problem_node) - workflow.add_node("Problem_Extension", Problem_Extension_node) - workflow.add_node("Retrieve", Retrieve_node) - workflow.add_node("Verify", Verify_node) - workflow.add_node("Summary", Summary_node) - workflow.add_node("Summary_fails", Summary_fails_node) - workflow.add_node("Retrieve_Summary", Retrieve_Summary_node) - workflow.add_node("Input_Summary", Input_Summary_node) - - # Add edges using imported routers - workflow.add_edge(START, "content_input") - workflow.add_conditional_edges("content_input", Split_continue) - workflow.add_edge("Input_Summary", END) - workflow.add_edge("Split_The_Problem", "Problem_Extension") - workflow.add_edge("Problem_Extension", "Retrieve") - workflow.add_conditional_edges("Retrieve", Retrieve_continue) - workflow.add_edge("Retrieve_Summary", END) - workflow.add_conditional_edges("Verify", Verify_continue) - workflow.add_edge("Summary_fails", END) - workflow.add_edge("Summary", END) - - graph = workflow.compile(checkpointer=memory) - yield graph - - -# 添加到文件末尾或创建新的执行脚本 -# 在 memory_agent_service.py 文件中添加以下函数 - diff --git a/app/core/memory/agent/langgraph_graph/routing/__init__.py b/app/core/memory/agent/langgraph_graph/routing/__init__.py deleted file mode 100644 index a9366bd0..00000000 --- a/app/core/memory/agent/langgraph_graph/routing/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph routing logic.""" - -from app.core.memory.agent.langgraph_graph.routing.routers import ( - Verify_continue, - Retrieve_continue, - Split_continue, -) - -__all__ = [ - "Verify_continue", - "Retrieve_continue", - "Split_continue", -] diff --git a/app/core/memory/agent/langgraph_graph/routing/routers.py b/app/core/memory/agent/langgraph_graph/routing/routers.py deleted file mode 100644 index c8abd544..00000000 --- a/app/core/memory/agent/langgraph_graph/routing/routers.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Routing functions for LangGraph conditional edges. - -This module provides routing functions that determine the next node to execute -based on state values. All functions return Literal types for type safety. -""" - -import logging -import re -from typing import Literal - -from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch -from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState - -logger = logging.getLogger(__name__) - -# Global counter for Verify routing -counter = COUNTState(limit=3) - - -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: - """ - Determine routing after Verify node based on verification result. - - This function checks the verification result in the last message and routes to: - - Summary: if verification succeeded - - content_input: if verification failed and retry limit not reached - - Summary_fails: if verification failed and retry limit reached - - Args: - state: LangGraph state containing messages - - Returns: - Next node name as Literal type - """ - messages = state.get("messages", []) - - # Boundary check - if not messages: - logger.warning("[Verify_continue] No messages in state, defaulting to Summary") - counter.reset() - return "Summary" - - # Increment counter - counter.add(1) - loop_count = counter.get_total() - logger.debug(f"[Verify_continue] Current loop count: {loop_count}") - - # Extract verification result from last message - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"[Verify_continue] Status tools: {status_tools}") - - # Route based on verification result - if "success" in status_tools: - counter.reset() - return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # Max retry count is 2 - return "content_input" - else: - counter.reset() - return "Summary_fails" - else: - # Default to Summary if status is unclear - counter.reset() - return "Summary" - - -def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing after Retrieve node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '0': Route to Verify (verification needed) - - search_switch == '1': Route to Retrieve_Summary (direct summary) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - search_switch = extract_search_switch(state) - - logger.debug(f"[Retrieve_continue] search_switch: {search_switch}") - - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # Default to Retrieve_Summary - logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary") - return 'Retrieve_Summary' - - -def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing after content_input node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '2': Route to Input_Summary (direct input summary) - - Otherwise: Route to Split_The_Problem (problem decomposition) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - logger.debug(f"[Split_continue] state keys: {state.keys()}") - - search_switch = extract_search_switch(state) - - logger.debug(f"[Split_continue] search_switch: {search_switch}") - - if search_switch == '2': - return 'Input_Summary' - - # Default to Split_The_Problem - return 'Split_The_Problem' diff --git a/app/core/memory/agent/langgraph_graph/state/__init__.py b/app/core/memory/agent/langgraph_graph/state/__init__.py deleted file mode 100644 index 279c6463..00000000 --- a/app/core/memory/agent/langgraph_graph/state/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph state management utilities.""" - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_search_switch, - extract_tool_call_id, - extract_content_payload, -) - -__all__ = [ - "extract_search_switch", - "extract_tool_call_id", - "extract_content_payload", -] diff --git a/app/core/memory/agent/langgraph_graph/state/extractors.py b/app/core/memory/agent/langgraph_graph/state/extractors.py deleted file mode 100644 index 92bec147..00000000 --- a/app/core/memory/agent/langgraph_graph/state/extractors.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -State extraction utilities for type-safe access to LangGraph state values. - -This module provides utility functions for extracting values from LangGraph state -dictionaries with proper error handling and sensible defaults. -""" - -import json -import logging -from typing import Any, Optional - -logger = logging.getLogger(__name__) - -def extract_search_switch(state: dict) -> Optional[str]: - """ - Extract search_switch from state or messages. - """ - - search_switch = state.get("search_switch") - - if search_switch is not None: - return str(search_switch) - - # Try to extract from messages - messages = state.get("messages", []) - if not messages: - return None - - # 从最新的消息开始查找 - for message in reversed(messages): - # 尝试从 tool_calls 中提取 - if hasattr(message, "tool_calls") and message.tool_calls: - for tool_call in message.tool_calls: - if isinstance(tool_call, dict): - # 从 tool_call 的 args 中提取 - if "args" in tool_call and isinstance(tool_call["args"], dict): - search_switch = tool_call["args"].get("search_switch") - if search_switch is not None: - return str(search_switch) - # 直接从 tool_call 中提取 - search_switch = tool_call.get("search_switch") - if search_switch is not None: - return str(search_switch) - - # 尝试从 content 中提取(如果是 JSON 格式) - if hasattr(message, "content"): - try: - import json - if isinstance(message.content, str): - content_data = json.loads(message.content) - if isinstance(content_data, dict): - search_switch = content_data.get("search_switch") - if search_switch is not None: - return str(search_switch) - except (json.JSONDecodeError, ValueError): - pass - - return None - - -def extract_tool_call_id(message: Any) -> str: - """ - Extract tool call ID from message using structured attributes. - - This function extracts the tool call ID from a message object, handling both - direct attribute access and tool_calls list structures. - - Args: - message: Message object (typically ToolMessage or AIMessage) - - Returns: - Tool call ID as string - - Raises: - ValueError: If tool call ID cannot be extracted - - Examples: - >>> message = ToolMessage(content="...", tool_call_id="call_123") - >>> extract_tool_call_id(message) - 'call_123' - """ - # Try direct attribute access for ToolMessage - if hasattr(message, "tool_call_id"): - tool_call_id = message.tool_call_id - if tool_call_id: - return str(tool_call_id) - - # Try extracting from tool_calls list for AIMessage - if hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "id" in tool_call: - return str(tool_call["id"]) - - # Try extracting from id attribute - if hasattr(message, "id"): - message_id = message.id - if message_id: - return str(message_id) - - # If all else fails, raise an error - raise ValueError(f"Could not extract tool call ID from message: {type(message)}") - - -def extract_content_payload(message: Any) -> Any: - """ - Extract content payload from ToolMessage, parsing JSON if needed. - - This function extracts the content from a message and attempts to parse it as JSON - if it appears to be a JSON string. It handles various message formats and provides - sensible fallbacks. - - Args: - message: Message object (typically ToolMessage) - - Returns: - Parsed content (dict, list, or str) - - Examples: - >>> message = ToolMessage(content='{"key": "value"}') - >>> extract_content_payload(message) - {'key': 'value'} - - >>> message = ToolMessage(content='plain text') - >>> extract_content_payload(message) - 'plain text' - """ - # Extract raw content - # For ToolMessages (responses from tools), extract from content - if hasattr(message, "content"): - raw_content = message.content - - # If content is empty and this is an AIMessage with tool_calls, - # extract from args (this handles the initial tool call from content_input) - if not raw_content and hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "args" in tool_call: - return tool_call["args"] - else: - raw_content = str(message) - - # If content is already a dict or list, return it directly - if isinstance(raw_content, (dict, list)): - return raw_content - - # Try to parse as JSON - if isinstance(raw_content, str): - # First, try direct JSON parsing - try: - return json.loads(raw_content) - except (json.JSONDecodeError, ValueError): - pass - - # If that fails, try to extract JSON from the string - # This handles cases where the content is embedded in a larger string - import re - json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL) - for candidate in json_candidates: - try: - return json.loads(candidate) - except (json.JSONDecodeError, ValueError): - continue - - # If all parsing attempts fail, return the raw content - return raw_content diff --git a/app/core/memory/agent/langgraph_graph/write_graph.py b/app/core/memory/agent/langgraph_graph/write_graph.py deleted file mode 100644 index dbdc51d6..00000000 --- a/app/core/memory/agent/langgraph_graph/write_graph.py +++ /dev/null @@ -1,78 +0,0 @@ -import asyncio -import json -from contextlib import asynccontextmanager -from langgraph.constants import START, END -from langgraph.graph import add_messages, StateGraph - -from langgraph.prebuilt import ToolNode -from app.core.memory.agent.utils.llm_tools import WriteState -import warnings -import sys -from langchain_core.messages import AIMessage -from app.core.logging_config import get_agent_logger - -warnings.filterwarnings("ignore", category=RuntimeWarning) - -logger = get_agent_logger(__name__) - -if sys.platform.startswith("win"): - import asyncio - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) -@asynccontextmanager -async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None): - logger.info("加载 MCP 工具: %s", [t.name for t in tools]) - if config_id: - logger.info(f"使用配置 ID: {config_id}") - - data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None) - data_write_tool = next((t for t in tools if t.name == "Data_write"), None) - - if not data_type_tool or not data_write_tool: - logger.error('不存在数据存储工具', exc_info=True) - raise ValueError('不存在数据存储工具') - # ToolNode - write_node = ToolNode([data_write_tool]) - - - async def call_model(state): - messages = state["messages"] - last_message = messages[-1] - - result = await data_type_tool.ainvoke({ - "context": last_message[1] if isinstance(last_message, tuple) else last_message.content - }) - result=json.loads( result) - - # 调用 Data_write,传递 config_id - write_params = { - "content": result["context"], - "apply_id": apply_id, - "group_id": group_id, - "user_id": user_id - } - - # 如果提供了 config_id,添加到参数中 - if config_id: - write_params["config_id"] = config_id - logger.debug(f"传递 config_id 到 Data_write: {config_id}") - - write_result = await data_write_tool.ainvoke(write_params) - - if isinstance(write_result, dict): - content = write_result.get("data", str(write_result)) - else: - content = str(write_result) - logger.info("写入内容: %s", content) - return {"messages": [AIMessage(content=content)]} - - workflow = StateGraph(WriteState) - workflow.add_node("content_input", call_model) - workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "content_input") - workflow.add_edge("content_input", "save_neo4j") - workflow.add_edge("save_neo4j", END) - - graph = workflow.compile() - - - yield graph diff --git a/app/core/memory/agent/logger_file/log_streamer.py b/app/core/memory/agent/logger_file/log_streamer.py deleted file mode 100644 index 4d98266c..00000000 --- a/app/core/memory/agent/logger_file/log_streamer.py +++ /dev/null @@ -1,285 +0,0 @@ -""" -Log Streamer Module - -Manages streaming of log file content with file watching and real-time transmission. -""" -import os -import re -import time -import asyncio -from typing import AsyncGenerator, Optional -from pathlib import Path - -from app.core.logging_config import get_logger - -logger = get_logger(__name__) - - -class LogStreamer: - """Manages log file streaming with file watching and content transmission""" - - def __init__(self, log_path: str, keepalive_interval: int = 300): - """ - Initialize LogStreamer - - Args: - log_path: Path to the log file to stream - keepalive_interval: Interval in seconds for sending keepalive messages (default: 300) - """ - self.log_path = log_path - self.keepalive_interval = keepalive_interval - self.last_position = 0 - - # Pattern to match and remove timestamp and log level prefix - # Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - " - # This pattern is comprehensive to handle various log formats - self.pattern = re.compile( - r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - ' - ) - - logger.info(f"LogStreamer initialized for {log_path}") - - @staticmethod - def clean_log_line(line: str) -> str: - """ - Static method to clean log entry by removing timestamp and log level prefix. - This is the canonical log cleaning method used by both file mode and transmission mode. - - Args: - line: Raw log line - - Returns: - Cleaned log line without timestamp and log level prefix - """ - # Pattern to match and remove timestamp and log level prefix - # Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - " - pattern = re.compile( - r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - ' - ) - cleaned = re.sub(pattern, '', line) - return cleaned - - def clean_log_entry(self, line: str) -> str: - """ - Clean log entry by removing timestamp and log level prefix. - This instance method delegates to the static method for consistency. - - Args: - line: Raw log line - - Returns: - Cleaned log line without timestamp and log level prefix - """ - return LogStreamer.clean_log_line(line) - - async def send_keepalive(self) -> dict: - """ - Generate keepalive message - - Returns: - Keepalive message dict with timestamp - """ - return { - "event": "keepalive", - "data": { - "timestamp": int(time.time()) - } - } - - async def read_existing_and_stream(self) -> AsyncGenerator[dict, None]: - """ - Read existing log content first, then watch for new content - - This method reads all existing content in the file first, - then continues to watch for new content as it's written. - - Yields: - Dict messages with event type and data: - - log events: {"event": "log", "data": {"content": "...", "timestamp": ...}} - - keepalive events: {"event": "keepalive", "data": {"timestamp": ...}} - - error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}} - - done events: {"event": "done", "data": {"message": "..."}} - """ - logger.info(f"Starting log stream (read existing) for {self.log_path}") - - # Check if file exists - if not os.path.exists(self.log_path): - logger.error(f"Log file not found: {self.log_path}") - yield { - "event": "error", - "data": { - "code": 4006, - "message": "日志文件不存在", - "error": f"File not found: {self.log_path}" - } - } - return - - try: - with open(self.log_path, 'r', encoding='utf-8') as f: - # First, read all existing content - for line in f: - if line.strip(): # Skip empty lines - cleaned_line = self.clean_log_entry(line) - yield { - "event": "log", - "data": { - "content": cleaned_line.rstrip('\n'), - "timestamp": int(time.time()) - } - } - - # Now watch for new content - self.last_position = f.tell() - last_keepalive = time.time() - - while True: - line = f.readline() - if line: - cleaned_line = self.clean_log_entry(line) - yield { - "event": "log", - "data": { - "content": cleaned_line.rstrip('\n'), - "timestamp": int(time.time()) - } - } - last_keepalive = time.time() - else: - # No new content, check if we need to send keepalive - current_time = time.time() - if current_time - last_keepalive >= self.keepalive_interval: - keepalive_msg = await self.send_keepalive() - yield keepalive_msg - last_keepalive = current_time - - # Sleep briefly before checking again - await asyncio.sleep(0.1) - - except FileNotFoundError: - logger.error(f"Log file disappeared during streaming: {self.log_path}") - yield { - "event": "error", - "data": { - "code": 4006, - "message": "日志文件在流式传输期间变得不可用", - "error": "File not found during streaming" - } - } - except Exception as e: - logger.error(f"Error during log streaming: {e}", exc_info=True) - yield { - "event": "error", - "data": { - "code": 8001, - "message": "流式传输期间发生错误", - "error": str(e) - } - } - finally: - logger.info(f"Log stream ended for {self.log_path}") - yield { - "event": "done", - "data": { - "message": "流式传输完成" - } - } - - async def watch_and_stream(self) -> AsyncGenerator[dict, None]: - """ - Watch log file and stream only new content as it's written - - This method starts from the end of the file and only streams - new content that is written after the stream starts. - - Yields: - Dict messages with event type and data: - - log events: {"event": "log", "data": {"content": "...", "timestamp": ...}} - - keepalive events: {"event": "keepalive", "data": {"timestamp": ...}} - - error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}} - - done events: {"event": "done", "data": {"message": "..."}} - """ - logger.info(f"Starting log stream (new content only) for {self.log_path}") - - # Check if file exists - if not os.path.exists(self.log_path): - logger.error(f"Log file not found: {self.log_path}") - yield { - "event": "error", - "data": { - "code": 4006, - "message": "日志文件不存在", - "error": f"File not found: {self.log_path}" - } - } - return - - try: - # Open file and seek to end to start streaming new content - with open(self.log_path, 'r', encoding='utf-8') as f: - # Move to end of file - f.seek(0, os.SEEK_END) - self.last_position = f.tell() - - last_keepalive = time.time() - - while True: - # Check if file has new content - current_position = f.tell() - - # Read new lines if available - line = f.readline() - if line: - # Clean the log entry - cleaned_line = self.clean_log_entry(line) - - # Yield log event - yield { - "event": "log", - "data": { - "content": cleaned_line.rstrip('\n'), - "timestamp": int(time.time()) - } - } - - # Update last keepalive time since we sent data - last_keepalive = time.time() - else: - # No new content, check if we need to send keepalive - current_time = time.time() - if current_time - last_keepalive >= self.keepalive_interval: - keepalive_msg = await self.send_keepalive() - yield keepalive_msg - last_keepalive = current_time - - # Sleep briefly before checking again - await asyncio.sleep(0.1) - - except FileNotFoundError: - logger.error(f"Log file disappeared during streaming: {self.log_path}") - yield { - "event": "error", - "data": { - "code": 4006, - "message": "日志文件在流式传输期间变得不可用", - "error": "File not found during streaming" - } - } - except Exception as e: - logger.error(f"Error during log streaming: {e}", exc_info=True) - yield { - "event": "error", - "data": { - "code": 8001, - "message": "流式传输期间发生错误", - "error": str(e) - } - } - finally: - logger.info(f"Log stream ended for {self.log_path}") - yield { - "event": "done", - "data": { - "message": "流式传输完成" - } - } diff --git a/app/core/memory/agent/logger_file/logger_data.py b/app/core/memory/agent/logger_file/logger_data.py deleted file mode 100644 index fb5e3e54..00000000 --- a/app/core/memory/agent/logger_file/logger_data.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Agent logger module for backward compatibility. - -This module maintains the get_named_logger() function for backward compatibility -while delegating to the centralized logging configuration. - -All new code should import directly from app.core.logging_config instead. -""" - -__version__ = "0.1.0" -__author__ = "RED_BEAR" - -from app.core.logging_config import get_agent_logger - - -def get_named_logger(name): - """Get a named logger for agent operations. - - This function maintains backward compatibility with existing code. - It delegates to the centralized get_agent_logger() function. - - Args: - name: Logger name for namespacing - - Returns: - Logger configured for agent operations - - Example: - >>> logger = get_named_logger("my_agent") - >>> logger.info("Agent operation started") - """ - return get_agent_logger(name) diff --git a/app/core/memory/agent/mcp_server/__init__.py b/app/core/memory/agent/mcp_server/__init__.py deleted file mode 100644 index 61a804c5..00000000 --- a/app/core/memory/agent/mcp_server/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -MCP Server package for memory agent. - -This package provides the FastMCP server implementation with context-based -dependency injection for tool functions. - -Package structure: -- server: FastMCP server initialization and context setup -- tools: MCP tool implementations -- models: Pydantic response models -- services: Business logic services -""" -from app.core.memory.agent.mcp_server.server import ( - mcp, - initialize_context, - main, - get_context_resource -) - -# Import tools to register them (but don't export them) -from app.core.memory.agent.mcp_server import tools - -__all__ = [ - 'mcp', - 'initialize_context', - 'main', - 'get_context_resource', -] \ No newline at end of file diff --git a/app/core/memory/agent/mcp_server/mcp_instance.py b/app/core/memory/agent/mcp_server/mcp_instance.py deleted file mode 100644 index 3a2eeb78..00000000 --- a/app/core/memory/agent/mcp_server/mcp_instance.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -MCP Server Instance - -This module contains the FastMCP server instance that is shared across all modules. -It's in a separate file to avoid circular import issues. -""" -from mcp.server.fastmcp import FastMCP - -# Initialize FastMCP server instance -# This instance is shared across all tool modules -mcp = FastMCP('data_flow') diff --git a/app/core/memory/agent/mcp_server/models/__init__.py b/app/core/memory/agent/mcp_server/models/__init__.py deleted file mode 100644 index 2d096f92..00000000 --- a/app/core/memory/agent/mcp_server/models/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Pydantic models for MCP server responses.""" - -from .problem_models import ( - ProblemBreakdownItem, - ProblemBreakdownResponse, - ExtendedQuestionItem, - ProblemExtensionResponse, -) -from .summary_models import ( - SummaryData, - SummaryResponse, - RetrieveSummaryData, - RetrieveSummaryResponse, -) -from .verification_models import VerificationResult -from .retrieval_models import RetrievalResult, DistinguishTypeResponse - -__all__ = [ - "ProblemBreakdownItem", - "ProblemBreakdownResponse", - "ExtendedQuestionItem", - "ProblemExtensionResponse", - "SummaryData", - "SummaryResponse", - "RetrieveSummaryData", - "RetrieveSummaryResponse", - "VerificationResult", - "RetrievalResult", - "DistinguishTypeResponse", -] diff --git a/app/core/memory/agent/mcp_server/models/problem_models.py b/app/core/memory/agent/mcp_server/models/problem_models.py deleted file mode 100644 index de08f3fa..00000000 --- a/app/core/memory/agent/mcp_server/models/problem_models.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Pydantic models for problem breakdown and extension operations.""" - -from typing import List, Optional -from pydantic import BaseModel, Field, RootModel - - -class ProblemBreakdownItem(BaseModel): - """Individual item in problem breakdown response.""" - - id: str - question: str - type: str - reason: Optional[str] = None - - -class ProblemBreakdownResponse(RootModel[List[ProblemBreakdownItem]]): - """Response model for problem breakdown containing list of breakdown items.""" - - pass - - -class ExtendedQuestionItem(BaseModel): - """Individual extended question item with reasoning.""" - - original_question: str = Field(..., description="原始初步问题") - extended_question: str = Field(..., description="扩展后的问题") - type: str = Field(..., description="类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)") - reason: str = Field(..., description="生成该扩展问题的理由") - - -class ProblemExtensionResponse(RootModel[List[ExtendedQuestionItem]]): - """Response model for problem extension containing list of extended questions.""" - - pass diff --git a/app/core/memory/agent/mcp_server/models/retrieval_models.py b/app/core/memory/agent/mcp_server/models/retrieval_models.py deleted file mode 100644 index e8c08c89..00000000 --- a/app/core/memory/agent/mcp_server/models/retrieval_models.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Pydantic models for retrieval operations.""" - -from typing import List, Dict, Any -from pydantic import BaseModel - - -class RetrievalResult(BaseModel): - """Result model for retrieval operation.""" - - Query: str - Expansion_issue: List[Dict[str, Any]] - - -class DistinguishTypeResponse(BaseModel): - """Response model for data type differentiation.""" - - type: str diff --git a/app/core/memory/agent/mcp_server/models/summary_models.py b/app/core/memory/agent/mcp_server/models/summary_models.py deleted file mode 100644 index bffe486a..00000000 --- a/app/core/memory/agent/mcp_server/models/summary_models.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Pydantic models for summary operations.""" - -from typing import List -from pydantic import BaseModel, Field - - -class SummaryData(BaseModel): - """Data structure for summary input.""" - - query: str - history: List[str] = Field(default_factory=list) - retrieve_info: List[str] = Field(default_factory=list) - - -class SummaryResponse(BaseModel): - """Response model for summary operation.""" - - data: SummaryData - query_answer: str - - -class RetrieveSummaryData(BaseModel): - """Data structure for retrieve summary response.""" - - query_answer: str = Field(default="") - - -class RetrieveSummaryResponse(BaseModel): - """Response model for retrieve summary operation.""" - - data: RetrieveSummaryData diff --git a/app/core/memory/agent/mcp_server/models/verification_models.py b/app/core/memory/agent/mcp_server/models/verification_models.py deleted file mode 100644 index bd8896b3..00000000 --- a/app/core/memory/agent/mcp_server/models/verification_models.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Pydantic models for verification operations.""" - -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field - - -class VerificationResult(BaseModel): - """Result model for verification operation.""" - - query: str - expansion_issue: List[Dict[str, Any]] - split_result: str - reason: Optional[str] = None - history: List[Dict[str, Any]] = Field(default_factory=list) diff --git a/app/core/memory/agent/mcp_server/server.py b/app/core/memory/agent/mcp_server/server.py deleted file mode 100644 index 6cb454ee..00000000 --- a/app/core/memory/agent/mcp_server/server.py +++ /dev/null @@ -1,161 +0,0 @@ -""" -MCP Server initialization with FastMCP context setup. - -This module initializes the FastMCP server and registers shared resources -in the context for dependency injection into tool functions. -""" -import os -import sys -from mcp.server.fastmcp import FastMCP - -from app.core.config import settings -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.agent.mcp_server.services.template_service import TemplateService -from app.core.memory.agent.mcp_server.services.search_service import SearchService -from app.core.memory.agent.mcp_server.services.session_service import SessionService -from app.core.memory.agent.mcp_server.mcp_instance import mcp - - -logger = get_agent_logger(__name__) - - -def get_context_resource(ctx, resource_name: str): - """ - Helper function to retrieve a resource from the FastMCP context. - - Args: - ctx: FastMCP Context object (passed to tool functions) - resource_name: Name of the resource to retrieve - - Returns: - The requested resource - - Raises: - AttributeError: If the resource doesn't exist - - Example: - @mcp.tool() - async def my_tool(ctx: Context): - template_service = get_context_resource(ctx, 'template_service') - llm_client = get_context_resource(ctx, 'llm_client') - """ - if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None: - raise RuntimeError("Context does not have fastmcp attribute") - - if not hasattr(ctx.fastmcp, resource_name): - raise AttributeError( - f"Resource '{resource_name}' not found in context. " - f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}" - ) - - return getattr(ctx.fastmcp, resource_name) - - -def initialize_context(): - """ - Initialize and register shared resources in FastMCP context. - - This function sets up all shared resources that will be available - to tool functions via dependency injection through the context parameter. - - Resources are stored as attributes on the FastMCP instance and can be - accessed via ctx.fastmcp in tool functions. - - Resources registered: - - session_store: RedisSessionStore for session management - - llm_client: LLM client for structured API calls - - app_settings: Application settings (renamed to avoid conflict with FastMCP settings) - - template_service: Service for template rendering - - search_service: Service for hybrid search - - session_service: Service for session operations - """ - try: - # Register Redis session store - logger.info("Registering session_store in context") - mcp.session_store = store - - # Register LLM client - try: - logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}") - llm_client = get_llm_client(SELECTED_LLM_ID) - mcp.llm_client = llm_client - logger.info("llm_client registered successfully") - except Exception as e: - logger.error(f"Failed to register llm_client: {e}", exc_info=True) - # 注册一个 None 值,避免工具调用时找不到资源 - mcp.llm_client = None - logger.warning("llm_client set to None due to initialization failure") - - # Register application settings (renamed to avoid conflict with FastMCP's settings) - logger.info("Registering app_settings in context") - mcp.app_settings = settings - - # Register template service - template_root = PROJECT_ROOT_ + '/agent/utils/prompt' - # logger.info(f"Registering template_service in context with root: {template_root}") - template_service = TemplateService(template_root) - mcp.template_service = template_service - - # Register search service - # logger.info("Registering search_service in context") - search_service = SearchService() - mcp.search_service = search_service - - # Register session service - # logger.info("Registering session_service in context") - session_service = SessionService(store) - mcp.session_service = session_service - - # logger.info("All context resources registered successfully") - - except Exception as e: - logger.error(f"Failed to initialize context: {e}", exc_info=True) - raise - - -def main(): - """ - Main entry point for the MCP server. - - Initializes context and starts the server with SSE transport. - """ - try: - # logger.info("Starting MCP server initialization") - reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True) - # Initialize context resources - initialize_context() - - # Import and register tools - # logger.info("Importing MCP tools") - from app.core.memory.agent.mcp_server.tools import ( - problem_tools, - retrieval_tools, - verification_tools, - summary_tools, - data_tools - ) - # logger.info("All MCP tools imported and registered") - - # Log registered tools for debugging - import asyncio - tools_list = asyncio.run(mcp.list_tools()) - # logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}") - # logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport") - - # Run the server with SSE transport for HTTP connections - # The server will be available at http://127.0.0.1:8081 - import uvicorn - app = mcp.sse_app() - uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info") - - except Exception as e: - logger.error(f"Failed to start MCP server: {e}", exc_info=True) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/app/core/memory/agent/mcp_server/services/__init__.py b/app/core/memory/agent/mcp_server/services/__init__.py deleted file mode 100644 index aab51c0c..00000000 --- a/app/core/memory/agent/mcp_server/services/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -MCP Server Services - -This module provides business logic services for the MCP server: -- TemplateService: Template loading and rendering -- SearchService: Search result processing -- SessionService: Session and history management -- ParameterBuilder: Tool parameter construction -""" - -from .template_service import TemplateService, TemplateRenderError -from .search_service import SearchService -from .session_service import SessionService -from .parameter_builder import ParameterBuilder - - -__all__ = [ - "TemplateService", - "TemplateRenderError", - "SearchService", - "SessionService", - "ParameterBuilder", -] diff --git a/app/core/memory/agent/mcp_server/services/parameter_builder.py b/app/core/memory/agent/mcp_server/services/parameter_builder.py deleted file mode 100644 index 0da9dd22..00000000 --- a/app/core/memory/agent/mcp_server/services/parameter_builder.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Parameter Builder for constructing tool call arguments. - -This service provides tool-specific parameter transformation logic -to build correct arguments for each tool type. -""" -import json -from typing import Any, Dict, Optional - -from app.core.logging_config import get_agent_logger - - -logger = get_agent_logger(__name__) - - -class ParameterBuilder: - """Service for building tool call arguments based on tool type.""" - - def __init__(self): - """Initialize the parameter builder.""" - logger.info("ParameterBuilder initialized") - - def build_tool_args( - self, - tool_name: str, - content: Any, - tool_call_id: str, - search_switch: str, - apply_id: str, - group_id: str, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None - ) -> Dict[str, Any]: - """ - Build tool arguments based on tool type. - - Different tools expect different argument formats: - - Verify: dict context - - Retrieve: dict context + search_switch - - Summary/Summary_fails: JSON string context - - Retrieve_Summary: unwrap nested context structures - - Input_Summary: raw message string - - Args: - tool_name: Name of the tool being invoked - content: Parsed content from previous tool result - tool_call_id: Extracted tool call identifier - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) - - Returns: - Dictionary of tool arguments ready for invocation - """ - # Base arguments common to most tools - base_args = { - "usermessages": tool_call_id, - "apply_id": apply_id, - "group_id": group_id - } - - # Always add storage_type and user_rag_memory_id (with defaults if None) - base_args["storage_type"] = storage_type if storage_type is not None else "" - base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else "" - - # Tool-specific argument construction - if tool_name == "Verify": - # Verify expects dict context - return { - "context": content if isinstance(content, dict) else {}, - **base_args - } - - elif tool_name == "Retrieve": - # Retrieve expects dict context + search_switch - return { - "context": content if isinstance(content, dict) else {}, - "search_switch": search_switch, - **base_args - } - - elif tool_name in ["Summary", "Summary_fails"]: - # Summary tools expect JSON string context - if isinstance(content, dict): - context_str = json.dumps(content, ensure_ascii=False) - elif isinstance(content, str): - context_str = content - else: - context_str = json.dumps({"data": content}, ensure_ascii=False) - - return { - "context": context_str, - **base_args - } - - elif tool_name == "Retrieve_Summary": - # Retrieve_Summary needs to unwrap nested context structures - # Handle both 'content' and 'context' keys - context_dict = content - - if isinstance(content, dict): - # Check for nested 'content' wrapper - if "content" in content: - inner = content["content"] - - # If it's a JSON string, parse it - if isinstance(inner, str): - try: - parsed = json.loads(inner) - # Check if parsed has 'context' wrapper - if isinstance(parsed, dict) and "context" in parsed: - context_dict = parsed["context"] - else: - context_dict = parsed - except json.JSONDecodeError: - logger.warning( - f"Failed to parse JSON content for {tool_name}: {inner[:100]}" - ) - context_dict = {"Query": "", "Expansion_issue": []} - elif isinstance(inner, dict): - context_dict = inner - - # Check for 'context' wrapper - elif "context" in content: - context_dict = content["context"] if isinstance(content["context"], dict) else content - - return { - "context": context_dict, - **base_args - } - - elif tool_name == "Input_Summary": - # Input_Summary expects raw message string + search_switch - # Content should be the raw message string - if isinstance(content, dict): - # Try to extract message from dict - message_str = content.get("sentence", str(content)) - else: - message_str = str(content) - - return { - "context": message_str, - "search_switch": search_switch, - **base_args - } - - else: - # Default: pass content as context - logger.warning( - f"Unknown tool name '{tool_name}', using default argument structure" - ) - return { - "context": content, - **base_args - } diff --git a/app/core/memory/agent/mcp_server/services/search_service.py b/app/core/memory/agent/mcp_server/services/search_service.py deleted file mode 100644 index 28dd82c7..00000000 --- a/app/core/memory/agent/mcp_server/services/search_service.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -Search Service for executing hybrid search and processing results. - -This service provides clean search result processing with content extraction -and deduplication. -""" -from typing import List, Tuple, Optional - -from app.core.logging_config import get_agent_logger -from app.core.memory.src.search import run_hybrid_search -from app.core.memory.utils.data.text_utils import escape_lucene_query - - -logger = get_agent_logger(__name__) - - -class SearchService: - """Service for executing hybrid search and processing results.""" - - def __init__(self): - """Initialize the search service.""" - logger.info("SearchService initialized") - - def extract_content_from_result(self, result: dict) -> str: - """ - Extract only meaningful content from search results, dropping all metadata. - - Extraction rules by node type: - - Statements: extract 'statement' field - - Entities: extract 'name' and 'fact_summary' fields - - Summaries: extract 'content' field - - Chunks: extract 'content' field - - Args: - result: Search result dictionary - - Returns: - Clean content string without metadata - """ - if not isinstance(result, dict): - return str(result) - - content_parts = [] - - # Statements: extract statement field - if 'statement' in result and result['statement']: - content_parts.append(result['statement']) - - # Summaries/Chunks: extract content field - if 'content' in result and result['content']: - content_parts.append(result['content']) - - # Entities: extract name and fact_summary (commented out in original) - # if 'name' in result and result['name']: - # content_parts.append(result['name']) - # if result.get('fact_summary'): - # content_parts.append(result['fact_summary']) - - # Return concatenated content or empty string - return '\n'.join(content_parts) if content_parts else "" - - def clean_query(self, query: str) -> str: - """ - Clean and escape query text for Lucene. - - - Removes wrapping quotes - - Removes newlines and carriage returns - - Applies Lucene escaping - - Args: - query: Raw query string - - Returns: - Cleaned and escaped query string - """ - q = str(query).strip() - - # Remove wrapping quotes - if (q.startswith("'") and q.endswith("'")) or ( - q.startswith('"') and q.endswith('"') - ): - q = q[1:-1] - - # Remove newlines and carriage returns - q = q.replace('\r', ' ').replace('\n', ' ').strip() - - # Apply Lucene escaping - q = escape_lucene_query(q) - - return q - - async def execute_hybrid_search( - self, - group_id: str, - question: str, - limit: int = 5, - search_type: str = "hybrid", - include: Optional[List[str]] = None, - rerank_alpha: float = 0.4, - output_path: str = "search_results.json", - return_raw_results: bool = False - ) -> Tuple[str, str, Optional[dict]]: - """ - Execute hybrid search and return clean content. - - Args: - group_id: Group identifier for filtering results - question: Search query text - limit: Maximum number of results to return (default: 5) - search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") - include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"]) - rerank_alpha: Weight for BM25 scores in reranking (default: 0.4) - output_path: Path to save search results (default: "search_results.json") - return_raw_results: If True, also return the raw search results as third element (default: False) - - Returns: - Tuple of (clean_content, cleaned_query, raw_results) - raw_results is None if return_raw_results=False - """ - if include is None: - include = ["statements", "chunks", "entities", "summaries"] - - # Clean query - cleaned_query = self.clean_query(question) - - try: - # Execute search - answer = await run_hybrid_search( - query_text=cleaned_query, - search_type=search_type, - group_id=group_id, - limit=limit, - include=include, - output_path=output_path, - rerank_alpha=rerank_alpha - ) - - # Extract results based on search type and include parameter - # Prioritize summaries as they contain synthesized contextual information - answer_list = [] - - # For hybrid search, use reranked_results - if search_type == "hybrid": - reranked_results = answer.get('reranked_results', {}) - - # Priority order: summaries first (most contextual), then statements, chunks, entities - priority_order = ['summaries', 'statements', 'chunks', 'entities'] - - for category in priority_order: - if category in include and category in reranked_results: - category_results = reranked_results[category] - if isinstance(category_results, list): - answer_list.extend(category_results) - else: - # For keyword or embedding search, results are directly in answer dict - # Apply same priority order - priority_order = ['summaries', 'statements', 'chunks', 'entities'] - - for category in priority_order: - if category in include and category in answer: - category_results = answer[category] - if isinstance(category_results, list): - answer_list.extend(category_results) - - # Extract clean content from all results - content_list = [ - self.extract_content_from_result(ans) - for ans in answer_list - ] - - - # Filter out empty strings and join with newlines - clean_content = '\n'.join([c for c in content_list if c]) - - # Log first 200 chars - logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...") - - # Return raw results if requested - if return_raw_results: - return clean_content, cleaned_query, answer - else: - return clean_content, cleaned_query, None - - except Exception as e: - logger.error( - f"Search failed for query '{question}' in group '{group_id}': {e}", - exc_info=True - ) - # Return empty results on failure - if return_raw_results: - return "", cleaned_query, {} - else: - return "", cleaned_query, None diff --git a/app/core/memory/agent/mcp_server/services/session_service.py b/app/core/memory/agent/mcp_server/services/session_service.py deleted file mode 100644 index b2d4f0ff..00000000 --- a/app/core/memory/agent/mcp_server/services/session_service.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Session Service for managing user sessions and conversation history. - -This service provides clean Redis interactions with error handling and -session management utilities. -""" -from typing import List, Optional - -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.redis_tool import RedisSessionStore - - -logger = get_agent_logger(__name__) - - -class SessionService: - """Service for managing user sessions and conversation history.""" - - def __init__(self, store: RedisSessionStore): - """ - Initialize the session service. - - Args: - store: Redis session store instance - """ - self.store = store - logger.info("SessionService initialized") - - def resolve_user_id(self, session_string: str) -> str: - """ - Extract user ID from session string. - - Handles formats like: - - 'call_id_user123' -> 'user123' - - 'prefix_id_user456_suffix' -> 'user456_suffix' - - Args: - session_string: Session identifier string - - Returns: - Extracted user ID - """ - try: - # Split by '_id_' and take everything after it - parts = session_string.split('_id_') - if len(parts) > 1: - return parts[1] - - # Fallback: return original string - return session_string - - except Exception as e: - logger.warning( - f"Failed to parse user ID from session string '{session_string}': {e}" - ) - return session_string - - async def get_history( - self, - user_id: str, - apply_id: str, - group_id: str - ) -> List[dict]: - """ - Retrieve conversation history from Redis. - - Args: - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - - Returns: - List of conversation history items with Query and Answer keys - Returns empty list if no history found or on error - """ - try: - history = self.store.find_user_apply_group(user_id, apply_id, group_id) - - # Validate history structure - if not isinstance(history, list): - logger.warning( - f"Invalid history format for user {user_id}, " - f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" - ) - return [] - - return history - - except Exception as e: - logger.error( - f"Failed to retrieve history for user {user_id}, " - f"apply {apply_id}, group {group_id}: {e}", - exc_info=True - ) - # Return empty list on error to allow execution to continue - return [] - - async def save_session( - self, - user_id: str, - query: str, - apply_id: str, - group_id: str, - ai_response: str - ) -> Optional[str]: - """ - Save conversation turn to Redis. - - Args: - user_id: User identifier - query: User query/message - apply_id: Application identifier - group_id: Group identifier - ai_response: AI response/answer - - Returns: - Session ID if successful, None on error - """ - try: - # Validate required fields - if not user_id: - logger.warning("Cannot save session: user_id is empty") - return None - - if not query: - logger.warning("Cannot save session: query is empty") - return None - - # Save session - session_id = self.store.save_session( - userid=user_id, - messages=query, - apply_id=apply_id, - group_id=group_id, - aimessages=ai_response - ) - - logger.info(f"Session saved successfully: {session_id}") - return session_id - - except Exception as e: - logger.error( - f"Failed to save session for user {user_id}: {e}", - exc_info=True - ) - return None - - async def cleanup_duplicates(self) -> int: - """ - Remove duplicate session entries. - - Duplicates are identified by matching: - - sessionid - - user_id (id field) - - group_id - - messages - - aimessages - - Returns: - Number of duplicate sessions deleted - """ - try: - deleted_count = self.store.delete_duplicate_sessions() - logger.info(f"Cleaned up {deleted_count} duplicate sessions") - return deleted_count - - except Exception as e: - logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True) - return 0 diff --git a/app/core/memory/agent/mcp_server/services/template_service.py b/app/core/memory/agent/mcp_server/services/template_service.py deleted file mode 100644 index 95223f0b..00000000 --- a/app/core/memory/agent/mcp_server/services/template_service.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -Template Service for loading and rendering Jinja2 templates. - -This service provides centralized template management with caching and error handling. -""" -import os -from functools import lru_cache -from typing import Optional -from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound - -from app.core.logging_config import get_agent_logger, log_prompt_rendering - - -logger = get_agent_logger(__name__) - - -class TemplateRenderError(Exception): - """Exception raised when template rendering fails.""" - - def __init__(self, template_name: str, error: Exception, variables: dict): - self.template_name = template_name - self.error = error - self.variables = variables - super().__init__( - f"Failed to render template '{template_name}': {str(error)}" - ) - - -class TemplateService: - """Service for loading and rendering Jinja2 templates with caching.""" - - def __init__(self, template_root: str): - """ - Initialize the template service. - - Args: - template_root: Root directory containing template files - """ - self.template_root = template_root - self.env = Environment( - loader=FileSystemLoader(template_root), - autoescape=False # Disable autoescape for prompt templates - ) - logger.info(f"TemplateService initialized with root: {template_root}") - - @lru_cache(maxsize=128) - def _load_template(self, template_name: str) -> Template: - """ - Load a template from disk with caching. - - Args: - template_name: Relative path to template file - - Returns: - Loaded Jinja2 Template object - - Raises: - TemplateNotFound: If template file doesn't exist - """ - try: - return self.env.get_template(template_name) - except TemplateNotFound as e: - expected_path = os.path.join(self.template_root, template_name) - logger.error( - f"Template not found: {template_name}. " - f"Expected path: {expected_path}" - ) - raise - - async def render_template( - self, - template_name: str, - operation_name: str, - **variables - ) -> str: - """ - Load and render a Jinja2 template. - - Args: - template_name: Relative path to template file - operation_name: Name for logging (e.g., "split_the_problem") - **variables: Template variables to render - - Returns: - Rendered template string - - Raises: - TemplateRenderError: If template loading or rendering fails - """ - try: - # Load template (cached) - template = self._load_template(template_name) - - # Render template - rendered = template.render(**variables) - - # Log rendered prompt - log_prompt_rendering(operation_name, rendered) - - return rendered - - except TemplateNotFound as e: - logger.error( - f"Template rendering failed for {operation_name} " - f"({template_name}): Template not found", - exc_info=True - ) - raise TemplateRenderError(template_name, e, variables) - - except Exception as e: - logger.error( - f"Template rendering failed for {operation_name} " - f"({template_name}): {e}", - exc_info=True - ) - raise TemplateRenderError(template_name, e, variables) diff --git a/app/core/memory/agent/mcp_server/tools/__init__.py b/app/core/memory/agent/mcp_server/tools/__init__.py deleted file mode 100644 index 5ce04ef3..00000000 --- a/app/core/memory/agent/mcp_server/tools/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -MCP Tools module. - -This module contains all MCP tool implementations organized by functionality. - -Tools are organized into the following modules: -- problem_tools: Question segmentation and extension -- retrieval_tools: Database and context retrieval -- verification_tools: Data verification -- summary_tools: Summarization and summary retrieval -- data_tools: Data type differentiation and writing -""" - -# Import all tool modules to register them with the MCP server -from . import problem_tools -from . import retrieval_tools -from . import verification_tools -from . import summary_tools -from . import data_tools - -__all__ = [ - 'problem_tools', - 'retrieval_tools', - 'verification_tools', - 'summary_tools', - 'data_tools', -] diff --git a/app/core/memory/agent/mcp_server/tools/data_tools.py b/app/core/memory/agent/mcp_server/tools/data_tools.py deleted file mode 100644 index 283aa6b6..00000000 --- a/app/core/memory/agent/mcp_server/tools/data_tools.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -Data Tools for data type differentiation and writing. - -This module contains MCP tools for distinguishing data types and writing data. -""" -import os - -from mcp.server.fastmcp import Context - -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse -from app.core.memory.agent.utils.write_tools import write - - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Data_type_differentiation( - ctx: Context, - context: str -) -> dict: - """ - Distinguish the type of data (read or write). - - Args: - ctx: FastMCP context for dependency injection - context: Text to analyze for type differentiation - - Returns: - dict: Contains 'context' with the original text and 'type' field - """ - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - llm_client = get_context_resource(ctx, 'llm_client') - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='distinguish_types_prompt.jinja2', - operation_name='status_typle', - user_query=context - ) - except Exception as e: - logger.error( - f"Template rendering failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "type": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=DistinguishTypeResponse - ) - - result = structured.model_dump() - - # Add context to result - result["context"] = context - - return result - - except Exception as e: - logger.error( - f"LLM call failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": f"LLM call failed: {str(e)}" - } - - except Exception as e: - logger.error( - f"Data_type_differentiation failed: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": str(e) - } - - -@mcp.tool() -async def Data_write( - ctx: Context, - content: str, - user_id: str, - apply_id: str, - group_id: str, - config_id: str -) -> dict: - """ - Write data to the database/file system. - - Args: - ctx: FastMCP context for dependency injection - content: Data content to write - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - config_id: Configuration ID for processing (optional, integer) - - Returns: - dict: Contains 'status', 'saved_to', and 'data' fields - """ - try: - # Ensure output directory exists - os.makedirs("data_output", exist_ok=True) - file_path = os.path.join("data_output", "user_data.csv") - - # Write data using utility function - try: - await write(content, user_id, apply_id, group_id, config_id=config_id) - logger.info(f"写入成功!Config ID: {config_id if config_id else 'None'}") - - return { - "status": "success", - "saved_to": file_path, - "data": content, - "config_id": config_id - } - - except Exception as e: - logger.error(f"写入失败: {e}", exc_info=True) - return { - "status": "error", - "message": str(e) - } - - except Exception as e: - logger.error( - f"Data_write failed: {e}", - exc_info=True - ) - return { - "status": "error", - "message": str(e) - } diff --git a/app/core/memory/agent/mcp_server/tools/problem_tools.py b/app/core/memory/agent/mcp_server/tools/problem_tools.py deleted file mode 100644 index 07d323a6..00000000 --- a/app/core/memory/agent/mcp_server/tools/problem_tools.py +++ /dev/null @@ -1,293 +0,0 @@ -""" -Problem Tools for question segmentation and extension. - -This module contains MCP tools for breaking down and extending user questions. -""" -import json -import time -from typing import List - -from pydantic import BaseModel, Field, RootModel -from mcp.server.fastmcp import Context - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.mcp_server.models.problem_models import ( - ProblemBreakdownItem, - ProblemBreakdownResponse, - ExtendedQuestionItem, - ProblemExtensionResponse -) -from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal - - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Split_The_Problem( - ctx: Context, - sentence: str, - sessionid: str, - messages_id: str, - apply_id: str, - group_id: str -) -> dict: - """ - Segment the dialogue or sentence into sub-problems. - - Args: - ctx: FastMCP context for dependency injection - sentence: Original sentence to split - sessionid: Session identifier - messages_id: Message identifier - apply_id: Application identifier - group_id: Group identifier - - Returns: - dict: Contains 'context' (JSON string of split results) and 'original' sentence - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - session_service = get_context_resource(ctx, 'session_service') - llm_client = get_context_resource(ctx, 'llm_client') - - # Extract user ID from session - user_id = session_service.resolve_user_id(sessionid) - - # Get conversation history - history = await session_service.get_history(user_id, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='problem_breakdown_prompt.jinja2', - operation_name='split_the_problem', - history=history, - sentence=sentence - ) - except Exception as e: - logger.error( - f"Template rendering failed for Split_The_Problem: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemBreakdownResponse - ) - - # Handle RootModel response with .root attribute access - if structured is None: - # LLM returned None, use empty list as fallback - split_result = json.dumps([], ensure_ascii=False) - elif hasattr(structured, 'root') and structured.root is not None: - split_result = json.dumps( - [item.model_dump() for item in structured.root], - ensure_ascii=False - ) - elif isinstance(structured, list): - # Fallback: treat structured itself as the list - split_result = json.dumps( - [item.model_dump() for item in structured], - ensure_ascii=False - ) - else: - # Last resort: use empty list - split_result = json.dumps([], ensure_ascii=False) - - except Exception as e: - logger.error( - f"LLM call failed for Split_The_Problem: {e}", - exc_info=True - ) - split_result = json.dumps([], ensure_ascii=False) - - logger.info(f"问题拆分") - logger.info(f"问题拆分结果==>>:{split_result}") - - # Emit intermediate output for frontend - result = { - "context": split_result, - "original": sentence, - "_intermediate": { - "type": "problem_split", - "data": json.loads(split_result) if split_result else [], - "original_query": sentence - } - } - - return result - - except Exception as e: - logger.error( - f"Split_The_Problem failed: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('问题拆分', duration) - - -@mcp.tool() -async def Problem_Extension( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Extend the problem with additional sub-questions. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing split problem results - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'context' (aggregated questions) and 'original' question - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - session_service = get_context_resource(ctx, 'session_service') - llm_client = get_context_resource(ctx, 'llm_client') - - # Resolve session ID from usermessages - from app.core.memory.agent.utils.messages_tool import Resolve_username - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Process context to extract questions - extent_quest, original = await Problem_Extension_messages_deal(context) - - # Format questions for template rendering - questions_formatted = [] - for msg in extent_quest: - if msg.get("role") == "user": - questions_formatted.append(msg.get("content", "")) - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='Problem_Extension_prompt.jinja2', - operation_name='problem_extension', - history=history, - questions=questions_formatted - ) - except Exception as e: - logger.error( - f"Template rendering failed for Problem_Extension: {e}", - exc_info=True - ) - return { - "context": {}, - "original": original, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - response_content = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemExtensionResponse - ) - - # Aggregate results by original question - aggregated_dict = {} - for item in response_content.root: - key = getattr(item, "original_question", None) or ( - item.get("original_question") if isinstance(item, dict) else None - ) - value = getattr(item, "extended_question", None) or ( - item.get("extended_question") if isinstance(item, dict) else None - ) - if not key or not value: - continue - aggregated_dict.setdefault(key, []).append(value) - - except Exception as e: - logger.error( - f"LLM call failed for Problem_Extension: {e}", - exc_info=True - ) - aggregated_dict = {} - - logger.info(f"问题扩展") - logger.info(f"问题扩展==>>:{aggregated_dict}") - - # Emit intermediate output for frontend - result = { - "context": aggregated_dict, - "original": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "problem_extension", - "data": aggregated_dict, - "original_query": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - return result - - except Exception as e: - logger.error( - f"Problem_Extension failed: {e}", - exc_info=True - ) - return { - "context": {}, - "original": context.get("original", ""), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('问题扩展', duration) diff --git a/app/core/memory/agent/mcp_server/tools/retrieval_tools.py b/app/core/memory/agent/mcp_server/tools/retrieval_tools.py deleted file mode 100644 index 3639742a..00000000 --- a/app/core/memory/agent/mcp_server/tools/retrieval_tools.py +++ /dev/null @@ -1,282 +0,0 @@ -""" -Retrieval Tools for database and context retrieval. - -This module contains MCP tools for retrieving data using hybrid search. -""" -from dotenv import load_dotenv -import os - -from app.core.rag.nlp.search import knowledge_retrieval - -# 加载.env文件 -load_dotenv() -import time -from typing import List - -from mcp.server.fastmcp import Context - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs -from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal - - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Retrieve( - ctx: Context, - context, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Retrieve data from the database using hybrid search. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary or string containing query information - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'context' with Query and Expansion_issue results - """ - kb_config = { - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id": os.getenv('reranker_id'), - "reranker_top_k": 10 - } - start = time.time() - logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - - try: - # Extract services from context - search_service = get_context_resource(ctx, 'search_service') - - databases_anser = [] - - # Handle both dict and string context - if isinstance(context, dict): - # Process dict context with extended questions - all_items = [] - content, original = await Retriev_messages_deal(context) - - # Extract all query items from content - # content is like {original_question: [extended_questions...], ...} - for key, values in content.items(): - if isinstance(values, list): - all_items.extend(values) - elif isinstance(values, str): - all_items.append(values) - elif values is not None: - # Fallback: convert non-empty non-list values to string - all_items.append(str(values)) - - # Execute search for each question - for idx, question in enumerate(all_items): - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": question, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query=question - raw_results=clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results='' - cleaned_query = question - logger.info(f"知识库没有检索的内容{user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params) - - databases_anser.append({ - "Query_small": cleaned_query, - "Result_small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": idx + 1, - "total": len(all_items) - } - }) - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for question '{question}': {e}", - exc_info=True - ) - # Continue with empty result for this question - databases_anser.append({ - "Query_small": question, - "Result_small": "" - }) - - # Build initial database data structure - databases_data = { - "Query": original, - "Expansion_issue": databases_anser - } - - # Collect intermediate outputs before deduplication - intermediate_outputs = [] - for item in databases_anser: - if '_intermediate' in item: - intermediate_outputs.append(item['_intermediate']) - - # Deduplicate and merge results - deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) - deduplicated_data_merged = merge_to_key_value_pairs( - deduplicated_data, - 'Query_small', - 'Result_small' - ) - - # Restructure for Verify/Retrieve_Summary compatibility - keys, val = [], [] - for item in deduplicated_data_merged: - for items_key, items_value in item.items(): - keys.append(items_key) - val.append(items_value) - - send_verify = [] - for i, j in zip(keys, val): - send_verify.append({ - "Query_small": i, - "Answer_Small": j - }) - - dup_databases = { - "Query": original, - "Expansion_issue": send_verify, - "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs - } - - logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") - - else: - # Handle string context (simple query) - query = str(context).strip() - - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query = query - raw_results = clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results = '' - cleaned_query = query - logger.info(f"知识库没有检索的内容{user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params) - # Keep structure for Verify/Retrieve_Summary compatibility - dup_databases = { - "Query": cleaned_query, - "Expansion_issue": [{ - "Query_small": cleaned_query, - "Answer_Small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": 1, - "total": 1 - } - }] - } - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for query '{query}': {e}", - exc_info=True - ) - # Return empty results on failure - dup_databases = { - "Query": query, - "Expansion_issue": [] - } - - logger.info( - f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, " - f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}" - ) - - # Build result with intermediate outputs - result = { - "context": dup_databases, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - # Add intermediate outputs list if they exist - intermediate_outputs = dup_databases.get('_intermediate_outputs', []) - if intermediate_outputs: - result['_intermediates'] = intermediate_outputs - logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result") - else: - logger.warning("No intermediate outputs found in dup_databases") - - return result - - except Exception as e: - logger.error( - f"Retrieve failed: {e}", - exc_info=True - ) - return { - "context": { - "Query": "", - "Expansion_issue": [] - }, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('检索', duration) diff --git a/app/core/memory/agent/mcp_server/tools/summary_tools.py b/app/core/memory/agent/mcp_server/tools/summary_tools.py deleted file mode 100644 index 4d0d77d4..00000000 --- a/app/core/memory/agent/mcp_server/tools/summary_tools.py +++ /dev/null @@ -1,647 +0,0 @@ -""" -Summary Tools for data summarization. - -This module contains MCP tools for summarizing retrieved data and generating responses. -""" -import json -import re -import time -from typing import List - -from pydantic import BaseModel, Field -from mcp.server.fastmcp import Context - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.mcp_server.models.summary_models import ( - SummaryData, - SummaryResponse, - RetrieveSummaryData, - RetrieveSummaryResponse -) -from app.core.memory.agent.utils.messages_tool import ( - Summary_messages_deal, - Resolve_username -) -from app.core.rag.nlp.search import knowledge_retrieval -from dotenv import load_dotenv -import os - -# 加载.env文件 -load_dotenv() -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Summary( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Summarize the verified data. - - Args: - ctx: FastMCP context for dependency injection - context: JSON string containing verified data - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - session_service = get_context_resource(ctx, 'session_service') - llm_client = get_context_resource(ctx, 'llm_client') - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Process context to extract answer and query - answer_small, query = await Summary_messages_deal(context) - - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - # Prepare data for template - data = { - "query": query, - "history": history, - "retrieve_info": answer_small - } - - except Exception as e: - logger.error( - f"Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='summary_prompt.jinja2', - operation_name='summary', - data=data, - query=query - ) - except Exception as e: - logger.error( - f"Template rendering failed for Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=SummaryResponse - ) - - aimessages = structured.query_answer or "" - - except Exception as e: - logger.error( - f"LLM call failed for Summary: {e}", - exc_info=True - ) - aimessages = "" - - try: - # Save session - if aimessages != "": - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}", - exc_info=True - ) - return { - "status": "error", - "message": str(e) - } - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"验证之后的总结==>>:{aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('总结', duration) - - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - -@mcp.tool() -async def Retrieve_Summary( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Summarize data directly from retrieval results. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing Query and Expansion_issue from Retrieve - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - session_service = get_context_resource(ctx, 'session_service') - llm_client = get_context_resource(ctx, 'llm_client') - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - - - # Handle both 'content' and 'context' keys (LangGraph uses 'content') - if isinstance(context, dict): - if "content" in context: - inner = context["content"] - # If it's a JSON string, parse it - if isinstance(inner, str): - try: - parsed = json.loads(inner) - logger.info(f"Retrieve_Summary: successfully parsed JSON") - except json.JSONDecodeError: - # Try unescaping first - try: - unescaped = inner.encode('utf-8').decode('unicode_escape') - parsed = json.loads(unescaped) - logger.info(f"Retrieve_Summary: parsed after unescaping") - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error( - f"Retrieve_Summary: parsing failed even after unescape: {e}" - ) - context_dict = {"Query": "", "Expansion_issue": []} - parsed = None - - if parsed: - # Check if parsed has 'context' wrapper - if isinstance(parsed, dict) and "context" in parsed: - context_dict = parsed["context"] - else: - context_dict = parsed - elif isinstance(inner, dict): - context_dict = inner - else: - context_dict = {"Query": "", "Expansion_issue": []} - elif "context" in context: - context_dict = context["context"] if isinstance(context["context"], dict) else context - else: - context_dict = context - else: - context_dict = {"Query": "", "Expansion_issue": []} - - query = context_dict.get("Query", "") - expansion_issue = context_dict.get("Expansion_issue", []) - - # Extract retrieve_info from expansion_issue - retrieve_info = [] - for item in expansion_issue: - # Check for both Answer_Small and Answer_Samll (typo) for backward compatibility - answer = None - if isinstance(item, dict): - if "Answer_Small" in item: - answer = item["Answer_Small"] - elif "Answer_Samll" in item: - answer = item["Answer_Samll"] - - if answer is not None: - # Handle both string and list formats - if isinstance(answer, list): - # Join list of characters/strings into a single string - retrieve_info.append(''.join(str(x) for x in answer)) - elif isinstance(answer, str): - retrieve_info.append(answer) - else: - retrieve_info.append(str(answer)) - - # Join all retrieve_info into a single string - retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else "" - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - - except Exception as e: - logger.error( - f"Retrieve_Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='Retrieve_Summary_prompt.jinja2', - operation_name='retrieve_summary', - query=query, - history=history, - retrieve_info=retrieve_info_str - ) - except Exception as e: - logger.error( - f"Template rendering failed for Retrieve_Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=RetrieveSummaryResponse - ) - - # Handle case where structured response might be None or incomplete - if structured and hasattr(structured, 'data') and structured.data: - aimessages = structured.data.query_answer or "" - else: - logger.warning("Structured response is None or incomplete, using default message") - aimessages = "信息不足,无法回答" - - - # Check for insufficient information response - if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="": - # Save session - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"Retrieve_Summary: LLM call failed: {e}", - exc_info=True - ) - aimessages = "" - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"检索之后的总结==>>:{aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('检索总结', duration) - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "retrieval_summary", - "summary": aimessages, - "query": query, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - -@mcp.tool() -async def Input_Summary( - ctx: Context, - context: str, - usermessages: str, - search_switch: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Generate a quick summary for direct input without verification. - - Args: - ctx: FastMCP context for dependency injection - context: String containing the input sentence - usermessages: User messages identifier - search_switch: Search switch value for routing ('2' for summaries only) - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'query_answer' with the summary result - """ - start = time.time() - logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - - # Initialize variables to avoid UnboundLocalError - - - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - session_service = get_context_resource(ctx, 'session_service') - llm_client = get_context_resource(ctx, 'llm_client') - search_service = get_context_resource(ctx, 'search_service') - - # Check if llm_client is None - if llm_client is None: - error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable." - logger.error(error_msg) - return error_msg - - # Resolve session ID - sessionid = Resolve_username(usermessages) or "" - sessionid = sessionid.replace('call_id_', '') - - # Get conversation history - history = await session_service.get_history( - str(sessionid), - str(apply_id), - str(group_id) - ) - # Override with empty list for now (as in original) - - # Log the raw context for debugging - logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}") - - # Extract sentence from context - # Context can be a string or might contain the sentence in various formats - try: - # Try to parse as JSON first - if isinstance(context, str) and (context.startswith('{') or context.startswith('[')): - try: - import json - context_dict = json.loads(context) - if isinstance(context_dict, dict): - query = context_dict.get('sentence', context_dict.get('content', context)) - else: - query = context - except json.JSONDecodeError: - # Not valid JSON, try regex - match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context) - query = match.group(1) if match else context - else: - query = context - except Exception as e: - logger.warning(f"Failed to extract query from context: {e}") - query = context - - # Clean query - query = str(query).strip().strip("\"'") - - logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}") - - # Execute search based on search_switch and storage_type - try: - logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}") - - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - - '''检索''' - if search_switch == '2': - search_params["include"] = ["summaries"] - if storage_type == "rag" and user_rag_memory_id: - raw_results = [] - retrieve_info = "" - kb_config={ - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id":os.getenv('reranker_id'), - "reranker_top_k": 10 - } - - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - retrieve_info = '\n\n'.join(retrieval_knowledge) - raw_results=[retrieve_info] - logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}") - except: - retrieve_info='' - raw_results=[''] - logger.info(f"知识库没有检索的内容{user_rag_memory_id}") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params) - logger.info(f"Input_Summary: 使用 summary 进行检索") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params) - - except Exception as e: - logger.error( - f"Input_Summary: hybrid_search failed, using empty results: {e}", - exc_info=True - ) - retrieve_info, question, raw_results = "", query, [] - - - # Render template - system_prompt = await template_service.render_template( - template_name='Retrieve_Summary_prompt.jinja2', - operation_name='input_summary', - query=query, - history=history, - retrieve_info=retrieve_info - ) - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=RetrieveSummaryResponse - ) - aimessages = structured.data.query_answer or "信息不足,无法回答" - except Exception as e: - logger.error( - f"Input_Summary: response_structured failed, using default answer: {e}", - exc_info=True - ) - aimessages = "信息不足,无法回答" - - logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}") - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "input_summary", - "title": "快速答案", - "summary": aimessages, - "query": query, - "raw_results": raw_results, - "search_mode": "quick_search", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Input_Summary failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "信息不足,无法回答", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('检索', duration) - - -@mcp.tool() -async def Summary_fails( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Handle workflow failure when summary cannot be generated. - - Args: - ctx: FastMCP context for dependency injection - context: Failure context string - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'query_answer' with failure message - """ - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Parse session ID from usermessages - usermessages_parts = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages_parts[:-1]) - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - logger.info(f"没有相关数据") - logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}") - - return { - "status": "success", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - except Exception as e: - logger.error( - f"Summary_fails failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } diff --git a/app/core/memory/agent/mcp_server/tools/verification_tools.py b/app/core/memory/agent/mcp_server/tools/verification_tools.py deleted file mode 100644 index 652386c7..00000000 --- a/app/core/memory/agent/mcp_server/tools/verification_tools.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Verification Tools for data verification. - -This module contains MCP tools for verifying retrieved data. -""" -import time - -from jinja2 import Template -from mcp.server.fastmcp import Context - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.verify_tool import VerifyTool -from app.core.memory.agent.utils.messages_tool import ( - Verify_messages_deal, - Retrieve_verify_tool_messages_deal, - Resolve_username -) -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ - - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Verify( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Verify the retrieved data. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing query and expansion issues - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'verified_data' with verification results - """ - start = time.time() - - - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Load verification prompt template - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2' - - # Read template file directly (VerifyTool expects raw template content) - from app.core.memory.agent.utils.messages_tool import read_template_file - system_prompt = await read_template_file(file_path) - - - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - - template = Template(system_prompt) - system_prompt = template.render(history=history, sentence=context) - - # Process context to extract query and results - Query_small, Result_small, query = await Verify_messages_deal(context) - - # Build query list for verification - query_list = [] - for query_small, anser in zip(Query_small, Result_small): - query_list.append({ - 'Query_small': query_small, - 'Answer_Small': anser - }) - - messages = { - "Query": query, - "Expansion_issue": query_list - } - - - - # Call verification workflow - verify_tool = VerifyTool(system_prompt, messages) - verify_result = await verify_tool.verify() - - # Parse LLM verification result with error handling - try: - messages_deal = await Retrieve_verify_tool_messages_deal( - verify_result, - history, - query - ) - except Exception as e: - logger.error( - f"Retrieve_verify_tool_messages_deal parsing failed: {e}", - exc_info=True - ) - # Fallback to avoid 500 errors - messages_deal = { - "data": { - "query": query, - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": history, - } - - logger.info(f"验证==>>:{messages_deal}") - - # Emit intermediate output for frontend - return { - "status": "success", - "verified_data": messages_deal, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "verification", - "title": "数据验证", - "result": messages_deal.get("split_result", "unknown"), - "reason": messages_deal.get("reason", ""), - "query": query, - "verified_count": len(query_list), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Verify failed: {e}", - exc_info=True - ) - return { - "status": "error", - "message": str(e), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "verified_data": { - "data": { - "query": "", - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": [], - } - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('验证', duration) diff --git a/app/core/memory/agent/utils/__init__.py b/app/core/memory/agent/utils/__init__.py deleted file mode 100644 index 2b77e240..00000000 --- a/app/core/memory/agent/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Agent utilities.""" - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor - -__all__ = [ - "MultimodalProcessor", -] diff --git a/app/core/memory/agent/utils/get_dialogs.py b/app/core/memory/agent/utils/get_dialogs.py deleted file mode 100644 index b03fe57c..00000000 --- a/app/core/memory/agent/utils/get_dialogs.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import json -from typing import List -from datetime import datetime - -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker -from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage - - -async def get_chunked_dialogs( - chunker_strategy: str = "RecursiveChunker", - group_id: str = "group_1", - user_id: str = "user1", - apply_id: str = "applyid", - content: str = "这是用户的输入", - ref_id: str = "wyl_20251027", - config_id: str = None -) -> List[DialogData]: - """Generate chunks from all test data entries using the specified chunker strategy. - - Args: - chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - group_id: Group identifier - user_id: User identifier - apply_id: Application identifier - content: Dialog content - ref_id: Reference identifier - config_id: Configuration ID for processing - - Returns: - List of DialogData objects with generated chunks for each test entry - """ - dialog_data_list = [] - messages = [] - - messages.append(ConversationMessage(role="用户", msg=content)) - - # Create DialogData - conversation_context = ConversationContext(msgs=messages) - # Create DialogData with group_id based on the entry's id for uniqueness - dialog_data = DialogData( - context=conversation_context, - ref_id=ref_id, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, - config_id=config_id - ) - # Create DialogueChunker and process the dialogue - chunker = DialogueChunker(chunker_strategy) - extracted_chunks = await chunker.process_dialogue(dialog_data) - dialog_data.chunks = extracted_chunks - - dialog_data_list.append(dialog_data) - - # Convert to dict with datetime serialized - def serialize_datetime(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - combined_output = [dd.model_dump() for dd in dialog_data_list] - - print(dialog_data_list) - - # with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f: - # json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime) - - - return dialog_data_list diff --git a/app/core/memory/agent/utils/llm_tools.py b/app/core/memory/agent/utils/llm_tools.py deleted file mode 100644 index e314dd09..00000000 --- a/app/core/memory/agent/utils/llm_tools.py +++ /dev/null @@ -1,204 +0,0 @@ -import asyncio -import json -from collections import defaultdict -from typing import TypedDict, Annotated -import os -import logging - -from jinja2 import Template -from langchain_core.messages import AnyMessage -from dotenv import load_dotenv -from langgraph.graph import add_messages -from openai import OpenAI - -from app.core.memory.agent.utils.messages_tool import read_template_file -from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME -from app.core.models.base import RedBearModelConfig -from app.core.memory.src.llm_tools.openai_client import OpenAIClient - -PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -logger = logging.getLogger(__name__) - -load_dotenv() - -#TODO: Refactor entire picture/voice -# async def LLM_model_request(context,data,query): -# ''' -# Agent model request -# Args: -# context:Input request -# data: template parameters -# query:request content -# Returns: - -# ''' -# template = Template(context) -# system_prompt = template.render(**data) -# llm_client = get_llm_client(SELECTED_LLM_ID) -# result = await llm_client.chat( -# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}] -# ) -# return result - -async def picture_model_requests(image_url): - ''' - - Args: - image_url: - Returns: - - ''' - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 ' - system_prompt = await read_template_file(file_path) - result = await Picture_recognize(image_url,system_prompt) - return (result) -class WriteState(TypedDict): - ''' - Langgrapg Writing TypedDict - ''' - messages: Annotated[list[AnyMessage], add_messages] - user_id:str - apply_id:str - group_id:str - -class ReadState(TypedDict): - ''' - Langgrapg READING TypedDict - name: - id:user id - loop_count:Traverse times - search_switch:type - config_id: configuration id for filtering results - ''' - messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息 - name: str - id: str - loop_count:int - search_switch: str - user_id: str - apply_id: str - group_id: str - config_id: str - - -class COUNTState: - ''' - The number of times the workflow dialogue retrieval content has no correct message recall traversal - ''' - def __init__(self, limit: int = 5): - self.total: int = 0 # 当前累加值 - self.limit: int = limit # 最大上限 - - def add(self, value: int = 1): - """累加数字,如果达到上限就保持最大值""" - self.total += value - print(f"[COUNTState] 当前值: {self.total}") - if self.total >= self.limit: - print(f"[COUNTState] 达到上限 {self.limit}") - self.total = self.limit # 达到上限不再增加 - - def get_total(self) -> int: - """获取当前累加值""" - return self.total - - def reset(self): - """手动重置累加值""" - self.total = 0 - print(f"[COUNTState] 已重置为 0") - - - -# def embed(texts: list[str]) -> list[list[float]]: -# # 这里可以换成 LangChain Embeddings -# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts] - - -# def export_store_to_json(store, namespace): -# """Export the entire storage content to a JSON file""" -# # 搜索所有存储项 -# all_items = store.search(namespace) - -# # 整理数据 -# export_data = {} -# for item in all_items: -# if hasattr(item, 'key') and hasattr(item, 'value'): -# export_data[item.key] = item.value - -# # 保存到文件 -# os.makedirs("memory_logs", exist_ok=True) -# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f: -# json.dump(export_data, f, ensure_ascii=False, indent=2) - -# print(f"{len(export_data)} 条记忆到 JSON 文件") - -def merge_to_key_value_pairs(data, query_key, result_key): - grouped = defaultdict(list) - for item in data: - grouped[item[query_key]].append(item[result_key]) - return [{key: values} for key, values in grouped.items()] - -def deduplicate_entries(entries): - seen = set() - deduped = [] - for entry in entries: - key = (entry['Query_small'], entry['Result_small']) - if key not in seen: - seen.add(key) - deduped.append(entry) - return deduped - - - -async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str: - try: - model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base=model_config['api_base'] - - logger.info(f"model_name: {backend_model_name}") - logger.info(f"api_key set: {'yes' if api_key else 'no'}") - logger.info(f"base_url: {model_config['api_base']}") - - client = OpenAI( - api_key=api_key, base_url=api_base, - ) - completion = client.chat.completions.create( - model=backend_model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url":image_path, - }, - {"type": "text", - "text": PROMPT_TICKET_EXTRACTION} - ] - } - ]) - picture_text = completion.choices[0].message.content - picture_text = picture_text.replace('```json', '').replace('```', '') - picture_text = json.loads(picture_text) - return (picture_text['statement']) - -async def Voice_recognize(): - try: - model_config = get_voice_config(SELECTED_LLM_VOICE_NAME) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base = model_config['api_base'] - return api_key,backend_model_name,api_base - - diff --git a/app/core/memory/agent/utils/mcp_tools.py b/app/core/memory/agent/utils/mcp_tools.py deleted file mode 100644 index e32b8ef3..00000000 --- a/app/core/memory/agent/utils/mcp_tools.py +++ /dev/null @@ -1,15 +0,0 @@ -from app.core.config import settings - -def get_mcp_server_config(): - """ - Get the MCP server configuration - """ - mcp_server_config = { - "data_flow": { - "url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口 - "transport": "sse", - "timeout": 15000, - "sse_read_timeout": 15000, - } - } - return mcp_server_config diff --git a/app/core/memory/agent/utils/messages_tool.py b/app/core/memory/agent/utils/messages_tool.py deleted file mode 100644 index 273bc719..00000000 --- a/app/core/memory/agent/utils/messages_tool.py +++ /dev/null @@ -1,239 +0,0 @@ -import json -import logging -import re -from typing import List, Any - -from langchain_core.messages import AnyMessage -from app.core.logging_config import get_agent_logger - -logger = get_agent_logger(__name__) - - -def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]: - out = [] - for m in msgs: - if hasattr(m, "content"): - out.append({"role": "user", "content": getattr(m, "content", "")}) - elif isinstance(m, dict) and "role" in m and "content" in m: - out.append(m) - else: - out.append({"role": "user", "content": str(m)}) - return out - - -def _extract_content(resp: Any) -> str: - """Extract LLM content and sanitize to raw JSON/text. - - - Supports both object and dict response shapes. - - Removes leading role labels (e.g., "Assistant:"). - - Strips Markdown code fences like ```json ... ```. - - Attempts to isolate the first valid JSON array/object block when extra text is present. - """ - - def _to_text(r: Any) -> str: - try: - # 对象形式: resp.choices[0].message.content - if hasattr(r, "choices") and getattr(r, "choices", None): - msg = r.choices[0].message - if hasattr(msg, "content"): - return msg.content - if isinstance(msg, dict) and "content" in msg: - return msg["content"] - # 字典形式: resp["choices"][0]["message"]["content"] - if isinstance(r, dict): - return r.get("choices", [{}])[0].get("message", {}).get("content", "") - except Exception: - pass - return str(r) - - def _clean_text(text: str) -> str: - s = str(text).strip() - # 移除可能的角色前缀 - s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s) - # 提取 ```json ... ``` 代码块 - m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I) - if m: - s = m.group(1).strip() - # 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段 - if not (s.startswith("{") or s.startswith("[")): - left = s.find("[") - right = s.rfind("]") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - else: - left = s.find("{") - right = s.rfind("}") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - return s - - raw = _to_text(resp) - return _clean_text(raw) - -def Resolve_username(usermessages): - ''' - Extract username - Args: - usermessages: user name - - Returns: - - ''' - usermessages = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages[:-1]) - return sessionid - - -# TODO: USE app.core.memory.src.utils.render_template instead -async def read_template_file(template_path: str) -> str: - """ - 读取模板文件 - - Args: - template_path: 模板文件路径 - - Returns: - 模板内容字符串 - - Note: - 建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能 - """ - try: - with open(template_path, "r", encoding="utf-8") as f: - return f.read() - except FileNotFoundError: - logger.error(f"模板文件未找到: {template_path}") - raise - except IOError as e: - logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True) - raise - - -async def Problem_Extension_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - extent_quest = [] - original = context.get('original', '') - messages = context.get('context', '') - messages = json.loads(messages) - for message in messages: - question = message.get('question', '') - type = message.get('type', '') - extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"}) - - return extent_quest, original - - -async def Retriev_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - if isinstance(context, dict): - if 'context' in context or 'original' in context: - return context.get('context', {}), context.get('original', '') - return content, original_value - -async def Verify_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - - query = context['context']['Query'] - Query_small_list = context['context']['Expansion_issue'] - Result_small = [] - Query_small = [] - for i in Query_small_list: - Result_small.append(i['Answer_Small'][0]) - Query_small.append(i['Query_small']) - return Query_small, Result_small, query - - -async def Summary_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - query = re.findall(r'"query": (.*?),', messages)[0] - query = query.replace('[', '').replace(']', '').strip() - matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages) - answer_small_texts = [] - for m in matches: - try: - parsed = json.loads(m) - for item in parsed: - answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', '')) - except Exception: - answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', '')) - - return answer_small_texts, query - - -async def VerifyTool_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - content_messages = messages.split('"context":')[1].replace('""', '"') - messages = str(content_messages).split("name='Retrieve'")[0] - query = re.findall(f'"Query": "(.*?)"', messages)[0] - Query_small = re.findall(f'"Query_small": "(.*?)"', messages) - Result_small = re.findall(f'"Result_small": "(.*?)"', messages) - return Query_small, Result_small, query - - -async def Retrieve_Summary_messages_deal(context): - pass - - -async def Retrieve_verify_tool_messages_deal(context, history, query): - ''' - Extract data - Args: - context: - Returns: - ''' - results = [] - # 统一转为字符串,避免 None 或非字符串导致正则报错 - text = str(context) - blocks = re.findall(r'\{(.*?)\}', text, flags=re.S) - for block in blocks: - query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block) - answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block) - status = re.search(r'"status"\s*:\s*"([^"]*)"', block) - query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block) - - results.append({ - "query_small": query_small.group(1) if query_small else None, - "answer_small": answer_small.group(1) if answer_small else None, - # 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误 - "status": status.group(1) if status else "", - "query_answer": query_answer.group(1) if query_answer else None - }) - result = [] - for r in results: - # 统一按字符串判定状态,兼容大小写和缺失情况 - status_str = str(r.get('status', '')).strip().lower() - if status_str == 'false': - continue - else: - result.append(r) - split_result = 'failed' if not result else 'success' - result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "", - "history": history} - return result diff --git a/app/core/memory/agent/utils/model_tool.py b/app/core/memory/agent/utils/model_tool.py deleted file mode 100644 index 969a2a91..00000000 --- a/app/core/memory/agent/utils/model_tool.py +++ /dev/null @@ -1,38 +0,0 @@ - - -# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -# sys.path.insert(0, project_root) - -# load_dotenv() - -# async def llm_client_chat(messages: List[dict]) -> str: -# """使用 OpenAI 兼容接口进行对话,返回内容字符串。""" -# try: -# cfg = get_model_config(SELECTED_LLM_ID) -# rb_config = RedBearModelConfig( -# model_name=cfg["model_name"], -# provider=cfg["provider"], -# api_key=cfg["api_key"], -# base_url=cfg["base_url"], -# ) -# client = OpenAIClient(model_config=rb_config, type_="chat") - -# except Exception as e: -# logger.error(f"获取模型配置失败:{e}") -# err = f"获取模型配置失败:{str(e)}。请检查!!!" -# return err -# try: -# response = await client.chat(messages) -# print(f"model_tool's llm_client_chat response ======>:\n {response}") -# return _extract_content(response) -# # return _extract_content(result) -# except Exception as e: -# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。") -# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。" - -# async def main(image_url): -# await llm_client_chat(image_url) -# -# # 运行主函数 -# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav'])) -# diff --git a/app/core/memory/agent/utils/multimodal.py b/app/core/memory/agent/utils/multimodal.py deleted file mode 100644 index 5beaf892..00000000 --- a/app/core/memory/agent/utils/multimodal.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Multimodal input processor for handling image and audio content. - -This module provides utilities for detecting and processing multimodal inputs -(images and audio files) by converting them to text using appropriate models. -""" - -import logging -from typing import List - -from app.core.memory.agent.multimodal.speech_model import Vico_recognition -from app.core.memory.agent.utils.llm_tools import picture_model_requests - -logger = logging.getLogger(__name__) - - -class MultimodalProcessor: - """ - Processor for handling multimodal inputs (images and audio). - - This class detects image and audio file paths in input content and converts - them to text using appropriate recognition models. - """ - - # Supported file extensions - IMAGE_EXTENSIONS = ['.jpg', '.png'] - AUDIO_EXTENSIONS = [ - 'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov', - 'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv' - ] - - def __init__(self): - """Initialize the multimodal processor.""" - pass - - def is_image(self, content: str) -> bool: - """ - Check if content is an image file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported image extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_image("photo.jpg") - True - >>> processor.is_image("document.pdf") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS) - - def is_audio(self, content: str) -> bool: - """ - Check if content is an audio file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported audio extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_audio("recording.mp3") - True - >>> processor.is_audio("video.mp4") - True - >>> processor.is_audio("document.txt") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS) - - async def process_input(self, content: str) -> str: - """ - Process input content, converting images/audio to text if needed. - - This method detects if the input is an image or audio file and converts - it to text using the appropriate recognition model. If processing fails - or the content is not multimodal, it returns the original content. - - Args: - content: Input string (may be file path or regular text) - - Returns: - Text content (original or converted from image/audio) - - Examples: - >>> processor = MultimodalProcessor() - >>> await processor.process_input("photo.jpg") - "Recognized text from image..." - - >>> await processor.process_input("Hello world") - "Hello world" - """ - if not isinstance(content, str): - logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}") - return str(content) - - try: - # Check for image input - if self.is_image(content): - logger.info(f"[MultimodalProcessor] Detected image input: {content}") - result = await picture_model_requests(content) - logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...") - return result - - # Check for audio input - if self.is_audio(content): - logger.info(f"[MultimodalProcessor] Detected audio input: {content}") - result = await Vico_recognition([content]).run() - logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...") - return result - - except Exception as e: - logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True) - logger.info(f"[MultimodalProcessor] Falling back to original content") - return content - - # Return original content if not multimodal - return content diff --git a/app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 b/app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 deleted file mode 100644 index a0e21fbd..00000000 --- a/app/core/memory/agent/utils/prompt/Problem_Extension_prompt.jinja2 +++ /dev/null @@ -1,81 +0,0 @@ - -你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则: - -角色: -- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。 -- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。 -- 如果历史信息或上下文与当前问题无关,可忽略。 - ---- - -### 历史信息参考 -在生成扩展问题时,你可以参考以下历史数据(如果提供): -- 历史对话或任务的主题; -- 历史中出现的关键实体(时间、人物、地点、研究主题等); -- 历史中已解答的问题(避免重复); -- 历史推理链(保持逻辑一致性)。 - -> 如果没有提供历史信息,则仅根据当前输入问题进行分析。 -输入历史信息内容:{{history}} - -## User Input -{% if questions is string %} -{{ questions }} -{% else %} -{% for question in questions %} -- {{ question }} -{% endfor %} -{% endif %} - -需求: -- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。 -- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。 -- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。 -- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。 -- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。 -- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。 -- 子问题数量不超过4个。 -- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。 - 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] - 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? - - - -输出要求: -- 仅输出 JSON 数组,不要包含任何解释或代码块。 -- 每个元素包含: - - `original_question`: 原始问题 - - `extended_question`: 扩展后的问题 - - `type`: 类型(事实检索/澄清/定义/比较/行动建议) - - `reason`: 生成该扩展问题的简短理由 -- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。 - -示例: -输入: -[ - "问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳", -] - -输出: -[ - { - "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", - "extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?", - "type": "多跳", - "reason": "输出原问题的关键要素" - }, - { - "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", - "extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?", - "type": "多跳", - "reason": "输出原问题的关键要素" - } -] -**Output format** -**CRITICAL JSON FORMATTING REQUIREMENTS:** -1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes -2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") -3. Ensure all JSON strings are properly closed and comma-separated -4. Do not include line breaks within JSON string values - -The output language should always be the same as the input language.{{ json_schema }} diff --git a/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 b/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 deleted file mode 100644 index 1fa71df3..00000000 --- a/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 +++ /dev/null @@ -1,37 +0,0 @@ -# 角色 -你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。 - -# 任务 -根据提供的上下文信息回答用户的问题。 - -# 输入信息 -- 历史对话:{{history}} -- 检索信息:{{retrieve_info}} - -## User Query -{{query}} - -# 回答指南 -1. 仔细分析用户的问题 -2. 优先使用检索信息中的相关内容回答 -3. 结合历史对话提供连贯的回复 -4. 如果信息不足: - - 对于简单问候或日常对话,给出自然简短的回复 - - 对于复杂问题,诚实说明信息不足 -5. 保持回答简洁、相关、自然 -6. 使用与问题相同的语言回答 - -**Output format** -- 直接回答问题,像人类对话一样自然流畅 -- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语 -- 不要解释推理过程或评论信息来源 -- 如果只能部分回答问题,先回答能回答的部分,然后说明哪些方面信息不足 -- 如果完全无法回答,简洁地说明:"信息不足,无法回答。" - -**CRITICAL JSON FORMATTING REQUIREMENTS:** -1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes -2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") -3. Ensure all JSON strings are properly closed and comma-separated -4. Do not include line breaks within JSON string values - -The output language should always be the same as the input language.{{ json_schema }} diff --git a/app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2 b/app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2 deleted file mode 100644 index 846777e9..00000000 --- a/app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2 +++ /dev/null @@ -1,29 +0,0 @@ - -# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#} -你是一个智能问答助手,任务如下 -## 目标: - -1. 接收一个字典,格式为 {'问题': [答案列表]}。 -2. 接收一个问题(字典中的 key)。 -3. 找到与问题匹配的答案列表。 -4. 将答案列表合并成一句自然流畅的话: - - 如果答案有两条,使用“是”连接,例如:“A,是B”。 - - 如果答案有三条或以上,使用“,并且”“另外”等自然连词,保证句子流畅。 -5. 输出内容时只输出合并后的答案,不输出关键点或其他文字。 -6. 如果问题未在字典中找到对应答案,请输出: - 对不起,我没有找到相关信息。 - - -输出要求: -- 文本形式 ---- - -字典示例: -{ - '今天的天气怎么样': ['今天天气很好', '今天是晴天'] -} - -问题示例: -今天的天气怎么样 -输出要求: -今天天气很好,是晴天 \ No newline at end of file diff --git a/app/core/memory/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 b/app/core/memory/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 deleted file mode 100644 index 83d6fbbc..00000000 --- a/app/core/memory/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 +++ /dev/null @@ -1,10 +0,0 @@ -请提图像内的文本 -返回数据格式以json方式输出, -- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 -- 关键的JSON格式要求{"statement":识别出的文本内容} -1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号 -2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们 -3.确保所有JSON字符串都正确关闭并以逗号分隔 -4.JSON字符串值中不包括换行符 -5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\"" -6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby``` diff --git a/app/core/memory/agent/utils/prompt/distinguish_types_prompt.jinja2 b/app/core/memory/agent/utils/prompt/distinguish_types_prompt.jinja2 deleted file mode 100644 index 38bd8615..00000000 --- a/app/core/memory/agent/utils/prompt/distinguish_types_prompt.jinja2 +++ /dev/null @@ -1,34 +0,0 @@ -你是一个输入分类助手,负责判断用户输入的意图类型。 - -## User Input -{{ user_query }} - -请你根据以下规则判断: -1. 如果输入是在寻求信息、提问、请求解释、或疑问句(包括隐含的问题),则分类为 "question"。 -2. 如果输入是命令、陈述、描述、感叹、或其他类型,不在寻求答案,则分类为 "other"。 -只输出: -{ - "type": "question" -} -或 -{ - "type": "other" -} -示例: -输入:"Python怎么读取文件?" -输出:{"type": "question"} - -输入:"帮我写个读取文件的函数" -输出:{"type": "other"} - -输入:"今天是星期几?" -输出:{"type": "question"} -返回数据格式以json方式输出, -- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 -- 关键的JSON格式要求{"statement":识别出的文本内容} -1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号 -2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们 -3.确保所有JSON字符串都正确关闭并以逗号分隔 -4.JSON字符串值中不包括换行符 -5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\"" -6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby``` diff --git a/app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 b/app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 deleted file mode 100644 index aca716a4..00000000 --- a/app/core/memory/agent/utils/prompt/problem_breakdown_prompt.jinja2 +++ /dev/null @@ -1,160 +0,0 @@ - -# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#} -你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型: -## 目标: -你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。 ---- - -### 历史信息参考 -在生成扩展问题时,你可以参考以下历史数据(如果提供): -- 历史对话或任务的主题; -- 历史中出现的关键实体(时间、人物、地点、研究主题等); -- 历史中已解答的问题(避免重复); -- 历史推理链(保持逻辑一致性)。 - -> 如果没有提供历史信息,则仅根据当前输入问题进行分析。 -输入历史信息内容:{{history}} - -## User Input -{{ sentence }} - -## 需求: -1:首先判断类型(单跳、多跳、开放域、时间)。 -2:根据类型进行拆分。 -3:拆分后的内容需保证信息完整且可独立处理。 -4:对每个拆分条目,可附加示例或说明。 -5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。 - 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] - 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? - -## 指令: -你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型: -单跳(Single-hop) - 描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。 - 拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。 - 示例: - 输入数据:"请列出今年诺贝尔物理学奖的得主" - 拆分结果:[ - { - "id": "Q1", - "question": "今年诺贝尔物理学奖得主是谁", - "type": "单跳’", - } - ] - 注意: 当遇到上下文依赖问题时,明确指出缺失的信息类型并且,question可填写输入问题 -多跳(Multi-hop): - 描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。 - 拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。 - 示例: - 输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果" - 拆分结果: - [ - { - "id": "Q1", - "question": 今年诺贝尔物理学奖得主是谁?", - "type": "多跳’", - }, - { - "id": "Q2", - "question": "该得主的研究领域是什么?", - "type": "多跳’", - }, - { - "id": "Q3", - "question": "该得主的代表性成果有哪些?", - "type": "多跳’" - } - ] -开放域(Open-domain): - 描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。 - 拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性) - 需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。 - 示例: - 输入数据:"介绍量子计算的最新研究进展" - 拆分结果: - [ - { - "id": "Q1", - "question": 量子计算的基本概念是什么?", - "type": "开放域’", - }, - { - "id": "Q2", - "question": "当前量子计算的主要研究方向有哪些?", - "type": "开放域’", - }, - { - "id": "Q3", - "question": "近期在量子计算领域有哪些重大进展?", - "type": "开放域’", - } - ] - -时间(Temporal): - 描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。 - 拆分策略:根据事件时间或时间段拆分为独立条目或问题。 - 示例: - 输入数据:"列出苹果公司过去五年的重大事件" - 拆分结果: - [ - { - "id": "Q1", - "question": 苹果公司2019年的重大事件有哪些?", - "type": "时间’", - }, - { - "id": "Q2", - "question": "苹果公司2020年的重大事件有哪些?", - "type": "时间’", - }, - { - "id": "Q3", - "question": "苹果公司2021年的重大事件有哪些?", - "type": "时间’", - }, - { - "id": "Q3", - "question": "苹果公司2022年的重大事件有哪些?", - "type": "时间’", - } - , - { - "id": "Q4", - "question": "苹果公司2023年的重大事件有哪些?", - "type": "时间’", - } - ] - -输出要求: -- 每个子问题包括: - - `id`: 子问题编号(Q1, Q2...) - - `question`: 子问题内容 - - `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等) - - `reason`: 拆分的理由(为什么要这样拆) -- 格式案例: -[ - { - "id": "Q1", - "question": 量子计算的基本概念是什么?", - "type": "开放域’", - }, - { - "id": "Q2", - "question": "当前量子计算的主要研究方向有哪些?", - "type": "开放域’", - }, - { - "id": "Q3", - "question": "近期在量子计算领域有哪些重大进展?", - "type": "开放域’", - } -] -- 必须通过json.loads()的格式支持的形式输出 -- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 -- 关键的JSON格式要求 -1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号 -2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们 -3.确保所有JSON字符串都正确关闭并以逗号分隔 -4.JSON字符串值中不包括换行符 -5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\"" -6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby``` diff --git a/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 b/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 deleted file mode 100644 index 6cdbaf6a..00000000 --- a/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 +++ /dev/null @@ -1,60 +0,0 @@ -# 角色 -你是验证专家 -你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析,是不是回答Query_Samll这个字段的问题 - -{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#} -## 工作步骤 -1. 获取所有的Query_Samll字段和Answer_Samll字段 -2. 分析Answer_Samll的回复是不是和Query_Samll有关系 -3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态 -4. 如果是True保留,否则不要相对应的问题和回答 -5. 输出,需要严格按照模版 -输入:{{history}} -历史消息:{"history":{{sentence}}} -### 第一步 获取用户的输入 -获取用户的输入提取对应的Query_Samll和Answer_Samll -### 第二步 分析验证 -需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容,如果有关系不是答非所问 -## 核心验证标准 -在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Samll): -1. 合理性标准(必须全部满足): -- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。 -- 最小化:每个不同的子问题数量应尽可能少,通常不超过原问题关键要素数量的2倍(建议2-4个),避免冗余和不必要拆分。 -- 相关性:每个不同的子问题必须直接服务于原问题的解答,不引入无关内容或扩展原问题未提及的主题。 -- 可操作性:每个不同的子问题应能在有限资源(如标准工具或合理时间)内独立解答,且难度适中。 -- 逻辑性:每个不同的子问题间应有清晰的逻辑关系(如并列、递进、因果),共同构成原问题的解答路径。 - -2. 不合理拆分的特征(出现任一特征即为不合理): -- 不同的子问题数量超过5个或明显多于必要数量。 -- 引入原问题未提及的新主题、人物、细节或个人看法。 -- 拆分过于细碎,失去实用价值,无法高效合成原问题答案。 - -3. 特殊情况说明: -- 每个不同的子问题与原问题相同,需进一步判断: - - 每个不同的子问题不可进一步拆分 → success(合理,最小化拆分) - - 每个不同的子问题能够进一步拆分为更小、更合理的问题 → failed(不合理,拆分没有最小化) -- 每个不同的子问题数量=原问题核心要素数量 → success(理想情况) -- 每个不同的子问题数量=核心要素数量+1 → success(通常合理) - -### 第三步 添加状态 -如果有相关性并且比较高给一个状态TRUE,否则给一个FLASE的状态 -### 第四步 判断 -如果状态是TRUE保留这条数据,否则需不需要这条数据 -### 第五步 输出格式 -按照json的形式输出 -{"data":"Query":原来Query的字段,"history":原来的history字段, -"expansion_issue":以为列表的形式存储验证之后的数据比如[ -{"query_small": query_small, - "answer_small": answer_small,, - "status": 回答的结果是否符合query_small,填写状态, - "query_answer": answer_small}, -{ - "query_small": "张曼婷生日是什么时候?", - "answer_small": "张曼婷喜欢绘画。", - "status": "True", - "query_answer": "张曼 婷喜欢绘画。" - },{}......] -, - "split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success, - "reason": 为以上分析完之后的结果给一个说明 - } \ No newline at end of file diff --git a/app/core/memory/agent/utils/prompt/summary_prompt.jinja2 b/app/core/memory/agent/utils/prompt/summary_prompt.jinja2 deleted file mode 100644 index e73171fc..00000000 --- a/app/core/memory/agent/utils/prompt/summary_prompt.jinja2 +++ /dev/null @@ -1,57 +0,0 @@ -{# 角色定义 #} -你是专业的问题解答专家,负责根据上下文信息和检索到的所有信息准确回答用户的问题。 - -{# 输入数据展示 #} -{% if data %} -## 输入数据 -上下文信息: -{% for item in data.history %} -- {{ item }} -{% endfor %} -检索到的所有信息: -{% for item in data.retrieve_info %} -- {{ item }} -{% endfor %} -{% endif %} - -## User Query -{{ query }} - -{# 问题回答标准 #} -## 问题回答核心标准 -根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。注意,若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。 -- 若能根据已有信息回答用户的问题,应根据上下文信息和检索到的所有信息提供简明扼要的答案。 -- 若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。 - -{# 重要提醒 #} -再次提醒,给出问题的答案时,仅根据已有的信息进行回答,不能自己编造答案。 - -{# 输出格式模板 #} -## 输出格式 -严格按照以下JSON格式输出,不添加任何其他内容: -{ - "data": { - "query": "{{ query }}", - "history": [ - {% for item in data.history %} - "{{ item | replace('"', '\\"') }}" - {% if not loop.last %},{% endif %} - {% endfor %} - ], - "retrieve_info": [ - {% for item in data.retrieve_info %} - "{{ item | replace('"', '\\"') }}" - {% if not loop.last %},{% endif %} - {% endfor %} - ] - }, - "query_answer": "{% if not data.history and not data.retrieve_info %}信息不足,无法回答。{% endif %}" -} -**Output format** -**CRITICAL JSON FORMATTING REQUIREMENTS:** -1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes -2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") -3. Ensure all JSON strings are properly closed and comma-separated -4. Do not include line breaks within JSON string values - -The output language should always be the same as the input language.{{ json_schema }} diff --git a/app/core/memory/agent/utils/redis_tool.py b/app/core/memory/agent/utils/redis_tool.py deleted file mode 100644 index 68c16e1d..00000000 --- a/app/core/memory/agent/utils/redis_tool.py +++ /dev/null @@ -1,203 +0,0 @@ -import redis -import uuid -from datetime import datetime -from app.core.config import settings -class RedisSessionStore: - def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''): - self.r = redis.Redis(host=host, port=port, db=db, password=password) - self.uudi=session_id - - - # 修改后的 save_session 方法 - def save_session(self, userid, messages, aimessages, apply_id, group_id): - """ - 写入一条会话数据,返回 session_id - """ - try: - session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID - starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - key = f"session:{session_id}" # 使用新生成的 session_id 作为 key - - # 使用 Hash 存储结构化数据 - result = self.r.hset(key, mapping={ - "id": self.uudi, - "sessionid": userid, - "apply_id": apply_id, - "group_id": group_id, - "messages": messages, - "aimessages": aimessages, - "starttime": starttime - }) - print(f"保存结果: {result}, session_id: {session_id}") - return session_id # 返回新生成的 session_id - except Exception as e: - print(f"保存会话失败: {e}") - raise e - - # ---------------- 读取 ---------------- - def get_session(self, session_id): - """ - 读取一条会话数据 - """ - key = f"session:{session_id}" - data = self.r.hgetall(key) - if data: - return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()} - return None - - def get_session_apply_group(self, sessionid, apply_id, group_id): - """ - 根据 sessionid、apply_id 和 group_id 三个条件查询会话数据 - """ - result_items = [] - - # 遍历所有会话数据 - for key_bytes in self.r.keys('session:*'): - key = key_bytes.decode('utf-8') - data = self.r.hgetall(key) - - if not data: - continue - - # 解码数据 - decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()} - - # 检查三个条件是否都匹配 - if (decoded_data.get('sessionid') == sessionid and - decoded_data.get('apply_id') == apply_id and - decoded_data.get('group_id') == group_id): - result_items.append(decoded_data) - - return result_items - - def get_all_sessions(self): - """ - 获取所有会话数据 - """ - sessions = {} - for key in self.r.keys('session:*'): - sid = key.decode('utf-8').split(':')[1] - sessions[sid] = self.get_session(sid) - return sessions - - # ---------------- 更新 ---------------- - def update_session(self, session_id, field, value): - """ - 更新单个字段 - """ - key = f"session:{session_id}" - if self.r.exists(key): - self.r.hset(key, field, value) - return True - return False - - # ---------------- 删除 ---------------- - def delete_session(self, session_id): - """ - 删除单条会话 - """ - key = f"session:{session_id}" - return self.r.delete(key) - - def delete_all_sessions(self): - """ - 删除所有会话 - """ - keys = self.r.keys('session:*') - if keys: - return self.r.delete(*keys) - return 0 - - def delete_duplicate_sessions(self): - """ - 删除重复会话数据,条件: - "sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 - """ - seen = set() # 用来记录已出现的唯一组合 - deleted_count = 0 - - for key_bytes in self.r.keys('session:*'): - key = key_bytes.decode('utf-8') - data = self.r.hgetall(key) - if not data: - continue - - # 获取五个字段的值并解码 - sessionid = data.get(b'sessionid', b'').decode('utf-8') - user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id - group_id = data.get(b'group_id', b'').decode('utf-8') - messages = data.get(b'messages', b'').decode('utf-8') - aimessages = data.get(b'aimessages', b'').decode('utf-8') - - # 用五元组作为唯一标识 - identifier = (sessionid, user_id, group_id, messages, aimessages) - - if identifier in seen: - # 重复,删除该 key - self.r.delete(key) - deleted_count += 1 - else: - # 第一次出现,加入 seen - seen.add(identifier) - - print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}") - return deleted_count - - def find_user_session(self,sessionid): - user_id = sessionid - - result_items = [] - for key, values in store.get_all_sessions().items(): - history = {} - if user_id == str(values['sessionid']): - history["Query"] = values['messages'] - history["Answer"] = values['aimessages'] - result_items.append(history) - - if len(result_items) <= 1: - result_items = [] - return (result_items) - - def find_user_apply_group(self, sessionid, apply_id, group_id): - """ - 根据 sessionid、apply_id 和 group_id 三个条件查询会话数据 - """ - result_items = [] - - # 遍历所有会话数据 - for key_bytes in self.r.keys('session:*'): - key = key_bytes.decode('utf-8') - data = self.r.hgetall(key) - - if not data: - continue - - # 解码数据 - decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()} - - - # 检查三个条件是否都匹配 - if (decoded_data.get('sessionid') == sessionid and - decoded_data.get('apply_id') == apply_id and - decoded_data.get('group_id') == group_id): - history = { - "Query": decoded_data.get('messages'), - "Answer": decoded_data.get('aimessages') - } - - - result_items.append(history) - - # 如果结果少于等于1条,返回空列表 - if len(result_items) <= 1: - result_items = [] - - return result_items - -store = RedisSessionStore( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, - password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, - session_id=str(uuid.uuid4()) -) diff --git a/app/core/memory/agent/utils/type_classifier.py b/app/core/memory/agent/utils/type_classifier.py deleted file mode 100644 index 2f5e2501..00000000 --- a/app/core/memory/agent/utils/type_classifier.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -Type classification utility for distinguishing read/write operations. -""" -from jinja2 import Template -from pydantic import BaseModel - -from app.core.logging_config import get_agent_logger, log_prompt_rendering -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.messages_tool import read_template_file -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.config import settings - - -logger = get_agent_logger(__name__) - - -class DistinguishTypeResponse(BaseModel): - """Response model for type classification""" - type: str - - -async def status_typle(messages: str) -> dict: - """ - Classify message type as read or write operation. - - Args: - messages: User message to classify - - Returns: - dict: Contains 'type' field with classification result - """ - try: - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/distinguish_types_prompt.jinja2' - template_content = await read_template_file(file_path) - template = Template(template_content) - system_prompt = template.render(user_query=messages) - log_prompt_rendering("status_typle", system_prompt) - except Exception as e: - logger.error(f"Template rendering failed for status_typle: {e}", exc_info=True) - return { - "type": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - from app.core.memory.utils.config import definitions as config_defs - llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) - - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=DistinguishTypeResponse - ) - return structured.model_dump() - except Exception as e: - logger.error(f"LLM call failed for status_typle: {e}", exc_info=True) - return { - "type": "error", - "message": f"LLM call failed: {str(e)}" - } diff --git a/app/core/memory/agent/utils/verify_tool.py b/app/core/memory/agent/utils/verify_tool.py deleted file mode 100644 index 5e1ce897..00000000 --- a/app/core/memory/agent/utils/verify_tool.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import TypedDict, Annotated, List, Any -from langchain_core.messages import AnyMessage -from langgraph.constants import START, END -from langgraph.graph import StateGraph, add_messages -import asyncio -import json -from dotenv import load_dotenv, find_dotenv -import os -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from langchain_core.messages import HumanMessage -from jinja2 import Environment, FileSystemLoader -from app.core.memory.agent.utils.messages_tool import _to_openai_messages -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.config.definitions import SELECTED_LLM_ID -from app.core.logging_config import get_agent_logger - -load_dotenv(find_dotenv()) - -logger = get_agent_logger(__name__) - -def keep_last(_, right): - return right -class State(TypedDict): - user_input: Annotated[dict, keep_last] - messages: Annotated[List[AnyMessage], add_messages] - agent1_response: str - agent2_response: str - agent3_response: str - final_response: str - status: Annotated[str, keep_last] - - -class VerifyTool: - def __init__(self, system_prompt: str="", verify_data: Any=None): - self.system_prompt = system_prompt - if isinstance(verify_data, str): - self.verify_data = verify_data - else: - try: - self.verify_data = json.dumps(verify_data, ensure_ascii=False) - except Exception: - self.verify_data = str(verify_data) - - async def model_1(self, state: State) -> State: - llm_client = get_llm_client(SELECTED_LLM_ID) - response_content = await llm_client.chat( - messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"]) - ) - return { - "agent1_response": response_content, - "status": "processed", - } - - - def get_graph(self): - graph = StateGraph(State) - graph.add_node("model_1", self.model_1) - - graph.add_edge(START, "model_1") - graph.add_edge("model_1", END) - - compiled_graph = graph.compile() - return compiled_graph - - async def verify(self): - graph = self.get_graph() - initial_state = { - "user_input": self.verify_data, - "messages": [HumanMessage(content=self.verify_data)], - "final_response": "", - "status": "" - } - final_state = await graph.ainvoke(initial_state) - # return final_state["final_response"] - return final_state["agent1_response"] - diff --git a/app/core/memory/agent/utils/write_to_database.py b/app/core/memory/agent/utils/write_to_database.py deleted file mode 100644 index bd78fe9d..00000000 --- a/app/core/memory/agent/utils/write_to_database.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import uuid -from datetime import datetime -from typing import Any -from sqlalchemy.orm import Session -import logging -import json - -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo - -logger = logging.getLogger(__name__) - -async def write_to_database(host_id: uuid.UUID, data: Any) -> str: - """ - 将数据写入数据库 - :param host_id: 宿主 ID - :param data: 要写入的数据 - :return: 写入数据库的结果 - """ - # 从数据库会话中获取会话 - db: Session = next(get_db()) - try: - if isinstance(data, (dict, list)): - serialized = json.dumps(data, ensure_ascii=False) - elif isinstance(data, str): - serialized = data - else: - serialized = str(data) - - new_retrieval_info = RetrievalInfo( - # host_id=host_id, - host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"), - retrieve_info=serialized, - created_at=datetime.now() - ) - db.add(new_retrieval_info) - db.commit() - logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}") - return "success to write data to database" - except Exception as e: - db.rollback() - logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}") - raise e - finally: - try: - db.close() - except Exception: - pass diff --git a/app/core/memory/agent/utils/write_tools.py b/app/core/memory/agent/utils/write_tools.py deleted file mode 100644 index a535fe9d..00000000 --- a/app/core/memory/agent/utils/write_tools.py +++ /dev/null @@ -1,183 +0,0 @@ -import asyncio -from dotenv import load_dotenv -import time -from datetime import datetime - -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j - -from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs -from app.core.logging_config import get_agent_logger - -logger = get_agent_logger(__name__) -# 使用新的模块化架构 -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( - embedding_generation_all, -) - -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -# 导入配置模块(而不是直接导入变量) -from app.core.memory.utils.config import definitions as config_defs -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.log.logging_utils import log_time -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation -from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges -load_dotenv() - - -async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None: - """ - 执行完整的知识提取流水线(使用新的 ExtractionOrchestrator) - - Args: - content: 对话内容 - user_id: 用户ID - apply_id: 应用ID - group_id: 组ID - ref_id: 参考ID,默认为 "wyl20251027" - config_id: 配置ID,用于标记数据处理配置 - """ - logger.info("=== MemSci Knowledge Extraction Pipeline ===") - logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}") - logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}") - logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}") - logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}") - logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}") - logger.info(f"Config ID: {config_id if config_id else 'None'}") - logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}") - logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}") - - # Initialize timing log - log_file = "logs/time.log" - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(log_file, "a", encoding="utf-8") as f: - f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n") - - pipeline_start = time.time() - - # 初始化客户端 - llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) - - # 获取 embedder 配置 - from app.core.models.base import RedBearModelConfig - from app.core.memory.utils.config.config_utils import get_embedder_config - from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient - - embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) - - neo4j_connector = Neo4jConnector() - - # Step 1: 加载和分块数据 - step_start = time.time() - chunked_dialogs = await get_chunked_dialogs( - chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, - content=content, - ref_id=ref_id, - config_id=config_id, - ) - log_time("Data Loading & Chunking", time.time() - step_start, log_file) - - # Step 2: 初始化并运行 ExtractionOrchestrator - step_start = time.time() - from app.core.memory.utils.config.config_utils import get_pipeline_config - config = get_pipeline_config() - - orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=neo4j_connector, - config=config, - ) - - # 运行完整的提取流水线 - # orchestrator.run returns a flat tuple of 7 values after deduplication - ( - all_dialogue_nodes, - all_chunk_nodes, - all_statement_nodes, - all_entity_nodes, - all_statement_chunk_edges, - all_statement_entity_edges, - all_entity_entity_edges, - ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) - - log_time("Extraction Pipeline", time.time() - step_start, log_file) - - # Step 8: Save all data to Neo4j database using graph models - step_start = time.time() - # 运行索引创建 - from app.repositories.neo4j.create_indexes import create_fulltext_indexes - try: - await create_fulltext_indexes() - except Exception as e: - logger.error(f"Error creating indexes: {e}", exc_info=True) - - try: - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=all_dialogue_nodes, - chunk_nodes=all_chunk_nodes, - statement_nodes=all_statement_nodes, - entity_nodes=all_entity_nodes, - statement_chunk_edges=all_statement_chunk_edges, - statement_entity_edges=all_statement_entity_edges, - entity_edges=all_entity_entity_edges, - connector=neo4j_connector - ) - if success: - logger.info("Successfully saved all data to Neo4j") - else: - logger.warning("Failed to save some data to Neo4j") - finally: - await neo4j_connector.close() - - log_time("Neo4j Database Save", time.time() - step_start, log_file) - - # Step 9: Generate Memory summaries and save to local vector DB and Neo4j - step_start = time.time() - try: - summaries = await Memory_summary_generation( - chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID - ) - - # Save memory summaries to Neo4j as nodes - try: - ms_connector = Neo4jConnector() - await add_memory_summary_nodes(summaries, ms_connector) - # Link summaries to statements via chunks for summary→entity queries - await add_memory_summary_statement_edges(summaries, ms_connector) - finally: - try: - await ms_connector.close() - except Exception: - pass - except Exception as e: - logger.error(f"Memory summary step failed: {e}", exc_info=True) - finally: - log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file) - - - - # Log total pipeline time - total_time = time.time() - pipeline_start - log_time("TOTAL PIPELINE TIME", total_time, log_file) - - # Add completion marker to log - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(log_file, "a", encoding="utf-8") as f: - f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") - - logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") - logger.info(f"Timing details saved to: {log_file}") - - -if __name__ == "__main__": - content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?" - asyncio.run(write(content, ref_id="wyl20251027")) diff --git a/app/core/memory/llm_tools/__init__.py b/app/core/memory/llm_tools/__init__.py deleted file mode 100644 index 55a1fc95..00000000 --- a/app/core/memory/llm_tools/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -LLM 工具模块 - -提供 LLM 和 Embedder 客户端的抽象基类和具体实现。 -""" - -from app.core.memory.llm_tools.llm_client import LLMClient -from app.core.memory.llm_tools.embedder_client import EmbedderClient -from app.core.memory.llm_tools.openai_client import OpenAIClient -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.llm_tools.chunker_client import ChunkerClient - -__all__ = [ - "LLMClient", - "EmbedderClient", - "OpenAIClient", - "OpenAIEmbedderClient", - "ChunkerClient", -] diff --git a/app/core/memory/llm_tools/chunker_client.py b/app/core/memory/llm_tools/chunker_client.py deleted file mode 100644 index 4178ce0a..00000000 --- a/app/core/memory/llm_tools/chunker_client.py +++ /dev/null @@ -1,330 +0,0 @@ -from typing import Any, List -import re -import os -import asyncio -import json -import numpy as np - -# Fix tokenizer parallelism warning -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -from chonkie import ( - SemanticChunker, - RecursiveChunker, - RecursiveRules, - LateChunker, - NeuralChunker, - SentenceChunker, - TokenChunker, -) - -from app.core.memory.models.config_models import ChunkerConfig -from app.core.memory.models.message_models import DialogData, Chunk -try: - from app.core.memory.llm_tools.openai_client import OpenAIClient -except Exception: - # 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入 - OpenAIClient = Any - - -class LLMChunker: - """基于LLM的智能分块策略""" - def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000): - self.llm_client = llm_client - self.chunk_size = chunk_size - - async def __call__(self, text: str) -> List[Any]: - # 使用LLM分析文本结构并进行智能分块 - prompt = f""" - 请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。 - 请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。 - - 文本内容: - {text[:5000]} - """ - - messages = [ - {"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"}, - {"role": "user", "content": prompt} - ] - - try: - # 使用异步的 achat 方法 - if hasattr(self.llm_client, 'achat'): - response = await self.llm_client.achat(messages) - else: - # 如果没有异步方法,使用同步方法并转换为异步 - response = await asyncio.to_thread(self.llm_client.chat, messages) - - # 检查响应格式并提取内容 - if hasattr(response, 'choices') and len(response.choices) > 0: - content = response.choices[0].message.content - elif hasattr(response, 'content'): - content = response.content - else: - content = str(response) - - # 解析LLM响应 - if "```json" in content: - json_str = content.split("```json")[1].split("```")[0].strip() - elif "```" in content: - json_str = content.split("```")[1].split("```")[0].strip() - else: - json_str = content - - result = json.loads(json_str) - - class SimpleChunk: - def __init__(self, text, index): - self.text = text - self.start_index = index * 100 # 近似位置 - self.end_index = (index + 1) * 100 - - return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))] - - except Exception as e: - print(f"LLM分块失败: {e}") - # 失败时返回空列表,外层会处理回退方案 - return [] - - -class HybridChunker: - """混合分块策略:先按结构分块,再按语义合并""" - def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300): - self.semantic_threshold = semantic_threshold - self.base_chunk_size = base_chunk_size - self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size) - self.semantic_chunker = SemanticChunker(threshold=semantic_threshold) - - def __call__(self, text: str) -> List[Any]: - # 先用基础分块 - base_chunks = self.base_chunker(text) - - # 如果文本不长,直接返回基础分块 - if len(base_chunks) <= 3: - return base_chunks - - # 对基础分块进行语义合并 - combined_text = " ".join([chunk.text for chunk in base_chunks]) - return self.semantic_chunker(combined_text) - - -class ChunkerClient: - def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None): - self.chunker_config = chunker_config - self.embedding_model = chunker_config.embedding_model - self.chunk_size = chunker_config.chunk_size - self.threshold = chunker_config.threshold - self.language = chunker_config.language - self.skip_window = chunker_config.skip_window - self.min_sentences = chunker_config.min_sentences - self.min_characters_per_chunk = chunker_config.min_characters_per_chunk - self.llm_client = llm_client - - # 可选参数(从配置中安全获取,提供默认值) - self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0) - self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1) - self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12) - self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"]) - self.include_delim = getattr(chunker_config, 'include_delim', "prev") - self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character") - - # 初始化具体分块器策略 - if chunker_config.chunker_strategy == "TokenChunker": - self.chunker = TokenChunker( - tokenizer=self.tokenizer_or_token_counter, - chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, - ) - elif chunker_config.chunker_strategy == "SemanticChunker": - self.chunker = SemanticChunker( - embedding_model=self.embedding_model, - threshold=self.threshold, - chunk_size=self.chunk_size, - min_sentences=self.min_sentences, - ) - elif chunker_config.chunker_strategy == "RecursiveChunker": - self.chunker = RecursiveChunker( - rules=RecursiveRules(), - min_characters_per_chunk=self.min_characters_per_chunk or 50, - chunk_size=self.chunk_size, - ) - elif chunker_config.chunker_strategy == "LateChunker": - self.chunker = LateChunker( - embedding_model=self.embedding_model, - chunk_size=self.chunk_size, - rules=RecursiveRules(), - min_characters_per_chunk=self.min_characters_per_chunk, - ) - elif chunker_config.chunker_strategy == "NeuralChunker": - self.chunker = NeuralChunker( - model=self.embedding_model, - min_characters_per_chunk=self.min_characters_per_chunk, - ) - elif chunker_config.chunker_strategy == "LLMChunker": - if not llm_client: - raise ValueError("LLMChunker requires an LLM client") - self.chunker = LLMChunker(llm_client, self.chunk_size) - elif chunker_config.chunker_strategy == "HybridChunker": - self.chunker = HybridChunker( - semantic_threshold=self.threshold, - base_chunk_size=self.chunk_size, - ) - elif chunker_config.chunker_strategy == "SentenceChunker": - # 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数 - # 为了兼容不同版本,这里仅传递广泛支持的参数 - self.chunker = SentenceChunker( - chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, - min_sentences_per_chunk=self.min_sentences_per_chunk, - min_characters_per_sentence=self.min_characters_per_sentence, - delim=self.delim, - include_delim=self.include_delim, - ) - else: - raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}") - - async def generate_chunks(self, dialogue: DialogData): - """ - 生成分块,支持异步操作 - """ - try: - # 预处理文本:确保对话标记格式统一 - content = dialogue.content - content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号 - content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行 - - if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__): - # 同步分块器 - chunks = self.chunker(content) - else: - # 异步分块器(如LLMChunker) - chunks = await self.chunker(content) - - # 过滤空块和过小的块 - valid_chunks = [] - for c in chunks: - chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c - if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50): - valid_chunks.append(c) - - dialogue.chunks = [ - Chunk( - content=c.text if hasattr(c, 'text') else str(c), - metadata={ - "start_index": getattr(c, "start_index", None), - "end_index": getattr(c, "end_index", None), - "chunker_strategy": self.chunker_config.chunker_strategy, - }, - ) - for c in valid_chunks - ] - return dialogue - - except Exception as e: - print(f"分块失败: {e}") - - # 改进的后备方案:尝试按对话回合分割 - try: - # 简单的按对话分割 - dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)' - matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL) - - class SimpleChunk: - def __init__(self, text, start_index, end_index): - self.text = text - self.start_index = start_index - self.end_index = end_index - - chunks = [] - current_chunk = "" - current_start = 0 - - for match in matches: - speaker, ct = match[0], match[1].strip() - turn_text = f"{speaker} {ct}" - - if len(current_chunk) + len(turn_text) > (self.chunk_size or 500): - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - current_chunk = turn_text - current_start = dialogue.content.find(turn_text, current_start) - else: - current_chunk += ("\n" + turn_text) if current_chunk else turn_text - - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - - dialogue.chunks = [ - Chunk( - content=c.text, - metadata={ - "start_index": c.start_index, - "end_index": c.end_index, - "chunker_strategy": "DialogueTurnFallback", - }, - ) - for c in chunks - ] - - except Exception: - # 最后的手段:单一大块 - dialogue.chunks = [Chunk( - content=dialogue.content, - metadata={"chunker_strategy": "SingleChunkFallback"}, - )] - - return dialogue - - def evaluate_chunking(self, dialogue: DialogData) -> dict: - """ - 评估分块质量 - """ - if not getattr(dialogue, 'chunks', None): - return {} - - chunks = dialogue.chunks - total_chars = sum(len(chunk.content) for chunk in chunks) - avg_chunk_size = total_chars / len(chunks) - - # 计算各种指标 - chunk_sizes = [len(chunk.content) for chunk in chunks] - - metrics = { - "strategy": self.chunker_config.chunker_strategy, - "num_chunks": len(chunks), - "total_characters": total_chars, - "avg_chunk_size": avg_chunk_size, - "min_chunk_size": min(chunk_sizes), - "max_chunk_size": max(chunk_sizes), - "chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0, - "coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0, - } - - return metrics - - def save_chunking_results(self, dialogue: DialogData, output_path: str): - """ - 保存分块结果到文件,文件名包含策略名称 - """ - strategy_name = self.chunker_config.chunker_strategy - # 在文件名中添加策略名称 - base_name, ext = os.path.splitext(output_path) - strategy_output_path = f"{base_name}_{strategy_name}{ext}" - - with open(strategy_output_path, 'w', encoding='utf-8') as f: - f.write(f"=== Chunking Strategy: {strategy_name} ===\n") - f.write(f"Total chunks: {len(dialogue.chunks)}\n") - f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n") - f.write("=" * 60 + "\n\n") - - for i, chunk in enumerate(dialogue.chunks): - f.write(f"Chunk {i+1}:\n") - f.write(f"Size: {len(chunk.content)} characters\n") - if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata: - f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n") - f.write(f"Content: {chunk.content}\n") - f.write("-" * 40 + "\n\n") - - print(f"Chunking results saved to: {strategy_output_path}") - return strategy_output_path diff --git a/app/core/memory/llm_tools/embedder_client.py b/app/core/memory/llm_tools/embedder_client.py deleted file mode 100644 index 0a08f824..00000000 --- a/app/core/memory/llm_tools/embedder_client.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -Embedder 客户端抽象基类 - -提供统一的嵌入向量生成接口,支持重试机制和错误处理。 -""" - -from abc import ABC, abstractmethod -from typing import List, Optional -import asyncio -import logging -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, - before_sleep_log, -) - -from app.core.models.base import RedBearModelConfig -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - -logger = logging.getLogger(__name__) - - -class EmbedderClientException(BusinessException): - """Embedder 客户端异常""" - def __init__(self, message: str, code: str = BizCode.EMBEDDING_ERROR): - super().__init__(message, code=code) - - -class EmbedderClient(ABC): - """ - Embedder 客户端抽象基类 - - 提供统一的嵌入向量生成接口,包括: - - 批量文本嵌入(response) - - 自动重试机制 - - 错误处理 - """ - - def __init__(self, model_config: RedBearModelConfig): - """ - 初始化 Embedder 客户端 - - Args: - model_config: 模型配置,包含模型名称、提供商、API密钥等信息 - """ - self.config = model_config - self.model_name = model_config.model_name - self.provider = model_config.provider - self.api_key = model_config.api_key - self.base_url = model_config.base_url - self.max_retries = model_config.max_retries - self.timeout = model_config.timeout - - logger.info( - f"初始化 Embedder 客户端: provider={self.provider}, " - f"model={self.model_name}, max_retries={self.max_retries}" - ) - - @abstractmethod - async def response( - self, - messages: List[str], - **kwargs - ) -> List[List[float]]: - """ - 生成嵌入向量 - - Args: - messages: 文本列表 - **kwargs: 额外参数 - - Returns: - 嵌入向量列表,每个向量是一个浮点数列表 - - Raises: - EmbedderClientException: 嵌入向量生成失败 - """ - pass - - def _create_retry_decorator(self): - """ - 创建重试装饰器 - - Returns: - 配置好的 tenacity retry 装饰器 - """ - return retry( - stop=stop_after_attempt(self.max_retries), - wait=wait_exponential(multiplier=1, min=2, max=10), - retry=retry_if_exception_type(( - asyncio.TimeoutError, - ConnectionError, - Exception, # 可以根据需要细化异常类型 - )), - before_sleep=before_sleep_log(logger, logging.WARNING), - reraise=True, - ) - - async def response_with_retry( - self, - messages: List[str], - **kwargs - ) -> List[List[float]]: - """ - 带重试机制的嵌入向量生成接口 - - Args: - messages: 文本列表 - **kwargs: 额外参数 - - Returns: - 嵌入向量列表 - - Raises: - EmbedderClientException: 重试失败后抛出 - """ - retry_decorator = self._create_retry_decorator() - - @retry_decorator - async def _response_with_retry(): - try: - return await self.response(messages, **kwargs) - except Exception as e: - logger.error(f"嵌入向量生成失败: {e}") - raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e - - return await _response_with_retry() - - async def embed_single(self, text: str, **kwargs) -> List[float]: - """ - 为单个文本生成嵌入向量 - - Args: - text: 单个文本 - **kwargs: 额外参数 - - Returns: - 嵌入向量(浮点数列表) - - Raises: - EmbedderClientException: 嵌入向量生成失败 - """ - embeddings = await self.response_with_retry([text], **kwargs) - return embeddings[0] if embeddings else [] - - async def embed_batch( - self, - texts: List[str], - batch_size: int = 100, - **kwargs - ) -> List[List[float]]: - """ - 批量生成嵌入向量(支持大批量文本) - - Args: - texts: 文本列表 - batch_size: 每批处理的文本数量 - **kwargs: 额外参数 - - Returns: - 嵌入向量列表 - - Raises: - EmbedderClientException: 嵌入向量生成失败 - """ - all_embeddings = [] - - for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] - batch_embeddings = await self.response_with_retry(batch, **kwargs) - all_embeddings.extend(batch_embeddings) - - return all_embeddings diff --git a/app/core/memory/llm_tools/llm_client.py b/app/core/memory/llm_tools/llm_client.py deleted file mode 100644 index e26aba3e..00000000 --- a/app/core/memory/llm_tools/llm_client.py +++ /dev/null @@ -1,187 +0,0 @@ -""" -LLM 客户端抽象基类 - -提供统一的 LLM 调用接口,支持重试机制和错误处理。 -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional -from pydantic import BaseModel -import asyncio -import logging -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, - before_sleep_log, -) - -from app.core.models.base import RedBearModelConfig -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - -logger = logging.getLogger(__name__) - - -class LLMClientException(BusinessException): - """LLM 客户端异常""" - def __init__(self, message: str, code: str = BizCode.LLM_ERROR): - super().__init__(message, code=code) - - -class LLMClient(ABC): - """ - LLM 客户端抽象基类 - - 提供统一的 LLM 调用接口,包括: - - 聊天接口(chat) - - 结构化输出接口(response_structured) - - 自动重试机制 - - 错误处理 - """ - - def __init__(self, model_config: RedBearModelConfig): - """ - 初始化 LLM 客户端 - - Args: - model_config: 模型配置,包含模型名称、提供商、API密钥等信息 - """ - self.config = model_config - self.model_name = self.config.model_name - self.provider = self.config.provider - self.api_key = self.config.api_key - self.base_url = self.config.base_url - self.max_retries = self.config.max_retries - self.timeout = self.config.timeout - - logger.info( - f"初始化 LLM 客户端: provider={self.provider}, " - f"model={self.model_name}, max_retries={self.max_retries}" - ) - - @abstractmethod - async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: - """ - 聊天接口 - - Args: - messages: 消息列表,每个消息包含 role 和 content - **kwargs: 额外参数 - - Returns: - LLM 响应内容 - - Raises: - LLMClientException: LLM 调用失败 - """ - pass - - @abstractmethod - async def response_structured( - self, - messages: List[Dict[str, str]], - response_model: type[BaseModel], - **kwargs - ) -> BaseModel: - """ - 结构化输出接口 - - Args: - messages: 消息列表 - response_model: 期望的响应模型类型(Pydantic BaseModel) - **kwargs: 额外参数 - - Returns: - 解析后的 Pydantic 模型实例 - - Raises: - LLMClientException: LLM 调用或解析失败 - """ - pass - - def _create_retry_decorator(self): - """ - 创建重试装饰器 - - Returns: - 配置好的 tenacity retry 装饰器 - """ - return retry( - stop=stop_after_attempt(self.max_retries), - wait=wait_exponential(multiplier=1, min=2, max=10), - retry=retry_if_exception_type(( - asyncio.TimeoutError, - ConnectionError, - Exception, # 可以根据需要细化异常类型 - )), - before_sleep=before_sleep_log(logger, logging.WARNING), - reraise=True, - ) - - async def chat_with_retry( - self, - messages: List[Dict[str, str]], - **kwargs - ) -> Any: - """ - 带重试机制的聊天接口 - - Args: - messages: 消息列表 - **kwargs: 额外参数 - - Returns: - LLM 响应内容 - - Raises: - LLMClientException: 重试失败后抛出 - """ - retry_decorator = self._create_retry_decorator() - - @retry_decorator - async def _chat_with_retry(): - try: - return await self.chat(messages, **kwargs) - except Exception as e: - logger.error(f"LLM 调用失败: {e}") - raise LLMClientException(f"LLM 调用失败: {e}") from e - - return await _chat_with_retry() - - async def response_structured_with_retry( - self, - messages: List[Dict[str, str]], - response_model: type[BaseModel], - **kwargs - ) -> BaseModel: - """ - 带重试机制的结构化输出接口 - - Args: - messages: 消息列表 - response_model: 期望的响应模型类型 - **kwargs: 额外参数 - - Returns: - 解析后的 Pydantic 模型实例 - - Raises: - LLMClientException: 重试失败后抛出 - """ - retry_decorator = self._create_retry_decorator() - - @retry_decorator - async def _response_structured_with_retry(): - try: - return await self.response_structured( - messages, - response_model, - **kwargs - ) - except Exception as e: - logger.error(f"LLM 结构化输出失败: {e}") - raise LLMClientException(f"LLM 结构化输出失败: {e}") from e - - return await _response_structured_with_retry() diff --git a/app/core/memory/llm_tools/openai_client.py b/app/core/memory/llm_tools/openai_client.py deleted file mode 100644 index bcaa52c2..00000000 --- a/app/core/memory/llm_tools/openai_client.py +++ /dev/null @@ -1,198 +0,0 @@ -""" -OpenAI LLM 客户端实现 - -基于 LangChain 和 RedBearLLM 的 OpenAI 客户端实现。 -""" - -import asyncio -from typing import List, Dict, Any -import json -import logging - -from pydantic import BaseModel -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.output_parsers import PydanticOutputParser - -from app.core.models.base import RedBearModelConfig -from app.core.models.llm import RedBearLLM -from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException -from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED - -logger = logging.getLogger(__name__) - - -class OpenAIClient(LLMClient): - """ - OpenAI LLM 客户端实现 - - 基于 LangChain 和 RedBearLLM 的实现,支持: - - 聊天接口 - - 结构化输出 - - Langfuse 追踪(可选) - """ - - def __init__(self, model_config: RedBearModelConfig, type_: str = "chat"): - """ - 初始化 OpenAI 客户端 - - Args: - model_config: 模型配置 - type_: 模型类型,"chat" 或 "completion" - """ - super().__init__(model_config) - - # 初始化 Langfuse 回调处理器(如果启用) - self.langfuse_handler = None - if LANGFUSE_ENABLED: - try: - from langfuse.langchain import CallbackHandler - self.langfuse_handler = CallbackHandler() - logger.info("Langfuse 追踪已启用") - except ImportError: - logger.warning("Langfuse 未安装,跳过追踪功能") - except Exception as e: - logger.warning(f"初始化 Langfuse 处理器失败: {e}") - - # 初始化 RedBearLLM 客户端 - self.client = RedBearLLM( - RedBearModelConfig( - model_name=self.model_name, - provider=self.provider, - api_key=self.api_key, - base_url=self.base_url, - max_retries=self.max_retries, - timeout=self.timeout, - ), - type=type_ - ) - - logger.info(f"OpenAI 客户端初始化完成: type={type_}") - - async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: - """ - 聊天接口实现 - - Args: - messages: 消息列表 - **kwargs: 额外参数 - - Returns: - LLM 响应内容 - - Raises: - LLMClientException: LLM 调用失败 - """ - try: - template = """{messages}""" - prompt = ChatPromptTemplate.from_template(template) - chain = prompt | self.client - - # 添加 Langfuse 回调(如果可用) - config = {} - if self.langfuse_handler: - config["callbacks"] = [self.langfuse_handler] - - response = await chain.ainvoke({"messages": messages}, config=config) - - logger.debug(f"LLM 响应成功: {len(str(response))} 字符") - return response - - except Exception as e: - logger.error(f"LLM 调用失败: {e}") - raise LLMClientException(f"LLM 调用失败: {e}") from e - - async def response_structured( - self, - messages: List[Dict[str, str]], - response_model: type[BaseModel], - **kwargs - ) -> BaseModel: - """ - 结构化输出接口实现 - - Args: - messages: 消息列表 - response_model: 期望的响应模型类型 - **kwargs: 额外参数 - - Returns: - 解析后的 Pydantic 模型实例 - - Raises: - LLMClientException: LLM 调用或解析失败 - """ - try: - # 构建问题文本 - question_text = "\n\n".join([ - str(m.get("content", "")) for m in messages - ]) - - # 准备配置(包含 Langfuse 回调) - config = {} - if self.langfuse_handler: - config["callbacks"] = [self.langfuse_handler] - - # 方法 1: 使用 PydanticOutputParser - if PydanticOutputParser is not None: - try: - parser = PydanticOutputParser(pydantic_object=response_model) - format_instructions = parser.get_format_instructions() - prompt = ChatPromptTemplate.from_template( - "{question}\n{format_instructions}" - ) - chain = prompt | self.client | parser - - parsed = await chain.ainvoke( - { - "question": question_text, - "format_instructions": format_instructions, - }, - config=config - ) - - logger.debug(f"使用 PydanticOutputParser 解析成功") - return parsed - - except Exception as e: - logger.warning( - f"PydanticOutputParser 解析失败,尝试其他方法: {e}" - ) - - # 方法 2: 使用 LangChain 的 with_structured_output - template = """{question}""" - prompt = ChatPromptTemplate.from_template(template) - - try: - with_so = getattr(self.client, "with_structured_output", None) - - if callable(with_so): - structured_chain = prompt | with_so(response_model, strict=True) - parsed = await structured_chain.ainvoke( - {"question": question_text}, - config=config - ) - - # 验证并返回结果 - try: - return response_model.model_validate(parsed) - except Exception: - # 如果已经是 Pydantic 实例,直接返回 - if hasattr(parsed, "model_dump"): - return parsed - # 尝试从 JSON 解析 - return response_model.model_validate_json(json.dumps(parsed)) - - except Exception as e: - logger.error(f"结构化输出失败: {e}") - raise LLMClientException(f"结构化输出失败: {e}") from e - - # 如果所有方法都失败,抛出异常 - raise LLMClientException( - "无法生成结构化输出,所有解析方法均失败" - ) - - except LLMClientException: - raise - except Exception as e: - logger.error(f"结构化输出处理失败: {e}") - raise LLMClientException(f"结构化输出处理失败: {e}") from e diff --git a/app/core/memory/llm_tools/openai_embedder.py b/app/core/memory/llm_tools/openai_embedder.py deleted file mode 100644 index 2d6fccbc..00000000 --- a/app/core/memory/llm_tools/openai_embedder.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -OpenAI Embedder 客户端实现 - -基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 -""" - -from typing import List -import logging - -from app.core.memory.llm_tools.embedder_client import ( - EmbedderClient, - EmbedderClientException -) -from app.core.models.base import RedBearModelConfig -from app.core.models.embedding import RedBearEmbeddings - -logger = logging.getLogger(__name__) - - -class OpenAIEmbedderClient(EmbedderClient): - """ - OpenAI Embedder 客户端实现 - - 基于 LangChain 和 RedBearEmbeddings 的实现,支持: - - 批量文本嵌入 - - 自动重试机制 - - 错误处理 - """ - - def __init__(self, model_config: RedBearModelConfig): - """ - 初始化 OpenAI Embedder 客户端 - - Args: - model_config: 模型配置 - """ - super().__init__(model_config) - - # 初始化 RedBearEmbeddings 模型 - self.model = RedBearEmbeddings( - RedBearModelConfig( - model_name=self.model_name, - provider=self.provider, - api_key=self.api_key, - base_url=self.base_url, - max_retries=self.max_retries, - timeout=self.timeout, - ) - ) - - logger.info("OpenAI Embedder 客户端初始化完成") - - async def response( - self, - messages: List[str], - **kwargs - ) -> List[List[float]]: - """ - 生成嵌入向量实现 - - Args: - messages: 文本列表 - **kwargs: 额外参数 - - Returns: - 嵌入向量列表 - - Raises: - EmbedderClientException: 嵌入向量生成失败 - """ - try: - # 过滤空文本 - texts: List[str] = [str(m) for m in messages if m is not None] - - if not texts: - logger.warning("输入文本列表为空,返回空结果") - return [] - - # 生成嵌入向量 - embeddings = await self.model.aembed_documents(texts) - - logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") - return embeddings - - except Exception as e: - logger.error(f"嵌入向量生成失败: {e}") - raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e diff --git a/app/core/memory/main.py b/app/core/memory/main.py deleted file mode 100644 index c4fd043a..00000000 --- a/app/core/memory/main.py +++ /dev/null @@ -1,332 +0,0 @@ -""" -MemSci 记忆系统主入口 - 重构版本 - -该模块是重构后的记忆系统主入口,使用新的模块化架构。 -旧版本入口(app/core/memory/src/main.py)已删除。 - -主要功能: -1. 协调整个知识提取流水线 -2. 支持试运行模式和正常运行模式 -3. 使用重构后的 storage_services 模块 -4. 提供统一的配置管理和日志记录 - -作者:Lance77 -日期:2025-11-22 -""" - -# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误 -import os -os.environ["LANGCHAIN_TRACING_V2"] = "false" -os.environ["LANGCHAIN_TRACING"] = "false" -import asyncio -import time -from datetime import datetime -from typing import Optional -from dotenv import load_dotenv - -# 导入重构后的模块 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData -from app.core.memory.models.variate_config import ExtractionPipelineConfig - -# 导入数据加载函数 -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - get_chunked_dialogs_with_preprocessing, - get_chunked_dialogs_from_preprocessed, -) -# 导入配置模块(而不是直接导入变量) -from app.core.memory.utils.config import definitions as config_defs -from app.core.logging_config import get_memory_logger, log_time - -load_dotenv() - -logger = get_memory_logger(__name__) - - - - - -async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False): - """ - 记忆系统主流程 - 重构版本 - - 该函数是重构后的主入口,使用新的模块化架构。 - - Args: - dialogue_text: 输入的对话文本(可选,用于试运行模式) - is_pilot_run: 是否为试运行模式 - - True: 试运行模式,不保存到 Neo4j - - False: 正常运行模式,保存到 Neo4j - - 工作流程: - 1. 初始化客户端和配置 - 2. 加载或准备数据 - 3. 执行知识提取流水线 - 4. 保存结果(正常模式)或输出结果(试运行模式) - """ - print("=" * 60) - print("MemSci 知识提取流水线 - 重构版本") - print("=" * 60) - print(f"运行模式: {'试运行(不保存到Neo4j)' if is_pilot_run else '正常运行(保存到Neo4j)'}") - print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY) - print("Using group ID:", config_defs.SELECTED_GROUP_ID) - print("Using model ID:", config_defs.SELECTED_LLM_ID) - print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID) - print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED) - print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED) - print("=" * 60) - - # 初始化日志 - log_file = "logs/time.log" - os.makedirs(os.path.dirname(log_file), exist_ok=True) - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(log_file, "a", encoding="utf-8") as f: - f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n") - - pipeline_start = time.time() - - try: - # 步骤 1: 初始化客户端 - logger.info("Initializing clients...") - step_start = time.time() - - llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) - - # 获取 embedder 配置并转换为 RedBearModelConfig 对象 - from app.core.models.base import RedBearModelConfig - embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) - - neo4j_connector = Neo4jConnector() - - log_time("Client Initialization", time.time() - step_start, log_file) - - # 步骤 2: 加载或准备数据 - logger.info("Loading data...") - logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}") - logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}") - logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}") - step_start = time.time() - - if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip(): - # 试运行模式:处理前端传入的对话文本 - logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)") - import re - - # 解析对话文本,支持 "用户:" 和 "AI:" 格式 - pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)" - matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL) - messages = [ - ConversationMessage(role=r, msg=c.strip()) - for r, c in matches if c.strip() - ] - - # 如果没有匹配到格式化的对话,将整个文本作为用户消息 - if not messages: - messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())] - - # 创建对话上下文和对话数据 - context = ConversationContext(msgs=messages) - dialog = DialogData( - context=context, - ref_id="pilot_dialog_1", - group_id=config_defs.SELECTED_GROUP_ID, - user_id=config_defs.SELECTED_USER_ID, - apply_id=config_defs.SELECTED_APPLY_ID, - metadata={"source": "pilot_run", "input_type": "frontend_text"} - ) - - # 对前端传入的对话进行分块处理 - chunked_dialogs = await get_chunked_dialogs_from_preprocessed( - data=[dialog], - chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY, - llm_client=llm_client, - ) - logger.info(f"Processed frontend dialogue text: {len(messages)} messages") - else: - # 正常运行模式:从 testdata.json 文件加载 - logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)") - logger.info("Loading data from testdata.json...") - test_data_path = os.path.join( - os.path.dirname(__file__), "data", "testdata.json" - ) - - if not os.path.exists(test_data_path): - raise FileNotFoundError(f"Test data file not found: {test_data_path}") - - chunked_dialogs = await get_chunked_dialogs_with_preprocessing( - chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY, - group_id=config_defs.SELECTED_GROUP_ID, - user_id=config_defs.SELECTED_USER_ID, - apply_id=config_defs.SELECTED_APPLY_ID, - indices=config_defs.SELECTED_TEST_DATA_INDICES, - input_data_path=test_data_path, - llm_client=llm_client, - skip_cleaning=True, - ) - logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json") - - log_time("Data Loading & Chunking", time.time() - step_start, log_file) - - # 步骤 3: 初始化流水线编排器 - logger.info("Initializing extraction orchestrator...") - step_start = time.time() - - # 从 runtime.json 加载配置(已经过数据库覆写) - from app.core.memory.utils.config.config_utils import get_pipeline_config - config = get_pipeline_config() - - logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}") - - orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=neo4j_connector, - config=config, - ) - - log_time("Orchestrator Initialization", time.time() - step_start, log_file) - - # 步骤 4: 执行知识提取流水线 - logger.info("Running extraction pipeline...") - step_start = time.time() - - extraction_result = await orchestrator.run( - dialog_data_list=chunked_dialogs, - is_pilot_run=is_pilot_run, # 传递试运行模式标志 - ) - - # 解包 extraction_result tuple - # extraction_result 是一个包含 7 个元素的 tuple: - # (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes, - # statement_chunk_edges, statement_entity_edges, entity_edges) - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_edges, - ) = extraction_result - - log_time("Extraction Pipeline", time.time() - step_start, log_file) - - # 步骤 5: 保存结果或输出结果 - if is_pilot_run: - logger.info("Pilot run mode: Skipping Neo4j save") - print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。") - print("提取结果已生成,可在相关输出中查看。") - else: - logger.info("Normal mode: Saving to Neo4j...") - step_start = time.time() - - # 创建索引和约束 - try: - from app.repositories.neo4j.create_indexes import ( - create_fulltext_indexes, - create_unique_constraints, - ) - await create_fulltext_indexes() - await create_unique_constraints() - logger.info("Successfully created indexes and constraints") - except Exception as e: - logger.error(f"Error creating indexes/constraints: {e}") - - # 保存数据到 Neo4j - try: - from app.repositories.neo4j.graph_saver import ( - save_dialog_and_statements_to_neo4j, - ) - - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=dialogue_nodes, - chunk_nodes=chunk_nodes, - statement_nodes=statement_nodes, - entity_nodes=entity_nodes, - statement_chunk_edges=statement_chunk_edges, - statement_entity_edges=statement_entity_edges, - entity_edges=entity_edges, - connector=neo4j_connector, - ) - - if success: - logger.info("Successfully saved all data to Neo4j") - print("\n✓ 成功保存所有数据到 Neo4j") - else: - logger.warning("Failed to save some data to Neo4j") - print("\n⚠ 部分数据保存到 Neo4j 失败") - except Exception as e: - logger.error(f"Error saving to Neo4j: {e}", exc_info=True) - print(f"\n✗ 保存到 Neo4j 失败: {e}") - - log_time("Neo4j Database Save", time.time() - step_start, log_file) - - # 步骤 6: 生成记忆摘要(可选) - try: - logger.info("Generating memory summaries...") - step_start = time.time() - - from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - Memory_summary_generation, - ) - from app.repositories.neo4j.add_nodes import add_memory_summary_nodes - from app.repositories.neo4j.add_edges import ( - add_memory_summary_statement_edges, - ) - - summaries = await Memory_summary_generation( - chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID - ) - - if not is_pilot_run: - # 保存记忆摘要到 Neo4j - ms_connector = Neo4jConnector() - try: - await add_memory_summary_nodes(summaries, ms_connector) - await add_memory_summary_statement_edges(summaries, ms_connector) - finally: - await ms_connector.close() - - log_time("Memory Summary Generation", time.time() - step_start, log_file) - except Exception as e: - logger.error(f"Memory summary step failed: {e}", exc_info=True) - - except Exception as e: - logger.error(f"Pipeline execution failed: {e}", exc_info=True) - print(f"\n✗ 流水线执行失败: {e}") - raise - finally: - # 清理资源 - try: - await neo4j_connector.close() - except Exception: - pass - - # 记录总时间 - total_time = time.time() - pipeline_start - log_time("TOTAL PIPELINE TIME", total_time, log_file) - - # 添加完成标记 - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - with open(log_file, "a", encoding="utf-8") as f: - f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") - - logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") - logger.info(f"Timing details saved to: {log_file}") - - print("\n" + "=" * 60) - print(f"✓ 流水线执行完成") - print(f"✓ 总耗时: {total_time:.2f} 秒") - print(f"✓ 详细日志: {log_file}") - print("=" * 60) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/app/core/memory/models/__init__.py b/app/core/memory/models/__init__.py deleted file mode 100644 index 1de3424a..00000000 --- a/app/core/memory/models/__init__.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Data models for the Memory module. - -This package contains all Pydantic models used in the memory system, -including models for messages, dialogues, statements, entities, triplets, -graph nodes/edges, configurations, and deduplication decisions. -""" - -# Base response models -from app.core.memory.models.base_response import RobustLLMResponse - -# Configuration models -from app.core.memory.models.config_models import ( - LLMConfig, - ChunkerConfig, - PruningConfig, - TemporalSearchParams, -) - -# Deduplication models -from app.core.memory.models.dedup_models import ( - EntityDedupDecision, - EntityDisambDecision, -) - -# Graph models (nodes and edges) -from app.core.memory.models.graph_models import ( - # Edges - Edge, - ChunkEdge, - ChunkEntityEdge, - ChunkDialogEdge, - StatementChunkEdge, - StatementEntityEdge, - EntityEntityEdge, - # Nodes - Node, - DialogueNode, - StatementNode, - ChunkNode, - ExtractedEntityNode, - MemorySummaryNode, -) - -# Message and dialogue models -from app.core.memory.models.message_models import ( - ConversationMessage, - TemporalValidityRange, - Statement, - ConversationContext, - Chunk, - DialogData, -) - -# Triplet and entity models -from app.core.memory.models.triplet_models import ( - Entity, - Triplet, - TripletExtractionResponse, -) - -# Variable configuration models -from app.core.memory.models.variate_config import ( - StatementExtractionConfig, - ForgettingEngineConfig, - TripletExtractionConfig, - TemporalExtractionConfig, - DedupConfig, - ExtractionPipelineConfig, -) - -__all__ = [ - # Base response - "RobustLLMResponse", - # Configuration - "LLMConfig", - "ChunkerConfig", - "PruningConfig", - "TemporalSearchParams", - # Deduplication - "EntityDedupDecision", - "EntityDisambDecision", - # Graph edges - "Edge", - "ChunkEdge", - "ChunkEntityEdge", - "ChunkDialogEdge", - "StatementChunkEdge", - "StatementEntityEdge", - "EntityEntityEdge", - # Graph nodes - "Node", - "DialogueNode", - "StatementNode", - "ChunkNode", - "ExtractedEntityNode", - "MemorySummaryNode", - # Messages and dialogues - "ConversationMessage", - "TemporalValidityRange", - "Statement", - "ConversationContext", - "Chunk", - "DialogData", - # Triplets and entities - "Entity", - "Triplet", - "TripletExtractionResponse", - # Variable configuration - "StatementExtractionConfig", - "ForgettingEngineConfig", - "TripletExtractionConfig", - "TemporalExtractionConfig", - "DedupConfig", - "ExtractionPipelineConfig", -] diff --git a/app/core/memory/models/base_response.py b/app/core/memory/models/base_response.py deleted file mode 100644 index 775588f3..00000000 --- a/app/core/memory/models/base_response.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Base classes for LLM response models with common validators. - -This module provides reusable base classes for Pydantic models that handle -common LLM response patterns and edge cases. - -Classes: - RobustLLMResponse: Base class for LLM response models with robust validation -""" - -from typing import Any -from pydantic import BaseModel, ConfigDict, model_validator - - -class RobustLLMResponse(BaseModel): - """Base class for LLM response models with robust validation. - - This base class provides: - - Automatic handling of list-wrapped responses (e.g., [{"field": "value"}]) - - Ignoring extra fields from LLM output - - Validation on assignment - - Usage: - class MyResponse(RobustLLMResponse): - field1: str - field2: int - """ - - model_config = ConfigDict( - extra="ignore", # Allow extra fields to be ignored (more forgiving) - validate_assignment=True # Validate on assignment - ) - - @model_validator(mode='before') - @classmethod - def handle_list_input(cls, data: Any) -> Any: - """Handle cases where LLM returns a list instead of a dict. - - Some LLMs may wrap the response in a list like [{"field": "value"}]. - This validator extracts the first item if that happens. - - Args: - data: The input data from the LLM - - Returns: - The unwrapped data (dict) - - Raises: - ValueError: If the input is invalid (empty list, wrong type, etc.) - """ - if isinstance(data, list): - if len(data) == 0: - raise ValueError("Received empty list from LLM") - # Extract first item from list - data = data[0] - - if not isinstance(data, dict): - raise ValueError(f"Expected dict or list, got {type(data).__name__}") - - return data diff --git a/app/core/memory/models/config_models.py b/app/core/memory/models/config_models.py deleted file mode 100644 index f3341cc5..00000000 --- a/app/core/memory/models/config_models.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Configuration models for Memory module components. - -This module contains Pydantic models for configuring various components -of the memory system including LLM, chunking, pruning, and search. - -Classes: - LLMConfig: Configuration for LLM client - ChunkerConfig: Configuration for dialogue chunking - PruningConfig: Configuration for semantic pruning - TemporalSearchParams: Parameters for temporal search queries -""" - -from typing import Optional -from pydantic import BaseModel, Field - - -class LLMConfig(BaseModel): - """Configuration for Large Language Model client. - - Attributes: - llm_name: The name of the LLM model to use (e.g., 'gpt-4', 'claude-3') - api_base: Optional base URL for the API endpoint - max_retries: Maximum number of retries for failed API calls (default: 3) - """ - llm_name: str = Field(..., description="The name of the LLM model to use.") - api_base: Optional[str] = Field(None, description="The base URL for the API endpoint.") - max_retries: Optional[int] = Field(3, ge=0, description="The maximum number of retries for API calls.") - - -class ChunkerConfig(BaseModel): - """Configuration for dialogue chunking strategy. - - Attributes: - chunker_strategy: Name of the chunking strategy (e.g., 'RecursiveChunker', 'SemanticChunker') - embedding_model: Name of the embedding model to use for semantic chunking - chunk_size: Maximum size of each chunk in characters (default: 2048) - threshold: Similarity threshold for semantic chunking (0-1, default: 0.8) - language: Language of the text (default: 'zh' for Chinese) - skip_window: Window size for skip-and-merge strategy (default: 0) - min_sentences: Minimum number of sentences per chunk (default: 1) - min_characters_per_chunk: Minimum characters per chunk (default: 24) - """ - chunker_strategy: str = Field(..., description="The name of the chunker strategy to use.") - embedding_model: str = Field(..., description="The name of the embedding model to use.") - chunk_size: Optional[int] = Field(2048, ge=0, description="The size of each chunk.") - threshold: Optional[float] = Field(0.8, ge=0, le=1, description="The threshold for similarity.") - language: Optional[str] = Field("zh", description="The language of the text.") - skip_window: Optional[int] = Field(0, ge=0, description="The window for skip-and-merge.") - min_sentences: Optional[int] = Field(1, ge=0, description="The minimum number of sentences in each chunk.") - min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.") - - -class PruningConfig(BaseModel): - """Configuration for semantic pruning of dialogue content. - - Attributes: - pruning_switch: Enable or disable semantic pruning - pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound') - pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal) - """ - pruning_switch: bool = Field(False, description="Enable semantic pruning when True.") - pruning_scene: str = Field( - "education", - description="Scene for pruning: one of 'education', 'online_service', 'outbound'.", - ) - pruning_threshold: float = Field( - 0.5, ge=0.0, le=0.9, - description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).") - - -class TemporalSearchParams(BaseModel): - """Parameters for temporal search queries in the knowledge graph. - - Attributes: - group_id: Group ID to filter search results (default: 'test') - apply_id: Application ID to filter search results - user_id: User ID to filter search results - start_date: Start date for temporal filtering (format: 'YYYY-MM-DD') - end_date: End date for temporal filtering (format: 'YYYY-MM-DD') - valid_date: Date when memory should be valid (format: 'YYYY-MM-DD') - invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD') - limit: Maximum number of results to return (default: 3) - """ - group_id: Optional[str] = Field("test", description="The group ID to filter the search.") - apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.") - user_id: Optional[str] = Field(None, description="The user ID to filter the search.") - start_date: Optional[str] = Field(None, description="The start date for the search.") - end_date: Optional[str] = Field(None, description="The end date for the search.") - valid_date: Optional[str] = Field(None, description="The valid date for the search.") - invalid_date: Optional[str] = Field(None, description="The invalid date for the search.") - limit: int = Field(default=3, description="The maximum number of results to return.") - - diff --git a/app/core/memory/models/dedup_models.py b/app/core/memory/models/dedup_models.py deleted file mode 100644 index 87dcfb84..00000000 --- a/app/core/memory/models/dedup_models.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Models for entity deduplication and disambiguation decisions. - -This module contains Pydantic models for structured LLM responses -during entity deduplication and disambiguation processes. - -Classes: - EntityDedupDecision: Decision model for entity deduplication - EntityDisambDecision: Decision model for entity disambiguation -""" - -from typing import Optional -from pydantic import BaseModel, Field - - -class EntityDedupDecision(BaseModel): - """Structured decision returned by LLM for entity deduplication. - - This model represents the LLM's decision on whether two entities - refer to the same real-world entity and should be merged. - - Attributes: - same_entity: Whether the two entities refer to the same real-world entity - confidence: Model confidence in the decision (0.0 to 1.0) - canonical_idx: Index of the canonical entity to keep when merging (0 or 1, -1 if not applicable) - reason: Brief rationale for the decision (1-3 sentences, kept for audit) - """ - same_entity: bool = Field(..., description="Two entities refer to the same entity") - confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence of the decision") - canonical_idx: int = Field(..., description="Index of canonical entity among the pair: 0 or 1; -1 if not applicable") - reason: str = Field(..., description="Short rationale, 1-3 sentences") - - -class EntityDisambDecision(BaseModel): - """Structured disambiguation decision for same-name but different-type entities. - - This model represents the LLM's decision on whether two entities with - the same name but different types should be merged or kept separate. - - Attributes: - should_merge: Whether the two entities should be merged despite type difference - confidence: Model confidence in the decision (0.0 to 1.0) - canonical_idx: Index of the canonical entity to keep when merging (0 or 1, -1 if not applicable) - block_pair: If True, this pair should be blocked from fuzzy/auto merges - suggested_type: Optional unified type to apply when should_merge is True - reason: Brief rationale for audit and analysis (1-3 sentences) - """ - should_merge: bool = Field(..., description="Merge the pair despite type difference") - confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence of the decision") - canonical_idx: int = Field(..., description="Index of canonical entity among the pair: 0 or 1; -1 if not applicable") - block_pair: bool = Field(False, description="Block this pair from fuzzy or heuristic merges") - suggested_type: Optional[str] = Field(None, description="Unified entity type when merging") - reason: str = Field(..., description="Short rationale, 1-3 sentences") diff --git a/app/core/memory/models/graph_models.py b/app/core/memory/models/graph_models.py deleted file mode 100644 index b1dc5de7..00000000 --- a/app/core/memory/models/graph_models.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Graph models for Neo4j knowledge graph nodes and edges. - -This module contains Pydantic models representing nodes and edges -in the Neo4j knowledge graph, including dialogues, statements, -chunks, entities, and their relationships. - -Classes: - Edge: Base class for all graph edges - ChunkEdge: Edge connecting chunks - ChunkEntityEdge: Edge connecting chunks to entities - ChunkDialogEdge: Edge connecting chunks to dialogues - StatementChunkEdge: Edge connecting statements to chunks - StatementEntityEdge: Edge connecting statements to entities - EntityEntityEdge: Edge connecting related entities - Node: Base class for all graph nodes - DialogueNode: Node representing a dialogue - StatementNode: Node representing a statement - ChunkNode: Node representing a conversation chunk - ExtractedEntityNode: Node representing an extracted entity - MemorySummaryNode: Node representing a memory summary -""" - -from uuid import uuid4 -from datetime import datetime, timezone -from typing import List, Optional -from pydantic import BaseModel, Field, field_validator -import re - -from app.core.memory.utils.data.ontology import TemporalInfo - - -def parse_historical_datetime(v): - """支持任意年份的日期解析,包括历史日期(如公元755年) - - Python datetime 支持公元1年到9999年的日期 - 此函数手动解析 ISO 8601 格式的日期字符串,支持1-4位年份 - - Args: - v: 日期值(可以是 None、datetime 对象或字符串) - - Returns: - datetime 对象或 None - """ - if v is None or isinstance(v, datetime): - return v - - if isinstance(v, str): - # 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM] - # 支持1-4位年份 - pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?' - match = re.match(pattern, v) - - if match: - try: - year = int(match.group(1)) - month = int(match.group(2)) - day = int(match.group(3)) - hour = int(match.group(4)) if match.group(4) else 0 - minute = int(match.group(5)) if match.group(5) else 0 - second = int(match.group(6)) if match.group(6) else 0 - microsecond = 0 - - # 处理微秒 - if match.group(7): - # 补齐或截断到6位 - us_str = match.group(7).ljust(6, '0')[:6] - microsecond = int(us_str) - - # 处理时区 - tzinfo = None - if 'Z' in v or match.group(8): - tzinfo = timezone.utc - - # 创建 datetime 对象 - return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo) - - except (ValueError, OverflowError): - # 日期值无效(如月份13、日期32等) - return None - - # 如果不匹配模式,尝试使用 fromisoformat(用于标准格式) - try: - return datetime.fromisoformat(v.replace('Z', '+00:00')) - except Exception: - return None - - return v - - -class Edge(BaseModel): - """Base class for all graph edges in the knowledge graph. - - Attributes: - id: Unique identifier for the edge - source: ID of the source node - target: ID of the target node - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data - run_id: Unique identifier for the pipeline run that created this edge - created_at: Timestamp when the edge was created (system perspective) - expired_at: Optional timestamp when the edge expires (system perspective) - """ - id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.") - source: str = Field(..., description="The ID of the source node.") - target: str = Field(..., description="The ID of the target node.") - group_id: str = Field(..., description="The group ID of the edge.") - user_id: str = Field(..., description="The user ID of the edge.") - apply_id: str = Field(..., description="The apply ID of the edge.") - run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") - created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") - expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") - - -class ChunkEdge(Edge): - """Edge connecting two chunks in sequence.""" - pass - - -class ChunkEntityEdge(Edge): - """Edge connecting a chunk to an entity mentioned in it.""" - pass - - -class ChunkDialogEdge(Edge): - """Edge connecting a chunk to its parent dialog. - - Attributes: - sequence_number: Order of this chunk within the dialog - """ - sequence_number: int = Field(..., description="Order of this chunk within the dialog") - - -class StatementChunkEdge(Edge): - """Edge connecting a statement to its parent chunk.""" - pass - - -class StatementEntityEdge(Edge): - """Edge connecting a statement to entities extracted from it. - - Attributes: - connect_strength: Classification of connection strength ('Strong' or 'Weak') - """ - connect_strength: str = Field(..., description="Strong VS Weak about this statement-entity edge") - - -class EntityEntityEdge(Edge): - """Edge connecting related entities (from triplet relationships). - - Attributes: - relation_type: Type of relationship as defined in ontology - relation_value: Optional value of the relation - statement: The statement text where this relationship was found - source_statement_id: ID of the statement where this relationship was extracted - valid_at: Optional start date of temporal validity - invalid_at: Optional end date of temporal validity - """ - relation_type: str = Field(..., description="Relation type as defined in ontology") - relation_value: Optional[str] = Field(None, description="Value of the relation") - statement: str = Field(..., description='The statement of the edge.') - source_statement_id: str = Field(..., description="Statement where this relationship was extracted") - valid_at: Optional[datetime] = Field(None, description="Temporal validity start") - invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - - @field_validator('valid_at', 'invalid_at', mode='before') - @classmethod - def validate_datetime(cls, v): - """使用通用的历史日期解析函数""" - return parse_historical_datetime(v) - - -class Node(BaseModel): - """Base class for all graph nodes in the knowledge graph. - - Attributes: - id: Unique identifier for the node - name: Name of the node - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data - run_id: Unique identifier for the pipeline run that created this node - created_at: Timestamp when the node was created (system perspective) - expired_at: Optional timestamp when the node expires (system perspective) - """ - id: str = Field(..., description="The unique identifier for the node.") - name: str = Field(..., description="The name of the node.") - group_id: str = Field(..., description="The group ID of the node.") - user_id: str = Field(..., description="The user ID of the edge.") - apply_id: str = Field(..., description="The apply ID of the edge.") - run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") - created_at: datetime = Field(..., description="The valid time of the node from system perspective.") - expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.") - - -class DialogueNode(Node): - """Node representing a dialogue in the knowledge graph. - - Attributes: - ref_id: Reference identifier linking to external dialog system - content: Full dialogue content as text - dialog_embedding: Optional embedding vector for the entire dialogue - config_id: Configuration ID used to process this dialogue - """ - ref_id: str = Field(..., description="Reference identifier of the dialog") - content: str = Field(..., description="Dialogue content") - dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)") - - -class StatementNode(Node): - """Node representing a statement extracted from dialogue. - - Attributes: - chunk_id: ID of the parent chunk this statement belongs to - stmt_type: Type of the statement (from ontology) - temporal_info: Temporal information extracted from the statement - statement: The actual statement text content - connect_strength: Classification of connection strength ('Strong' or 'Weak') - valid_at: Optional start date of temporal validity - invalid_at: Optional end date of temporal validity - statement_embedding: Optional embedding vector for the statement - chunk_embedding: Optional embedding vector for the parent chunk - config_id: Configuration ID used to process this statement - """ - chunk_id: str = Field(..., description="ID of the parent chunk") - stmt_type: str = Field(..., description="Type of the statement") - temporal_info: TemporalInfo = Field(..., description="Temporal information") - statement: str = Field(..., description="The statement text content") - connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") - valid_at: Optional[datetime] = Field(None, description="Temporal validity start") - invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") - statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") - chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") - - @field_validator('valid_at', 'invalid_at', mode='before') - @classmethod - def validate_datetime(cls, v): - """使用通用的历史日期解析函数""" - return parse_historical_datetime(v) - - -class ChunkNode(Node): - """Node representing a chunk of conversation in the knowledge graph. - - Attributes: - dialog_id: ID of the parent dialog - content: The text content of the chunk - chunk_embedding: Optional embedding vector for the chunk - sequence_number: Order of this chunk within the dialog - metadata: Additional chunk metadata as key-value pairs - """ - dialog_id: str = Field(..., description="ID of the parent dialog") - content: str = Field(..., description="The text content of the chunk") - chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") - sequence_number: int = Field(..., description="Order of this chunk within the dialog") - metadata: dict = Field(default_factory=dict, description="Additional chunk metadata") - - -class ExtractedEntityNode(Node): - """Node representing an extracted entity in the knowledge graph. - - Attributes: - entity_idx: Unique numeric identifier for the entity - statement_id: ID of the statement this entity was extracted from - entity_type: Type/category of the entity - description: Textual description of the entity - aliases: Optional list of alternative names for the entity - name_embedding: Optional embedding vector for the entity name - fact_summary: Summary of facts about this entity - connect_strength: Classification of connection strength ('Strong' or 'Weak') - config_id: Configuration ID used to process this entity - """ - entity_idx: int = Field(..., description="Unique identifier for the entity") - statement_id: str = Field(..., description="Statement this entity was extracted from") - entity_type: str = Field(..., description="Type of the entity") - description: str = Field(..., description="Entity description") - aliases: Optional[List[str]] = Field(default_factory=list, description="Entity aliases") - name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") - fact_summary: str = Field(..., description="Summary of the fact about this entity") - connect_strength: str = Field(..., description="Strong VS Weak about this entity") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") - - -class MemorySummaryNode(Node): - """Node representing a memory summary with vector embedding. - - Attributes: - summary_id: Unique identifier for the summary - dialog_id: ID of the parent dialog - chunk_ids: List of chunk IDs used to generate this summary - content: Summary text content - summary_embedding: Optional embedding vector for the summary - metadata: Additional metadata for the summary - config_id: Configuration ID used to process this summary - """ - summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary") - dialog_id: str = Field(..., description="ID of the parent dialog") - chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary") - content: str = Field(..., description="Summary text content") - summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") - metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") diff --git a/app/core/memory/models/message_models.py b/app/core/memory/models/message_models.py deleted file mode 100644 index 192816fd..00000000 --- a/app/core/memory/models/message_models.py +++ /dev/null @@ -1,247 +0,0 @@ -"""Models for dialogue messages, conversations, and statements. - -This module contains Pydantic models for representing dialogue data, -including messages, conversation context, chunks, and statements. - -Classes: - ConversationMessage: Single message in a conversation - TemporalValidityRange: Temporal validity range for statements - Statement: Statement extracted from dialogue with metadata - ConversationContext: Full conversation history - Chunk: Chunk of conversation text - DialogData: Complete dialogue data structure -""" - -from typing import List, Dict, Any, Optional -from pydantic import BaseModel, Field -from uuid import uuid4 -from datetime import datetime - -from app.core.memory.utils.data.ontology import StatementType, TemporalInfo, RelevenceInfo -from app.core.memory.models.triplet_models import TripletExtractionResponse, Triplet - - -class ConversationMessage(BaseModel): - """Represents a single message in a conversation. - - Attributes: - role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant) - msg: Text content of the message - """ - role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').") - msg: str = Field(..., description="The text content of the message.") - - -class TemporalValidityRange(BaseModel): - """Represents the temporal validity range of a statement. - - Attributes: - valid_at: Start date of validity in 'YYYY-MM-DD' format (None if not specified) - invalid_at: End date of validity in 'YYYY-MM-DD' format (None if not specified) - """ - valid_at: Optional[str] = Field( - None, - description="The start date of the statement's validity, in 'YYYY-MM-DD' format or 'None'.", - ) - invalid_at: Optional[str] = Field( - None, - description="The end date of the statement's validity, in 'YYYY-MM-DD' format or 'None'.", - ) - - -class Statement(BaseModel): - """Represents a statement extracted from dialogue with metadata. - - Attributes: - id: Unique identifier for the statement - chunk_id: ID of the parent chunk this statement belongs to - group_id: Optional group ID for multi-tenancy - statement: The actual statement text content - statement_embedding: Optional embedding vector for the statement - stmt_type: Type of the statement (from ontology) - temporal_info: Temporal information extracted from the statement - relevence_info: Relevance classification (RELEVANT or IRRELEVANT) - connect_strength: Optional connection strength ('Strong' or 'Weak') - temporal_validity: Optional temporal validity range - triplet_extraction_info: Optional triplet extraction results - """ - id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.") - chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") - group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") - statement: str = Field(..., description="The text content of the statement.") - statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") - stmt_type: StatementType = Field(..., description="The type of the statement.") - temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.") - relevence_info: RelevenceInfo = Field(RelevenceInfo.RELEVANT, description="The relevence information of the statement.") - connect_strength: Optional[str] = Field(None, description="Strong VS Weak about this entity") - temporal_validity: Optional[TemporalValidityRange] = Field( - None, description="The temporal validity range of the statement." - ) - triplet_extraction_info: Optional[TripletExtractionResponse] = Field( - None, description="The triplet extraction information of the statement." - ) - - -class ConversationContext(BaseModel): - """Represents the full conversation history. - - Attributes: - msgs: List of messages in the conversation - - Properties: - content: Formatted string representation of the conversation - """ - msgs: List[ConversationMessage] = Field(..., description="A list of messages in the conversation.") - - @property - def content(self) -> str: - """Get the content of the conversation as a formatted string. - - Returns: - String with format "role: message" for each message, joined by newlines - """ - return "\n".join([f"{msg.role}: {msg.msg}" for msg in self.msgs]) - -class Chunk(BaseModel): - """A chunk of text from the conversation context. - - Attributes: - id: Unique identifier for the chunk - text: List of messages in the chunk - content: The content of the chunk as a formatted string - statements: List of statements extracted from this chunk - chunk_embedding: Optional embedding vector for the chunk - metadata: Additional metadata as key-value pairs - """ - id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.") - text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.") - content: str = Field(..., description="The content of the chunk as a string.") - statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") - chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") - - @classmethod - def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None): - """Create a chunk from a list of messages. - - Args: - messages: List of conversation messages - metadata: Optional metadata dictionary - - Returns: - Chunk instance with formatted content - """ - if metadata is None: - metadata = {} - # Generate content from messages - content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages]) - return cls(text=messages, content=content, metadata=metadata) - - -class DialogData(BaseModel): - """Represents the complete data structure for a dialog record. - - Attributes: - id: Unique identifier for the dialog - context: Full conversation context - dialog_embedding: Optional embedding vector for the entire dialog - ref_id: Reference ID linking to external dialog system - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data - created_at: Timestamp when the dialog was created - expired_at: Timestamp when the dialog expires (default: far future) - metadata: Additional metadata as key-value pairs - chunks: List of chunks from the conversation - config_id: Configuration ID used to process this dialog - - Properties: - content: Formatted string representation of the dialog - """ - id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the dialog.") - context: ConversationContext = Field(..., description="The full conversation context as a single string.") - dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.") - ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.") - group_id: str = Field(default=..., description="Group ID of dialogue data") - user_id: str = Field(..., description="USER ID of dialogue data") - apply_id: str = Field(..., description="APPLY ID of dialogue data") - run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") - created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.") - expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the dialog.") - chunks: List[Chunk] = Field(default_factory=list, description="A list of chunks from the conversation context.") - config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialog (integer or string)") - - @property - def content(self) -> str: - """Get the content of the dialog as a formatted string. - - Returns: - String representation of the conversation context - """ - return self.context.content - - def get_statement_chunk(self, statement_id: str) -> Optional[Chunk]: - """Find the chunk containing a specific statement. - - Args: - statement_id: ID of the statement to find - - Returns: - Chunk containing the statement, or None if not found - """ - for chunk in self.chunks: - for statement in chunk.statements: - if statement.id == statement_id: - return chunk - return None - - def get_all_statements(self) -> List[Statement]: - """Get all statements from all chunks. - - Returns: - List of all statements in the dialog - """ - all_statements = [] - for chunk in self.chunks: - all_statements.extend(chunk.statements) - return all_statements - - def get_statement_by_id(self, statement_id: str) -> Optional[Statement]: - """Find a specific statement by its ID. - - Args: - statement_id: ID of the statement to find - - Returns: - Statement with the given ID, or None if not found - """ - for chunk in self.chunks: - for statement in chunk.statements: - if statement.id == statement_id: - return statement - return None - - def get_triplets_for_statement(self, statement_id: str) -> List[Triplet]: - """Get all triplets extracted from a specific statement. - - Args: - statement_id: ID of the statement - - Returns: - List of triplets from the statement, or empty list if none found - """ - statement = self.get_statement_by_id(statement_id) - if statement and statement.triplet_extraction_info: - return statement.triplet_extraction_info.triplets - return [] - - def assign_group_id_to_statements(self) -> None: - """Assign this dialog's group_id to all statements in all chunks. - - This method updates statements that don't have a group_id set. - """ - for chunk in self.chunks: - for statement in chunk.statements: - if statement.group_id is None: - statement.group_id = self.group_id diff --git a/app/core/memory/models/triplet_models.py b/app/core/memory/models/triplet_models.py deleted file mode 100644 index 7439ee34..00000000 --- a/app/core/memory/models/triplet_models.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Models for knowledge triplets and entities. - -This module contains Pydantic models for representing extracted knowledge -in the form of entities and triplets (subject-predicate-object relationships). - -Classes: - Entity: Represents an extracted entity - Triplet: Represents a knowledge triplet (subject-predicate-object) - TripletExtractionResponse: Response model containing extracted triplets and entities -""" - -from typing import List, Optional -from pydantic import BaseModel, Field, ConfigDict -from uuid import uuid4 - - -class Entity(BaseModel): - """Represents an extracted entity from dialogue. - - Attributes: - id: Unique string identifier for the entity - entity_idx: Numeric index for the entity - name: Name of the entity - name_embedding: Optional embedding vector for the entity name - type: Type/category of the entity (e.g., 'Person', 'Organization') - description: Textual description of the entity - - Config: - extra: Ignore extra fields from LLM output - """ - model_config = ConfigDict(extra='ignore') - id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the entity.") - entity_idx: int = Field(..., description="Unique identifier for the entity") - name: str = Field(..., description="Name of the entity") - name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name") - type: str = Field(..., description="Type/category of the entity") - description: str = Field(..., description="Description of the entity") - - -class Triplet(BaseModel): - """Represents an extracted knowledge triplet (subject-predicate-object). - - A triplet represents a relationship between two entities, forming - the basic unit of knowledge in the knowledge graph. - - Attributes: - id: Unique string identifier for the triplet - statement_id: Optional ID of the parent statement (set programmatically) - subject_name: Name of the subject entity - subject_id: Numeric ID of the subject entity - predicate: Relationship/predicate between subject and object - object_name: Name of the object entity - object_id: Numeric ID of the object entity - value: Optional additional value or context for the relationship - - Config: - extra: Ignore extra fields from LLM output - """ - model_config = ConfigDict(extra='ignore') - id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the triplet.") - statement_id: Optional[str] = Field(None, description="ID of the parent statement this triplet was extracted from.") - subject_name: str = Field(..., description="Name of the subject entity") - subject_id: int = Field(..., description="ID of the subject entity") - predicate: str = Field(..., description="Relationship/predicate between subject and object") - object_name: str = Field(..., description="Name of the object entity") - object_id: int = Field(..., description="ID of the object entity") - value: Optional[str] = Field(None, description="Additional value or context") - - -class TripletExtractionResponse(BaseModel): - """Response model for triplet extraction from LLM. - - This model represents the structured output from the LLM when - extracting knowledge triplets and entities from statements. - - Attributes: - triplets: List of extracted knowledge triplets - entities: List of extracted entities - - Config: - extra: Ignore extra fields from LLM output - """ - model_config = ConfigDict(extra='ignore') - triplets: List[Triplet] = Field(default_factory=list, description="List of extracted triplets") - entities: List[Entity] = Field(default_factory=list, description="List of extracted entities") diff --git a/app/core/memory/models/variate_config.py b/app/core/memory/models/variate_config.py deleted file mode 100644 index 24abd39c..00000000 --- a/app/core/memory/models/variate_config.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Variable configuration models for extraction pipeline components. - -This module contains Pydantic models for configuring various aspects -of the extraction pipeline, including statement extraction, triplet extraction, -temporal extraction, deduplication, and forgetting mechanisms. - -Classes: - StatementExtractionConfig: Configuration for statement extraction - ForgettingEngineConfig: Configuration for forgetting engine - TripletExtractionConfig: Configuration for triplet extraction - TemporalExtractionConfig: Configuration for temporal extraction - DedupConfig: Configuration for entity deduplication - ExtractionPipelineConfig: Combined configuration for entire pipeline -""" - -from typing import Optional -from pydantic import BaseModel, Field - - -class StatementExtractionConfig(BaseModel): - """Configuration for statement extraction behavior. - - Attributes: - statement_granularity: Granularity level (1-3): - - 1: Split sentences into different statements - - 2: Sentence-level statements - - 3: Combine sentences, shorten long statements - temperature: LLM temperature for statement extraction (0-2, default: 0.1) - include_dialogue_context: Whether to include full dialogue context - max_dialogue_context_chars: Maximum characters from dialogue context (default: 2000) - """ - statement_granularity: Optional[int] = Field(None, ge=1, le=3, description="Granularity of statements to extract, level 1 to 3") - temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for statement extraction") - include_dialogue_context: bool = Field(True, description="Whether to include full dialogue context in extraction") - max_dialogue_context_chars: Optional[int] = Field(2000, ge=100, description="Maximum number of characters to include from dialogue context") - - -class ForgettingEngineConfig(BaseModel): - """Configuration for the forgetting engine. - - The forgetting engine implements a memory decay mechanism based on - time and memory strength parameters. - - Attributes: - offset: Minimum retention level (0-1, prevents complete forgetting, default: 0.1) - lambda_time: Lambda parameter controlling time decay effect (default: 0.1) - lambda_mem: Lambda parameter controlling memory strength effect (default: 1.0) - """ - offset: float = Field(0.1, ge=0.0, le=1.0, description="Minimum retention level (prevents complete forgetting).") - lambda_time: float = Field(0.1, gt=0.0, description="Lambda parameter controlling time effect.") - lambda_mem: float = Field(1.0, gt=0.0, description="Lambda parameter controlling memory strength effect.") - - -class TripletExtractionConfig(BaseModel): - """Configuration for triplet extraction behavior. - - Attributes: - temperature: LLM temperature for triplet extraction (0-2, default: 0.1) - enable_entity_normalization: Whether to normalize entity names (default: True) - confidence_threshold: Minimum confidence for extracted triplets (0-1, default: 0.7) - """ - temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for triplet extraction") - enable_entity_normalization: bool = Field(True, description="Whether to normalize entity names") - confidence_threshold: Optional[float] = Field(0.7, ge=0, le=1, description="Minimum confidence threshold for extracted triplets") - - -class TemporalExtractionConfig(BaseModel): - """Configuration for temporal extraction behavior. - - Attributes: - temperature: LLM temperature for temporal extraction (0-2, default: 0.1) - """ - temperature: Optional[float] = Field(0.1, ge=0, le=2, description="LLM temperature for temporal extraction") - - -class DedupConfig(BaseModel): - """Configuration for entity deduplication behavior. - - This configuration controls the multi-stage deduplication process, - including fuzzy matching, LLM-based deduplication, and disambiguation. - - Attributes: - enable_llm_dedup_blockwise: Enable blockwise LLM-driven deduplication (default: False) - enable_llm_disambiguation: Enable LLM disambiguation for same-name different-type entities (default: False) - enable_llm_fallback_only_on_borderline: Only trigger LLM when borderline pairs exist (default: True) - fuzzy_name_threshold_strict: Strict threshold for name similarity (0-1, default: 0.90) - fuzzy_type_threshold_strict: Strict threshold for type similarity (0-1, default: 0.75) - fuzzy_overall_threshold: Overall similarity threshold to merge (0-1, default: 0.82) - fuzzy_unknown_type_name_threshold: Name threshold when entity type is UNKNOWN (0-1, default: 0.92) - fuzzy_unknown_type_type_threshold: Type threshold when entity type is UNKNOWN (0-1, default: 0.50) - name_weight: Weight of name similarity in overall score (0-1, default: 0.50) - desc_weight: Weight of description similarity in overall score (0-1, default: 0.30) - type_weight: Weight of type similarity in overall score (0-1, default: 0.20) - context_bonus: Bonus when entities co-occur in same statements (0-0.2, default: 0.03) - llm_fallback_floor: Lower bound for borderline score (0-1, default: 0.76) - llm_fallback_ceiling: Upper bound for borderline score (0-1, default: 0.82) - llm_block_size: Entities per block for LLM dedup (1-500, default: 50) - llm_block_concurrency: Concurrent blocks processed by LLM (1-64, default: 4) - llm_pair_concurrency: Concurrent pairwise decisions per block (1-64, default: 4) - llm_max_rounds: Maximum LLM iterative dedup rounds (1-10, default: 3) - """ - # LLM deduplication toggles - enable_llm_dedup_blockwise: bool = Field(False, description="Toggle blockwise LLM-driven deduplication") - enable_llm_disambiguation: bool = Field(False, description="Toggle LLM-driven disambiguation for same-name different-type entities") - enable_llm_fallback_only_on_borderline: bool = Field(True, description="Trigger LLM dedup only when borderline pairs are detected in fuzzy stage") - - # Fuzzy match thresholds - fuzzy_name_threshold_strict: float = Field(0.90, ge=0, le=1, description="Strict threshold for name similarity") - fuzzy_type_threshold_strict: float = Field(0.75, ge=0, le=1, description="Strict threshold for type similarity") - fuzzy_overall_threshold: float = Field(0.82, ge=0, le=1, description="Overall similarity threshold to merge") - - # Specialized thresholds when type is UNKNOWN - fuzzy_unknown_type_name_threshold: float = Field(0.92, ge=0, le=1, description="Name threshold when any entity type is UNKNOWN") - fuzzy_unknown_type_type_threshold: float = Field(0.50, ge=0, le=1, description="Type threshold when any entity type is UNKNOWN") - - # Weighted scoring components for overall similarity - name_weight: float = Field(0.50, ge=0, le=1, description="Weight of name similarity in overall score") - desc_weight: float = Field(0.30, ge=0, le=1, description="Weight of description similarity in overall score") - type_weight: float = Field(0.20, ge=0, le=1, description="Weight of type similarity in overall score") - context_bonus: float = Field(0.03, ge=0, le=0.2, description="Bonus added to score when entities co-occur in same statements") - - # Borderline range for LLM fallback triggering - llm_fallback_floor: float = Field(0.76, ge=0, le=1, description="Lower bound of overall score to consider as borderline for LLM fallback") - llm_fallback_ceiling: float = Field(0.82, ge=0, le=1, description="Upper bound (below merge threshold) of overall score for LLM fallback") - - # LLM iterative dedup parameters - llm_block_size: int = Field(50, ge=1, le=500, description="Entities per block for LLM dedup") - llm_block_concurrency: int = Field(4, ge=1, le=64, description="Concurrent blocks processed by LLM") - llm_pair_concurrency: int = Field(4, ge=1, le=64, description="Concurrent pairwise decisions per block") - llm_max_rounds: int = Field(3, ge=1, le=10, description="Maximum LLM iterative dedup rounds") - - -class ExtractionPipelineConfig(BaseModel): - """Configuration for the entire extraction pipeline. - - This model combines all configuration components for the complete - extraction pipeline, including statement extraction, triplet extraction, - temporal extraction, deduplication, and forgetting mechanisms. - - Attributes: - statement_extraction: Configuration for statement extraction - triplet_extraction: Configuration for triplet extraction - temporal_extraction: Configuration for temporal extraction - deduplication: Configuration for entity deduplication - forgetting_engine: Configuration for forgetting engine - """ - statement_extraction: StatementExtractionConfig = Field(default_factory=StatementExtractionConfig) - triplet_extraction: TripletExtractionConfig = Field(default_factory=TripletExtractionConfig) - temporal_extraction: TemporalExtractionConfig = Field(default_factory=TemporalExtractionConfig) - deduplication: DedupConfig = Field(default_factory=DedupConfig) - forgetting_engine: ForgettingEngineConfig = Field(default_factory=ForgettingEngineConfig) diff --git a/app/core/memory/src/__init__.py b/app/core/memory/src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/memory/src/llm_tools/__init__.py b/app/core/memory/src/llm_tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/memory/src/llm_tools/chunker_client.py b/app/core/memory/src/llm_tools/chunker_client.py deleted file mode 100644 index 780f3345..00000000 --- a/app/core/memory/src/llm_tools/chunker_client.py +++ /dev/null @@ -1,330 +0,0 @@ -from typing import Any, List -import re -import os -import asyncio -import json -import numpy as np - -# Fix tokenizer parallelism warning -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -from chonkie import ( - SemanticChunker, - RecursiveChunker, - RecursiveRules, - LateChunker, - NeuralChunker, - SentenceChunker, - TokenChunker, -) - -from app.core.memory.models.config_models import ChunkerConfig -from app.core.memory.models.message_models import DialogData, Chunk -try: - from app.core.memory.src.llm_tools.openai_client import OpenAIClient -except Exception: - # 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入 - OpenAIClient = Any - - -class LLMChunker: - """基于LLM的智能分块策略""" - def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000): - self.llm_client = llm_client - self.chunk_size = chunk_size - - async def __call__(self, text: str) -> List[Any]: - # 使用LLM分析文本结构并进行智能分块 - prompt = f""" - 请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。 - 请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。 - - 文本内容: - {text[:5000]} - """ - - messages = [ - {"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"}, - {"role": "user", "content": prompt} - ] - - try: - # 使用异步的 achat 方法 - if hasattr(self.llm_client, 'achat'): - response = await self.llm_client.achat(messages) - else: - # 如果没有异步方法,使用同步方法并转换为异步 - response = await asyncio.to_thread(self.llm_client.chat, messages) - - # 检查响应格式并提取内容 - if hasattr(response, 'choices') and len(response.choices) > 0: - content = response.choices[0].message.content - elif hasattr(response, 'content'): - content = response.content - else: - content = str(response) - - # 解析LLM响应 - if "```json" in content: - json_str = content.split("```json")[1].split("```")[0].strip() - elif "```" in content: - json_str = content.split("```")[1].split("```")[0].strip() - else: - json_str = content - - result = json.loads(json_str) - - class SimpleChunk: - def __init__(self, text, index): - self.text = text - self.start_index = index * 100 # 近似位置 - self.end_index = (index + 1) * 100 - - return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))] - - except Exception as e: - print(f"LLM分块失败: {e}") - # 失败时返回空列表,外层会处理回退方案 - return [] - - -class HybridChunker: - """混合分块策略:先按结构分块,再按语义合并""" - def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300): - self.semantic_threshold = semantic_threshold - self.base_chunk_size = base_chunk_size - self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size) - self.semantic_chunker = SemanticChunker(threshold=semantic_threshold) - - def __call__(self, text: str) -> List[Any]: - # 先用基础分块 - base_chunks = self.base_chunker(text) - - # 如果文本不长,直接返回基础分块 - if len(base_chunks) <= 3: - return base_chunks - - # 对基础分块进行语义合并 - combined_text = " ".join([chunk.text for chunk in base_chunks]) - return self.semantic_chunker(combined_text) - - -class ChunkerClient: - def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None): - self.chunker_config = chunker_config - self.embedding_model = chunker_config.embedding_model - self.chunk_size = chunker_config.chunk_size - self.threshold = chunker_config.threshold - self.language = chunker_config.language - self.skip_window = chunker_config.skip_window - self.min_sentences = chunker_config.min_sentences - self.min_characters_per_chunk = chunker_config.min_characters_per_chunk - self.llm_client = llm_client - - # 可选参数(从配置中安全获取,提供默认值) - self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0) - self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1) - self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12) - self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"]) - self.include_delim = getattr(chunker_config, 'include_delim', "prev") - self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character") - - # 初始化具体分块器策略 - if chunker_config.chunker_strategy == "TokenChunker": - self.chunker = TokenChunker( - tokenizer=self.tokenizer_or_token_counter, - chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, - ) - elif chunker_config.chunker_strategy == "SemanticChunker": - self.chunker = SemanticChunker( - embedding_model=self.embedding_model, - threshold=self.threshold, - chunk_size=self.chunk_size, - min_sentences=self.min_sentences, - ) - elif chunker_config.chunker_strategy == "RecursiveChunker": - self.chunker = RecursiveChunker( - rules=RecursiveRules(), - min_characters_per_chunk=self.min_characters_per_chunk or 50, - chunk_size=self.chunk_size, - ) - elif chunker_config.chunker_strategy == "LateChunker": - self.chunker = LateChunker( - embedding_model=self.embedding_model, - chunk_size=self.chunk_size, - rules=RecursiveRules(), - min_characters_per_chunk=self.min_characters_per_chunk, - ) - elif chunker_config.chunker_strategy == "NeuralChunker": - self.chunker = NeuralChunker( - model=self.embedding_model, - min_characters_per_chunk=self.min_characters_per_chunk, - ) - elif chunker_config.chunker_strategy == "LLMChunker": - if not llm_client: - raise ValueError("LLMChunker requires an LLM client") - self.chunker = LLMChunker(llm_client, self.chunk_size) - elif chunker_config.chunker_strategy == "HybridChunker": - self.chunker = HybridChunker( - semantic_threshold=self.threshold, - base_chunk_size=self.chunk_size, - ) - elif chunker_config.chunker_strategy == "SentenceChunker": - # 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数 - # 为了兼容不同版本,这里仅传递广泛支持的参数 - self.chunker = SentenceChunker( - chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, - min_sentences_per_chunk=self.min_sentences_per_chunk, - min_characters_per_sentence=self.min_characters_per_sentence, - delim=self.delim, - include_delim=self.include_delim, - ) - else: - raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}") - - async def generate_chunks(self, dialogue: DialogData): - """ - 生成分块,支持异步操作 - """ - try: - # 预处理文本:确保对话标记格式统一 - content = dialogue.content - content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号 - content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行 - - if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__): - # 同步分块器 - chunks = self.chunker(content) - else: - # 异步分块器(如LLMChunker) - chunks = await self.chunker(content) - - # 过滤空块和过小的块 - valid_chunks = [] - for c in chunks: - chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c - if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50): - valid_chunks.append(c) - - dialogue.chunks = [ - Chunk( - content=c.text if hasattr(c, 'text') else str(c), - metadata={ - "start_index": getattr(c, "start_index", None), - "end_index": getattr(c, "end_index", None), - "chunker_strategy": self.chunker_config.chunker_strategy, - }, - ) - for c in valid_chunks - ] - return dialogue - - except Exception as e: - print(f"分块失败: {e}") - - # 改进的后备方案:尝试按对话回合分割 - try: - # 简单的按对话分割 - dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)' - matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL) - - class SimpleChunk: - def __init__(self, text, start_index, end_index): - self.text = text - self.start_index = start_index - self.end_index = end_index - - chunks = [] - current_chunk = "" - current_start = 0 - - for match in matches: - speaker, ct = match[0], match[1].strip() - turn_text = f"{speaker} {ct}" - - if len(current_chunk) + len(turn_text) > (self.chunk_size or 500): - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - current_chunk = turn_text - current_start = dialogue.content.find(turn_text, current_start) - else: - current_chunk += ("\n" + turn_text) if current_chunk else turn_text - - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - - dialogue.chunks = [ - Chunk( - content=c.text, - metadata={ - "start_index": c.start_index, - "end_index": c.end_index, - "chunker_strategy": "DialogueTurnFallback", - }, - ) - for c in chunks - ] - - except Exception: - # 最后的手段:单一大块 - dialogue.chunks = [Chunk( - content=dialogue.content, - metadata={"chunker_strategy": "SingleChunkFallback"}, - )] - - return dialogue - - def evaluate_chunking(self, dialogue: DialogData) -> dict: - """ - 评估分块质量 - """ - if not getattr(dialogue, 'chunks', None): - return {} - - chunks = dialogue.chunks - total_chars = sum(len(chunk.content) for chunk in chunks) - avg_chunk_size = total_chars / len(chunks) - - # 计算各种指标 - chunk_sizes = [len(chunk.content) for chunk in chunks] - - metrics = { - "strategy": self.chunker_config.chunker_strategy, - "num_chunks": len(chunks), - "total_characters": total_chars, - "avg_chunk_size": avg_chunk_size, - "min_chunk_size": min(chunk_sizes), - "max_chunk_size": max(chunk_sizes), - "chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0, - "coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0, - } - - return metrics - - def save_chunking_results(self, dialogue: DialogData, output_path: str): - """ - 保存分块结果到文件,文件名包含策略名称 - """ - strategy_name = self.chunker_config.chunker_strategy - # 在文件名中添加策略名称 - base_name, ext = os.path.splitext(output_path) - strategy_output_path = f"{base_name}_{strategy_name}{ext}" - - with open(strategy_output_path, 'w', encoding='utf-8') as f: - f.write(f"=== Chunking Strategy: {strategy_name} ===\n") - f.write(f"Total chunks: {len(dialogue.chunks)}\n") - f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n") - f.write("=" * 60 + "\n\n") - - for i, chunk in enumerate(dialogue.chunks): - f.write(f"Chunk {i+1}:\n") - f.write(f"Size: {len(chunk.content)} characters\n") - if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata: - f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n") - f.write(f"Content: {chunk.content}\n") - f.write("-" * 40 + "\n\n") - - print(f"Chunking results saved to: {strategy_output_path}") - return strategy_output_path diff --git a/app/core/memory/src/llm_tools/embedder_client.py b/app/core/memory/src/llm_tools/embedder_client.py deleted file mode 100644 index f1033fc6..00000000 --- a/app/core/memory/src/llm_tools/embedder_client.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List - -from app.core.models.base import RedBearModelConfig -class EmbedderClient(ABC): - def __init__(self, model_config: RedBearModelConfig): - self.config = model_config - - self.model_name = model_config.model_name - self.provider = model_config.provider - self.api_key = model_config.api_key - self.base_url = model_config.base_url - self.max_retries = model_config.max_retries - # self.dimension = model_config.dimension - - - @abstractmethod - async def response( - self, - messages: List[str], - ) -> List[str]: - pass diff --git a/app/core/memory/src/llm_tools/llm_client.py b/app/core/memory/src/llm_tools/llm_client.py deleted file mode 100644 index 8925de6a..00000000 --- a/app/core/memory/src/llm_tools/llm_client.py +++ /dev/null @@ -1,37 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Dict, Any -from pydantic import BaseModel -from app.core.memory.models.config_models import LLMConfig - -""" - model_name: str - provider: str - api_key: str - base_url: Optional[str] = None - timeout: float = 30.0 # 请求超时时间(秒) - max_retries: int = 3 # 最大重试次数 - concurrency: int = 5 # 并发限流 - extra_params: Dict[str, Any] = {} -""" -from app.core.models.base import RedBearModelConfig -class LLMClient(ABC): - def __init__(self, model_config: RedBearModelConfig): - self.config = model_config - - self.model_name = self.config.model_name - self.provider = self.config.provider - self.api_key = self.config.api_key - self.base_url = self.config.base_url - self.max_retries = self.config.max_retries - - @abstractmethod - def chat(self, messages: List[Dict[str, str]]) -> Any: - pass - - @abstractmethod - async def response_structured( - self, - messages: List[Dict[str, str]], - response_model: type[BaseModel], - ) -> type[BaseModel]: - pass diff --git a/app/core/memory/src/llm_tools/openai_client.py b/app/core/memory/src/llm_tools/openai_client.py deleted file mode 100644 index dcb9da27..00000000 --- a/app/core/memory/src/llm_tools/openai_client.py +++ /dev/null @@ -1,224 +0,0 @@ -import asyncio -from typing import List, Dict, Any -import json - -from pydantic import BaseModel -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.output_parsers import PydanticOutputParser - -from app.core.models.base import RedBearModelConfig -from app.core.models.llm import RedBearLLM -from app.core.memory.src.llm_tools.llm_client import LLMClient -# from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED -LANGFUSE_ENABLED=False - -class OpenAIClient(LLMClient): - def __init__(self, model_config: RedBearModelConfig, type_: str = "chat"): - super().__init__(model_config) - - # Initialize Langfuse callback handler if enabled - self.langfuse_handler = None - if LANGFUSE_ENABLED: - try: - from langfuse.langchain import CallbackHandler - self.langfuse_handler = CallbackHandler() - except ImportError: - # Langfuse not installed, continue without tracing - pass - except Exception as e: - # Log error but don't fail initialization - import logging - logging.warning(f"Failed to initialize Langfuse handler: {e}") - - # Initialize RedBearLLM client - self.client = RedBearLLM(RedBearModelConfig( - model_name=self.model_name, - provider=self.provider, - api_key=self.api_key, - base_url=self.base_url, - max_retries=self.max_retries, - ), type=type_) - - async def chat(self, messages: List[Dict[str, str]]) -> Any: - template = """{messages}""" - # ChatPromptTemplate - prompt = ChatPromptTemplate.from_template(template) - chain = prompt | self.client - - # Add Langfuse callback if available - config = {} - if self.langfuse_handler: - config["callbacks"] = [self.langfuse_handler] - - response = await chain.ainvoke({"messages": messages}, config=config) - # print(f"OpenAIClient response ======>:\n {response}") - return response - - async def response_structured( - self, - messages: List[Dict[str, str]], - response_model: type[BaseModel], - ) -> type[BaseModel]: - # Build a simple prompt pipeline that sends messages to the underlying LLM - question_text = "\n\n".join([str(m.get("content", "")) for m in messages]) - - # Prepare config with Langfuse callback if available - config = {} - if self.langfuse_handler: - config["callbacks"] = [self.langfuse_handler] - - # Primary: enforce schema with PydanticOutputParser if available - if PydanticOutputParser is not None: - try: - import logging - logger = logging.getLogger(__name__) - # 使用正确的属性路径:self.config.timeout(从LLMClient基类继承) - # logger.info(f"开始LLM结构化输出请求 (模型: {self.model_name}, 超时: {self.config.timeout}秒)") - - parser = PydanticOutputParser(pydantic_object=response_model) - format_instructions = parser.get_format_instructions() - prompt = ChatPromptTemplate.from_template("{question}\n{format_instructions}") - chain = prompt | self.client | parser - parsed = await chain.ainvoke({ - "question": question_text, - "format_instructions": format_instructions, - }) - - # logger.info(f"LLM结构化输出请求成功完成") - return parsed - except Exception as e: - import logging - logger = logging.getLogger(__name__) - logger.warning(f"PydanticOutputParser失败,尝试备用方法: {str(e)}") - # Fall through to alternative structured methods - pass - - # Fallback path: create plain prompt for other structured methods - template = """{question}""" - prompt = ChatPromptTemplate.from_template(template) - - # Try LangChain structured output if available on the underlying client - try: - with_so = getattr(self.client, "with_structured_output", None) - - if callable(with_so): - try: - structured_chain = prompt | with_so(response_model, strict=True) - parsed = await structured_chain.ainvoke({"question": question_text}, config=config) - # parsed may already be a pydantic model or a dict - try: - return response_model.model_validate(parsed) - except Exception: - try: - # If it's already a pydantic instance (LangChain returns model), return it - if hasattr(parsed, "model_dump"): - return parsed - return response_model.model_validate_json(json.dumps(parsed)) - except Exception: - # Fall through to manual parsing below - pass - except NotImplementedError: - # The underlying model doesn't support structured output, fall through - import logging - logger = logging.getLogger(__name__) - logger.warning( - f"Model {self.model_name} doesn't support with_structured_output, falling back to manual parsing") - pass - except Exception as e: - import logging - logger = logging.getLogger(__name__) - logger.warning(f"Structured output attempt failed: {e}, falling back to manual parsing") - - # Final fallback: manual parsing with plain LLM response - try: - import logging - logger = logging.getLogger(__name__) - logger.info(f"Using manual parsing fallback for model {self.model_name}") - - # Create a prompt that asks for JSON output - json_prompt = ChatPromptTemplate.from_template( - "{question}\n\n" - "Please respond with a valid JSON object that matches this schema:\n" - "{schema}\n\n" - "Response (JSON only):" - ) - - # Get the schema from the response model - schema = response_model.model_json_schema() - - chain = json_prompt | self.client - response = await chain.ainvoke({ - "question": question_text, - "schema": json.dumps(schema, indent=2) - }, config=config) - - # Extract JSON from response - response_text = str(response.content if hasattr(response, 'content') else response) - - # Try to find JSON in the response - import re - json_match = re.search(r'\{.*\}', response_text, re.DOTALL) - if json_match: - json_str = json_match.group(0) - try: - parsed_dict = json.loads(json_str) - return response_model.model_validate(parsed_dict) - except json.JSONDecodeError: - pass - - # If JSON parsing fails, try to create a minimal valid response - logger.warning(f"Failed to parse JSON from LLM response, creating minimal response") - - # Create a minimal response based on the schema - return self._create_minimal_response(response_model) - - except Exception as fallback_error: - import logging - logger = logging.getLogger(__name__) - logger.error(f"Manual parsing fallback also failed: {fallback_error}") - # Return minimal response as last resort - return self._create_minimal_response(response_model) - - def _create_minimal_response(self, response_model: type[BaseModel]) -> BaseModel: - """Create a minimal valid response based on the model schema.""" - try: - minimal_response = {} - - for field_name, field_info in response_model.model_fields.items(): - # Check if field has a default value - if hasattr(field_info, 'default') and field_info.default is not None: - minimal_response[field_name] = field_info.default - else: - # Create default based on field type - field_type = field_info.annotation - - # Handle nested BaseModel - if hasattr(field_type, '__bases__') and BaseModel in field_type.__bases__: - minimal_response[field_name] = self._create_minimal_response(field_type) - elif field_type == str: - minimal_response[field_name] = "信息不足,无法回答" - elif field_type == int: - minimal_response[field_name] = 0 - elif field_type == float: - minimal_response[field_name] = 0.0 - elif field_type == bool: - minimal_response[field_name] = False - elif field_type == list: - minimal_response[field_name] = [] - elif field_type == dict: - minimal_response[field_name] = {} - else: - minimal_response[field_name] = None - - return response_model.model_validate(minimal_response) - - except Exception as e: - import logging - logger = logging.getLogger(__name__) - logger.error(f"Failed to create minimal response: {e}") - # Last resort: try to create with just required fields - try: - return response_model() - except Exception: - # If even that fails, raise the original error - raise ValueError(f"Unable to create minimal response for {response_model.__name__}") from e diff --git a/app/core/memory/src/llm_tools/openai_embedder.py b/app/core/memory/src/llm_tools/openai_embedder.py deleted file mode 100644 index 427c38a0..00000000 --- a/app/core/memory/src/llm_tools/openai_embedder.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import List - -from app.core.memory.src.llm_tools.embedder_client import EmbedderClient -from app.core.models.base import RedBearModelConfig -# from app.models.models_model import ModelType -from app.core.models.embedding import RedBearEmbeddings - - -class OpenAIEmbedderClient(EmbedderClient): - def __init__(self, model_config: RedBearModelConfig): - super().__init__(model_config) - - async def response( - self, - messages: List[str], - ) -> List[List[float]]: - texts: List[str] = [str(m) for m in messages if m is not None] - - model = RedBearEmbeddings(RedBearModelConfig( - model_name=self.model_name, - provider=self.provider, - api_key=self.api_key, - base_url=self.base_url, - )) - embeddings = await model.aembed_documents(texts) - return embeddings diff --git a/app/core/memory/src/search.py b/app/core/memory/src/search.py deleted file mode 100644 index 685d038f..00000000 --- a/app/core/memory/src/search.py +++ /dev/null @@ -1,980 +0,0 @@ -import argparse -import asyncio -import json -import os -import time -from typing import List, Dict, Any, Optional -from dotenv import load_dotenv -from datetime import datetime -import math -from app.core.logging_config import get_memory_logger -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.graph_search import ( - search_graph_by_embedding, search_graph, - search_graph_by_temporal, search_graph_by_keyword_temporal, - search_graph_by_chunk_id -) -from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.models.config_models import TemporalSearchParams -from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config -from app.core.memory.utils.data.time_utils import normalize_date_safe -from app.core.memory.models.variate_config import ForgettingEngineConfig -from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG -from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine -from app.core.memory.utils.data.text_utils import extract_plain_query -from app.core.memory.utils.config import definitions as config_defs -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.llm.llm_utils import get_reranker_client -load_dotenv() - -logger = get_memory_logger(__name__) - -def _parse_datetime(value: Any) -> Optional[datetime]: - """Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'.""" - if value is None: - return None - if isinstance(value, datetime): - return value - if isinstance(value, str): - s = value.strip() - if not s: - return None - try: - return datetime.fromisoformat(s) - except Exception: - return None - return None - - -def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") -> List[Dict[str, Any]]: - """Normalize scores using z-score normalization followed by sigmoid transformation.""" - if not results: - return results - - # Extract scores, ensuring they are numeric and not None - scores = [] - for item in results: - if score_field in item: - score = item.get(score_field) - if score is not None and isinstance(score, (int, float)): - scores.append(float(score)) - else: - scores.append(0.0) # Default for None or non-numeric values - - if not scores: - return results - - if len(scores) == 1: - # Single score, set to 1.0 - for item in results: - if score_field in item: - item[f"normalized_{score_field}"] = 1.0 - return results - - # Calculate mean and standard deviation - mean_score = sum(scores) / len(scores) - variance = sum((score - mean_score) ** 2 for score in scores) / len(scores) - std_dev = math.sqrt(variance) - - if std_dev == 0: - # All scores are the same, set them to 1.0 - for item in results: - if score_field in item: - item[f"normalized_{score_field}"] = 1.0 - else: - for item in results: - if score_field in item: - score = item[score_field] - # Handle None or non-numeric scores - if score is None or not isinstance(score, (int, float)): - score = 0.0 - # Calculate z-score - z_score = (score - mean_score) / std_dev - # Transform to positive range using sigmoid function - normalized = 1 / (1 + math.exp(-z_score)) - item[f"normalized_{score_field}"] = normalized - - return results - - -def rerank_hybrid_results( - keyword_results: Dict[str, List[Dict[str, Any]]], - embedding_results: Dict[str, List[Dict[str, Any]]], - alpha: float = 0.6, - limit: int = 10 -) -> Dict[str, List[Dict[str, Any]]]: - """ - Rerank hybrid search results by combining BM25 and embedding scores. - - Args: - keyword_results: Results from keyword/BM25 search - embedding_results: Results from embedding search - alpha: Weight for BM25 scores (1-alpha for embedding scores) - limit: Maximum number of results to return per category - - Returns: - Reranked results with combined scores - """ - reranked = {} - - for category in ["statements", "chunks", "entities","summaries"]: - keyword_items = keyword_results.get(category, []) - embedding_items = embedding_results.get(category, []) - - # Normalize scores within each search type - keyword_items = normalize_scores(keyword_items, "score") - embedding_items = normalize_scores(embedding_items, "score") - - # Create a combined pool of unique items - combined_items = {} - - # Add keyword results with BM25 scores - for item in keyword_items: - item_id = item.get("id") or item.get("uuid") - if item_id: - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) - combined_items[item_id]["embedding_score"] = 0 # Default - - # Add or update with embedding results - for item in embedding_items: - item_id = item.get("id") or item.get("uuid") - if item_id: - if item_id in combined_items: - # Update existing item with embedding score - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - else: - # New item from embedding search only - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = 0 # Default - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - - # Calculate combined scores and rank - for item_id, item in combined_items.items(): - bm25_score = item.get("bm25_score", 0) - embedding_score = item.get("embedding_score", 0) - - # Combined score: weighted average of normalized scores - combined_score = alpha * bm25_score + (1 - alpha) * embedding_score - item["combined_score"] = combined_score - - # Keep original score for reference - if "score" not in item and bm25_score > 0: - item["score"] = bm25_score - elif "score" not in item and embedding_score > 0: - item["score"] = embedding_score - - # Sort by combined score and limit results - sorted_items = sorted( - combined_items.values(), - key=lambda x: x.get("combined_score", 0), - reverse=True - )[:limit] - - reranked[category] = sorted_items - - return reranked - -def rerank_with_forgetting_curve( - keyword_results: Dict[str, List[Dict[str, Any]]], - embedding_results: Dict[str, List[Dict[str, Any]]], - alpha: float = 0.6, - limit: int = 10, - forgetting_config: ForgettingEngineConfig | None = None, - now: datetime | None = None, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Rerank hybrid results with a forgetting curve applied to combined scores. - - The forgetting curve reduces scores for older memories or weaker connections. - - Args: - keyword_results: Results from keyword/BM25 search - embedding_results: Results from embedding search - alpha: Weight for BM25 scores (1-alpha for embedding scores) - limit: Maximum number of results to return per category - forgetting_config: Configuration for the forgetting engine - now: Optional current time override for testing - - Returns: - Reranked results with combined and final scores (after forgetting) - """ - engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig()) - now_dt = now or datetime.now() - - reranked: Dict[str, List[Dict[str, Any]]] = {} - - for category in ["statements", "chunks", "entities","summaries"]: - keyword_items = keyword_results.get(category, []) - embedding_items = embedding_results.get(category, []) - - # Normalize scores within each search type - keyword_items = normalize_scores(keyword_items, "score") - embedding_items = normalize_scores(embedding_items, "score") - - combined_items: Dict[str, Dict[str, Any]] = {} - - # Combine two result sets by ID - for src_items, is_embedding in ( - (keyword_items, False), (embedding_items, True) - ): - for item in src_items: - item_id = item.get("id") or item.get("uuid") - if not item_id: - continue - existing = combined_items.get(item_id) - if not existing: - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = 0 - combined_items[item_id]["embedding_score"] = 0 - # Update normalized score from the right source - if is_embedding: - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - else: - combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) - - # Calculate scores and apply forgetting weights - for item_id, item in combined_items.items(): - bm25_score = float(item.get("bm25_score", 0) or 0) - embedding_score = float(item.get("embedding_score", 0) or 0) - combined_score = alpha * bm25_score + (1 - alpha) * embedding_score - - # Estimate time elapsed in days - dt = _parse_datetime(item.get("created_at")) - if dt is None: - time_elapsed_days = 0.0 - else: - time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) - - # Memory strength (currently set to default value) - memory_strength = 1.0 - forgetting_weight = engine.calculate_weight( - time_elapsed=time_elapsed_days, memory_strength=memory_strength - ) - # print(f"Forgetting weight for {item_id}: {forgetting_weight}") - # print(f"Time elapsed days for {item_id}: {time_elapsed_days}") - final_score = combined_score * forgetting_weight - item["combined_score"] = final_score - - sorted_items = sorted( - combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True - )[:limit] - - reranked[category] = sorted_items - - return reranked - - -def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = "search_log.txt"): - """Log search query information to file""" - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - # Ensure the query text is plain and clean before logging - cleaned_query = extract_plain_query(query_text) - log_entry = { - "timestamp": timestamp, - # "query": query_text, - "query": cleaned_query, - "search_type": search_type, - "group_id": group_id, - "limit": limit, - "include": include - } - - # Append to log file - with open(log_file, "a", encoding="utf-8") as f: - f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") - - logger.info(f"Search logged: {query_text} ({search_type})") - - -def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any: - """Remove specified keys recursively from dict/list structures (in place).""" - try: - if isinstance(obj, dict): - for k in keys_to_remove: - if k in obj: - obj.pop(k, None) - for v in list(obj.values()): - _remove_keys_recursive(v, keys_to_remove) - elif isinstance(obj, list): - for item in obj: - _remove_keys_recursive(item, keys_to_remove) - except Exception: - # Be defensive: never fail search because of sanitization - pass - return obj - - -def apply_reranker_placeholder( - results: Dict[str, List[Dict[str, Any]]], - query_text: str, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Placeholder for a cross-encoder reranker. - If config enables reranker, annotate items with a final_score equal to combined_score - and keep ordering. This is a no-op reranker to be replaced later. - """ - try: - rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})) - except Exception as e: - logger.debug(f"Failed to load reranker config: {e}") - rc = {} - if not rc or not rc.get("enabled", False): - return results - - top_k = int(rc.get("top_k", 100)) - model_name = rc.get("model", "placeholder") - - for cat, items in results.items(): - head = items[:top_k] - for it in head: - base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0) - it["final_score"] = base - it["reranker_model"] = model_name - # Keep overall order by final_score if present, otherwise combined/score - results[cat] = sorted( - items, - key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)), - reverse=True, - ) - return results - - -async def apply_llm_reranker( - results: Dict[str, List[Dict[str, Any]]], - query_text: str, - reranker_client: Optional[Any] = None, - llm_weight: Optional[float] = None, - top_k: Optional[int] = None, - batch_size: Optional[int] = None, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Apply LLM-based reranking to search results. - - Args: - results: Search results organized by category - query_text: Original search query - reranker_client: Optional pre-initialized reranker client - llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM) - top_k: Maximum number of items to rerank per category - batch_size: Number of items to process concurrently - - Returns: - Reranked results with final_score and reranker_model fields - """ - # Load reranker configuration from runtime.json - try: - rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}) - except Exception as e: - logger.debug(f"Failed to load reranker config: {e}") - rc = {} - - # Check if reranking is enabled - enabled = rc.get("enabled", False) - if not enabled: - logger.debug("LLM reranking is disabled in configuration") - return results - - # Load configuration parameters with defaults - llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5) - top_k = top_k if top_k is not None else rc.get("top_k", 20) - batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5) - - # Initialize reranker client if not provided - if reranker_client is None: - try: - reranker_client = get_reranker_client() - except Exception as e: - logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking") - return results - - # Get model name for metadata - model_name = getattr(reranker_client, 'model_name', 'unknown') - - # Process each category - reranked_results = {} - for category in ["statements", "chunks", "entities", "summaries"]: - items = results.get(category, []) - if not items: - reranked_results[category] = [] - continue - - # Select top K items by combined_score for reranking - sorted_items = sorted( - items, - key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0), - reverse=True - ) - - top_items = sorted_items[:top_k] - remaining_items = sorted_items[top_k:] - - # Extract text content from each item - def extract_text(item: Dict[str, Any]) -> str: - """Extract text content from a result item.""" - # Try different text fields based on category - text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or "" - return str(text).strip() - - # Batch items for concurrent processing - batches = [] - for i in range(0, len(top_items), batch_size): - batch = top_items[i:i + batch_size] - batches.append(batch) - - # Process batches concurrently - async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Process a batch of items with LLM relevance scoring.""" - scored_batch = [] - - for item in batch: - item_text = extract_text(item) - - # Skip items with no text - if not item_text: - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item_copy["final_score"] = combined_score - item_copy["llm_relevance_score"] = 0.0 - item_copy["reranker_model"] = model_name - scored_batch.append(item_copy) - continue - - # Create relevance scoring prompt - prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0. - -Query: {query_text} - -Result: {item_text} - -Respond with only a number between 0.0 and 1.0, where: -- 0.0 means completely irrelevant -- 1.0 means perfectly relevant - -Relevance score:""" - - # Send request to LLM - try: - messages = [{"role": "user", "content": prompt}] - response = await reranker_client.chat(messages) - - # Parse LLM response to extract relevance score - response_text = str(response.content if hasattr(response, 'content') else response).strip() - - # Try to extract a float from the response - try: - # Remove any non-numeric characters except decimal point - import re - score_match = re.search(r'(\d+\.?\d*)', response_text) - if score_match: - llm_score = float(score_match.group(1)) - # Clamp to [0.0, 1.0] - llm_score = max(0.0, min(1.0, llm_score)) - else: - raise ValueError("No numeric score found in response") - except (ValueError, AttributeError) as e: - logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}") - llm_score = None - - # Calculate final score - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - - if llm_score is not None: - final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score - item_copy["llm_relevance_score"] = llm_score - else: - # Use combined_score as fallback - final_score = combined_score - item_copy["llm_relevance_score"] = combined_score - - item_copy["final_score"] = final_score - item_copy["reranker_model"] = model_name - scored_batch.append(item_copy) - except Exception as e: - logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score") - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item_copy["final_score"] = combined_score - item_copy["llm_relevance_score"] = combined_score - item_copy["reranker_model"] = model_name - scored_batch.append(item_copy) - - return scored_batch - - # Process all batches concurrently - try: - batch_tasks = [process_batch(batch) for batch in batches] - batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) - - # Merge batch results - scored_items = [] - for result in batch_results: - if isinstance(result, Exception): - logger.warning(f"Batch processing failed: {result}") - continue - scored_items.extend(result) - - # Add remaining items (not in top K) with their combined_score as final_score - for item in remaining_items: - item_copy = item.copy() - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item_copy["final_score"] = combined_score - item_copy["reranker_model"] = model_name - scored_items.append(item_copy) - - # Sort all items by final_score in descending order - scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True) - reranked_results[category] = scored_items - - except Exception as e: - logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results") - # Return original items with combined_score as final_score - for item in items: - combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - item["final_score"] = combined_score - item["reranker_model"] = model_name - reranked_results[category] = items - - return reranked_results - - -async def run_hybrid_search( - query_text: str, - search_type: str, - group_id: str | None, - limit: int, - include: List[str], - output_path: str | None, - rerank_alpha: float = 0.6, - use_forgetting_rerank: bool = False, - use_llm_rerank: bool = False, -): - """ - - Run search with specified type: 'keyword', 'embedding', or 'hybrid' - """ - # Start overall timing - search_start_time = time.time() - latency_metrics = {} - - # Clean and normalize the incoming query before use/logging - query_text = extract_plain_query(query_text) - - # Validate query is not empty after cleaning - if not query_text or not query_text.strip(): - logger.warning(f"Empty query after cleaning, returning empty results") - return { - "keyword_search": {}, - "embedding_search": {}, - "reranked_results": {}, - "combined_summary": { - "total_keyword_results": 0, - "total_embedding_results": 0, - "total_reranked_results": 0, - "search_query": "", - "search_timestamp": datetime.now().isoformat(), - "error": "Empty query" - } - } - - # Log the search query - log_search_query(query_text, search_type, group_id, limit, include) - - connector = Neo4jConnector() - results = {} - - try: - keyword_task = None - embedding_task = None - - if search_type in ["keyword", "hybrid"]: - # Keyword-based search - logger.info("Starting keyword search...") - keyword_start = time.time() - keyword_task = asyncio.create_task( - search_graph( - connector=connector, - q=query_text, - group_id=group_id, - limit=limit, - include=include - ) - ) - - if search_type in ["embedding", "hybrid"]: - # Embedding-based search - logger.info("Starting embedding search...") - embedding_start = time.time() - - # 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig - config_load_start = time.time() - embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID) - rb_config = RedBearModelConfig( - model_name=embedder_config_dict["model_name"], - provider=embedder_config_dict["provider"], - api_key=embedder_config_dict["api_key"], - base_url=embedder_config_dict["base_url"], - type="llm" - ) - config_load_time = time.time() - config_load_start - logger.info(f"Config loading took {config_load_time:.4f}s") - - # Init embedder - embedder_init_start = time.time() - embedder = OpenAIEmbedderClient(model_config=rb_config) - embedder_init_time = time.time() - embedder_init_start - logger.info(f"Embedder init took {embedder_init_time:.4f}s") - - embedding_task = asyncio.create_task( - search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=query_text, - group_id=group_id, - limit=limit, - include=include, - ) - ) - - if keyword_task: - keyword_results = await keyword_task - keyword_latency = time.time() - keyword_start - latency_metrics["keyword_search_latency"] = round(keyword_latency, 4) - logger.info(f"Keyword search completed in {keyword_latency:.4f}s") - if search_type == "keyword": - results = keyword_results - else: - results["keyword_search"] = keyword_results - - if embedding_task: - embedding_results = await embedding_task - embedding_latency = time.time() - embedding_start - latency_metrics["embedding_search_latency"] = round(embedding_latency, 4) - logger.info(f"Embedding search completed in {embedding_latency:.4f}s") - if search_type == "embedding": - results = embedding_results - else: - results["embedding_search"] = embedding_results - - # Merge and rank results for hybrid search - if search_type == "hybrid": - results["combined_summary"] = { - "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), - "total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), - "search_query": query_text, - "search_timestamp": datetime.now().isoformat() - } - - # Apply reranking (optionally with forgetting curve) - rerank_start = time.time() - if use_forgetting_rerank: - # Load forgetting parameters from pipeline config - try: - pc = get_pipeline_config() - forgetting_cfg = pc.forgetting_engine - except Exception as e: - logger.debug(f"Failed to load forgetting config, using defaults: {e}") - forgetting_cfg = ForgettingEngineConfig() - reranked_results = rerank_with_forgetting_curve( - keyword_results=keyword_results, - embedding_results=embedding_results, - alpha=rerank_alpha, - limit=limit, - forgetting_config=forgetting_cfg, - ) - else: - reranked_results = rerank_hybrid_results( - keyword_results=keyword_results, - embedding_results=embedding_results, - alpha=rerank_alpha, # Configurable weight for BM25 vs embedding - limit=limit - ) - rerank_latency = time.time() - rerank_start - latency_metrics["reranking_latency"] = round(rerank_latency, 4) - logger.info(f"Reranking completed in {rerank_latency:.4f}s") - - # Optional: apply reranker placeholder if enabled via config - reranked_results = apply_reranker_placeholder(reranked_results, query_text) - - # Apply LLM reranking if enabled - llm_rerank_applied = False - if use_llm_rerank: - try: - reranked_results = await apply_llm_reranker( - results=reranked_results, - query_text=query_text, - ) - llm_rerank_applied = True - logger.info("LLM reranking applied successfully") - except Exception as e: - logger.warning(f"LLM reranking failed: {e}, using previous scores") - - results["reranked_results"] = reranked_results - results["combined_summary"] = { - "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), - "total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), - "total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()), - "search_query": query_text, - "search_timestamp": datetime.now().isoformat(), - "reranking_alpha": rerank_alpha, - "forgetting_rerank": use_forgetting_rerank, - "llm_rerank": llm_rerank_applied, - } - - # Calculate total latency - total_latency = time.time() - search_start_time - latency_metrics["total_latency"] = round(total_latency, 4) - - # Add latency metrics to results - if "combined_summary" in results: - results["combined_summary"]["latency_metrics"] = latency_metrics - else: - results["latency_metrics"] = latency_metrics - - logger.info(f"Total search completed in {total_latency:.4f}s") - logger.info(f"Latency breakdown: {latency_metrics}") - - # Sanitize results: drop large/unused fields - _remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs - - # print(json.dumps(results, ensure_ascii=False, indent=2, default=str)) - - # Save to file - output_path = output_path or "search_results.json" - out_dir = os.path.dirname(output_path) - if out_dir: - os.makedirs(out_dir, exist_ok=True) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(results, f, ensure_ascii=False, indent=2, default=str) - logger.info(f"Search results saved to: {output_path}") - - # Log search completion with result count - if search_type == "hybrid": - result_counts = { - "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()}, - "embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()} - } - else: - result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()} - - completion_log = { - "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "query": query_text, - "search_type": search_type, - "status": "completed", - "result_counts": result_counts, - "output_file": output_path, - "latency_metrics": latency_metrics - } - - with open("search_log.txt", "a", encoding="utf-8") as f: - f.write(json.dumps(completion_log, ensure_ascii=False) + "\n") - - return results - - finally: - await connector.close() - - -async def search_by_temporal( - group_id: Optional[str] = "test", - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 1, -): - """ - Temporal search across Statements. - - - Matches statements created between start_date and end_date - - Optionally filters by group_id - - Returns up to 'limit' statements - """ - connector = Neo4jConnector() - if start_date: - start_date = normalize_date_safe(start_date) - if end_date: - end_date = normalize_date_safe(end_date) - - params = TemporalSearchParams.model_validate({ - "group_id": group_id, - "apply_id": apply_id, - "user_id": user_id, - "start_date": start_date, - "end_date": end_date, - "valid_date": valid_date, - "invalid_date": invalid_date, - "limit": limit, - }) - statements = await search_graph_by_temporal( - connector=connector, - group_id=params.group_id, - apply_id=params.apply_id, - user_id=params.user_id, - start_date=params.start_date, - end_date=params.end_date, - valid_date=params.valid_date, - invalid_date=params.invalid_date, - limit=params.limit - ) - return {"statements": statements} - - -async def search_by_keyword_temporal( - query_text: str, - group_id: Optional[str] = "test", - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 1, -): - """ - Temporal keyword search across Statements. - """ - connector = Neo4jConnector() - if start_date: - start_date = normalize_date_safe(start_date) - if end_date: - end_date = normalize_date_safe(end_date) - if valid_date: - valid_date = normalize_date_safe(valid_date) - if invalid_date: - invalid_date = normalize_date_safe(invalid_date) - - params = TemporalSearchParams.model_validate({ - "group_id": group_id, - "apply_id": apply_id, - "user_id": user_id, - "start_date": start_date, - "end_date": end_date, - "valid_date": valid_date, - "invalid_date": invalid_date, - "limit": limit, - }) - statements = await search_graph_by_keyword_temporal( - connector=connector, - query_text=query_text, - group_id=params.group_id, - apply_id=params.apply_id, - user_id=params.user_id, - start_date=params.start_date, - end_date=params.end_date, - valid_date=params.valid_date, - invalid_date=params.invalid_date, - limit=params.limit - ) - return {"statements": statements} - - -async def search_chunk_by_chunk_id( - chunk_id: str, - group_id: Optional[str] = "test", - limit: int = 1, -): - """ - Search for Chunks by chunk_id. - """ - connector = Neo4jConnector() - chunks = await search_graph_by_chunk_id( - connector=connector, - chunk_id=chunk_id, - group_id=group_id, - limit=limit - ) - return {"chunks": chunks} - - -def main(): - """Main entry point for the hybrid graph search CLI. - - Parses command line arguments and executes search with specified parameters. - Supports keyword, embedding, and hybrid search modes. - """ - parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options") - parser.add_argument( - "--query", "-q", required=True, help="Free-text query to search" - ) - parser.add_argument( - "--search-type", - "-t", - choices=["keyword", "embedding", "hybrid"], - default="hybrid", - help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)" - ) - parser.add_argument( - "--embedding-name", - "-m", - default="openai/nomic-embed-text:v1.5", - help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)", - ) - parser.add_argument( - "--group-id", - "-g", - default=None, - help="Optional group_id to filter results (default: None)", - ) - parser.add_argument( - "--limit", - "-k", - type=int, - default=5, - help="Max number of results per type (default: 5)", - ) - parser.add_argument( - "--include", - "-i", - nargs="+", - default=["statements", "chunks", "entities", "summaries"], - choices=["statements", "chunks", "entities", "summaries"], - help="Which targets to search for embedding search (default: statements chunks entities summaries)" - ) - parser.add_argument( - "--output", - "-o", - default="search_results.json", - help="Path to save the search results JSON (default: search_results.json)", - ) - parser.add_argument( - "--rerank-alpha", - "-a", - type=float, - default=0.6, - help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)", - ) - parser.add_argument( - "--forgetting-rerank", - action="store_true", - help="Apply forgetting curve during reranking for hybrid search.", - ) - parser.add_argument( - "--llm-rerank", - action="store_true", - help="Apply LLM-based reranking for hybrid search.", - ) - args = parser.parse_args() - - asyncio.run( - run_hybrid_search( - query_text=args.query, - search_type=args.search_type, - group_id=args.group_id, - limit=args.limit, - include=args.include, - output_path=args.output, - rerank_alpha=args.rerank_alpha, - use_forgetting_rerank=args.forgetting_rerank, - use_llm_rerank=args.llm_rerank, - ) - ) - - -if __name__ == "__main__": - main() diff --git a/app/core/memory/storage_services/__init__.py b/app/core/memory/storage_services/__init__.py deleted file mode 100644 index d7cd6df6..00000000 --- a/app/core/memory/storage_services/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -存储服务模块 - -包含三大引擎: -1. 萃取引擎(Extraction Engine)- 知识提取、预处理、去重消歧 -2. 遗忘引擎(Forgetting Engine)- 记忆遗忘机制 -3. 自我反思引擎(Reflection Engine)- 自我反思和优化 -""" diff --git a/app/core/memory/storage_services/extraction_engine/__init__.py b/app/core/memory/storage_services/extraction_engine/__init__.py deleted file mode 100644 index 6ddfb3bc..00000000 --- a/app/core/memory/storage_services/extraction_engine/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -萃取引擎(Extraction Engine) - -负责从对话数据中提取结构化知识,包括: -- 数据预处理 -- 知识提取(分块、陈述句、三元组、时间信息、嵌入向量) -- 去重消歧 -""" diff --git a/app/core/memory/storage_services/extraction_engine/data_preprocessing/__init__.py b/app/core/memory/storage_services/extraction_engine/data_preprocessing/__init__.py deleted file mode 100644 index 0704e350..00000000 --- a/app/core/memory/storage_services/extraction_engine/data_preprocessing/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -数据预处理模块 - 负责对话数据的清洗、转换和预处理 - -包含: -- data_preprocessor: 数据预处理器 - 读取、清洗和转换对话数据 -- data_pruning: 语义剪枝器 - 过滤与场景不相关的内容 -- data_chunker: 数据分块器 - 将对话分割成可处理的片段 -""" - -from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor -from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner - -__all__ = ['DataPreprocessor', 'SemanticPruner'] diff --git a/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_chunker.py b/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_chunker.py deleted file mode 100644 index 37d02360..00000000 --- a/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_chunker.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -数据分块器 - 将对话分割成可处理的片段 - -功能: -- 支持多种分块策略(递归分块、语义分块、LLM分块等) -- 根据对话长度和内容特征进行智能分块 -- 保持对话上下文的连贯性 - -注意:此模块当前为占位符,具体实现将在后续任务中完成。 -分块功能目前在 app/core/memory/llm_tools/chunker_client.py 中实现。 -""" - -from typing import List, Optional -from app.core.memory.models.message_models import DialogData, Chunk - - -class DataChunker: - """数据分块器 - 将长对话分割成多个可处理的片段""" - - def __init__(self, chunker_strategy: str = "RecursiveChunker"): - """ - 初始化数据分块器 - - Args: - chunker_strategy: 分块策略名称 - """ - self.chunker_strategy = chunker_strategy - - async def chunk_dialog(self, dialog: DialogData) -> List[Chunk]: - """ - 将对话分割成多个块 - - Args: - dialog: 对话数据 - - Returns: - 分块列表 - - Note: - 当前此功能在 app/core/memory/llm_tools/chunker_client.py 中实现 - """ - raise NotImplementedError("数据分块功能将在后续任务中实现") - - async def chunk_dialogs(self, dialogs: List[DialogData]) -> List[DialogData]: - """ - 批量处理多个对话的分块 - - Args: - dialogs: 对话数据列表 - - Returns: - 包含分块信息的对话数据列表 - """ - raise NotImplementedError("数据分块功能将在后续任务中实现") diff --git a/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py b/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py deleted file mode 100644 index 796a76af..00000000 --- a/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py +++ /dev/null @@ -1,785 +0,0 @@ -""" -数据预处理器 - 支持多种格式的对话数据读取、清洗和预处理 - -功能: -- 支持多种文件格式:JSON、CSV、Excel、TXT -- 自动检测文件编码 -- 清洗和标准化对话数据 -- 转换为 DialogData 对象 -""" - -import json -import csv -import pandas as pd -import re -import os - -from typing import List, Dict, Any, Optional, Union -from pathlib import Path -from datetime import datetime - -from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage - - -class DataPreprocessor: - """数据预处理器类,支持多种格式的对话数据读取、清洗和预处理。""" - - def __init__(self, input_file_path: str = None, output_file_path: str = None): - """ - 初始化数据预处理器。 - - Args: - input_file_path: 输入文件路径(可选,可后续通过set_input_path设置) - output_file_path: 输出文件路径(可选,可后续通过set_output_path设置) - - 注意:您可以通过以下方式指定输入输出路径: - 1. 初始化时传入参数 - 2. 调用set_input_path()和set_output_path()方法 - 3. 在preprocess()方法中直接传入路径参数 - """ - self.input_file_path = input_file_path or r"src\extracted_statements.txt" - self.output_file_path = output_file_path or r"src\data_preprocessing\out-file\extracted_statements-pre.txt" - self.supported_formats = ['.json', '.csv', '.txt', '.xlsx', '.tsv'] - - def set_input_path(self, input_path: str) -> None: - """ - 设置输入文件路径。 - - Args: - input_path: 输入文件的完整路径 - """ - self.input_file_path = input_path - - def set_output_path(self, output_path: str) -> None: - """ - 设置输出文件路径。 - - Args: - output_path: 输出文件的完整路径 - """ - self.output_file_path = output_path - - def get_file_format(self, file_path: str) -> str: - """ - 获取文件格式。 - - Args: - file_path: 文件路径 - - Returns: - 文件扩展名(小写) - """ - return Path(file_path).suffix.lower() - - def _detect_encoding(self, file_path: str) -> str: - """ - 检测文件编码,使用多种方法确保准确性。 - - Args: - file_path: 文件路径 - - Returns: - 检测到的编码格式 - """ - # 常见编码列表,按优先级排序 - encodings_to_try = ['utf-8', 'gbk', 'gb2312', 'utf-16', 'latin-1'] - - # 首先尝试使用chardet检测 - try: - import chardet - with open(file_path, 'rb') as f: - raw_data = f.read(10000) # 读取前10KB进行检测 - result = chardet.detect(raw_data) - detected_encoding = result.get('encoding') - confidence = result.get('confidence', 0) - - # 如果检测置信度较高,使用检测结果 - if detected_encoding and confidence > 0.7: - return detected_encoding - except ImportError: - print("警告: chardet库未安装,使用备用编码检测方法") - except Exception as e: - print(f"chardet检测失败: {e},使用备用方法") - - # 备用方法:尝试不同编码读取文件开头 - for encoding in encodings_to_try: - try: - with open(file_path, 'r', encoding=encoding) as f: - f.read(1000) # 尝试读取前1000个字符 - return encoding - except (UnicodeDecodeError, UnicodeError): - continue - - # 如果所有编码都失败,返回utf-8作为最后选择 - return 'utf-8' - - def _read_json(self, data_path: str) -> List[Dict[str, Any]]: - """ - 读取JSON格式的对话数据,支持标准JSON和JSONL格式。 - - Args: - data_path: JSON文件路径 - - Returns: - 解析后的数据列表 - """ - encoding = self._detect_encoding(data_path) - content = None - - # 尝试使用检测到的编码读取文件 - encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1'] - - for enc in encodings_to_try: - try: - with open(data_path, 'r', encoding=enc) as f: - content = f.read().strip() - print(f"成功使用编码 {enc} 读取文件") - break - except (UnicodeDecodeError, UnicodeError) as e: - print(f"编码 {enc} 读取失败: {e}") - continue - - if content is None: - raise ValueError(f"无法使用任何编码读取文件: {data_path}") - - try: - - # 尝试解析为标准JSON - try: - data = json.loads(content) - if isinstance(data, dict): - return [data] - elif isinstance(data, list): - return data - else: - raise ValueError(f"不支持的JSON数据结构: {type(data)}") - except json.JSONDecodeError as e: - # 如果标准JSON解析失败,尝试JSONL格式(每行一个JSON对象) - print(f"标准JSON解析失败: {e},尝试JSONL格式...") - data_list = [] - lines = content.split('\n') - - for line_num, line in enumerate(lines, 1): - line = line.strip() - if line: # 跳过空行 - try: - json_obj = json.loads(line) - data_list.append(json_obj) - except json.JSONDecodeError as line_error: - # 如果是单行巨大JSON数组,可能需要特殊处理 - if line_num == 1 and len(lines) == 1: - print(f"检测到单行大型JSON,尝试分块解析...") - # 对于超大单行JSON,尝试使用json.JSONDecoder进行流式解析 - try: - decoder = json.JSONDecoder() - idx = 0 - while idx < len(line): - line = line[idx:].lstrip() - if not line: - break - try: - obj, end_idx = decoder.raw_decode(line) - if isinstance(obj, list): - data_list.extend(obj) - elif isinstance(obj, dict): - data_list.append(obj) - idx += end_idx - except json.JSONDecodeError: - break - except Exception as decode_error: - print(f"分块解析也失败: {decode_error}") - else: - print(f"警告: 第{line_num}行JSON解析失败: {line_error}") - continue - - return data_list - - except Exception as e: - raise ValueError(f"读取JSON文件时发生错误: {e}") - - def _read_csv(self, data_path: str) -> List[Dict[str, Any]]: - """ - 读取CSV格式的对话数据。 - - Args: - data_path: CSV文件路径 - - Returns: - 解析后的数据列表 - """ - encoding = self._detect_encoding(data_path) - encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1'] - - for enc in encodings_to_try: - try: - # 尝试不同的分隔符 - separators = [',', '\t', ';', '|'] - df = None - - for sep in separators: - try: - df = pd.read_csv(data_path, encoding=enc, sep=sep) - if len(df.columns) > 1: # 如果成功分割出多列,则认为找到了正确的分隔符 - break - except Exception: - continue - - if df is None: - df = pd.read_csv(data_path, encoding=enc) - - print(f"成功使用编码 {enc} 读取CSV文件") - return df.to_dict('records') - - except (UnicodeDecodeError, UnicodeError) as e: - print(f"编码 {enc} 读取CSV失败: {e}") - continue - except Exception as e: - if enc == encodings_to_try[-1]: # 最后一个编码也失败了 - raise ValueError(f"读取CSV文件失败: {e}") - continue - - raise ValueError(f"无法使用任何编码读取CSV文件: {data_path}") - - def _read_excel(self, data_path: str) -> List[Dict[str, Any]]: - """ - 读取Excel格式的对话数据。 - - Args: - data_path: Excel文件路径 - - Returns: - 解析后的数据列表 - """ - try: - df = pd.read_excel(data_path) - return df.to_dict('records') - except Exception as e: - raise ValueError(f"读取Excel文件失败: {e}") - - def _read_text(self, data_path: str) -> List[Dict[str, Any]]: - """ - 读取纯文本格式的对话数据。 - - Args: - data_path: 文本文件路径 - - Returns: - 解析后的数据列表 - """ - encoding = self._detect_encoding(data_path) - encodings_to_try = [encoding, 'utf-8', 'gbk', 'gb2312', 'latin-1'] - content = None - - # 尝试使用不同编码读取文件 - for enc in encodings_to_try: - try: - with open(data_path, 'r', encoding=enc) as f: - content = f.read() - print(f"成功使用编码 {enc} 读取文本文件") - break - except (UnicodeDecodeError, UnicodeError) as e: - print(f"编码 {enc} 读取文本失败: {e}") - continue - - if content is None: - raise ValueError(f"无法使用任何编码读取文本文件: {data_path}") - - try: - - # 尝试解析不同的文本格式 - lines = content.strip().split('\n') - - # 格式1: 每行一个对话轮次,格式为 "角色: 内容" 或 "角色:内容" - messages = [] - for line in lines: - line = line.strip() - if not line: - continue - - # 尝试匹配 "角色: 内容" 或 "角色:内容" 格式 - match = re.match(r'^([^::]+)[::]\s*(.+)$', line) - if match: - role, msg = match.groups() - messages.append({'role': role.strip(), 'msg': msg.strip()}) - else: - # 如果不匹配,则作为用户消息处理 - messages.append({'role': 'User', 'msg': line}) - - if messages: - return [{'context': {'msgs': messages}}] - else: - # 如果没有解析出消息,则将整个文本作为一条消息 - return [{'context': {'msgs': [{'role': 'User', 'msg': content}]}}] - - except Exception as e: - raise ValueError(f"读取文本文件失败: {e}") - - def read_data(self, data_path: str = None) -> List[Dict[str, Any]]: - """ - 根据文件格式自动选择合适的读取方法。 - - Args: - data_path: 数据文件路径(如果为None,则使用初始化时设置的路径) - - Returns: - 解析后的原始数据列表 - """ - if data_path is None: - data_path = self.input_file_path - - if not data_path: - raise ValueError("请指定输入文件路径") - - if not os.path.exists(data_path): - raise FileNotFoundError(f"文件不存在: {data_path}") - - file_format = self.get_file_format(data_path) - - if file_format == '.json': - return self._read_json(data_path) - elif file_format == '.csv': - return self._read_csv(data_path) - elif file_format in ['.xlsx', '.xls']: - return self._read_excel(data_path) - elif file_format in ['.txt', '.tsv']: - return self._read_text(data_path) - else: - raise ValueError(f"不支持的文件格式: {file_format}。支持的格式: {self.supported_formats}") - - def _clean_text(self, text: str) -> str: - """ - 增强的文本清洗函数。 - """ - if not text or not isinstance(text, str): - return "" - - # 1. 移除消息中的角色标识(支持英文冒号":"与中文冒号":") - text = re.sub(r'^(用户|AI|user|ai|assistant|bot|助手|机器人)[::]\s*', '', text, flags=re.IGNORECASE) - - # 2. 移除URL链接 - text = re.sub(r'https?://[^\s]+', '', text) - text = re.sub(r'www\.[^\s]+', '', text) - - # 3. 移除HTML标签 - text = re.sub(r'<[^>]+>', '', text) - - # 4. 移除乱码和控制字符 - text = re.sub(r'[�]+', '', text) - text = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', text) - - # 5. 标点符号规范化 - # 将连续的感叹号(中英文)替换为一个句号 - text = re.sub(r'[!!]+', '。', text) - # 将连续的句点/省略号(中英文)替换为一个句号 - text = re.sub(r'(…{1,}|\.{2,}|。{2,})', '。', text) - # 将英文句点统一为中文句号(避免残留英文句点影响断句) - text = re.sub(r'\.', '。', text) - # 将连续的逗号(中英文)规范为一个中文逗号 - text = re.sub(r'[,,]{2,}', ',', text) - # 将英文逗号统一为中文逗号 - text = re.sub(r',', ',', text) - - # 6. 规范化空白字符 - text = re.sub(r'\s+', ' ', text) - text = text.strip() - - return text - - def _parse_message_content(self, content: str) -> List[Dict[str, str]]: - """ - 增强的消息内容解析。 - """ - messages = [] - - # 先清洗内容 - cleaned_content = self._clean_text(content) - - if not cleaned_content: - return messages - - # 检查是否为有效消息(至少包含中文或英文单词) - if not re.search(r'[\u4e00-\u9fff\w]', cleaned_content): - return messages - - # 根据内容特征判断角色(更智能的角色识别) - if re.search(r'(你好|嗨|早上好|晚上好|请问|谢谢|抱歉)', cleaned_content): - role = 'User' - elif re.search(r'(很高兴|建议|推荐|可以帮助|请提供)', cleaned_content): - role = 'Assistant' - else: - role = 'User' # 默认 - - messages.append({'role': role, 'msg': cleaned_content}) - - return messages - - def _filter_empty_messages(self, messages: List[ConversationMessage]) -> List[ConversationMessage]: - """ - 更严格的空消息过滤。 - """ - filtered = [] - for msg in messages: - # 检查消息是否有效 - if (msg.msg and - isinstance(msg.msg, str) and - len(msg.msg.strip()) >= 2 and # 至少2个字符 - re.search(r'[\u4e00-\u9fff\w]', msg.msg)): # 包含有效字符 - filtered.append(msg) - return filtered - - - def _normalize_role(self, role: str) -> str: - """ - 标准化角色名称。 - - Args: - role: 原始角色名称 - - Returns: - 标准化后的角色名称 - """ - if not role or not isinstance(role, str): - return "User" - - role = role.strip().lower() - - # 用户角色的各种表示 - user_roles = ['user', 'human', '用户', '人类', 'customer', '客户', 'u'] - # AI角色的各种表示 - ai_roles = ['assistant', 'ai', 'bot', 'chatbot', '助手', '机器人', 'system', 'a'] - - if role in user_roles: - return "User" - elif role in ai_roles: - return "Assistant" - else: - return "User" # 默认为用户 - - def clean_data(self, raw_data: List[Dict[str, Any]], skip_cleaning: bool = True) -> List[DialogData]: - """ - 清洗原始数据并转换为DialogData对象。 - - Args: - raw_data: 原始数据列表 - skip_cleaning: 是否跳过数据清洗,直接转换为DialogData对象(默认False) - - Returns: - 清洗后的DialogData对象列表 - """ - if skip_cleaning: - print("跳过数据清洗步骤,直接转换数据...") - return self._convert_to_dialog_data(raw_data) - - cleaned_dialogs = [] - - for i, item in enumerate(raw_data): - conv_date: Optional[str] = None - try: - # 提取对话消息 - messages = [] - - # 处理不同的数据结构 - if 'content' in item and isinstance(item['content'], list): - # 新格式:dialog_release_zh.json格式,content是字符串数组 - content_list = item['content'] - for j, content_text in enumerate(content_list): - # 交替分配角色:偶数索引为用户,奇数索引为AI - role = 'User' if j % 2 == 0 else 'Assistant' - normalized_role = self._normalize_role(role) - - # 清洗消息内容 - cleaned_content = self._clean_text(str(content_text)) - - # 过滤空消息 - if cleaned_content: - messages.append(ConversationMessage(role=normalized_role, msg=cleaned_content)) - - elif 'context' in item and isinstance(item['context'], dict) and 'msgs' in item['context']: - # 标准格式:context是字典且包含msgs - raw_messages = item['context']['msgs'] - elif 'context' in item and isinstance(item['context'], str): - # testdata.json格式:context是字符串,需要解析对话内容 - context_text = item['context'] - # 从context文本中解析绝对日期并存入conv_date(格式:YYYY-MM-DD) - m = re.search(r"(\d{4})年(\d{1,2})月(\d{1,2})日", context_text) - if m: - y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) - conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - else: - m = re.search(r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})", context_text) - if m: - y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) - conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - messages = self._parse_context_string(context_text) - elif 'messages' in item: - # 另一种常见格式 - raw_messages = item['messages'] - elif 'conversation' in item: - # 对话格式 - raw_messages = item['conversation'] - else: - # 尝试直接解析 - raw_messages = [item] if 'role' in item and 'msg' in item else [] - - # 如果messages还是空的,说明需要处理raw_messages - if not messages and 'raw_messages' in locals(): - # 清洗每条消息 - for msg_data in raw_messages: - if isinstance(msg_data, dict): - role = self._normalize_role(msg_data.get('role', 'User')) - content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', ''))) - - # 清洗消息内容 - cleaned_content = self._clean_text(str(content)) - - # 过滤空消息 - if cleaned_content: - messages.append(ConversationMessage(role=role, msg=cleaned_content)) - - # 过滤空对话 - if not messages: - continue - - # 去重相邻的重复消息 - deduplicated_messages = [] - for msg in messages: - if not deduplicated_messages or ( - deduplicated_messages[-1].role != msg.role or - deduplicated_messages[-1].msg != msg.msg - ): - deduplicated_messages.append(msg) - - # 创建DialogData对象 - context = ConversationContext(msgs=deduplicated_messages) - # 获取对话ID,优先使用dialog_id,然后是ref_id、id,最后生成默认ID - dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) - - - # 获取group_id,如果不存在则生成默认值 - group_id = item.get('group_id', f'group_default_{i}') - user_id = item.get('user_id', f'user_default_{i}') - apply_id = item.get('apply_id', f'apply_default_{i}') - - - # 构建元数据,附加解析到的会话日期 - metadata = { - **item.get('metadata', {}), - 'document_id': str(item.get('document_id', 'unknown')) if item.get('document_id') is not None else 'unknown', - 'original_format': 'dialog_release_zh' if 'content' in item and isinstance(item['content'], list) else 'testdata' - } - if conv_date: - metadata['conversation_date'] = conv_date - metadata['publication_date'] = conv_date - - dialog_data = DialogData( - context=context, - ref_id=dialog_id, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, - metadata=metadata - ) - - cleaned_dialogs.append(dialog_data) - - except Exception as e: - print(f"警告: 处理第{i+1}条数据时出错: {e}") - continue - - return cleaned_dialogs - - def _convert_to_dialog_data(self, raw_data: List[Dict[str, Any]]) -> List[DialogData]: - """ - 直接将原始数据转换为DialogData对象,不进行清洗。 - - Args: - raw_data: 原始数据列表 - - Returns: - DialogData对象列表 - """ - dialog_list = [] - - for i, item in enumerate(raw_data): - try: - messages = [] - - # 处理不同的数据结构 - if 'content' in item and isinstance(item['content'], list): - content_list = item['content'] - for j, content_text in enumerate(content_list): - role = 'User' if j % 2 == 0 else 'Assistant' - if content_text: - messages.append(ConversationMessage(role=role, msg=str(content_text))) - - elif 'context' in item and isinstance(item['context'], dict) and 'msgs' in item['context']: - raw_messages = item['context']['msgs'] - for msg_data in raw_messages: - if isinstance(msg_data, dict): - role = msg_data.get('role', 'User') - content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', ''))) - if content: - messages.append(ConversationMessage(role=role, msg=str(content))) - - elif 'context' in item and isinstance(item['context'], str): - # 尝试解析结构化对话,如果失败则作为单条用户消息处理 - messages = self._parse_context_string(item['context']) - if not messages: - # 如果没有解析出结构化消息,将整个context作为用户消息 - context_text = item['context'].strip() - if context_text: - messages.append(ConversationMessage(role='User', msg=context_text)) - - elif 'messages' in item: - raw_messages = item['messages'] - for msg_data in raw_messages: - if isinstance(msg_data, dict): - role = msg_data.get('role', 'User') - content = msg_data.get('msg', msg_data.get('content', msg_data.get('message', ''))) - if content: - messages.append(ConversationMessage(role=role, msg=str(content))) - - if not messages: - continue - - context = ConversationContext(msgs=messages) - dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) - group_id = item.get('group_id', f'group_default_{i}') - user_id = item.get('user_id', f'user_default_{i}') - apply_id = item.get('apply_id', f'apply_default_{i}') - - metadata = { - **item.get('metadata', {}), - 'document_id': str(item.get('document_id', 'unknown')) if item.get('document_id') is not None else 'unknown', - 'original_format': 'raw' - } - - dialog_data = DialogData( - context=context, - ref_id=dialog_id, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, - metadata=metadata - ) - - dialog_list.append(dialog_data) - - except Exception as e: - print(f"警告: 转换第{i+1}条数据时出错: {e}") - continue - - return dialog_list - - def _parse_context_string(self, context_text: str) -> List[ConversationMessage]: - """ - 解析context字符串中的对话内容。 - - Args: - context_text: 包含对话的字符串 - - Returns: - 解析后的ConversationMessage列表 - """ - messages = [] - - # 使用正则表达式匹配对话模式 - # 匹配 "User: 内容" / "用户: 内容" 或 "Assistant: 内容" / "AI: 内容" 格式 - pattern = r'(User|用户|Assistant|AI|user|assistant)[::]\s*([^\n]+(?:\n(?!(?:User|用户|Assistant|AI|user|assistant)[::])[^\n]*)*?)' - matches = re.findall(pattern, context_text, re.MULTILINE | re.DOTALL | re.IGNORECASE) - - for role, content in matches: - # 标准化角色名称 - normalized_role = self._normalize_role(role) - - # 清洗消息内容 - cleaned_content = self._clean_text(content.strip()) - - # 过滤空消息 - if cleaned_content: - messages.append(ConversationMessage(role=normalized_role, msg=cleaned_content)) - - return messages - - def save_data(self, dialog_data_list: List[DialogData], output_path: str = None) -> None: - """ - 保存处理后的数据。 - - Args: - dialog_data_list: DialogData对象列表 - output_path: 输出文件路径(如果为None,则使用初始化时设置的路径) - """ - if output_path is None: - output_path = self.output_file_path - - if not output_path: - raise ValueError("请指定输出文件路径") - - # 确保输出目录存在 - os.makedirs(os.path.dirname(output_path), exist_ok=True) - - # 转换为可序列化的格式 - serializable_data = [] - for dialog in dialog_data_list: - serializable_data.append({ - 'id': dialog.id, - 'ref_id': dialog.ref_id, - 'created_at': dialog.created_at.isoformat(), - 'context': { - 'msgs': [{'role': msg.role, 'msg': msg.msg} for msg in dialog.context.msgs] - }, - 'metadata': dialog.metadata, - 'chunks': [] - }) - - # 保存为JSON格式 - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(serializable_data, f, ensure_ascii=False, indent=2) - - print(f"数据已保存到: {output_path}") - - def preprocess(self, input_path: str = None, output_path: str = None, skip_cleaning: bool = True, indices: Optional[List[int]] = None) -> List[DialogData]: - """ - 完整的数据预处理流程。 - - Args: - input_path: 输入文件路径(可选) - output_path: 输出文件路径(可选) - skip_cleaning: 是否跳过数据清洗步骤(默认False) - indices: 要处理的数据索引列表(可选) - - Returns: - 处理后的DialogData对象列表 - """ - print("开始数据预处理...") - - # 读取原始数据 - print("正在读取数据...") - raw_data = self.read_data(input_path) - print(f"成功读取 {len(raw_data)} 条原始数据") - - # 根据索引筛选数据 - if indices: - selected = [raw_data[i] for i in indices if 0 <= i < len(raw_data)] - if selected: - raw_data = selected - print(f"根据索引 {indices} 筛选后,保留 {len(raw_data)} 条数据") - else: - print(f"警告: 提供的索引 {indices} 筛选为空,处理全部 {len(raw_data)} 条数据") - - # 清洗数据 - if skip_cleaning: - print("跳过数据清洗步骤...") - cleaned_data = self.clean_data(raw_data, skip_cleaning=True) - else: - print("正在清洗数据...") - cleaned_data = self.clean_data(raw_data, skip_cleaning=False) - print(f"处理完成,得到 {len(cleaned_data)} 条有效对话") - - # 保存数据(如果指定了输出路径) - if output_path or self.output_file_path: - print("正在保存数据...") - self.save_data(cleaned_data, output_path) - - print("数据预处理完成!") - return cleaned_data diff --git a/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py deleted file mode 100644 index 6544b2ce..00000000 --- a/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ /dev/null @@ -1,573 +0,0 @@ -""" -语义剪枝器 - 在预处理与分块之间过滤与场景不相关内容 - -功能: -- 对话级一次性抽取判定相关性 -- 仅对"不相关对话"的消息按比例删除 -- 重要信息(时间、编号、金额、联系方式、地址等)优先保留 -""" - -import os -import hashlib -import json -import re -from datetime import datetime -from typing import List, Optional -from pydantic import BaseModel, Field - -from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext -from app.core.memory.models.config_models import PruningConfig -from app.core.memory.utils.config.config_utils import get_pruning_config -from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering - - -class DialogExtractionResponse(BaseModel): - """对话级一次性抽取的结构化返回,用于加速剪枝。 - - - is_related:对话与场景的相关性判定。 - - times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。 - """ - is_related: bool = Field(...) - times: List[str] = Field(default_factory=list) - ids: List[str] = Field(default_factory=list) - amounts: List[str] = Field(default_factory=list) - contacts: List[str] = Field(default_factory=list) - addresses: List[str] = Field(default_factory=list) - keywords: List[str] = Field(default_factory=list) - - -class SemanticPruner: - """语义剪枝:在预处理与分块之间过滤与场景不相关内容。 - - 采用对话级一次性抽取判定相关性;仅对"不相关对话"的消息按比例删除, - 重要信息(时间、编号、金额、联系方式、地址等)优先保留。 - """ - - def __init__(self, config: Optional[PruningConfig] = None, llm_client=None): - cfg_dict = get_pruning_config() if config is None else config.model_dump() - self.config = PruningConfig.model_validate(cfg_dict) - self.llm_client = llm_client - # Load Jinja2 template - self.template = prompt_env.get_template("extracat_Pruning.jinja2") - # 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染 - self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {} - # 运行日志:收集关键终端输出,便于写入 JSON - self.run_logs: List[str] = [] - # 采用顺序处理,移除并发配置以简化与稳定执行 - - def _is_important_message(self, message: ConversationMessage) -> bool: - """基于启发式规则识别重要信息消息,优先保留。 - - - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。 - - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。 - - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。 - """ - import re - text = message.msg.strip() - if not text: - return False - patterns = [ - r"\b\d{4}-\d{1,2}-\d{1,2}\b", - r"\b\d{1,2}:\d{2}\b", - r"\d{4}年\d{1,2}月\d{1,2}日", - r"上午|下午|AM|PM", - r"订单号|工单|申请号|编号|ID|账号|账户", - r"电话|手机号|微信|QQ|邮箱", - r"地址|地点", - r"金额|费用|价格|¥|¥|\d+元", - r"时间|日期|有效期|截止", - ] - for p in patterns: - if re.search(p, text, flags=re.IGNORECASE): - return True - return False - - def _importance_score(self, message: ConversationMessage) -> int: - """为重要消息打分,用于在保留比例内优先保留更关键的内容。 - - 简单启发:匹配到的类别越多、越关键分值越高。 - """ - import re - text = message.msg.strip() - score = 0 - weights = [ - (r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3), - (r"\b\d{1,2}:\d{2}\b", 2), - (r"\d{4}年\d{1,2}月\d{1,2}日", 3), - (r"订单号|工单|申请号|编号|ID|账号|账户", 4), - (r"电话|手机号|微信|QQ|邮箱", 3), - (r"地址|地点", 2), - (r"金额|费用|价格|¥|¥|\d+元", 4), - (r"时间|日期|有效期|截止", 2), - ] - for p, w in weights: - if re.search(p, text, flags=re.IGNORECASE): - score += w - return score - - def _is_filler_message(self, message: ConversationMessage) -> bool: - """检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。 - - 满足以下之一视为填充消息: - - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体; - - 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。 - """ - import re - t = message.msg.strip() - if not t: - return True - # 常见填充语 - fillers = [ - "你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢", - "拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??" - ] - if t in fillers: - return True - # 长度与字符类型判断 - if len(t) <= 8: - # 非数字、无关键实体的短文本 - if not re.search(r"[0-9]", t) and not self._is_important_message(message): - # 主要是标点或简单确认词 - if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers: - return True - return False - - async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse: - """对话级一次性抽取:从整段对话中提取重要信息并判定相关性。 - - - 仅使用 LLM 结构化输出; - """ - # 缓存命中则直接返回(场景+内容作为键) - cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest() - if cache_key in self._dialog_extract_cache: - return self._dialog_extract_cache[cache_key] - - rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text) - log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene}) - log_prompt_rendering("pruning-extract", rendered) - - # 强制使用 LLM;移除正则回退 - if not self.llm_client: - raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。") - - messages = [ - {"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"}, - {"role": "user", "content": rendered}, - ] - try: - ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) - self._dialog_extract_cache[cache_key] = ex - return ex - except Exception as e: - raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e - - def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool: - """判断消息是否包含任意抽取到的重要片段。""" - if not tokens: - return False - t = message.msg - return any(tok and (tok in t) for tok in tokens) - - async def prune_dialog(self, dialog: DialogData) -> DialogData: - """单对话剪枝:使用一次性对话抽取,避免逐条消息 LLM 调用。 - - 流程: - - 对整段对话进行抽取与相关性判定;若相关则不剪; - - 若不相关:用抽取到的重要片段 + 简单启发识别重要消息,按比例删除不相关消息,优先删除不重要,再删除重要(但重要最多按比例)。 - - 删除策略:不重要消息按出现顺序删除(确定性、无随机)。 - """ - if not self.config.pruning_switch: - return dialog - - proportion = float(self.config.pruning_threshold) - extraction = await self._extract_dialog_important(dialog.content) - if extraction.is_related: - # 相关对话不剪枝 - return dialog - - # 在不相关对话中,识别重要/不重要消息 - tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords - msgs = dialog.context.msgs - imp_unrel_msgs: List[ConversationMessage] = [] - unimp_unrel_msgs: List[ConversationMessage] = [] - for m in msgs: - if self._msg_matches_tokens(m, tokens) or self._is_important_message(m): - imp_unrel_msgs.append(m) - else: - unimp_unrel_msgs.append(m) - # 计算总删除目标数量 - total_unrel = len(msgs) - delete_target = int(total_unrel * proportion) - if proportion > 0 and total_unrel > 0 and delete_target == 0: - delete_target = 1 - imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs)) - unimp_del_cap = len(unimp_unrel_msgs) - max_capacity = max(0, len(msgs) - 1) - max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity) - delete_target = min(delete_target, max_deletable) - # 删除配额分配 - del_unimp = min(delete_target, unimp_del_cap) - rem = delete_target - del_unimp - del_imp = min(rem, imp_del_cap) - - # 选取删除集合 - unimp_delete_ids = [] - imp_delete_ids = [] - if del_unimp > 0: - # 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现) - unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]] - if del_imp > 0: - imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m)) - imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]] - - # 统计实际删除数量(重要/不重要) - actual_unimp_deleted = 0 - actual_imp_deleted = 0 - kept_msgs = [] - delete_targets = set(unimp_delete_ids) | set(imp_delete_ids) - for m in msgs: - mid = id(m) - if mid in delete_targets: - if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp: - actual_unimp_deleted += 1 - continue - if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp: - actual_imp_deleted += 1 - continue - kept_msgs.append(m) - if not kept_msgs and msgs: - kept_msgs = [msgs[0]] - - deleted_total = actual_unimp_deleted + actual_imp_deleted - self._log( - f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}" - ) - - dialog.context = ConversationContext(msgs=kept_msgs) - return dialog - - async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]: - """数据集层面:全局消息级剪枝,保留所有对话。 - - - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。 - - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。 - - 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。 - - 保证每段对话至少保留1条消息,不会删除整段对话。 - """ - # 如果剪枝功能关闭,直接返回原始数据集。 - if not self.config.pruning_switch: - return dialogs - - # 阈值保护:最高0.9 - proportion = float(self.config.pruning_threshold) - if proportion > 0.9: - print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9") - proportion = 0.9 - if proportion < 0.0: - proportion = 0.0 - evaluated_dialogs = [] # list of dicts: {dialog, is_related} - - self._log( - f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}" - ) - # 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存) - evaluated_dialogs = [] - for idx, dd in enumerate(dialogs): - try: - ex = await self._extract_dialog_important(dd.content) - evaluated_dialogs.append({ - "dialog": dd, - "is_related": bool(ex.is_related), - "index": idx, - "extraction": ex - }) - except Exception: - evaluated_dialogs.append({ - "dialog": dd, - "is_related": True, - "index": idx, - "extraction": None - }) - - # 统计相关 / 不相关对话 - not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]] - related_dialogs = [d for d in evaluated_dialogs if d["is_related"]] - self._log( - f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}" - ) - - # 简洁打印第几段对话相关/不相关(索引基于1) - def _fmt_indices(items, cap: int = 10): - inds = [i["index"] + 1 for i in items] - if len(inds) <= cap: - return inds - # 超过上限时只打印前cap个,并标注总数 - return inds[:cap] + ["...", f"共{len(inds)}个"] - - rel_inds = _fmt_indices(related_dialogs) - nrel_inds = _fmt_indices(not_related_dialogs) - self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段") - - result: List[DialogData] = [] - if not_related_dialogs: - # 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM) - per_dialog_info = {} - total_unrelated = 0 - total_capacity = 0 - for d in not_related_dialogs: - dd = d["dialog"] - extraction = d.get("extraction") - if extraction is None: - extraction = await self._extract_dialog_important(dd.content) - # 合并所有重要标记 - tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords - msgs = dd.context.msgs - # 分类消息 - imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)] - unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs] - # 重要消息按重要性排序 - imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))] - info = { - "dialog": dd, - "total_msgs": len(msgs), - "unrelated_count": len(msgs), - "imp_ids_sorted": imp_sorted_ids, - "unimp_ids": [id(m) for m in unimp_unrel_msgs], - } - per_dialog_info[d["index"]] = info - total_unrelated += info["unrelated_count"] - # 全局删除配额:比例作用于全部不相关消息(重要+不重要) - global_delete = int(total_unrelated * proportion) - if proportion > 0 and total_unrelated > 0 and global_delete == 0: - global_delete = 1 - # 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息 - capacities = [] - for d in not_related_dialogs: - idx = d["index"] - info = per_dialog_info[idx] - # 统计重要数量 - imp_count = len(info["imp_ids_sorted"]) - unimp_count = len(info["unimp_ids"]) - imp_cap = int(imp_count * proportion) - cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1)) - capacities.append(cap) - total_capacity = sum(capacities) - if global_delete > total_capacity: - print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。") - global_delete = total_capacity - - # 配额分配:按不相关消息占比分配到各对话,但不超过各自容量 - alloc = [] - for i, d in enumerate(not_related_dialogs): - idx = d["index"] - info = per_dialog_info[idx] - share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0 - alloc.append(min(share, capacities[i])) - allocated = sum(alloc) - rem = global_delete - allocated - turn = 0 - while rem > 0 and turn < 100000: - progressed = False - for i in range(len(not_related_dialogs)): - if rem <= 0: - break - if alloc[i] < capacities[i]: - alloc[i] += 1 - rem -= 1 - progressed = True - if not progressed: - break - turn += 1 - - # 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先) - total_deleted_confirm = 0 - for d in evaluated_dialogs: - dd = d["dialog"] - msgs = dd.context.msgs - original = len(msgs) - if d["is_related"]: - result.append(dd) - continue - idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None) - if idx_in_unrel is None: - result.append(dd) - continue - quota = alloc[idx_in_unrel] - info = per_dialog_info[d["index"]] - # 计算本对话重要最多可删数量 - imp_count = len(info["imp_ids_sorted"]) - imp_del_cap = int(imp_count * proportion) - # 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条) - unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))]) - del_unimp = min(quota, len(unimp_delete_ids)) - rem_quota = quota - del_unimp - # 再从重要里选低分优先的删除ID(不超过 imp_del_cap) - imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)]) - deleted_here = 0 - actual_unimp_deleted = 0 - actual_imp_deleted = 0 - kept = [] - for m in msgs: - mid = id(m) - if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp: - actual_unimp_deleted += 1 - deleted_here += 1 - continue - if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids): - actual_imp_deleted += 1 - deleted_here += 1 - continue - kept.append(m) - if not kept and msgs: - kept = [msgs[0]] - dd.context.msgs = kept - total_deleted_confirm += deleted_here - self._log( - f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}" - ) - result.append(dd) - self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。") - else: - # 全部相关:不执行剪枝 - result = [d["dialog"] for d in evaluated_dialogs] - self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") - - # 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成) - try: - from app.core.config import settings - settings.ensure_memory_output_dir() - log_output_path = settings.get_memory_output_path("pruned_terminal.json") - # 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存 - sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs] - payload = self._parse_logs_to_structured(sanitized_logs) - with open(log_output_path, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - except Exception as e: - self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}") - - # Safety: avoid empty dataset - if not result: - print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") - return dialogs - return result - - def _log(self, msg: str) -> None: - """记录日志并打印到终端。""" - try: - self.run_logs.append(msg) - except Exception: - # 任何异常都不影响打印 - pass - print(msg) - - def _sanitize_log_line(self, line: str) -> str: - """移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。""" - try: - return re.sub(r"^\[[^\]]+\]\s*", "", line) - except Exception: - return line - - def _parse_logs_to_structured(self, logs: List[str]) -> dict: - """将已去前缀的日志列表解析为结构化 JSON,便于数据对接。""" - summary = { - "scene": self.config.pruning_scene, - "dialog_total": None, - "deletion_ratio": None, - "enabled": None, - "related_count": None, - "unrelated_count": None, - "related_indices": [], - "unrelated_indices": [], - "total_deleted_messages": None, - "remaining_dialogs": None, - } - dialogs = [] - - # 解析函数 - def parse_int(value: str) -> Optional[int]: - try: - return int(value) - except Exception: - return None - - def parse_float(value: str) -> Optional[float]: - try: - return float(value) - except Exception: - return None - - def parse_indices(s: str) -> List[int]: - s = s.strip() - if not s: - return [] - parts = [p.strip() for p in s.split(",") if p.strip()] - out: List[int] = [] - for p in parts: - try: - out.append(int(p)) - except Exception: - pass - return out - - # 正则 - re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)") - re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)") - re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段") - re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)") - re_total_del = re.compile(r"总删除\s+(\d+)\s+条") - re_remaining = re.compile(r"剩余对话数=(\d+)") - - for line in logs: - # 第一行:总览 - m = re_header.search(line) - if m: - summary["dialog_total"] = parse_int(m.group(1)) - # 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2) - summary["deletion_ratio"] = parse_float(m.group(3)) - summary["enabled"] = True if m.group(4) == "True" else False - continue - - # 第二行:相关/不相关数量 - m = re_counts.search(line) - if m: - summary["related_count"] = parse_int(m.group(1)) - summary["unrelated_count"] = parse_int(m.group(2)) - continue - - # 第三行:相关/不相关索引 - m = re_indices.search(line) - if m: - summary["related_indices"] = parse_indices(m.group(1)) - summary["unrelated_indices"] = parse_indices(m.group(2)) - continue - - # 对话级统计 - m = re_dialog.search(line) - if m: - dialogs.append({ - "index": parse_int(m.group(1)), - "total_messages": parse_int(m.group(2)), - "quota_delete": parse_int(m.group(3)), - "actual_deleted": parse_int(m.group(4)), - "kept": parse_int(m.group(5)), - }) - continue - - # 全局删除总数 - m = re_total_del.search(line) - if m: - summary["total_deleted_messages"] = parse_int(m.group(1)) - continue - - # 剩余对话数 - m = re_remaining.search(line) - if m: - summary["remaining_dialogs"] = parse_int(m.group(1)) - continue - - return { - "scene": summary["scene"], - "timestamp": datetime.now().isoformat(), - "summary": {k: v for k, v in summary.items() if k != "scene"}, - "dialogs": dialogs, - } diff --git a/app/core/memory/storage_services/extraction_engine/deduplication/__init__.py b/app/core/memory/storage_services/extraction_engine/deduplication/__init__.py deleted file mode 100644 index 9257bcce..00000000 --- a/app/core/memory/storage_services/extraction_engine/deduplication/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -去重消歧模块 - -提供实体去重和消歧功能,包括: -- 基础去重和消歧(精确匹配、模糊匹配) -- LLM 实体去重 -- 第二层去重(与 Neo4j 数据库联合去重) -- 两阶段去重(完整的去重流程) -""" - -from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( - deduplicate_entities_and_edges, - accurate_match, - fuzzy_match, - LLM_decision, - LLM_disamb_decision, -) -from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import ( - llm_dedup_entities, - llm_dedup_entities_iterative_blocks, - llm_disambiguate_pairs_iterative, -) -from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import ( - second_layer_dedup_and_merge_with_neo4j, -) -from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( - dedup_layers_and_merge_and_return, -) - -__all__ = [ - "deduplicate_entities_and_edges", - "accurate_match", - "fuzzy_match", - "LLM_decision", - "LLM_disamb_decision", - "llm_dedup_entities", - "llm_dedup_entities_iterative_blocks", - "llm_disambiguate_pairs_iterative", - "second_layer_dedup_and_merge_with_neo4j", - "dedup_layers_and_merge_and_return", -] diff --git a/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py deleted file mode 100644 index 8af9042f..00000000 --- a/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ /dev/null @@ -1,784 +0,0 @@ -""" -去重功能函数 -""" -from app.core.memory.models.variate_config import DedupConfig -from typing import List, Dict, Tuple -from app.core.memory.models.graph_models import( - StatementEntityEdge, - EntityEntityEdge, - ExtractedEntityNode -) -import os -from datetime import datetime -import difflib # 提供字符串相似度计算工具 -import asyncio -import importlib -import re -# 模块级属性融合工具函数(统一行为) -def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode): - # 强弱连接合并 - can_strength = (getattr(canonical, "connect_strength", "") or "").lower() - inc_strength = (getattr(ent, "connect_strength", "") or "").lower() - pair = {can_strength, inc_strength} - {""} - if pair: - if "both" in pair or pair == {"strong", "weak"}: - canonical.connect_strength = "both" - elif pair == {"strong"}: - canonical.connect_strength = "strong" - elif pair == {"weak"}: - canonical.connect_strength = "weak" - else: - canonical.connect_strength = next(iter(pair)) - - # 别名合并(去重保序) - try: - existing = getattr(canonical, "aliases", []) or [] - incoming = getattr(ent, "aliases", []) or [] - seen = set() - merged_list: List[str] = [] - for x in existing + incoming: - xn = (x or "").strip() - if xn and xn not in seen: - seen.add(xn) - merged_list.append(x) - canonical.aliases = merged_list - except Exception: - pass - - # 描述与事实摘要(保留更长者) - try: - desc_a = getattr(canonical, "description", "") or "" - desc_b = getattr(ent, "description", "") or "" - if len(desc_b) > len(desc_a): - canonical.description = desc_b - # 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序 - fact_a = getattr(canonical, "fact_summary", "") or "" - fact_b = getattr(ent, "fact_summary", "") or "" - def _extract_sources(txt: str) -> List[str]: - sources: List[str] = [] - if not txt: - return sources - for line in str(txt).splitlines(): - ln = line.strip() - # 支持“来源:”或“来源:”前缀 - m = re.match(r"^来源[::]\s*(.+)$", ln) - if m: - content = m.group(1).strip() - if content: - sources.append(content) - # 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失 - if not sources and txt.strip(): - sources.append(txt.strip()) - return sources - try: - src_a = _extract_sources(fact_a) - src_b = _extract_sources(fact_b) - seen = set() - merged_sources: List[str] = [] - for s in src_a + src_b: - if s and s not in seen: - seen.add(s) - merged_sources.append(s) - if merged_sources: - name_line = f"实体: {getattr(canonical, 'name', '')}".strip() - canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources]) - elif fact_b and not fact_a: - canonical.fact_summary = fact_b - except Exception: - # 兜底:若解析失败,保留较长文本 - if len(fact_b) > len(fact_a): - canonical.fact_summary = fact_b - except Exception: - pass - - # 名称向量补全 - try: - emb_a = getattr(canonical, "name_embedding", []) or [] - emb_b = getattr(ent, "name_embedding", []) or [] - if not emb_a and emb_b: - canonical.name_embedding = emb_b - except Exception: - pass - - # 时间范围合并 - try: - # 统一使用 created_at / expired_at - if getattr(ent, "created_at", None) and getattr(canonical, "created_at", None) and ent.created_at < canonical.created_at: - canonical.created_at = ent.created_at - if getattr(ent, "expired_at", None) and getattr(canonical, "expired_at", None): - if canonical.expired_at is None: - canonical.expired_at = ent.expired_at - elif ent.expired_at and ent.expired_at > canonical.expired_at: - canonical.expired_at = ent.expired_at - except Exception: - pass - -def accurate_match( - entity_nodes: List[ExtractedEntityNode] -) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: - """ - 精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。 - 返回: (deduped_entities, id_redirect, exact_merge_map) - """ - exact_merge_map: Dict[str, Dict] = {} - canonical_map: Dict[str, ExtractedEntityNode] = {} - id_redirect: Dict[str, str] = {} - - # 1) 构建规范实体映射(按名称+类型+group 精确匹配) - for ent in entity_nodes: - name_norm = (getattr(ent, "name", "") or "").strip() - type_norm = (getattr(ent, "entity_type", "") or "").strip() - key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}" - # 为避免跨业务组误并,明确以 group_id 为范围边界 - if key not in canonical_map: - canonical_map[key] = ent - id_redirect[getattr(ent, "id")] = getattr(ent, "id") - continue - canonical = canonical_map[key] - - # 执行精确属性与强弱合并,并建立重定向 - _merge_attribute(canonical, ent) - id_redirect[getattr(ent, "id")] = getattr(canonical, "id") - # 记录精确匹配的合并项(使用规范化键,避免外层变量误用) - try: - k = f"{getattr(canonical, 'group_id')}|{(getattr(canonical, 'name') or '').strip()}|{(getattr(canonical, 'entity_type') or '').strip()}" - if k not in exact_merge_map: - exact_merge_map[k] = { - "canonical_id": getattr(canonical, "id"), - "group_id": getattr(canonical, "group_id"), - "name": getattr(canonical, "name"), - "entity_type": getattr(canonical, "entity_type"), - "merged_ids": set(), - } - exact_merge_map[k]["merged_ids"].add(getattr(ent, "id")) - except Exception: - pass - - deduped_entities = list(canonical_map.values()) - return deduped_entities, id_redirect, exact_merge_map - -def fuzzy_match( - deduped_entities: List[ExtractedEntityNode], - statement_entity_edges: List[StatementEntityEdge], - id_redirect: Dict[str, str], - config: DedupConfig | None = None, -) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]: - """ - 模糊匹配:在精确匹配之后,基于名称/类型相似度与上下文共现,进一步融合高相似实体。 - 返回: (updated_entities, updated_redirect, fuzzy_merge_records) - """ - fuzzy_merge_records: List[str] = [] - - def _normalize_text(s: str) -> str: - try: - return re.sub(r"\s+", " ", re.sub(r"[^\w\u4e00-\u9fff]+", " ", (s or "").lower())).strip() - except Exception: - return str(s).lower().strip() - - def _tokenize(s: str) -> List[str]: - norm = _normalize_text(s) - tokens = re.findall(r"[\u4e00-\u9fff]+|[a-z0-9]+", norm) - return tokens - - def _jaccard(a_tokens: List[str], b_tokens: List[str]) -> float: - try: - set_a, set_b = set(a_tokens), set(b_tokens) - if not set_a and not set_b: - return 0.0 - inter = len(set_a & set_b) - union = len(set_a | set_b) - return inter / union if union > 0 else 0.0 - except Exception: - return 0.0 - - def _cosine(a: List[float], b: List[float]) -> float: - try: - if not a or not b or len(a) != len(b): - return 0.0 - dot = sum(x * y for x, y in zip(a, b)) - na = sum(x * x for x in a) ** 0.5 - nb = sum(y * y for y in b) ** 0.5 - if na == 0 or nb == 0: - return 0.0 - return dot / (na * nb) - except Exception: - return 0.0 - - def _name_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode): - emb_sim = _cosine(getattr(e1, "name_embedding", []) or [], getattr(e2, "name_embedding", []) or []) - tokens1 = set(_tokenize(getattr(e1, "name", "") or "")) - tokens2 = set(_tokenize(getattr(e2, "name", "") or "")) - aliases1 = getattr(e1, "aliases", []) or [] - aliases2 = getattr(e2, "aliases", []) or [] - alias_tokens1 = set(tokens1) - alias_tokens2 = set(tokens2) - for a in aliases1: - alias_tokens1 |= set(_tokenize(a)) - for a in aliases2: - alias_tokens2 |= set(_tokenize(a)) - j_primary = _jaccard(list(tokens1), list(tokens2)) - j_alias = _jaccard(list(alias_tokens1), list(alias_tokens2)) - s_name = 0.6 * emb_sim + 0.2 * j_primary + 0.2 * j_alias - return s_name, emb_sim, j_primary, j_alias - - def _desc_similarity(e1: ExtractedEntityNode, e2: ExtractedEntityNode): - """ - 计算实体描述的相似度(Jaccard + SequenceMatcher) - 返回: (相似度得分, Jaccard 相似度(词重合), SequenceMatcher 相似度(序列相似)) - """ - d1 = getattr(e1, "description", "") or "" - d2 = getattr(e2, "description", "") or "" - if not d1 and not d2: - return 0.0, 0.0, 0.0 - t1 = _tokenize(d1) - t2 = _tokenize(d2) - j = _jaccard(t1, t2) - try: - seq = difflib.SequenceMatcher(None, _normalize_text(d1), _normalize_text(d2)).ratio() - except Exception: - seq = 0.0 - # 平衡词重合与序列相似(更鲁棒) - s_desc = 0.5 * j + 0.5 * seq - return s_desc, j, seq - - def _canonicalize_type(t: str) -> str: # 扩展类型同义归一 - t = (t or "").strip() - if not t: - return "" - t_up = t.upper() - TYPE_ALIASES = { - "PERSON": {"人物", "人", "个人", "人名", "PERSON", "PEOPLE", "INDIVIDUAL"}, - "ORG": {"组织", "ORG"}, - "COMPANY": {"公司", "企业", "COMPANY"}, - "INSTITUTION": {"机构", "INSTITUTION"}, - "LOCATION": {"地点", "位置", "LOCATION"}, - "CITY": {"城市", "CITY"}, - "COUNTRY": {"国家", "COUNTRY"}, - "EVENT": {"事件", "EVENT"}, - # 扩展活动与技能近义,统一到 ACTIVITY,便于本地模糊匹配 - "ACTIVITY": {"活动", "技术活动", "技能", "ACTIVITY", "SKILL"}, - "PRODUCT": {"产品", "商品", "物品", "OBJECT", "PRODUCT"}, - "TOOL": {"工具", "TOOL"}, - "SOFTWARE": {"软件", "SOFTWARE"}, - "FOOD": {"食品", "食物", "FOOD"}, - "INGREDIENT": {"食材", "配料", "原料", "INGREDIENT"}, - "SWEETMEATS": {"甜点", "甜品", "甜食", "SWEETMEATS"}, - # 统一本地与 LLM 阶段:将 EQUIPMENT/装备 映射为 APPLIANCE - "APPLIANCE": {"设备", "器材", "摄影器材", "摄影设备", "电器", "烤箱", "装备","镜头", "EQUIPMENT", "APPLIANCE"}, - "ART": {"艺术", "艺术形式", "ART"}, - "FLOWER": {"花卉", "鲜花", "FLOWER"}, - "PLANT": {"植物", "PLANT"}, - "AGENT": {"AI助手", "助手", "人工智能助手", "智能助手", "智能体", "Agent", "AGENTA"}, - "ROLE": {"角色", "ROLE"}, - "SCENE_ELEMENT": {"场景元素", "SCENE_ELEMENT"}, - "UNKNOWN": {"UNKNOWN", "未知", "不明"}, - } - for canon, aliases in TYPE_ALIASES.items(): - if t_up in {a.upper() for a in aliases}: - return canon - return t_up - - def _type_similarity(t1: str, t2: str) -> float: - import difflib - c1 = _canonicalize_type(t1) - c2 = _canonicalize_type(t2) - if not c1 or not c2: - return 0.0 - if c1 == c2: - return 0.5 if c1 == "UNKNOWN" else 1.0 - if c1 == "UNKNOWN" or c2 == "UNKNOWN": - return 0.5 - sim_table = { - ("ORG", "COMPANY"): 0.9, ("COMPANY", "ORG"): 0.9, - ("ORG", "INSTITUTION"): 0.85, ("INSTITUTION", "ORG"): 0.85, - ("LOCATION", "CITY"): 0.9, ("CITY", "LOCATION"): 0.9, - ("LOCATION", "COUNTRY"): 0.9, ("COUNTRY", "LOCATION"): 0.9, - ("EVENT", "ACTIVITY"): 0.8, ("ACTIVITY", "EVENT"): 0.8, - ("PRODUCT", "TOOL"): 0.8, ("TOOL", "PRODUCT"): 0.8, - ("PRODUCT", "SOFTWARE"): 0.8, ("SOFTWARE", "PRODUCT"): 0.8, - ("FOOD", "SWEETMEATS"): 0.8, ("SWEETMEATS", "FOOD"): 0.8, - ("INGREDIENT", "FOOD"): 0.85, ("FOOD", "INGREDIENT"): 0.85, - ("APPLIANCE", "TOOL"): 0.8, ("TOOL", "APPLIANCE"): 0.8, - ("APPLIANCE", "PRODUCT"): 0.7, ("PRODUCT", "APPLIANCE"): 0.7, - ("FLOWER", "PLANT"): 0.9, ("PLANT", "FLOWER"): 0.9, - ("AGENT", "SOFTWARE"): 0.85, ("SOFTWARE", "AGENT"): 0.85, - ("AGENT", "PRODUCT"): 0.7, ("PRODUCT", "AGENT"): 0.7, - ("AGENT", "ROLE"): 0.9, ("ROLE", "AGENT"): 0.9, - ("SCENE_ELEMENT", "PRODUCT"): 0.6, ("PRODUCT", "SCENE_ELEMENT"): 0.6, - } - base = sim_table.get((c1, c2), 0.0) - if base: - return base - t1n = (t1 or "").strip().lower() - t2n = (t2 or "").strip().lower() - seq_ratio = difflib.SequenceMatcher(None, t1n, t2n).ratio() - return seq_ratio * 0.6 - # 阈值与权重设定(从配置读取;若无配置则使用 DedupConfig 的默认值) - _defaults = DedupConfig() - T_NAME_STRICT = (config.fuzzy_name_threshold_strict if config is not None else _defaults.fuzzy_name_threshold_strict) - T_TYPE_STRICT = (config.fuzzy_type_threshold_strict if config is not None else _defaults.fuzzy_type_threshold_strict) - T_OVERALL = (config.fuzzy_overall_threshold if config is not None else _defaults.fuzzy_overall_threshold) - UNKNOWN_NAME_T = (config.fuzzy_unknown_type_name_threshold if config is not None else _defaults.fuzzy_unknown_type_name_threshold) - UNKNOWN_TYPE_T = (config.fuzzy_unknown_type_type_threshold if config is not None else _defaults.fuzzy_unknown_type_type_threshold) - W_NAME = (config.name_weight if config is not None else _defaults.name_weight) - W_DESC = (config.desc_weight if config is not None else _defaults.desc_weight) - W_TYPE = (config.type_weight if config is not None else _defaults.type_weight) - CTX_BONUS = (config.context_bonus if config is not None else _defaults.context_bonus) # 上下文共现加分 - FALL_FLOOR = (config.llm_fallback_floor if config is not None else _defaults.llm_fallback_floor) - FALL_CEIL = (config.llm_fallback_ceiling if config is not None else _defaults.llm_fallback_ceiling) - - - i = 0 - while i < len(deduped_entities): - a = deduped_entities[i] - j = i + 1 - while j < len(deduped_entities): - b = deduped_entities[j] - if getattr(a, "group_id", None) != getattr(b, "group_id", None): - j += 1 - continue - # 上下文共现 - try: - sources_a = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(a, "id", None)} - sources_b = {e.source for e in statement_entity_edges if getattr(e, "target", None) == getattr(b, "id", None)} - co_ctx = bool(sources_a & sources_b) - except Exception: - co_ctx = False - s_name, emb_sim, j_primary, j_alias = _name_similarity(a, b) - s_desc, j_desc, seq_desc = _desc_similarity(a, b) - s_type = _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)) - unknown_present = ( - str(getattr(a, "entity_type", "")).upper() == "UNKNOWN" - or str(getattr(b, "entity_type", "")).upper() == "UNKNOWN" - ) - tn = UNKNOWN_NAME_T if unknown_present else T_NAME_STRICT - tn = min(tn, 0.88) if co_ctx else tn - type_threshold = UNKNOWN_TYPE_T if unknown_present else T_TYPE_STRICT - tover = T_OVERALL - a_cs = (getattr(a, "connect_strength", "") or "").lower() - b_cs = (getattr(b, "connect_strength", "") or "").lower() - if a_cs in ("strong", "both") or b_cs in ("strong", "both"): - tover = 0.80 - # 综合评分:名称、描述、类型加权 + 上下文加分 - overall = W_NAME * s_name + W_DESC * s_desc + W_TYPE * s_type + (CTX_BONUS if co_ctx else 0.0) - - if s_name >= tn and s_type >= type_threshold and overall >= tover: - _merge_attribute(a, b) - try: - fuzzy_merge_records.append( - f"[模糊] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}" - ) - except Exception: - pass - # 用于处理合并实体后,Statement节点下方无挂载边的情况 后续考虑将其代码逻辑统一由关系去重消歧管理 - # 建立 ID 重定向:将合并实体 b 的 ID 指向规范实体 a 的 ID - try: - canonical_id = id_redirect.get(getattr(a, "id", None), getattr(a, "id", None)) - losing_id = getattr(b, "id", None) - if losing_id and canonical_id: - id_redirect[losing_id] = canonical_id - # 扁平化可能的重定向链:凡是映射到 b.id 的,统一指向 a.id - for k, v in list(id_redirect.items()): - if v == losing_id: - id_redirect[k] = canonical_id - except Exception: - pass - deduped_entities.pop(j) - continue - else: - try: - if s_name >= tn and s_type >= type_threshold and (FALL_FLOOR <= overall < tover) and (overall <= FALL_CEIL): - fuzzy_merge_records.append( - f"[边界] {a.id}<->{b.id} ({a.group_id}|{a.name}|{a.entity_type} ~ {b.group_id}|{b.name}|{b.entity_type}) | s_name={s_name:.3f}, s_desc={s_desc:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, ctx={co_ctx}" - ) - except Exception: - pass - j += 1 - i += 1 - - return deduped_entities, id_redirect, fuzzy_merge_records - -async def LLM_decision( # 决策中包含去重和消歧的功能 - deduped_entities: List[ExtractedEntityNode], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - id_redirect: Dict[str, str], - config: DedupConfig | None = None, -) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]: - """ - 基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。 - 返回 (updated_entities, updated_redirect, llm_records)。 - - 仅在配置 enable_llm_dedup_blockwise 为 True 时启用; - 若未提供配置,则使用 DedupConfig 的默认值作为回退。 - - 内部调用 llm_dedup_entities_iterative_blocks 获取 pairwise 的重定向映射。 - - 将映射应用到 deduped_entities 与 id_redirect,并记录融合日志。 - """ - llm_records: List[str] = [] - try: - # 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量 - enable_switch = ( - bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise - ) - if not enable_switch: - return deduped_entities, id_redirect, llm_records - # 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值 - _defaults = DedupConfig() - block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size) - block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency) - pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency) - max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds) - - # 动态导入 llm 客户端(统一从 app.core.memory.utils.llm_utils 获取) - try: - llm_utils_mod = importlib.import_module("app.core.memory.utils.llm_utils") - get_llm_client_fn = getattr(llm_utils_mod, "get_llm_client") - except Exception: - get_llm_client_fn = lambda: None - - try: - llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm") - llm_fn = getattr(llm_mod, "llm_dedup_entities_iterative_blocks") - except Exception: - raise RuntimeError("LLM 模块加载失败:deduplication.entity_dedup_llm 缺少 llm_dedup_entities_iterative_blocks") - - # 获取 LLM 客户端,若环境未配置或抛错则回退为 None - try: - llm_client = get_llm_client_fn() - except Exception: - llm_client = None - - llm_redirect, llm_records = await llm_fn( - entity_nodes=deduped_entities, - statement_entity_edges=statement_entity_edges, - entity_entity_edges=entity_entity_edges, - llm_client=llm_client, - block_size=block_size, - block_concurrency=block_concurrency, - pair_concurrency=pair_concurrency, - max_rounds=max_rounds, - ) - except Exception as e: - # 记录错误,不中断主流程 - llm_records.append(f"[LLM错误] 迭代分块执行失败: {e}") - return deduped_entities, id_redirect, llm_records - - # 若存在 LLM 的重定向,应用到实体与映射 - # 确保实体集合与 id_redirect 完整反映 LLM 的合并结果;否则后续边重定向不会指向规范 ID,实体仍然重复 - if llm_redirect: - entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities} - for losing_id, canonical_id in list(llm_redirect.items()): - if losing_id == canonical_id: - continue - a = entity_by_id.get(canonical_id) - b = entity_by_id.get(losing_id) - if not a or not b: # 若不存在 a 或 b,可能已在精确或模糊阶段合并,在之前阶段合并之后,不会再处理但是处于审计的目的会记录 - continue - _merge_attribute(a, b) - # ID 重定向 - try: - id_redirect[b.id] = a.id - for k, v in list(id_redirect.items()): - if v == b.id: - id_redirect[k] = a.id - except Exception: - pass - # 记录 LLM 融合日志 - try: - llm_records.append( - f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" - ) - # 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason - except Exception: - pass - # 移除 losing 实体 - try: - if b in deduped_entities: - deduped_entities.remove(b) - entity_by_id.pop(b.id, None) - except Exception: - pass - - return deduped_entities, id_redirect, llm_records - -async def LLM_disamb_decision( - deduped_entities: List[ExtractedEntityNode], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - id_redirect: Dict[str, str], - config: DedupConfig | None = None, -) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]: - """ - 预消歧阶段:对“同名但类型不同”的实体对调用LLM进行消歧, - 产出:需阻断的实体对(blocked_pairs)与必要的合并(merge_redirect)。 - 返回 (updated_entities, updated_redirect, blocked_pairs, disamb_records)。 - - 仅在配置开关 enable_llm_disambiguation 为 True 时启用;否则返回空阻断列表。 - """ - disamb_records: List[str] = [] - blocked_pairs: set[tuple[str, str]] = set() - try: - enable_switch = ( - config.enable_llm_disambiguation - if config is not None - else DedupConfig().enable_llm_disambiguation - ) - if not bool(enable_switch): - return deduped_entities, id_redirect, blocked_pairs, disamb_records - - from app.core.memory.utils.llm.llm_utils import get_llm_client - from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative - from app.core.memory.utils.config import definitions as config_defs - llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) - merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative( - entity_nodes=deduped_entities, - statement_entity_edges=statement_entity_edges, - entity_entity_edges=entity_entity_edges, - llm_client=llm_client, - ) - - # 应用LLM消歧的合并建议 - if merge_redirect: - entity_by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in deduped_entities} - for losing_id, canonical_id in list(merge_redirect.items()): - if losing_id == canonical_id: - continue - a = entity_by_id.get(canonical_id) - b = entity_by_id.get(losing_id) - if not a or not b: - continue - _merge_attribute(a, b) - id_redirect[b.id] = a.id - for k, v in list(id_redirect.items()): - if v == b.id: - id_redirect[k] = a.id - try: - disamb_records.append( - f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" - ) - except Exception: - pass - try: - if b in deduped_entities: - deduped_entities.remove(b) - entity_by_id.pop(b.id, None) - except Exception: - pass - # 保存阻断对 - try: - blocked_pairs = {tuple(sorted(p)) for p in (block_list or [])} - except Exception: - blocked_pairs = set() - except Exception as e: - disamb_records.append(f"[DISAMB错误] 消歧执行失败: {e}") - return deduped_entities, id_redirect, blocked_pairs, disamb_records - - return deduped_entities, id_redirect, blocked_pairs, disamb_records - -async def deduplicate_entities_and_edges( - entity_nodes: List[ExtractedEntityNode], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - report_stage: str = "第一层去重消歧", - report_append: bool = False, - report_stage_notes: List[str] | None = None, - dedup_config: DedupConfig | None = None, -) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: - """ - 主流程:依次执行精确匹配、模糊匹配与(可选)LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧 - 返回:去重后的实体、语句→实体边、实体↔实体边。 - """ - local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化,保留为了之后对于LLM决策追溯 - # 1) 精确匹配 - deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes) - - # 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并 - deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision( - deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config - ) - - # 2) 模糊匹配(本地规则) - deduped_entities, id_redirect, fuzzy_merge_records = fuzzy_match( - deduped_entities, statement_entity_edges, id_redirect, config=dedup_config - ) - - # 3) LLM 决策(仅按配置开关) - try: - enable_switch = ( - dedup_config.enable_llm_dedup_blockwise - if dedup_config is not None - else DedupConfig().enable_llm_dedup_blockwise - ) - should_trigger_llm = bool(enable_switch) - # 将触发信息写入阶段备注,便于输出报告审计 - if report_stage_notes is None: - report_stage_notes = [] - report_stage_notes.append(f"LLM触发: {'是' if should_trigger_llm else '否'}") - except Exception: - should_trigger_llm = False - - if should_trigger_llm: - deduped_entities, id_redirect, llm_decision_records = await LLM_decision( - deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config - ) - else: - llm_decision_records = [] - # 累加 LLM 记录 把 LLM_decision 返回的日志 llm_decision_records 追加到 local_llm_records - try: - local_llm_records.extend(llm_decision_records or []) - except Exception: - pass - - -# 在主流程这里 这里是之后关系去重和消歧的地方,方法可以写在其他地方 -# 此处统一对边进行处理,使用累积的 id_redirect 把边的 source/target 改成规范ID - # 4) 边重定向与去重 - # 4.1 语句→实体边:重复时优先保留 strong - stmt_ent_map: Dict[str, StatementEntityEdge] = {} - for edge in statement_entity_edges: - new_target = id_redirect.get(edge.target, edge.target) - edge.target = new_target - key = f"{edge.source}_{edge.target}" - if key not in stmt_ent_map: - stmt_ent_map[key] = edge - else: - existing = stmt_ent_map[key] - old_strength = getattr(existing, "connect_strength", "") - new_strength = getattr(edge, "connect_strength", "") - if old_strength != "strong" and new_strength == "strong": - stmt_ent_map[key] = edge - - # 4.2 实体↔实体边:按 source_target 去重(无强弱属性) - ent_ent_map: Dict[str, EntityEntityEdge] = {} - for edge in entity_entity_edges: - new_source = id_redirect.get(edge.source, edge.source) - new_target = id_redirect.get(edge.target, edge.target) - edge.source = new_source - edge.target = new_target - key = f"{edge.source}_{edge.target}" - if key not in ent_ent_map: - ent_ent_map[key] = edge - - - _write_dedup_fusion_report( - exact_merge_map=exact_merge_map, - fuzzy_merge_records=fuzzy_merge_records, - local_llm_records=local_llm_records, - disamb_records=disamb_records, - stage_label=report_stage, - append=report_append, - stage_notes=report_stage_notes, - ) - - return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()) - -# 独立模块:去重融合报告写入(与实体/边的计算解耦) -def _write_dedup_fusion_report( - exact_merge_map: Dict[str, Dict], - fuzzy_merge_records: List[str], - local_llm_records: List[str], - disamb_records: List[str] | None = None, - stage_label: str | None = None, - append: bool = False, - stage_notes: List[str] | None = None, -): - try: - # 使用全局配置的输出路径 - from app.core.config import settings - settings.ensure_memory_output_dir() - out_path = settings.get_memory_output_path("dedup_entity_output.txt") - report_lines: List[str] = [] - if not append: - report_lines.append(f"去重融合报告 - {datetime.now().isoformat()}") - report_lines.append("") - if stage_label: - # 追加写入时,在阶段标题前增加一个空行以增强分隔 - if append: - report_lines.append("") - report_lines.append(f"=== {stage_label} ===") - report_lines.append("") - # 阶段注释:在标题下追加,如候选数、是否跳过等 - if stage_notes: - for note in stage_notes: - try: - report_lines.append(str(note)) - except Exception: - pass - report_lines.append("") - # 精确 - report_lines.append("精确匹配去重:") - aggregated_exact_lines: List[str] = [] - try: - for k, info in (exact_merge_map or {}).items(): - merged_ids = sorted(list(info.get("merged_ids", set()))) - if merged_ids: - aggregated_exact_lines.append( - f"[精确] 键 {k} 规范实体 {info.get('canonical_id')} 名称 '{info.get('name')}' 类型 {info.get('entity_type')} <- 合并实体IDs {', '.join(merged_ids)}" - ) - except Exception: - pass - report_lines.extend(aggregated_exact_lines if aggregated_exact_lines else ["无合并项"]) - report_lines.append("") - # 消歧 - report_lines.append("LLM 决策消歧:") - try: - # 仅展示阻断项,过滤掉合并与合并应用 - disamb_block_only = [ - line for line in (disamb_records or []) - if str(line).startswith("[DISAMB阻断]") or str(line).startswith("[DISAMB异常阻断]") - ] - except Exception: - disamb_block_only = disamb_records or [] - report_lines.extend(disamb_block_only if disamb_block_only else ["未执行或无阻断/合并项"]) - report_lines.append("") - # 模糊 - report_lines.append("模糊匹配去重:") - report_lines.extend(fuzzy_merge_records if fuzzy_merge_records else ["未执行或无合并项"]) - report_lines.append("") - # LLM - report_lines.append("LLM 决策去重:") - try: - # 仅保留 LLM 的“去重判定”记录,排除“合并指令/融合落地” - def _is_llm_dedup_record(s: str) -> bool: - try: - text = str(s) - return "[LLM去重]" in text - except Exception: - return False - - llm_dedup_only = [ - line for line in (local_llm_records or []) - if _is_llm_dedup_record(str(line)) - ] - # 同名类型相似的 LLM 去重记录可能来源于消歧阶段,将其也纳入展示 - try: - llm_dedup_only.extend([ - line for line in (disamb_records or []) - if _is_llm_dedup_record(str(line)) - ]) - except Exception: - pass - except Exception: - llm_dedup_only = [] - # 输出前移除块前缀(如 "[LLM块0] "),并对重复记录去重(保序) - try: - import re as _re - def _strip_block_prefix(s: str) -> str: - try: - return _re.sub(r"^\[LLM块\d+\]\s*", "", str(s)) - except Exception: - return str(s) - stripped = [ _strip_block_prefix(line) for line in (llm_dedup_only or []) ] - seen = set() - deduped_ordered = [] - for line in stripped: - if line not in seen: - seen.add(line) - deduped_ordered.append(line) - llm_dedup_only = deduped_ordered - except Exception: - pass - report_lines.extend(llm_dedup_only if llm_dedup_only else ["未执行或无合并项"]) - with open(out_path, ("a" if append else "w"), encoding="utf-8") as f: - f.write("\n".join(report_lines) + "\n") - except Exception: - # 静默失败,避免影响主流程 - pass diff --git a/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py deleted file mode 100644 index 01799941..00000000 --- a/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ /dev/null @@ -1,689 +0,0 @@ -""" -用于实体去重,基于LLM的决策 -提供“LLM判定逻辑”的核心实现与并发控制。 -""" - -import asyncio -import difflib -from typing import List, Tuple, Dict -import anyio - -from app.core.memory.llm_tools.openai_client import OpenAIClient -from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge -from app.core.memory.models.dedup_models import EntityDedupDecision, EntityDisambDecision -from app.core.memory.utils.prompt.prompt_utils import render_entity_dedup_prompt - - -# --- 类型同义归并与相似度 --- -_TYPE_ALIASES_UPPER: Dict[str, set[str]] = { - # 设备/器材类近义:统一到 EQUIPMENT - "EQUIPMENT": {s.upper() for s in {"设备", "器材", "摄影器材", "装备", "工具", "APPLIANCE", "TOOL"}}, - # 活动/技能近义:统一到 ACTIVITY,放宽“技术活动/技能”的同类判断 - "ACTIVITY": {s.upper() for s in {"活动", "技术活动", "技能", "ACTIVITY", "SKILL"}}, - # 常见类别,按需扩展 - "PERSON": {s.upper() for s in {"人物", "人", "个人", "人名", "PERSON"}}, - "LOCATION": {s.upper() for s in {"地点", "位置", "LOCATION", "城市", "CITY", "国家", "COUNTRY"}}, - "SOFTWARE": {s.upper() for s in {"软件", "SOFTWARE"}}, - "EVENT": {s.upper() for s in {"事件", "EVENT"}}, -} - -def _canonicalize_type(t: str | None) -> str: - u = (str(t or "").strip().upper()) - if not u or u == "UNKNOWN": - return "UNKNOWN" - for canon, aliases in _TYPE_ALIASES_UPPER.items(): - if u in aliases: - return canon - return u # 未知类型直接返回自身(保守兼容) - -def _type_similarity(t1: str | None, t2: str | None) -> float: - c1, c2 = _canonicalize_type(t1), _canonicalize_type(t2) - if c1 == c2: - return 1.0 - if c1 == "UNKNOWN" or c2 == "UNKNOWN": - return 0.6 # 任一未知,给中等相似度,允许模型结合描述判断 - return 0.0 - -def _simple_type_ok(t1: str | None, t2: str | None) -> bool: - """类型门控: - - 允许同类(含近义归并后同类)或任一 UNKNOWN/空; - - 其余不同类不放行(例如 PERSON vs EQUIPMENT)。 - """ - c1, c2 = _canonicalize_type(t1), _canonicalize_type(t2) - if c1 == "UNKNOWN" or c2 == "UNKNOWN": - return True - return c1 == c2 - - -def _name_embed_sim(a: List[float] | None, b: List[float] | None) -> float: # 计算实体名称嵌入向量的余弦相似度 - a = a or [] - b = b or [] - if not a or not b or len(a) != len(b): - return 0.0 - try: - dot = sum(x * y for x, y in zip(a, b)) - na = (sum(x * x for x in a)) ** 0.5 - nb = (sum(y * y for y in b)) ** 0.5 - if na > 0 and nb > 0: - return dot / (na * nb) - except Exception: - pass - return 0.0 - - -def _name_text_sim(name1: str, name2: str) -> float: # 计算实体名称文本的字符串相似度 - name1 = (name1 or "").strip().lower() - name2 = (name2 or "").strip().lower() - if not name1 or not name2: - return 0.0 - return difflib.SequenceMatcher(None, name1, name2).ratio() - - -def _co_occurrence(statement_edges: List[StatementEntityEdge], a_id: str, b_id: str) -> bool: # 判断两个实体是否在同一陈述中 “同现” - try: - sources_a = {e.source for e in statement_edges if getattr(e, "target", None) == a_id} - sources_b = {e.source for e in statement_edges if getattr(e, "target", None) == b_id} - return bool(sources_a & sources_b) - except Exception: - return False - - -def _relation_statements(entity_edges: List[EntityEntityEdge], a_id: str, b_id: str) -> List[str]: # 提取两个实体间的所有关联语句 - stmts: List[str] = [] - for e in entity_edges: - if (getattr(e, "source", None) == a_id and getattr(e, "target", None) == b_id) or ( - getattr(e, "source", None) == b_id and getattr(e, "target", None) == a_id - ): - s_text = getattr(e, "statement", None) or "" - r_type = getattr(e, "relation_type", None) or "" - if s_text or r_type: - stmts.append(f"{r_type}: {s_text}".strip(': ')) - return stmts - - -def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: # 选择 “规范实体”(合并时保留的实体) - # 0 for a, 1 for b - # 1. 第一优先级:按“连接强度”排序(连接强度越高,实体越可靠) - cs_a = (getattr(a, "connect_strength", "") or "").lower() - cs_b = (getattr(b, "connect_strength", "") or "").lower() - prio = {"strong": 3, "both": 3, "weak": 1, "": 0} - if prio.get(cs_a, 0) != prio.get(cs_b, 0): - return 0 if prio.get(cs_a, 0) > prio.get(cs_b, 0) else 1 - # pick longer description/fact_summary - # 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整) - desc_a = (getattr(a, "description", "") or "") - desc_b = (getattr(b, "description", "") or "") - fact_a = (getattr(a, "fact_summary", "") or "") - fact_b = (getattr(b, "fact_summary", "") or "") - score_a = len(desc_a) + len(fact_a) - score_b = len(desc_b) + len(fact_b) - if score_a != score_b: - return 0 if score_a >= score_b else 1 - return 0 - -# _judge_pair(单对实体的 LLM 判断) 已经有分块迭代的函数内容是否还需要单对LLM判断--这是已经创建的工具服务于分块迭代的函数 -async def _judge_pair( - llm_client: OpenAIClient, - a: ExtractedEntityNode, - b: ExtractedEntityNode, - statement_edges: List[StatementEntityEdge], - entity_edges: List[EntityEntityEdge], -) -> Tuple[EntityDedupDecision, Dict]: -# 1. 计算实体名称的核心相似度指标 - name_text_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", "")) - name_embed_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", [])) -# 2. 判断名称是否存在“包含关系”(如“苹果公司”包含“苹果”) - name_contains = False - try: - n1 = (getattr(a, "name", "") or "").strip().lower() - n2 = (getattr(b, "name", "") or "").strip().lower() - name_contains = bool(n1 and n2 and (n1 in n2 or n2 in n1)) - except Exception: - pass -# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系 - ctx = { - "same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), - "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), - "type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), - "name_text_sim": name_text_sim, - "name_embed_sim": name_embed_sim, - "name_contains": name_contains, - "co_occurrence": _co_occurrence(statement_edges, getattr(a, "id", None), getattr(b, "id", None)), - "relation_statements": _relation_statements(entity_edges, getattr(a, "id", None), getattr(b, "id", None)), - } - - entity_a = { - "name": getattr(a, "name", None), - "entity_type": getattr(a, "entity_type", None), - "description": getattr(a, "description", None), - "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), - "connect_strength": getattr(a, "connect_strength", None), - } - entity_b = { - "name": getattr(b, "name", None), - "entity_type": getattr(b, "entity_type", None), - "description": getattr(b, "description", None), - "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), - "connect_strength": getattr(b, "connect_strength", None), - } - # 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式) - prompt = render_entity_dedup_prompt( - entity_a=entity_a, - entity_b=entity_b, - context=ctx, - json_schema=EntityDedupDecision.model_json_schema(), - ) - - messages = [ - {"role": "system", "content": "You judge whether two entities are the same. Return valid JSON only."}, - {"role": "user", "content": prompt}, - ] - - decision = await llm_client.response_structured(messages, EntityDedupDecision) - return decision, ctx - -# 消歧场景(同名不同类型)下的LLM判断 -async def _judge_pair_disamb( - llm_client: OpenAIClient, - a: ExtractedEntityNode, - b: ExtractedEntityNode, - statement_edges: List[StatementEntityEdge], - entity_edges: List[EntityEntityEdge], -) -> Tuple[EntityDisambDecision, Dict]: - name_text_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", "")) - name_embed_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", [])) - name_contains = False - try: - n1 = (getattr(a, "name", "") or "").strip().lower() - n2 = (getattr(b, "name", "") or "").strip().lower() - name_contains = bool(n1 and n2 and (n1 in n2 or n2 in n1)) - except Exception: - pass - ctx = { - "same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), - "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), - "name_text_sim": name_text_sim, - "name_embed_sim": name_embed_sim, - "name_contains": name_contains, - "co_occurrence": _co_occurrence(statement_edges, getattr(a, "id", None), getattr(b, "id", None)), - "relation_statements": _relation_statements(entity_edges, getattr(a, "id", None), getattr(b, "id", None)), - } - entity_a = { - "name": getattr(a, "name", None), - "entity_type": getattr(a, "entity_type", None), - "description": getattr(a, "description", None), - "aliases": getattr(a, "aliases", None) or [], - "fact_summary": getattr(a, "fact_summary", None), - "connect_strength": getattr(a, "connect_strength", None), - } - entity_b = { - "name": getattr(b, "name", None), - "entity_type": getattr(b, "entity_type", None), - "description": getattr(b, "description", None), - "aliases": getattr(b, "aliases", None) or [], - "fact_summary": getattr(b, "fact_summary", None), - "connect_strength": getattr(b, "connect_strength", None), - } - prompt = render_entity_dedup_prompt( - entity_a=entity_a, - entity_b=entity_b, - context=ctx, - json_schema=EntityDisambDecision.model_json_schema(), - disambiguation_mode=True, - ) - messages = [ - {"role": "system", "content": "You disambiguate same-name different-type entities. Return valid JSON only."}, - {"role": "user", "content": prompt}, - ] - decision = await llm_client.response_structured(messages, EntityDisambDecision) - return decision, ctx - -# llm_dedup_entities(单轮实体去重) -async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了保证高精度、可审计、可复用和行为一致性 - # 对偶判断让每次决策只聚焦于一对实体,信息维度清晰,噪声更低,模型更容易给出稳定的“是否同一实体”与“规范方”选择。 - # 考虑是否将其保留 - entity_nodes: List[ExtractedEntityNode], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - llm_client: OpenAIClient, - max_concurrency: int = 4, - auto_merge_threshold: float = 0.90, - co_ctx_threshold: float = 0.83, -) -> Tuple[Dict[str, str], List[str]]: - """ - Use LLM to assist fuzzy deduplication among candidate entity pairs and - produce an `id_redirect` mapping plus audit log records. - - Parameters: - - entity_nodes: deduplication input entities - - statement_entity_edges: edges from statements to entities (for co-occurrence context) - - entity_entity_edges: relational edges between entities (for relation statements) - - llm_client: configured async client used to call the model - - max_concurrency: semaphore limit for concurrent LLM calls (default 4) - - auto_merge_threshold: confidence threshold to auto-merge without co-occurrence (default 0.90) - - co_ctx_threshold: slightly lower threshold when co-occurrence is detected (default 0.83) - - Returns: - - id_redirect_updates: dict of losing_id -> canonical_id decided by LLM - - records: textual logs for decisions, errors, and non-merges - - Notes: - - Candidate generation uses simple gates: same group, type compatible, and - name similarity or containment, optionally lowered threshold with co-occurrence. - - The higher-level pipeline should call this async function upstream, then - pass the resulting mapping and records into `deduplicate_entities_and_edges` - via `llm_redirect` and `llm_records` to apply merges synchronously before - edge redirection. - """ - # 1. 构建“候选实体对”(用规则层筛选,减少LLM调用量,提高效率) - # Build candidate pairs: simple gates - candidates: List[Tuple[int, int]] = [] - for i in range(len(entity_nodes)): - a = entity_nodes[i] - for j in range(i + 1, len(entity_nodes)): - b = entity_nodes[j] - # 规则1:必须属于同一组(group_id相同,不同组的实体不重复) - if getattr(a, "group_id", None) != getattr(b, "group_id", None): - continue - # 规则2:类型必须兼容(调用_simple_type_ok判断) - if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)): - continue - # 规则3:名称相似度达标(文本/嵌入相似度取最大值) - txt_sim = _name_text_sim(getattr(a, "name", ""), getattr(b, "name", "")) - emb_sim = _name_embed_sim(getattr(a, "name_embedding", []), getattr(b, "name_embedding", [])) - # 规则4:名称是否包含(如“苹果公司”和“苹果”) - contains = False - try: - n1 = (getattr(a, "name", "") or "").strip().lower() - n2 = (getattr(b, "name", "") or "").strip().lower() - contains = bool(n1 and n2 and (n1 in n2 or n2 in n1)) - except Exception: - pass - # 规则5:是否同现(同现的实体更可能重复,降低相似度阈值) - co_ctx = _co_occurrence(statement_entity_edges, getattr(a, "id", None), getattr(b, "id", None)) - sim = max(txt_sim, emb_sim) - # 候选对筛选条件:满足任一即加入(减少漏判) - if (sim >= 0.80) or (co_ctx and sim >= 0.75) or contains: - candidates.append((i, j)) - - # Use anyio for cross-compatibility with asyncio and trio - results = [] - async with anyio.create_task_group() as tg: - result_list = [None] * len(candidates) - - async def _wrapped(idx: int, i: int, j: int): - try: - result_list[idx] = await _judge_pair(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges) - except Exception as e: - result_list[idx] = e - - # Limit concurrency using semaphore - sem = anyio.Semaphore(max_concurrency) - - async def _limited_wrapped(idx: int, i: int, j: int): - async with sem: - await _wrapped(idx, i, j) - - for idx, (i, j) in enumerate(candidates): - tg.start_soon(_limited_wrapped, idx, i, j) - - results = result_list - - id_redirect_updates: Dict[str, str] = {} - records: List[str] = [] - for idx, res in enumerate(results): - if isinstance(res, Exception): - i, j = candidates[idx] - a = entity_nodes[i] - b = entity_nodes[j] - records.append(f"[LLM异常] pair ({a.id},{b.id}) -> {res}") - continue - decision, ctx = res - i, j = candidates[idx] - a = entity_nodes[i] - b = entity_nodes[j] - th = auto_merge_threshold if not ctx.get("co_occurrence") else co_ctx_threshold - if decision.same_entity and decision.confidence >= th: - canon_idx = decision.canonical_idx if decision.canonical_idx in (0, 1) else _choose_canonical(a, b) - canon = a if canon_idx == 0 else b - other = b if canon_idx == 0 else a - id_redirect_updates[getattr(other, "id")] = getattr(canon, "id") - records.append( - f"[LLM合并] 规范实体 {canon.id} 名称 '{getattr(canon, 'name', '')}' <- 合并实体 {other.id} 名称 '{getattr(other, 'name', '')}' | conf={decision.confidence:.3f}, th={th:.3f}, co_ctx={ctx.get('co_occurrence')}" - ) - # 若类型相同且名称高度相似/包含关系,补充“同类名称相似”记录,格式与报告要求一致(名称后带类型) - try: - type_same = (getattr(a, "entity_type", None) == getattr(b, "entity_type", None)) and getattr(a, "entity_type", None) is not None - name_sim = max(float(ctx.get("name_text_sim", 0.0)), float(ctx.get("name_embed_sim", 0.0))) - name_contains = bool(ctx.get("name_contains", False)) - if type_same and (name_sim >= 0.80 or name_contains): - name_a = (getattr(a, "name", "") or "").strip() - name_b = (getattr(b, "name", "") or "").strip() - type_a = getattr(a, "entity_type", "") - type_b = getattr(b, "entity_type", "") - records.append( - f"[LLM去重] 同类名称相似 {name_a}({type_a})|{name_b}({type_b}) | conf={decision.confidence:.2f} | reason={decision.reason}" - ) - except Exception: - pass - else: - records.append( - f"[LLM不合并] A={a.id} B={b.id} | same={decision.same_entity} conf={decision.confidence:.3f} co_ctx={ctx.get('co_occurrence')}" - ) - - return id_redirect_updates, records - -# 迭代分块去重,这才是重点 -async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 - entity_nodes: List[ExtractedEntityNode], # 待去重实体列表(需先经过精确去重),LLM决策属于模糊匹配下 - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - llm_client: OpenAIClient, - block_size: int = 50, - block_concurrency: int = 4, - pair_concurrency: int = 4, - max_rounds: int = 3, - auto_merge_threshold: float = 0.90, - co_ctx_threshold: float = 0.83, - shuffle_each_round: bool = True, # 每轮是否打乱实体顺序(避免同一块内实体重复,提高覆盖度) -) -> Tuple[Dict[str, str], List[str]]: # 返回:全局ID映射、全局审计日志 - """ - Iteratively deduplicate entities using LLM in block-wise concurrent rounds. - - Process: - - Partition the input entities (post exact + local fuzzy stage) into blocks per round. - - Run LLM pairwise decisions concurrently *within each block*, and also run multiple blocks concurrently. - - Apply merges from all blocks, collapse to canonical set, re-partition, and repeat until no new merges or max_rounds reached. - - Parameters: - - entity_nodes: entities to deduplicate (should already be exact/fuzzy merged candidates) - - statement_entity_edges: statement→entity edges for co-occurrence context - - entity_entity_edges: entity↔entity relational edges for relation statements context - - llm_client: initialized async client - - block_size: target number of entities per block (default 50) - - block_concurrency: how many blocks to process concurrently (default 4) - - pair_concurrency: concurrency for pairwise LLM calls inside each block (default 4) - - max_rounds: upper bound for iterative passes (default 3) - - auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90) - - co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83) - - shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition - - Returns: - - global_redirect: dict losing_id -> canonical_id accumulated across rounds - - records: textual logs including per-round/per-block summaries and per-pair decisions - """ - import asyncio - import random - # 初始化全局日志和全局ID映射(存储所有轮次的结果) - records: List[str] = [] - global_redirect: Dict[str, str] = {} - - # Helper: resolve final canonical id following redirect chain - # 辅助函数1:_resolve:递归解析实体的“最终规范ID”(处理ID映射链,如a→b→c,返回c) - def _resolve(id_: str) -> str: - while id_ in global_redirect and global_redirect[id_] != id_: # 若ID在映射中且未指向自身 - id_ = global_redirect[id_] # 递归替换为映射的ID - return id_ # 返回最终规范ID - ## 这里辅助函数没有看懂 - - # Helper: collapse nodes to canonical representatives per current global_redirect - # 辅助函数2:_collapse_nodes:根据全局ID映射,“折叠”实体列表(保留每个规范ID对应的实体) - def _collapse_nodes(nodes: List[ExtractedEntityNode]) -> List[ExtractedEntityNode]: - by_id: Dict[str, ExtractedEntityNode] = {e.id: e for e in nodes} # 实体ID→实体的映射 - keep: Dict[str, ExtractedEntityNode] = {} # 存储需保留的规范实体 - for e in nodes: - cid = _resolve(e.id) # 解析e的最终规范ID - # 优先保留by_id中已存在的规范实体(若有),否则保留第一个遇到的实体 - if cid in by_id: - keep[cid] = by_id[cid] - else: - keep[cid] = keep.get(cid, e) - return list(keep.values()) - - def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]: - """ - 按 group_id 分块,避免跨组实体在同一块,减少无效候选对 - - Args: - nodes: 实体节点列表 - - Returns: - 分块后的实体列表 - """ - groups: Dict[str, List[ExtractedEntityNode]] = {} - for e in nodes: - gid = getattr(e, "group_id", None) - groups.setdefault(str(gid), []).append(e) - blocks: List[List[ExtractedEntityNode]] = [] - for gid, arr in groups.items(): - if shuffle_each_round: - random.shuffle(arr) - # chunk into block_size - for i in range(0, len(arr), max(1, block_size)): - blocks.append(arr[i:i + max(1, block_size)]) - return blocks - - # Semaphore for block-level concurrency - # 初始化块级并发信号量(控制同时处理的块数量) - block_sem = asyncio.Semaphore(max(1, block_concurrency)) - - # 辅助函数4:_run_one_block:处理单个块的去重(调用llm_dedup_entities) - async def _run_one_block(block_idx: int, block_nodes: List[ExtractedEntityNode]): - async with block_sem: - # Delegate to existing pairwise function with limited concurrency per block - id_map, recs = await llm_dedup_entities( - entity_nodes=block_nodes, - statement_entity_edges=statement_entity_edges, - entity_entity_edges=entity_entity_edges, - llm_client=llm_client, - max_concurrency=pair_concurrency, - auto_merge_threshold=auto_merge_threshold, - co_ctx_threshold=co_ctx_threshold, - ) - # Prefix block index in records for readability - prefixed = [f"[LLM块{block_idx}] {line}" for line in recs] - return id_map, prefixed - - # Iterative rounds - # 核心:迭代分块去重(多轮处理) - current_nodes: List[ExtractedEntityNode] = list(entity_nodes) - round_idx = 1 - while round_idx <= max(1, max_rounds): - # Collapse nodes to canonical reps before each round to avoid redundant comparisons - # 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量) - current_nodes = _collapse_nodes(current_nodes) - # 步骤2:分块(按group_id分块,避免跨组处理) - blocks = _partition_blocks(current_nodes) - if not blocks: # 无块可处理(实体已全部折叠),退出循环 - break - # 步骤3:记录当前轮次的基本信息(轮次、块数、块大小) - records.append(f"[LLM批次] 轮次 {round_idx} 预计处理块数 {len(blocks)} 每块大小≈{block_size}") - - # Run all blocks concurrently with block-level semaphore - # 步骤4:并发处理所有块(创建块处理任务,批量执行) - results = [None] * len(blocks) - async with anyio.create_task_group() as tg: - async def _run_block_wrapper(idx: int, block: List[ExtractedEntityNode]): - try: - results[idx] = await _run_one_block(idx, block) - except Exception as e: - results[idx] = e - - for i in range(len(blocks)): - tg.start_soon(_run_block_wrapper, i, blocks[i]) - - # Collect and normalize redirects from blocks - # 步骤5:合并块结果到全局映射和日志 - merged_this_round = 0 - for bi, res in enumerate(results): - if isinstance(res, Exception): - records.append(f"[LLM块异常] 轮次 {round_idx} 块 {bi} -> {res}") - continue - id_map, recs = res - records.extend(recs) - # Normalize with current global redirects - for losing, canon in id_map.items(): - losing_final = _resolve(losing) - canon_final = _resolve(canon) - if losing_final == canon_final: - continue - # Apply mapping and ensure chain consistency - global_redirect[losing_final] = canon_final - merged_this_round += 1 - records.append(f"[LLM批次] 轮次 {round_idx} 块数 {len(blocks)} 新合并 {merged_this_round}") - - if merged_this_round == 0: - break - - # Prepare nodes for next round: collapse canonical set - current_nodes = _collapse_nodes(current_nodes) - round_idx += 1 - - return global_redirect, records - - -# LLM 消歧:同名不同类型的实体对,输出合并建议与阻断对列表 -async def llm_disambiguate_pairs_iterative( - entity_nodes: List[ExtractedEntityNode], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - llm_client: OpenAIClient, - max_concurrency: int = 4, - merge_conf_threshold: float = 0.88, - block_conf_threshold: float = 0.60, -) -> Tuple[Dict[str, str], List[Tuple[str, str]], List[str]]: - """ - Disambiguate same-name different-type pairs using LLM. - - Returns: - - merge_redirect: dict losing_id -> canonical_id for merges decided by LLM - - block_pairs: list of sorted (id1, id2) pairs to block from fuzzy/heuristic merges - - records: textual logs for audit - """ - records: List[str] = [] - merge_redirect: Dict[str, str] = {} - block_pairs: List[Tuple[str, str]] = [] - - def _is_typed(t: str) -> bool: - t = (t or "").strip().upper() - return bool(t) and t not in {"UNKNOWN", "UNDEFINED", ""} - - candidates: List[Tuple[int, int]] = [] - n = len(entity_nodes) - for i in range(n): - for j in range(i + 1, n): - a = entity_nodes[i] - b = entity_nodes[j] - # 必须同组 - if getattr(a, "group_id", None) != getattr(b, "group_id", None): - continue - ta = getattr(a, "entity_type", None) - tb = getattr(b, "entity_type", None) - # 必须不同类型且两者均为已定义类型 - if ta == tb: - continue - if not (_is_typed(ta) and _is_typed(tb)): - continue - # 严格“同名不同义”:名称需严格相同(大小写与首尾空格忽略) - try: - na = (getattr(a, "name", "") or "").strip().lower() - nb = (getattr(b, "name", "") or "").strip().lower() - except Exception: - na, nb = "", "" - if not na or not nb: - continue - if na == nb: - candidates.append((i, j)) - - if not candidates: - return merge_redirect, block_pairs, records - - # Use anyio for cross-compatibility with asyncio and trio - judged = [None] * len(candidates) - async with anyio.create_task_group() as tg: - async def _wrapped(idx: int, i: int, j: int): - try: - judged[idx] = await _judge_pair_disamb(llm_client, entity_nodes[i], entity_nodes[j], statement_entity_edges, entity_entity_edges) - except Exception as e: - judged[idx] = e - - # Limit concurrency using semaphore - sem = anyio.Semaphore(max_concurrency) - - async def _limited_wrapped(idx: int, i: int, j: int): - async with sem: - await _wrapped(idx, i, j) - - for idx, (i, j) in enumerate(candidates): - tg.start_soon(_limited_wrapped, idx, i, j) - for k, res in enumerate(judged): - i, j = candidates[k] - a = entity_nodes[i] - b = entity_nodes[j] - a_id = getattr(a, "id", None) or "" - b_id = getattr(b, "id", None) or "" - if isinstance(res, Exception): - records.append(f"[DISAMB错误] 对({a_id},{b_id})调用失败: {res}") - block_pairs.append(tuple(sorted((a_id, b_id)))) - continue - decision, ctx = res - try: - if decision.should_merge and decision.confidence >= merge_conf_threshold: - can_idx = 0 if decision.canonical_idx == 0 else 1 - canonical = a if can_idx == 0 else b - losing = b if can_idx == 0 else a - merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "") - records.append( - f"[DISAMB合并] {getattr(losing,'id','')} -> {getattr(canonical,'id','')} | conf={decision.confidence:.2f} | reason={decision.reason} | suggested_type={decision.suggested_type or ''}" - ) - # 追加 LLM 决策去重记录,以便下方报告展示到“LLM 决策去重”区块 - records.append( - f"[LLM去重] 同名类型相似 {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | conf={decision.confidence:.2f} | reason={decision.reason}" - ) - else: - # Fallback:同名且类型不同,但语义高度相似且未要求阻断,按“同名类型相似”进行合并 - name_a = (getattr(a, "name", "") or "").strip().lower() - name_b = (getattr(b, "name", "") or "").strip().lower() - def _strength_rank(x: str) -> int: - s = (x or "").strip().lower() - return {"strong": 3, "both": 2, "weak": 1}.get(s, 0) - if ( - name_a and name_b and name_a == name_b - and (not decision.block_pair) - and decision.confidence >= max(0.80, block_conf_threshold) - ): - # 选择规范实体:优先使用 canonical_idx;否则根据连接强度挑选更强者 - if decision.canonical_idx in (0, 1): - canonical = a if decision.canonical_idx == 0 else b - losing = b if decision.canonical_idx == 0 else a - else: - sa = _strength_rank(getattr(a, "connect_strength", None)) - sb = _strength_rank(getattr(b, "connect_strength", None)) - canonical = a if sa >= sb else b - losing = b if sa >= sb else a - merge_redirect[getattr(losing, "id", "")] = getattr(canonical, "id", "") - # 消歧合并审计 - records.append( - f"[DISAMB合并] {getattr(losing,'id','')} -> {getattr(canonical,'id','')} | conf={decision.confidence:.2f} | reason={decision.reason} | suggested_type={decision.suggested_type or ''}" - ) - # 追加 LLM 决策去重记录(同名类型相似) - records.append( - f"[LLM去重] 同名类型相似 {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | conf={decision.confidence:.2f} | reason={decision.reason}" - ) - else: - if decision.block_pair or decision.confidence >= block_conf_threshold: - block_pairs.append(tuple(sorted((a_id, b_id)))) - # 仅保留阻断条目在预筛选报告,包含实体名称与类型,便于人读 - records.append( - f"[DISAMB阻断] {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | conf={decision.confidence:.2f} | reason={decision.reason} || block_pair={decision.block_pair}" - ) - except Exception: - block_pairs.append(tuple(sorted((a_id, b_id)))) - # 异常情况也以阻断形式记录,包含名称便于定位 - records.append( - f"[DISAMB异常阻断] {getattr(a,'name','')}({getattr(a,'entity_type','')})|{getattr(b,'name','')}({getattr(b,'entity_type','')}) | ids=({a_id},{b_id})" - ) - - return merge_redirect, block_pairs, records diff --git a/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py deleted file mode 100644 index 04aa6cb6..00000000 --- a/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ /dev/null @@ -1,149 +0,0 @@ -# 导入 Python 的annotations特性,允许在类型注解中使用尚未定义的类(支持 “向前引用”),提升代码中类型注解的灵活性。 -# 这是什么意思? 该类的属性的类型是这个类本身(递归定义)? -""" -这段代码是 “第二层去重消歧” 的核心实现,逻辑可分为四步: -1.从第一层去重消歧后的实体中提取核心信息,作为索引查询 Neo4j 中同组的候选实体; -2.对候选实体去重并转换为统一模型; -3.构建预重定向关系(第一层实体 ID→数据库实体 ID),确保优先使用数据库 ID; -4.合并数据库候选实体与第一层实体,调用去重函数完成最终融合,返回结果。 -""" - -from __future__ import annotations - -from typing import List, Dict, Any, Tuple -from datetime import datetime - -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互 -from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。 -from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge -from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明 -from app.core.memory.models.variate_config import DedupConfig - - -def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt,用于将任意类型的输入值解析为datetime对象(处理实体节点中的时间字段) - if isinstance(val, datetime): - return val - if isinstance(val, str) and val: - try: - return datetime.fromisoformat(val) # 使用fromisoformat方法将 ISO 格式的字符串(如 "2023-10-01T12:00:00")解析为datetime对象 - except Exception: - pass - # Fallback: now; upstream should provide real times - return datetime.now() - - -def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: - """ - 将 Neo4j 返回的数据库记录转换为 ExtractedEntityNode 模型对象 - - Args: - row: Neo4j 查询返回的记录字典 - - Returns: - ExtractedEntityNode: 实体节点对象 - - Note: - 从数据库中查询到的内容是 JSON 格式的字符串,需要先解析为 Python 对象 - """ - return ExtractedEntityNode( - id=row.get("id"), - name=row.get("name") or "", - group_id=row.get("group_id") or "", - user_id=row.get("user_id") or "", - apply_id=row.get("apply_id") or "", - created_at=_parse_dt(row.get("created_at")), - expired_at=_parse_dt(row.get("expired_at")) if row.get("expired_at") else None, - entity_idx=int(row.get("entity_idx") or 0), - statement_id=row.get("statement_id") or "", - entity_type=row.get("entity_type") or "", - description=row.get("description") or "", - aliases=row.get("aliases") or [], - name_embedding=row.get("name_embedding") or [], - fact_summary=row.get("fact_summary") or "", - connect_strength=row.get("connect_strength") or "", - ) - - -async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重 - connector: Neo4jConnector, - group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重 - entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体 - statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 - entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 - dedup_config: DedupConfig | None = None, -) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: - """ - 第二层去重消歧: - - 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体 - - 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合 - - 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID) - """ - if not entity_nodes: - return entity_nodes, statement_entity_edges, entity_entity_edges - - # 构造批量行并检索候选(精确/别名 + CONTAINS 召回) - # 将第一层去重消歧的结果作为索引,批量查询DB候选实体 - incoming_rows: List[Dict[str, Any]] = [ # 定义 包含第一层实体的核心信息(用于数据库查询) - {"id": e.id, "name": e.name, "entity_type": e.entity_type} for e in entity_nodes # 对entity_nodes中的每个实体e,提取id(实体 ID)、name(名称)、entity_type(类型),构造字典作为查询条件。 - - ] - candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。 - connector=connector, group_id=group_id, - entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引) - use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动 - ) - - # 拉平候选,转为模型(按 DB 节点优先) - db_candidate_rows: List[Dict[str, Any]] = [] # 存储去重后的数据库候选实体记录(行) - seen_db_ids: set[str] = set() # 集合,用于记录已处理的数据库实体 ID(避免重复添加同一实体) - for _, rows in candidates_map.items(): - for r in rows: - rid = r.get("id") - if rid and rid not in seen_db_ids: # 如果rid存在且未被处理 - seen_db_ids.add(rid) # 将rid加入seen_db_ids,标记为已处理 - db_candidate_rows.append(r) # 将该记录r添加到db_candidate_rows(确保数据库实体唯一) - - db_candidate_models: List[ExtractedEntityNode] = [] - for r in db_candidate_rows: # db_candidate_rows:去重后的数据库候选实体记录(行) - try: - m = _row_to_entity(r) # 调用_row_to_entity函数,将数据库记录r转换为实体模型m - db_candidate_models.append(m) # m添加到db_candidate_models - except Exception: - # 忽略无法解析的记录 - pass - - # 若 DB 候选为空:跳过二层融合,直接返回第一层结果,并在报告中标注候选数 - candidate_count = len(db_candidate_models) - if candidate_count == 0: - try: - _write_dedup_fusion_report( - exact_merge_map={}, - fuzzy_merge_records=[], - local_llm_records=[], - disamb_records=[], - stage_label="第二层去重消歧", - append=True, - stage_notes=[f"候选数:{candidate_count}(DB 为空则标注跳过)"], - ) - except Exception: - # 报告写入失败不影响主流程 - pass - return entity_nodes, statement_entity_edges, entity_entity_edges - - # 联合集合(DB 在前,确保规范 ID 优先使用 DB ID) - # 将从 DB 检索到的候选实体与第一层去重消歧的实体合并,作为输入继续调用去重方法。 - # 由于按顺序遍历,规范实体将优先选择位于前面的 DB 节点,因此无需显式预重定向。 - union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes) - - # 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重) - fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges( - union_entities, - statement_entity_edges, - entity_entity_edges, - report_stage="第二层去重消歧", - report_append=True, - dedup_config=dedup_config, - ) - - return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges diff --git a/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py deleted file mode 100644 index a5f600b4..00000000 --- a/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from typing import List, Tuple, Optional - -from app.core.memory.models.variate_config import ExtractionPipelineConfig -from app.core.memory.utils.config.config_utils import get_pipeline_config -from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges -from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.models.graph_models import ( - DialogueNode, - ChunkNode, - StatementNode, - ExtractedEntityNode, - StatementChunkEdge, - StatementEntityEdge, - EntityEntityEdge, -) -from app.core.memory.models.message_models import DialogData - - -async def dedup_layers_and_merge_and_return( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - pipeline_config: Optional[ExtractionPipelineConfig] = None, - connector: Optional[Neo4jConnector] = None, -) -> Tuple[ - List[DialogueNode], - List[ChunkNode], - List[StatementNode], - List[ExtractedEntityNode], - List[StatementChunkEdge], - List[StatementEntityEdge], - List[EntityEntityEdge], -]: - """ - 执行两层实体去重与融合: - - 第一层:精确/模糊/LLM 决策去重 - - 第二层:与 Neo4j 同组实体联合去重与融合(依赖传入的 connector) - 返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。 - """ - - # 默认从 runtime.json 加载管线配置,避免回退到环境变量 - if pipeline_config is None: - try: - pipeline_config = get_pipeline_config() - except Exception: - pipeline_config = None - - # 先探测 group_id,决定报告写入策略 - group_id: Optional[str] = None - for dd in dialog_data_list: - group_id = getattr(dd, "group_id", None) - if group_id: - break - - # 第一层去重消歧 - dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges( - entity_nodes, - statement_entity_edges, - entity_entity_edges, - report_stage="第一层去重消歧", - report_append=False, - dedup_config=(pipeline_config.deduplication if pipeline_config else None), - ) - - # 初始化第二层融合结果为第一层结果 - fused_entity_nodes = dedup_entity_nodes - fused_statement_entity_edges = dedup_statement_entity_edges - fused_entity_entity_edges = dedup_entity_entity_edges - - # 第二层去重消歧:与 Neo4j 中同组实体联合融合 - try: - if group_id: - if connector: - fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j( - connector=connector, - group_id=group_id, - entity_nodes=dedup_entity_nodes, - statement_entity_edges=dedup_statement_entity_edges, - entity_entity_edges=dedup_entity_entity_edges, - dedup_config=(pipeline_config.deduplication if pipeline_config else None), - ) - else: - print("Skip second-layer dedup: missing connector") - else: - print("Skip second-layer dedup: missing group_id") - except Exception as e: - print(f"Second-layer dedup failed: {e}") - - return ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - fused_entity_nodes, - statement_chunk_edges, - fused_statement_entity_edges, - fused_entity_entity_edges, - ) diff --git a/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py deleted file mode 100644 index 3cf74b41..00000000 --- a/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ /dev/null @@ -1,1306 +0,0 @@ -""" -萃取引擎 - 流水线编排器 - -该模块提供了一个统一的流水线编排器,用于协调整个知识提取流程。 -它整合了数据预处理、知识提取、去重消歧等模块,提供统一的执行接口。 - -主要功能: -1. 协调数据预处理、分块、陈述句提取、三元组提取、时间信息提取等步骤 -2. 管理嵌入向量生成 -3. 执行两阶段去重和消歧 -4. 将提取结果转换为图数据库节点和边 -5. 提供错误处理和日志记录 -6. 支持试运行模式(不写入数据库) - -作者:Memory Refactoring Team -日期:2025-11-21 -""" - -import asyncio -import logging -from typing import List, Dict, Any, Tuple, Optional -from datetime import datetime - -from app.core.memory.models.message_models import DialogData -from app.core.memory.models.graph_models import ( - DialogueNode, - ChunkNode, - StatementNode, - ExtractedEntityNode, - StatementChunkEdge, - StatementEntityEdge, - EntityEntityEdge, -) -from app.core.memory.utils.data.ontology import TemporalInfo -from app.core.memory.models.variate_config import ( - ExtractionPipelineConfig, - StatementExtractionConfig, -) -from app.core.memory.src.llm_tools.openai_client import LLMClient -from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -# 导入各个提取模块 -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import ( - StatementExtractor, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import ( - TripletExtractor, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.temporal_extraction import ( - TemporalExtractor, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import ( - embedding_generation, - embedding_generation_all, - generate_entity_embeddings_from_triplets, -) -from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import ( - dedup_layers_and_merge_and_return, -) -from app.core.memory.storage_services.extraction_engine.pipeline_help import ( - _write_extracted_result_summary, - export_test_input_doc, -) - -# 配置日志 -logger = logging.getLogger(__name__) - - -class ExtractionOrchestrator: - """ - 知识提取流水线编排器 - - 该类负责协调整个知识提取流程,包括: - 1. 陈述句提取 - 2. 三元组提取 - 3. 时间信息提取 - 4. 嵌入向量生成 - 5. 数据赋值到语句 - 6. 节点和边的创建 - 7. 两阶段去重和消歧 - 8. 结果汇总和输出 - - Attributes: - llm_client: LLM 客户端,用于调用大语言模型 - embedder_client: 嵌入模型客户端,用于生成向量嵌入 - connector: Neo4j 连接器,用于数据库操作 - config: 流水线配置 - """ - - def __init__( - self, - llm_client: LLMClient, - embedder_client: OpenAIEmbedderClient, - connector: Neo4jConnector, - config: Optional[ExtractionPipelineConfig] = None, - ): - """ - 初始化流水线编排器 - - Args: - llm_client: LLM 客户端 - embedder_client: 嵌入模型客户端 - connector: Neo4j 连接器 - config: 流水线配置,如果为 None 则使用默认配置 - """ - self.llm_client = llm_client - self.embedder_client = embedder_client - self.connector = connector - self.config = config or ExtractionPipelineConfig() - self.is_pilot_run = False # 默认非试运行模式 - - # 初始化各个提取器 - self.statement_extractor = StatementExtractor( - llm_client=llm_client, - config=self.config.statement_extraction, - ) - self.triplet_extractor = TripletExtractor(llm_client=llm_client) - self.temporal_extractor = TemporalExtractor(llm_client=llm_client) - - logger.info("ExtractionOrchestrator 初始化完成") - - async def run( - self, - dialog_data_list: List[DialogData], - is_pilot_run: bool = False, - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - ]: - """ - 运行完整的知识提取流水线(优化版:并行执行) - - 该方法协调所有提取步骤,优化执行顺序: - 1. 陈述句提取 - 2. 并行执行:三元组提取 + 时间信息提取 + 陈述句/分块嵌入生成 - 3. 实体嵌入生成(依赖三元组) - 4. 数据赋值 - 5. 节点和边创建 - 6. 两阶段去重 - 7. 结果汇总 - - Args: - dialog_data_list: 已分块的对话数据列表 - is_pilot_run: 是否为试运行模式(不写入数据库) - - Returns: - 包含三个元组的元组: - - 第一个元组:(对话节点列表, 分块节点列表, 陈述句节点列表) - - 第二个元组:去重前的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) - - 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) - """ - try: - # 设置试运行模式标志 - self.is_pilot_run = is_pilot_run - mode_str = "试运行模式" if is_pilot_run else "正式模式" - logger.info(f"开始运行知识提取流水线(优化版 - {mode_str}),共 {len(dialog_data_list)} 个对话") - - # 步骤 1: 陈述句提取 - logger.info("步骤 1/6: 陈述句提取(全局分块级并行)") - dialog_data_list = await self._extract_statements(dialog_data_list) - - # 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成 - logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成") - ( - triplet_maps, - temporal_maps, - statement_embedding_maps, - chunk_embedding_maps, - dialog_embeddings, - ) = await self._parallel_extract_and_embed(dialog_data_list) - - # 步骤 3: 生成实体嵌入(依赖三元组提取结果) - logger.info("步骤 3/6: 生成实体嵌入") - triplet_maps = await self._generate_entity_embeddings(triplet_maps) - - # 步骤 4: 将提取的数据赋值到语句 - logger.info("步骤 4/6: 数据赋值") - dialog_data_list = await self._assign_extracted_data( - dialog_data_list, - temporal_maps, - triplet_maps, - statement_embedding_maps, - chunk_embedding_maps, - dialog_embeddings, - ) - - # 步骤 5: 创建节点和边 - logger.info("步骤 5/6: 创建节点和边") - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - ) = await self._create_nodes_and_edges(dialog_data_list) - - # 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总) - export_test_input_doc(entity_nodes, statement_entity_edges, entity_entity_edges) - - # 步骤 6: 两阶段去重和消歧 - if is_pilot_run: - logger.info("步骤 6/6: 去重和消歧(试运行模式:仅第一层去重)") - else: - logger.info("步骤 6/6: 两阶段去重和消歧") - - result = await self._run_dedup_and_write_summary( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - dialog_data_list, - ) - - logger.info(f"知识提取流水线运行完成({mode_str})") - return result - - except Exception as e: - logger.error(f"知识提取流水线运行失败: {e}", exc_info=True) - raise - - async def _extract_statements( - self, dialog_data_list: List[DialogData] - ) -> List[DialogData]: - """ - 从对话中提取陈述句(优化版:全局分块级并行) - - Args: - dialog_data_list: 对话数据列表 - - Returns: - 更新后的对话数据列表(包含提取的陈述句) - """ - logger.info("开始陈述句提取(全局分块级并行)") - - # 收集所有分块及其元数据 - all_chunks = [] - chunk_metadata = [] # (dialog_idx, chunk_idx) - - for d_idx, dialog in enumerate(dialog_data_list): - dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None - for c_idx, chunk in enumerate(dialog.chunks): - all_chunks.append((chunk, dialog.group_id, dialogue_content)) - chunk_metadata.append((d_idx, c_idx)) - - logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") - - # 全局并行处理所有分块 - async def extract_for_chunk(chunk_data): - chunk, group_id, dialogue_content = chunk_data - try: - return await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) - except Exception as e: - logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}") - return [] - - tasks = [extract_for_chunk(chunk_data) for chunk_data in all_chunks] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 将结果分配回对话 - for i, result in enumerate(results): - d_idx, c_idx = chunk_metadata[i] - if isinstance(result, Exception): - logger.error(f"分块处理异常: {result}") - dialog_data_list[d_idx].chunks[c_idx].statements = [] - elif isinstance(result, list): - dialog_data_list[d_idx].chunks[c_idx].statements = result - else: - dialog_data_list[d_idx].chunks[c_idx].statements = [] - - # 统计并保存(试运行和正式模式都需要保存,用于生成结果汇总) - all_statements = [] - for dialog in dialog_data_list: - for chunk in dialog.chunks: - if chunk.statements: - all_statements.extend(chunk.statements) - - # 保存陈述句到文件(试运行和正式模式都需要) - self.statement_extractor.save_statements(all_statements) - - logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句") - - return dialog_data_list - - async def _extract_triplets( - self, dialog_data_list: List[DialogData] - ) -> List[Dict[str, Any]]: - """ - 从对话中提取三元组(优化版:全局陈述句级并行) - - Args: - dialog_data_list: 对话数据列表 - - Returns: - 三元组映射列表,每个对话对应一个字典 - """ - logger.info("开始三元组提取(全局陈述句级并行)") - - # 收集所有陈述句及其元数据 - all_statements = [] - statement_metadata = [] # (dialog_idx, statement_id, chunk_content) - - for d_idx, dialog in enumerate(dialog_data_list): - for chunk in dialog.chunks: - for statement in chunk.statements: - all_statements.append((statement, chunk.content)) - statement_metadata.append((d_idx, statement.id)) - - logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组") - - # 全局并行处理所有陈述句 - async def extract_for_statement(stmt_data): - statement, chunk_content = stmt_data - try: - return await self.triplet_extractor._extract_triplets(statement, chunk_content) - except Exception as e: - logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}") - from app.core.memory.models.triplet_models import TripletExtractionResponse - return TripletExtractionResponse(triplets=[], entities=[]) - - tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 将结果组织成对话级别的映射 - triplet_maps = [{} for _ in dialog_data_list] - all_responses = [] - - for i, result in enumerate(results): - d_idx, stmt_id = statement_metadata[i] - if isinstance(result, Exception): - logger.error(f"陈述句处理异常: {result}") - from app.core.memory.models.triplet_models import TripletExtractionResponse - triplet_maps[d_idx][stmt_id] = TripletExtractionResponse(triplets=[], entities=[]) - else: - triplet_maps[d_idx][stmt_id] = result - all_responses.append(result) - - # 统计提取结果 - total_triplets = sum(len(m) for m in triplet_maps) - logger.info(f"三元组提取完成,共提取 {total_triplets} 个三元组") - - # 保存三元组到文件(试运行和正式模式都需要,用于生成结果汇总) - if all_responses: - try: - self.triplet_extractor.save_triplets(all_responses) - logger.info(f"三元组数据已保存到文件") - except Exception as e: - logger.error(f"保存三元组到文件失败: {e}", exc_info=True) - - return triplet_maps - - async def _extract_temporal( - self, dialog_data_list: List[DialogData] - ) -> List[Dict[str, Any]]: - """ - 从对话中提取时间信息(优化版:全局陈述句级并行) - - Args: - dialog_data_list: 对话数据列表 - - Returns: - 时间信息映射列表,每个对话对应一个字典 - """ - logger.info("开始时间信息提取(全局陈述句级并行)") - - # 收集所有需要提取时间的陈述句 - all_statements = [] - statement_metadata = [] # (dialog_idx, statement_id, ref_dates) - - for d_idx, dialog in enumerate(dialog_data_list): - # 获取参考日期 - ref_dates = {} - if hasattr(dialog, 'metadata') and dialog.metadata: - if 'conversation_date' in dialog.metadata: - ref_dates['conversation_date'] = dialog.metadata['conversation_date'] - if 'publication_date' in dialog.metadata: - ref_dates['publication_date'] = dialog.metadata['publication_date'] - - if not ref_dates: - from datetime import datetime - ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")} - - for chunk in dialog.chunks: - for statement in chunk.statements: - # 跳过 ATEMPORAL 类型的陈述句 - from app.core.memory.utils.data.ontology import TemporalInfo - if statement.temporal_info != TemporalInfo.ATEMPORAL: - all_statements.append((statement, ref_dates)) - statement_metadata.append((d_idx, statement.id)) - - logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取") - - # 全局并行处理所有陈述句 - async def extract_for_statement(stmt_data): - statement, ref_dates = stmt_data - try: - return await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates) - except Exception as e: - logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}") - from app.core.memory.models.message_models import TemporalValidityRange - return TemporalValidityRange(valid_at=None, invalid_at=None) - - tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 将结果组织成对话级别的映射 - temporal_maps = [{} for _ in dialog_data_list] - - for i, result in enumerate(results): - d_idx, stmt_id = statement_metadata[i] - if isinstance(result, Exception): - logger.error(f"陈述句处理异常: {result}") - from app.core.memory.models.message_models import TemporalValidityRange - temporal_maps[d_idx][stmt_id] = TemporalValidityRange(valid_at=None, invalid_at=None) - else: - temporal_maps[d_idx][stmt_id] = result - - # 为 ATEMPORAL 陈述句添加空的时间范围 - from app.core.memory.utils.data.ontology import TemporalInfo - from app.core.memory.models.message_models import TemporalValidityRange - for d_idx, dialog in enumerate(dialog_data_list): - for chunk in dialog.chunks: - for statement in chunk.statements: - if statement.temporal_info == TemporalInfo.ATEMPORAL and statement.id not in temporal_maps[d_idx]: - temporal_maps[d_idx][statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None) - - # 统计提取结果 - total_temporal = sum(len(m) for m in temporal_maps) - logger.info(f"时间信息提取完成,共提取 {total_temporal} 个时间范围") - - return temporal_maps - - async def _parallel_extract_and_embed( - self, dialog_data_list: List[DialogData] - ) -> Tuple[ - List[Dict[str, Any]], - List[Dict[str, Any]], - List[Dict[str, List[float]]], - List[Dict[str, List[float]]], - List[List[float]], - ]: - """ - 并行执行三元组提取、时间信息提取和基础嵌入生成 - - 这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行: - - 三元组提取:从陈述句中提取实体和关系 - - 时间信息提取:从陈述句中提取时间范围 - - 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组) - - Args: - dialog_data_list: 对话数据列表 - - Returns: - 五个列表的元组: - - 三元组映射列表 - - 时间信息映射列表 - - 陈述句嵌入映射列表 - - 分块嵌入映射列表 - - 对话嵌入列表 - """ - logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成") - - # 创建三个并行任务 - triplet_task = self._extract_triplets(dialog_data_list) - temporal_task = self._extract_temporal(dialog_data_list) - embedding_task = self._generate_basic_embeddings(dialog_data_list) - - # 并行执行 - results = await asyncio.gather( - triplet_task, - temporal_task, - embedding_task, - return_exceptions=True - ) - - # 解包结果 - triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list] - temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list] - - if isinstance(results[2], Exception): - logger.error(f"基础嵌入生成失败: {results[2]}") - statement_embedding_maps = [{} for _ in dialog_data_list] - chunk_embedding_maps = [{} for _ in dialog_data_list] - dialog_embeddings = [[] for _ in dialog_data_list] - else: - statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2] - - logger.info("并行任务执行完成") - return ( - triplet_maps, - temporal_maps, - statement_embedding_maps, - chunk_embedding_maps, - dialog_embeddings, - ) - - async def _generate_basic_embeddings( - self, dialog_data_list: List[DialogData] - ) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]: - """ - 生成基础嵌入向量(陈述句、分块、对话) - - 这些嵌入不依赖三元组提取结果,可以提前生成 - 在试运行模式下,跳过嵌入生成以节省时间 - - Args: - dialog_data_list: 对话数据列表 - - Returns: - 三个列表的元组: - - 陈述句嵌入映射列表 - - 分块嵌入映射列表 - - 对话嵌入列表 - """ - # 试运行模式:跳过嵌入生成 - if self.is_pilot_run: - logger.info("试运行模式:跳过基础嵌入生成(节省约 20 秒)") - return ( - [{} for _ in dialog_data_list], - [{} for _ in dialog_data_list], - [[] for _ in dialog_data_list], - ) - - logger.info("开始生成基础嵌入向量(陈述句、分块、对话)") - - try: - # 从 runtime.json 获取嵌入模型配置ID - from app.core.memory.utils.config import definitions as config_defs - embedding_id = config_defs.SELECTED_EMBEDDING_ID - - if not embedding_id: - logger.error("未在 runtime.json 中配置 embedding 模型 ID") - raise ValueError("未配置嵌入模型ID") - - # 只生成陈述句、分块和对话的嵌入(不包括实体) - statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation( - dialog_data_list, embedding_id - ) - - # 统计生成结果 - total_statement_embeddings = sum(len(m) for m in statement_embedding_maps) - total_chunk_embeddings = sum(len(m) for m in chunk_embedding_maps) - logger.info( - f"基础嵌入生成完成:{total_statement_embeddings} 个陈述句嵌入," - f"{total_chunk_embeddings} 个分块嵌入,{len(dialog_embeddings)} 个对话嵌入" - ) - - return statement_embedding_maps, chunk_embedding_maps, dialog_embeddings - - except Exception as e: - logger.error(f"基础嵌入生成失败: {e}", exc_info=True) - # 返回空结果 - return ( - [{} for _ in dialog_data_list], - [{} for _ in dialog_data_list], - [[] for _ in dialog_data_list], - ) - - async def _generate_entity_embeddings( - self, triplet_maps: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """ - 生成实体嵌入向量 - - 在试运行模式下,跳过实体嵌入生成以节省时间 - - Args: - triplet_maps: 三元组映射列表 - - Returns: - 更新后的三元组映射列表(包含实体嵌入) - """ - # 试运行模式:跳过实体嵌入生成 - if self.is_pilot_run: - logger.info("试运行模式:跳过实体嵌入生成(节省约 5-8 秒)") - return triplet_maps - - logger.info("开始生成实体嵌入向量") - - try: - # 从 runtime.json 获取嵌入模型配置ID - from app.core.memory.utils.config import definitions as config_defs - embedding_id = config_defs.SELECTED_EMBEDDING_ID - - if not embedding_id: - logger.error("未在 runtime.json 中配置 embedding 模型 ID") - return triplet_maps - - # 生成实体嵌入 - updated_triplet_maps = await generate_entity_embeddings_from_triplets( - triplet_maps, embedding_id - ) - - logger.info("实体嵌入生成完成") - return updated_triplet_maps - - except Exception as e: - logger.error(f"实体嵌入生成失败: {e}", exc_info=True) - return triplet_maps - - - - async def _assign_extracted_data( - self, - dialog_data_list: List[DialogData], - temporal_maps: List[Dict[str, Any]], - triplet_maps: List[Dict[str, Any]], - statement_embedding_maps: List[Dict[str, List[float]]], - chunk_embedding_maps: List[Dict[str, List[float]]], - dialog_embeddings: List[List[float]], - ) -> List[DialogData]: - """ - 将提取的数据赋值到语句 - - Args: - dialog_data_list: 对话数据列表 - temporal_maps: 时间信息映射列表 - triplet_maps: 三元组映射列表 - statement_embedding_maps: 陈述句嵌入映射列表 - chunk_embedding_maps: 分块嵌入映射列表 - dialog_embeddings: 对话嵌入列表 - - Returns: - 更新后的对话数据列表 - """ - logger.info("开始将提取数据赋值到语句") - - # 确保列表长度匹配 - expected_length = len(dialog_data_list) - if ( - len(temporal_maps) != expected_length - or len(triplet_maps) != expected_length - or len(statement_embedding_maps) != expected_length - or len(chunk_embedding_maps) != expected_length - or len(dialog_embeddings) != expected_length - ): - logger.warning( - f"数据大小不匹配 - 对话: {len(dialog_data_list)}, " - f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, " - f"陈述句嵌入: {len(statement_embedding_maps)}, " - f"分块嵌入: {len(chunk_embedding_maps)}, " - f"对话嵌入: {len(dialog_embeddings)}" - ) - - total_statements = 0 - assigned_temporal = 0 - assigned_triplets = 0 - assigned_statement_embeddings = 0 - assigned_chunk_embeddings = 0 - assigned_dialog_embeddings = 0 - - # 处理每个对话 - for i, dialog_data in enumerate(dialog_data_list): - # 检查是否有缺失的数据 - if i >= len(temporal_maps) or i >= len(triplet_maps): - logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值") - continue - - temporal_map = temporal_maps[i] - triplet_map = triplet_maps[i] - statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {} - chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {} - dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else [] - - # 赋值对话嵌入 - if dialog_embedding: - dialog_data.dialog_embedding = dialog_embedding - assigned_dialog_embeddings += 1 - - # 处理每个分块 - for chunk in dialog_data.chunks: - # 赋值分块嵌入 - if chunk.id in chunk_embedding_map: - chunk.chunk_embedding = chunk_embedding_map[chunk.id] - assigned_chunk_embeddings += 1 - - # 处理每个陈述句 - for statement in chunk.statements: - total_statements += 1 - - # 赋值时间信息 - if statement.id in temporal_map: - statement.temporal_validity = temporal_map[statement.id] - assigned_temporal += 1 - - # 赋值三元组 - if statement.id in triplet_map: - statement.triplet_extraction_info = triplet_map[statement.id] - assigned_triplets += 1 - - # 赋值陈述句嵌入 - if statement.id in statement_embedding_map: - statement.statement_embedding = statement_embedding_map[statement.id] - assigned_statement_embeddings += 1 - - logger.info( - f"数据赋值完成 - 总陈述句: {total_statements}, " - f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, " - f"陈述句嵌入: {assigned_statement_embeddings}, " - f"分块嵌入: {assigned_chunk_embeddings}, " - f"对话嵌入: {assigned_dialog_embeddings}" - ) - - return dialog_data_list - - async def _create_nodes_and_edges( - self, dialog_data_list: List[DialogData] - ) -> Tuple[ - List[DialogueNode], - List[ChunkNode], - List[StatementNode], - List[ExtractedEntityNode], - List[StatementChunkEdge], - List[StatementEntityEdge], - List[EntityEntityEdge], - ]: - """ - 创建图数据库节点和边 - - 将对话数据转换为图数据库的节点和边结构 - - Args: - dialog_data_list: 对话数据列表 - - Returns: - 包含所有节点和边的元组 - """ - logger.info("开始创建节点和边") - - dialogue_nodes = [] - chunk_nodes = [] - statement_nodes = [] - entity_nodes = [] - statement_chunk_edges = [] - statement_entity_edges = [] - entity_entity_edges = [] - - # 用于去重的集合 - entity_id_set = set() - - for dialog_data in dialog_data_list: - # 创建对话节点 - dialogue_node = DialogueNode( - id=dialog_data.id, - name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段 - ref_id=dialog_data.ref_id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, - run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id - content=dialog_data.context.content if dialog_data.context else "", - dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None, - created_at=dialog_data.created_at, - expired_at=dialog_data.expired_at, - metadata=dialog_data.metadata, - config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, - ) - dialogue_nodes.append(dialogue_node) - - # 处理每个分块 - for chunk_idx, chunk in enumerate(dialog_data.chunks): - # 创建分块节点 - chunk_node = ChunkNode( - id=chunk.id, - name=f"Chunk_{chunk.id}", # 添加必需的 name 字段 - dialog_id=dialog_data.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, - run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id - content=chunk.content, - chunk_embedding=chunk.chunk_embedding, - sequence_number=chunk_idx, # 添加必需的 sequence_number 字段 - created_at=dialog_data.created_at, - expired_at=dialog_data.expired_at, - metadata=chunk.metadata, - ) - chunk_nodes.append(chunk_node) - - # 处理每个陈述句 - for statement in chunk.statements: - # 创建陈述句节点 - statement_node = StatementNode( - id=statement.id, - name=f"Statement_{statement.id}", # 添加必需的 name 字段 - chunk_id=chunk.id, - stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 - temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 - connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, - run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id - statement=statement.statement, - statement_embedding=statement.statement_embedding, - valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, - invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, - created_at=dialog_data.created_at, - expired_at=dialog_data.expired_at, - config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, - ) - statement_nodes.append(statement_node) - - # 创建陈述句-分块边 - statement_chunk_edge = StatementChunkEdge( - source=statement.id, - target=chunk.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, - run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id - created_at=dialog_data.created_at, - ) - statement_chunk_edges.append(statement_chunk_edge) - - # 处理三元组信息 - if statement.triplet_extraction_info: - triplet_info = statement.triplet_extraction_info - - # 创建实体索引到ID的映射 - entity_idx_to_id = {} - - # 创建实体节点 - for entity_idx, entity in enumerate(triplet_info.entities): - # 映射实体索引到实体ID - entity_idx_to_id[entity.entity_idx] = entity.id - - if entity.id not in entity_id_set: - entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') - entity_node = ExtractedEntityNode( - id=entity.id, - name=getattr(entity, 'name', f"Entity_{entity.id}"), # 使用 name 而不是 entity_name - entity_idx=entity.entity_idx, # 使用实体自己的 entity_idx - statement_id=statement.id, # 添加必需的 statement_id 字段 - entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type - description=getattr(entity, 'description', ''), # 添加必需的 description 字段 - fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段 - connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 - name_embedding=getattr(entity, 'name_embedding', None), - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, - run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id - created_at=dialog_data.created_at, - expired_at=dialog_data.expired_at, - config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None, - ) - entity_nodes.append(entity_node) - entity_id_set.add(entity.id) - - # 创建陈述句-实体边 - entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') - statement_entity_edge = StatementEntityEdge( - source=statement.id, - target=entity.id, - connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, - run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id - created_at=dialog_data.created_at, - ) - statement_entity_edges.append(statement_entity_edge) - - # 创建实体-实体边(从三元组) - for triplet in triplet_info.triplets: - # 将三元组中的整数索引映射到实体ID - subject_entity_id = entity_idx_to_id.get(triplet.subject_id) - object_entity_id = entity_idx_to_id.get(triplet.object_id) - - # 只有当两个实体ID都存在时才创建边 - if subject_entity_id and object_entity_id: - entity_entity_edge = EntityEntityEdge( - source=subject_entity_id, - target=object_entity_id, - relation_type=triplet.predicate, - statement=statement.statement, - source_statement_id=statement.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, - run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id - created_at=dialog_data.created_at, - expired_at=dialog_data.expired_at, - ) - entity_entity_edges.append(entity_entity_edge) - else: - logger.warning( - f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, " - f"object_id={triplet.object_id}, statement_id={statement.id}" - ) - - logger.info( - f"节点和边创建完成 - 对话节点: {len(dialogue_nodes)}, " - f"分块节点: {len(chunk_nodes)}, 陈述句节点: {len(statement_nodes)}, " - f"实体节点: {len(entity_nodes)}, 陈述句-分块边: {len(statement_chunk_edges)}, " - f"陈述句-实体边: {len(statement_entity_edges)}, " - f"实体-实体边: {len(entity_entity_edges)}" - ) - - return ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - ) - - async def _run_dedup_and_write_summary( - self, - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - entity_entity_edges: List[EntityEntityEdge], - dialog_data_list: List[DialogData], - ) -> Tuple[ - Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]], - ]: - """ - 执行两阶段去重并写入汇总 - - Args: - dialogue_nodes: 对话节点列表 - chunk_nodes: 分块节点列表 - statement_nodes: 陈述句节点列表 - entity_nodes: 实体节点列表 - statement_chunk_edges: 陈述句-分块边列表 - statement_entity_edges: 陈述句-实体边列表 - entity_entity_edges: 实体-实体边列表 - dialog_data_list: 对话数据列表 - - Returns: - 包含三个元组的元组: - - 第一个元组:(对话节点列表, 分块节点列表, 陈述句节点列表) - - 第二个元组:去重前的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) - - 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表) - """ - logger.info("开始两阶段实体去重和消歧") - logger.info( - f"去重前: {len(entity_nodes)} 个实体节点, " - f"{len(statement_entity_edges)} 条陈述句-实体边, " - f"{len(entity_entity_edges)} 条实体-实体边" - ) - - try: - # 在试运行模式下,跳过第二层去重(不查询数据库) - if self.is_pilot_run: - logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重") - # 只执行第一层去重 - from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges - - dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges( - entity_nodes, - statement_entity_edges, - entity_entity_edges, - report_stage="第一层去重消歧(试运行)", - report_append=False, - dedup_config=self.config.deduplication, - ) - - result_tuple = ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - dedup_entity_nodes, - statement_chunk_edges, - dedup_statement_entity_edges, - dedup_entity_entity_edges, - ) - - final_entity_nodes = dedup_entity_nodes - final_statement_entity_edges = dedup_statement_entity_edges - final_entity_entity_edges = dedup_entity_entity_edges - else: - # 正式模式:执行完整的两阶段去重 - result_tuple = await dedup_layers_and_merge_and_return( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - dialog_data_list, - self.config, - self.connector, - ) - - # 解包返回值 - ( - _, - _, - _, - final_entity_nodes, - _, - final_statement_entity_edges, - final_entity_entity_edges, - ) = result_tuple - - logger.info( - f"去重后: {len(final_entity_nodes)} 个实体节点, " - f"{len(final_statement_entity_edges)} 条陈述句-实体边, " - f"{len(final_entity_entity_edges)} 条实体-实体边" - ) - logger.info( - f"去重效果: 实体减少 {len(entity_nodes) - len(final_entity_nodes)}, " - f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, " - f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}" - ) - - # 写入提取结果汇总(试运行和正式模式都需要生成) - try: - from app.core.config import settings - settings.ensure_memory_output_dir() - _write_extracted_result_summary( - chunk_nodes=chunk_nodes, - pipeline_output_dir=settings.MEMORY_OUTPUT_DIR, - ) - mode_str = "试运行" if self.is_pilot_run else "正式" - logger.info(f"提取结果汇总已写入({mode_str}模式)") - except Exception as e: - logger.warning(f"写入提取结果汇总失败: {e}") - - return result_tuple - - except Exception as e: - logger.error(f"两阶段去重失败: {e}", exc_info=True) - raise - - -# ============================================================================ -# 数据加载和预处理函数 -# ============================================================================ -# 以下函数从 extraction_pipeline.py 迁移而来,用于数据加载和预处理 - - -async def get_chunked_dialogs( - chunker_strategy: str = "RecursiveChunker", - group_id: str = "group_1", - indices: Optional[List[int]] = None, -) -> List[DialogData]: - """从测试数据生成分块对话 - - Args: - chunker_strategy: 分块策略(默认: RecursiveChunker) - group_id: 组ID - indices: 要处理的数据索引列表(可选) - - Returns: - 包含分块的 DialogData 对象列表 - """ - import json - import re - import os - - # 加载测试数据 - testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json") - with open(testdata_path, "r", encoding="utf-8") as f: - test_data = [json.loads(line) for line in f] - - dialog_data_list = [] - - if indices is not None: - # 选择特定索引 - selected_data = [test_data[i] for i in indices if 0 <= i < len(test_data)] - else: - # 默认使用所有数据 - selected_data = test_data - - for data in selected_data: - # 解析对话上下文 - context_text = data["context"] - - # 从context文本中解析日期 - conv_date: Optional[str] = None - m = re.search(r"(\d{4})年(\d{1,2})月(\d{1,2})日", context_text) - if m: - y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) - conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - else: - m = re.search(r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})", context_text) - if m: - y, mo, d = int(m.group(1)), int(m.group(2)), int(m.group(3)) - conv_date = f"{y:04d}-{mo:02d}-{d:02d}" - - dialog_metadata: Dict[str, Any] = {} - if conv_date: - dialog_metadata["conversation_date"] = conv_date - dialog_metadata["publication_date"] = conv_date - - # 分割对话为消息 - lines = context_text.split("\n") - messages = [] - - # 解析对话行 - for raw_line in lines: - line = raw_line.strip() - match = re.match(r'^[""]?(用户|AI)\s*[::]\s*(.*)$', line) - if match: - role = match.group(1) - msg = match.group(2).strip().rstrip('""') - from app.core.memory.models.message_models import ConversationMessage - messages.append(ConversationMessage(role=role, msg=msg)) - - # 创建 DialogData - from app.core.memory.models.message_models import ConversationContext - conversation_context = ConversationContext(msgs=messages) - dialog_data = DialogData( - context=conversation_context, - ref_id=data['id'], - group_id=group_id, - metadata=dialog_metadata, - ) - - # 创建分块器并处理对话 - from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker - chunker = DialogueChunker(chunker_strategy) - extracted_chunks = await chunker.process_dialogue(dialog_data) - dialog_data.chunks = extracted_chunks - - dialog_data_list.append(dialog_data) - - # 保存输出 - def serialize_datetime(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError( - f"Object of type {obj.__class__.__name__} is not JSON serializable" - ) - - combined_output = [dd.model_dump() for dd in dialog_data_list] - from app.core.config import settings - settings.ensure_memory_output_dir() - output_path = settings.get_memory_output_path("chunker_test_output.txt") - - import json - with open(output_path, "w", encoding="utf-8") as f: - json.dump( - combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime - ) - - return dialog_data_list - - -def preprocess_data( - input_path: Optional[str] = None, - output_path: Optional[str] = None, - skip_cleaning: bool = True, - indices: Optional[List[int]] = None -) -> List[DialogData]: - """数据预处理 - - Args: - input_path: 原始数据路径 - output_path: 预处理后数据保存路径 - skip_cleaning: 是否跳过数据清洗步骤(默认False) - indices: 要处理的数据索引列表 - - Returns: - 经过清洗转换后的 DialogData 列表 - """ - print("\n=== 数据预处理 ===") - from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor - preprocessor = DataPreprocessor() - try: - cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) - print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") - return cleaned_data - except Exception as e: - print(f"数据预处理过程中出现错误: {e}") - raise - - -async def get_chunked_dialogs_from_preprocessed( - data: List[DialogData], - chunker_strategy: str = "RecursiveChunker", - llm_client: Optional[Any] = None, -) -> List[DialogData]: - """从预处理后的数据中生成分块 - - Args: - data: 预处理后的 DialogData 列表 - chunker_strategy: 分块策略 - llm_client: LLM 客户端(用于 LLMChunker) - - Returns: - 带 chunks 的 DialogData 列表 - """ - print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===") - if not data: - raise ValueError("预处理数据为空,无法进行分块") - - all_chunked_dialogs: List[DialogData] = [] - from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker - - for dialog_data in data: - chunker = DialogueChunker(chunker_strategy, llm_client=llm_client) - chunks = await chunker.process_dialogue(dialog_data) - dialog_data.chunks = chunks - all_chunked_dialogs.append(dialog_data) - - return all_chunked_dialogs - - -async def get_chunked_dialogs_with_preprocessing( - chunker_strategy: str = "RecursiveChunker", - group_id: str = "default", - user_id: str = "default", - apply_id: str = "default", - indices: Optional[List[int]] = None, - input_data_path: Optional[str] = None, - llm_client: Optional[Any] = None, - skip_cleaning: bool = True, -) -> List[DialogData]: - """包含数据预处理步骤的完整分块流程 - - Args: - chunker_strategy: 分块策略 - group_id: 组ID - user_id: 用户ID - apply_id: 应用ID - indices: 要处理的数据索引列表 - input_data_path: 输入数据路径 - llm_client: LLM 客户端 - skip_cleaning: 是否跳过数据清洗步骤(默认False) - - Returns: - 带 chunks 的 DialogData 列表 - """ - import os - print("\n=== 完整数据处理流程(包含预处理)===") - - if input_data_path is None: - input_data_path = os.path.join( - os.path.dirname(__file__), "../../data", "testdata.json" - ) - - # 步骤1: 数据预处理(包含索引筛选) - from app.core.config import settings - settings.ensure_memory_output_dir() - preprocessed_data = preprocess_data( - input_path=input_data_path, - output_path=settings.get_memory_output_path("preprocessed_data.json"), - skip_cleaning=skip_cleaning, - indices=indices, - ) - - # 设置 group_id, user_id, apply_id - for dd in preprocessed_data: - dd.group_id = group_id - dd.user_id = user_id - dd.apply_id = apply_id - - # 步骤2: 语义剪枝 - try: - from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner - pruner = SemanticPruner(llm_client=llm_client) - - # 记录单对话场景下剪枝前的消息数量 - single_dialog_original_msgs = None - if len(preprocessed_data) == 1 and preprocessed_data[0].context: - single_dialog_original_msgs = len(preprocessed_data[0].context.msgs) - - preprocessed_data = await pruner.prune_dataset(preprocessed_data) - - # 单对话:打印清洗与剪枝信息 - if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None: - remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0 - deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs) - print( - f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs}," - f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。" - ) - else: - print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") - - # 保存剪枝后的数据 - try: - from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor - pruned_output_path = settings.get_memory_output_path("pruned_data.json") - dp = DataPreprocessor(output_file_path=pruned_output_path) - dp.save_data(preprocessed_data, output_path=pruned_output_path) - except Exception as se: - print(f"保存剪枝结果失败:{se}") - except Exception as e: - print(f"语义剪枝过程中出现错误,跳过剪枝: {e}") - - # 步骤3: 对话分块 - return await get_chunked_dialogs_from_preprocessed( - preprocessed_data, - chunker_strategy=chunker_strategy, - llm_client=llm_client, - ) diff --git a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py b/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py deleted file mode 100644 index 53815124..00000000 --- a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -知识提取模块 - -包含以下提取器: -- DialogueChunker: 对话分块 -- StatementExtractor: 陈述句提取 -- TripletExtractor: 三元组提取 -- TemporalExtractor: 时间信息提取 -- EmbeddingGenerator: 嵌入向量生成 -- MemorySummaryGenerator: 记忆摘要生成 -""" diff --git a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py b/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py deleted file mode 100644 index edb60a4d..00000000 --- a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py +++ /dev/null @@ -1,103 +0,0 @@ -import os -from typing import Optional - -from app.core.logging_config import get_memory_logger -from app.core.memory.models.message_models import DialogData, Chunk -from app.core.memory.models.config_models import ChunkerConfig -from app.core.memory.llm_tools.chunker_client import ChunkerClient -from app.core.memory.utils.config.config_utils import get_chunker_config - -logger = get_memory_logger(__name__) - - -class DialogueChunker: - """A class that processes dialogues and fills them with chunks based on a specified strategy. - - This class encapsulates the chunking process, allowing for easy configuration and application - of different chunking strategies to dialogue data. - """ - - def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None): - """Initialize the DialogueChunker with a specific chunking strategy. - - Args: - chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker - """ - self.chunker_strategy = chunker_strategy - chunker_config_dict = get_chunker_config(chunker_strategy) - self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) - # 对于 LLMChunker,需要传入 llm_client - if self.chunker_config.chunker_strategy == "LLMChunker": - self.chunker_client = ChunkerClient(self.chunker_config, llm_client) - else: - self.chunker_client = ChunkerClient(self.chunker_config) - - async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]: - """Process a dialogue by generating chunks and adding them to the DialogData object. - - Args: - dialogue: The DialogData object to process - - Returns: - A list of Chunk objects - """ - result_dialogue = await self.chunker_client.generate_chunks(dialogue) - # Defensive fallback: ensure at least one chunk is returned for non-empty content - try: - chunks = result_dialogue.chunks - except Exception: - chunks = [] - - if not chunks or len(chunks) == 0: - # If the dialogue has content, return a single fallback chunk built from messages - content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "") - if content_str and len(content_str.strip()) > 0: - fallback_chunk = Chunk.from_messages( - dialogue.context.msgs, - metadata={ - "fallback": "single_chunk", - "chunker_strategy": self.chunker_config.chunker_strategy, - "source": "DialogueChunkerFallback", - }, - ) - return [fallback_chunk] - # No content: return empty list - return [] - - return chunks - - def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str: - """Save the chunking results to a file and return the output path. - - Args: - dialogue: The processed DialogData object with chunks - output_path: Optional path to save the output (default: chunker_output_{strategy}.txt) - - Returns: - The path where the output was saved - """ - if not output_path: - output_path = os.path.join(os.path.dirname(__file__), "..", "..", - f"chunker_output_{self.chunker_strategy.lower()}.txt") - - output_lines = [] - output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===") - output_lines.append(f"Dialogue ID: {dialogue.ref_id}") - output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages") - output_lines.append(f"Total characters: {len(dialogue.content)}") - - output_lines.append(f"Generated {len(dialogue.chunks)} chunks:") - for i, chunk in enumerate(dialogue.chunks): - output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") - output_lines.append(f" Content preview: {chunk.content}...") - if chunk.metadata: - output_lines.append(f" Metadata: {chunk.metadata}") - - with open(output_path, "w", encoding="utf-8") as f: - f.write("\n".join(output_lines)) - - logger.info(f"Chunking results saved to: {output_path}") - return output_path - - diff --git a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py b/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py deleted file mode 100644 index 0dc48815..00000000 --- a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -嵌入向量生成器 - -为陈述句、分块、对话和实体生成嵌入向量,用于语义搜索。 -""" - -import asyncio -from typing import List, Dict, Any, Tuple -from app.core.memory.models.message_models import DialogData -from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.models.base import RedBearModelConfig - - -class EmbeddingGenerator: - """嵌入向量生成器""" - - def __init__(self, embedding_id: str): - """初始化嵌入向量生成器 - - Args: - embedding_id: 嵌入模型 ID - """ - embedder_config = get_embedder_config(embedding_id) - self.embedder_client = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(embedder_config), - ) - - async def _generate_embeddings(self, texts: List[str], batch_size: int = 100) -> List[List[float]]: - """生成一批文本的嵌入向量(支持分批并行) - - Args: - texts: 文本列表 - batch_size: 每批处理的文本数量(默认 100) - - Returns: - 嵌入向量列表 - """ - if not texts: - return [] - - # 如果文本数量小于批次大小,直接处理 - if len(texts) <= batch_size: - return await self.embedder_client.response(texts) - - # 分批并行处理 - print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") - batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)] - print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") - - # 并行发送所有批次 - batch_results = await asyncio.gather(*[ - self.embedder_client.response(batch) for batch in batches - ]) - - # 合并结果 - embeddings = [] - for batch_result in batch_results: - embeddings.extend(batch_result) - - print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") - return embeddings - - async def generate_statement_embeddings( - self, - chunked_dialogs: List[DialogData] - ) -> List[Dict[str, List[float]]]: - """为所有对话中的陈述句生成嵌入向量 - - Args: - chunked_dialogs: 包含分块和陈述句的对话列表 - - Returns: - 每个对话的陈述句嵌入向量映射列表 - """ - print("\n=== 生成陈述句嵌入向量 ===") - - # 收集所有陈述句 - all_statements = [] - statement_to_dialog_chunk_map = [] - - for d_idx, dialog in enumerate(chunked_dialogs): - chunks = dialog.chunks - if asyncio.iscoroutine(chunks): - chunks = await chunks - for c_idx, chunk in enumerate(chunks): - for s_idx, stmt in enumerate(chunk.statements): - all_statements.append(stmt.statement) - statement_to_dialog_chunk_map.append((d_idx, c_idx, s_idx)) - - # 批量生成嵌入向量 - stmt_embeddings = await self._generate_embeddings(all_statements) - - # 创建映射 - stmt_embedding_maps = [{} for _ in chunked_dialogs] - for idx, embedding in enumerate(stmt_embeddings): - d_idx, c_idx, s_idx = statement_to_dialog_chunk_map[idx] - stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id - stmt_embedding_maps[d_idx][stmt_id] = embedding - - print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量") - return stmt_embedding_maps - - async def generate_chunk_embeddings( - self, - chunked_dialogs: List[DialogData] - ) -> List[Dict[str, List[float]]]: - """为所有对话中的分块生成嵌入向量 - - Args: - chunked_dialogs: 包含分块的对话列表 - - Returns: - 每个对话的分块嵌入向量映射列表 - """ - print("\n=== 生成分块嵌入向量 ===") - - # 收集所有分块 - all_chunks = [] - chunk_to_dialog_map = [] - - for d_idx, dialog in enumerate(chunked_dialogs): - for c_idx, chunk in enumerate(dialog.chunks): - all_chunks.append(chunk.content) - chunk_to_dialog_map.append((d_idx, c_idx)) - - # 批量生成嵌入向量 - chunk_embeddings = await self._generate_embeddings(all_chunks) - - # 创建映射 - chunk_embedding_maps = [{} for _ in chunked_dialogs] - for idx, embedding in enumerate(chunk_embeddings): - d_idx, c_idx = chunk_to_dialog_map[idx] - chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id - chunk_embedding_maps[d_idx][chunk_id] = embedding - - print(f"为 {len(all_chunks)} 个分块生成了嵌入向量") - return chunk_embedding_maps - - async def generate_dialog_embeddings( - self, - chunked_dialogs: List[DialogData] - ) -> List[List[float]]: - """为对话生成嵌入向量(当前跳过,返回空列表) - - Args: - chunked_dialogs: 对话列表 - - Returns: - 对话嵌入向量列表(当前为空) - """ - # 跳过对话嵌入向量生成,但保持正确的长度 - return [[] for _ in chunked_dialogs] - - async def generate_all_embeddings( - self, - chunked_dialogs: List[DialogData] - ) -> Tuple[ - List[Dict[str, List[float]]], - List[Dict[str, List[float]]], - List[List[float]] - ]: - """生成所有类型的嵌入向量 - - Args: - chunked_dialogs: 包含分块和陈述句的对话列表 - - Returns: - (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表) - """ - print("\n=== 生成所有嵌入向量 ===") - - # 并发生成陈述句和分块嵌入向量 - stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather( - self.generate_statement_embeddings(chunked_dialogs), - self.generate_chunk_embeddings(chunked_dialogs) - ) - - # 对话嵌入向量(当前跳过) - dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs) - - print( - f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量" - ) - - return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings - - async def generate_entity_embeddings( - self, - triplet_maps: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """为三元组中的实体生成嵌入向量 - - Args: - triplet_maps: 三元组映射列表 - - Returns: - 更新后的三元组映射列表(实体包含嵌入向量) - """ - print("\n=== 生成实体嵌入向量 ===") - - entity_texts: List[str] = [] - entity_refs: List[Any] = [] - - # 收集所有实体 - for trip_map in triplet_maps: - for _, triplet_info in trip_map.items(): - entities = getattr(triplet_info, "entities", None) - if not entities: - continue - for ent in entities: - text = getattr(ent, "name", None) or getattr(ent, "description", None) - if text: - entity_texts.append(text) - entity_refs.append(ent) - - if not entity_texts: - print("没有找到需要生成嵌入向量的实体") - return triplet_maps - - # 批量生成嵌入向量 - embeddings = await self._generate_embeddings(entity_texts) - - # 打印前几个嵌入向量的维度 - for i in range(min(5, len(embeddings))): - print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") - - # 将嵌入向量赋值给实体 - for ent, emb in zip(entity_refs, embeddings): - setattr(ent, "name_embedding", emb) - - print(f"为 {len(entity_refs)} 个实体生成了嵌入向量") - return triplet_maps - - -# 保持向后兼容的函数接口 -async def embedding_generation( - chunked_dialogs: List[DialogData], - embedding_id: str -) -> Tuple[ - List[Dict[str, List[float]]], - List[Dict[str, List[float]]], - List[List[float]] -]: - """生成陈述句、分块和对话的嵌入向量(向后兼容接口) - - Args: - chunked_dialogs: 包含分块和陈述句的对话列表 - embedding_id: 嵌入模型 ID - - Returns: - (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表) - """ - generator = EmbeddingGenerator(embedding_id) - return await generator.generate_all_embeddings(chunked_dialogs) - - -async def generate_entity_embeddings_from_triplets( - triplet_maps: List[Dict[str, Any]], - embedding_id: str -) -> List[Dict[str, Any]]: - """为三元组中的实体生成嵌入向量(向后兼容接口) - - Args: - triplet_maps: 三元组映射列表 - embedding_id: 嵌入模型 ID - - Returns: - 更新后的三元组映射列表(实体包含嵌入向量) - """ - generator = EmbeddingGenerator(embedding_id) - return await generator.generate_entity_embeddings(triplet_maps) - - -async def embedding_generation_all( - chunked_dialogs: List[DialogData], - triplet_maps: List[Dict[str, Any]], - embedding_id: str -) -> Tuple[ - List[Dict[str, List[float]]], - List[Dict[str, List[float]]], - List[List[float]], - List[Dict[str, Any]] -]: - """生成所有类型的嵌入向量(向后兼容接口) - - Args: - chunked_dialogs: 包含分块和陈述句的对话列表 - triplet_maps: 三元组映射列表 - embedding_id: 嵌入模型 ID - - Returns: - (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表) - """ - print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") - - generator = EmbeddingGenerator(embedding_id) - - # 生成陈述句、分块和对话的嵌入向量 - stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings = await generator.generate_all_embeddings( - chunked_dialogs - ) - - # 生成实体嵌入向量 - updated_triplet_maps = await generator.generate_entity_embeddings(triplet_maps) - - return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings, updated_triplet_maps diff --git a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py deleted file mode 100644 index 4c62bd4c..00000000 --- a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import asyncio -from datetime import datetime -from typing import List, Optional - -from pydantic import Field, field_validator - -from app.core.logging_config import get_memory_logger -from app.core.memory.models.message_models import DialogData - -logger = get_memory_logger(__name__) -from app.core.memory.models.graph_models import MemorySummaryNode -from app.core.memory.models.base_response import RobustLLMResponse -from app.core.models.base import RedBearModelConfig -from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt -from uuid import uuid4 - - -class MemorySummaryResponse(RobustLLMResponse): - """Structured response for summary generation per chunk. - - This model ensures the LLM returns a valid, non-empty summary. - Inherits robust validation from RobustLLMResponse. - """ - summary: str = Field( - ..., - description="Concise memory summary for a single chunk. Must be a meaningful, non-empty string.", - min_length=1, - max_length=5000 - ) - - -async def _process_chunk_summary( - dialog: DialogData, - chunk, - llm_client, - embedder: OpenAIEmbedderClient, -) -> Optional[MemorySummaryNode]: - """Process a single chunk to generate a memory summary node.""" - # Skip empty chunks - if not chunk.content or not chunk.content.strip(): - return None - - try: - # Render prompt via Jinja2 for a single chunk - prompt_content = await render_memory_summary_prompt( - chunk_texts=chunk.content, - json_schema=MemorySummaryResponse.model_json_schema(), - max_words=200, - ) - - messages = [ - {"role": "system", "content": "You are an expert memory summarizer."}, - {"role": "user", "content": prompt_content}, - ] - - # Generate structured summary with the existing LLM client - structured = await llm_client.response_structured( - messages=messages, - response_model=MemorySummaryResponse, - ) - summary_text = structured.summary.strip() - - # Embed the summary - embedding = (await embedder.response([summary_text]))[0] - - # Build node per chunk - node = MemorySummaryNode( - id=uuid4().hex, - name=f"MemorySummaryChunk_{chunk.id}", - group_id=dialog.group_id, - user_id=dialog.user_id, - apply_id=dialog.apply_id, - run_id=dialog.run_id, # 使用 dialog 的 run_id - created_at=datetime.now(), - expired_at=datetime(9999, 12, 31), - dialog_id=dialog.id, - chunk_ids=[chunk.id], - content=summary_text, - summary_embedding=embedding, - metadata={"ref_id": dialog.ref_id}, - config_id=dialog.config_id, # 添加 config_id - ) - return node - - except Exception as e: - # Log the error but continue processing other chunks - logger.warning(f"Failed to generate summary for chunk {chunk.id} in dialog {dialog.id}: {e}", exc_info=True) - return None - - -async def Memory_summary_generation( - chunked_dialogs: List[DialogData], - llm_client, - embedding_id, -) -> List[MemorySummaryNode]: - """Generate memory summaries per chunk, embed them, and return nodes.""" - embedder_cfg_dict = get_embedder_config(embedding_id) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(embedder_cfg_dict), - ) - - # Collect all tasks for parallel processing - tasks = [] - for dialog in chunked_dialogs: - for chunk in dialog.chunks: - tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder)) - - # Process all chunks in parallel - results = await asyncio.gather(*tasks, return_exceptions=False) - - # Filter out None values (failed or empty chunks) - nodes = [node for node in results if node is not None] - - return nodes diff --git a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py deleted file mode 100644 index 1e79c339..00000000 --- a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ /dev/null @@ -1,301 +0,0 @@ -import os -import asyncio -import logging -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field -from datetime import datetime - -from app.core.memory.models.message_models import DialogData, Statement -#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 -from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo - -from app.core.memory.models.variate_config import StatementExtractionConfig -from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt -from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo - -logger = logging.getLogger(__name__) - -class ExtractedStatement(BaseModel): - """Schema for extracted statement from LLM""" - statement: str = Field(..., description="The extracted statement text") - statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION") - temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL") - relevence: str = Field(..., description="RELEVANT or IRRELEVANT") - -# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句) -class StatementExtractionResponse(BaseModel): - statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements") - -class StatementExtractor: - """Class for extracting statements from dialog chunks using LLM (relations separated)""" - - def __init__(self, llm_client: Any, config: StatementExtractionConfig = None): - # 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 - """Initialize the StatementExtractor with an LLM client and configuration - - Args: - llm_client: OpenAIClient instance for processing LLM requests - config: StatementExtractionConfig for controlling extraction behavior - """ - self.llm_client = llm_client - self.config = config or StatementExtractionConfig() - - async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: - """Process a single chunk and return extracted statements - - Args: - chunk: Chunk object to process - group_id: Group ID to assign to all statements in this chunk - dialogue_content: Full dialogue content to provide as context - - Returns: - List of ExtractedStatement objects extracted from the chunk - """ - # Prepare the chunk content for processing - chunk_content = chunk.content - - # Render the prompt using helper function - prompt_content = await render_statement_extraction_prompt( - chunk_content=chunk_content, - definitions=LABEL_DEFINITIONS, - json_schema=ExtractedStatement.model_json_schema(), - granularity=self.config.statement_granularity, - include_dialogue_context=self.config.include_dialogue_context, - dialogue_content=dialogue_content, - max_dialogue_chars=self.config.max_dialogue_context_chars - ) - - # Simple system message - system_content = "You are an expert at extracting and labeling atomic statements from conversational text. Return valid JSON conforming to the schema." - - # Create messages for LLM - messages = [ - {"role": "system", "content": system_content}, - {"role": "user", "content": prompt_content} - ] - - try: - # Get structured response from LLM (statements only) - response = await self.llm_client.response_structured(messages, StatementExtractionResponse) - # Defensive: ensure response has the expected structure - if not hasattr(response, "statements") or response.statements is None: - logger.warning("Invalid structured response: missing 'statements'. Returning empty list for this chunk.") - return [] - - # Convert extracted statements to Statement objects - chunk_statements = [] - for extracted_stmt in response.statements: - # Normalize and correct enums defensively - stmt_type_str = str(extracted_stmt.statement_type).strip().upper() - temporal_type_str = str(extracted_stmt.temporal_type).strip().upper() - relevence_str = str(extracted_stmt.relevence).strip().upper() - - # Convert strings to enum types with fallback defaults - try: - stmt_type = StatementType[stmt_type_str] if stmt_type_str in StatementType.__members__ else StatementType.FACT - except (KeyError, ValueError): - stmt_type = StatementType.FACT - - try: - temporal_type = TemporalInfo[temporal_type_str] if temporal_type_str in TemporalInfo.__members__ else TemporalInfo.STATIC - except (KeyError, ValueError): - temporal_type = TemporalInfo.STATIC - - try: - relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT - except (KeyError, ValueError): - relevence_info = RelevenceInfo.RELEVANT - - chunk_statement = Statement( - statement=extracted_stmt.statement, - stmt_type=stmt_type, - temporal_info=temporal_type, - relevence_info=relevence_info, - chunk_id=chunk.id, - group_id=group_id, - ) - chunk_statements.append(chunk_statement) - - # 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata - return chunk_statements - - except Exception as e: - logger.error(f"Error processing chunk: {e}", exc_info=True) - # Return empty list to indicate failure for this chunk - return [] - - async def extract_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> List[List[Statement]]: - """Extract statements from a DialogData object. - - Args: - dialog_data: The DialogData object containing chunks. - limit_chunks: Optional limit on the number of chunks to process. - """ - # Determine how many chunks to process - chunks_to_process = dialog_data.chunks[:limit_chunks] if limit_chunks else dialog_data.chunks - - logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction") - - # Process all chunks concurrently, passing the group_id and dialogue content from dialog_data - dialogue_content = dialog_data.content if self.config.include_dialogue_context else None - results = await asyncio.gather( - *[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process], - return_exceptions=True - ) - - # Filter out exceptions and return valid results - valid_results = [] - for result in results: - if isinstance(result, list) and result is not None: - valid_results.append(result) - else: - print(f"Error in statement extraction: {result}") - valid_results.append([]) - - return valid_results - - def save_statements(self, statements: List[Statement], output_path: str = None) -> str: - """Save the extracted statements to a file and return the output path. - - Args: - statements: List of Statement objects to save - output_path: Optional path to save the output (default: statement_extraction.txt) - - Returns: - The path where the output was saved - """ - # 使用全局配置的输出路径 - if not output_path: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_path = settings.get_memory_output_path("statement_extraction.txt") - - with open(output_path, "w", encoding="utf-8") as f: - f.write(f"Extracted Statements ({len(statements)} total)\n") - f.write("=" * 50 + "\n\n") - - for i, statement in enumerate(statements, 1): - f.write(f"Statement {i}:\n") - f.write(f"Id: {statement.id}\n") - f.write(f"Group Id: {statement.group_id}\n") - f.write(f"Content: {statement.statement}\n") - f.write(f"Type: {statement.stmt_type.value}\n") - f.write(f"Temporal Info: {statement.temporal_info.value}\n") - f.write(f"Created At: {datetime.now()}\n") - f.write(f"Expired At: {None}\n") - f.write(f"Valid At: {statement.temporal_validity.valid_at if statement.temporal_validity else None}\n") - f.write(f"Invalid At: {statement.temporal_validity.invalid_at if statement.temporal_validity else None}\n") - f.write(f"Chunk Id: {statement.chunk_id}\n") - # add relevance information to satisfy tests - if hasattr(statement, "relevence_info") and statement.relevence_info is not None: - f.write(f"Relevence Info: {statement.relevence_info.value}\n") - f.write("-" * 30 + "\n\n") - - print(f"Extracted {len(statements)} statements and saved to {output_path}") - return output_path - - def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str: - """按对话分组聚合强/弱关系并写入 TXT 文件。 - - 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content` - - 在该对话段内再分为 Strong Relations / Weak Relations 两部分 - - Strong: 逐条输出 `Chunk ID` 与 `Triple` - - Weak: 逐条输出 `Chunk ID` 与 `Entity` - """ - print("\n=== Relations Classify ===") - - # 使用全局配置的输出路径 - if not output_path: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_path = settings.get_memory_output_path("relations_output.txt") - # output_path = os.path.join(os.path.dirname(__file__), "..", "relations_output.txt") - - dialog_sections: List[Dict[str, Any]] = [] - total_strong = 0 - total_weak = 0 - - for dialog in dialogs: - strong_relations: List[Dict[str, Any]] = [] - weak_relations: List[Dict[str, Any]] = [] - - for chunk in dialog.chunks or []: - # 基于三元组/实体推导强弱关系 - for stmt in chunk.statements or []: - te = getattr(stmt, "triplet_extraction_info", None) - if not te: - continue - trips = getattr(te, "triplets", []) or [] - ents = getattr(te, "entities", []) or [] - - # Strong: 逐条输出三元组 - if trips: - for trip in trips: - subj = getattr(trip, "subject_name", "") - pred = str(getattr(trip, "predicate", "")) - obj = getattr(trip, "object_name", "") - triple_str = f"({subj}, {pred}, {obj})" - strong_relations.append({ - "chunk_id": chunk.id, - "triple": triple_str, - }) - else: - # Weak: 无三元组但有实体 - for ent in ents: - name = getattr(ent, "name", "") - desc = getattr(ent, "description", "") or "" - entity_str = f"{name}: {desc}" if desc else name - if name: - weak_relations.append({ - "chunk_id": chunk.id, - "entity": entity_str, - }) - - total_strong += len(strong_relations) - total_weak += len(weak_relations) - - dialog_sections.append({ - "dialog_id": dialog.ref_id, - "group_id": dialog.group_id, - "content": dialog.content if getattr(dialog, "content", None) else "", - "strong": strong_relations, - "weak": weak_relations, - }) - - try: - with open(output_path, "w", encoding="utf-8") as f: - f.write(f"Relations Extraction (grouped by dialogs, strong: {total_strong}, weak: {total_weak})\n") - f.write("=" * 50 + "\n\n") - - for idx, section in enumerate(dialog_sections, 1): - f.write(f"Dialog {idx}:\n") - f.write(f"Dialog ID: {section.get('dialog_id', '')}\n") - f.write(f"Group ID: {section.get('group_id', '')}\n") - f.write("Content:\n") - f.write(f"{section.get('content', '')}\n") - f.write("-" * 40 + "\n\n") - - # Strong Relations for this dialog - strong_list = section.get("strong", []) - f.write(f"Strong Relations ({len(strong_list)} total)\n") - f.write("-" * 30 + "\n\n") - for i, item in enumerate(strong_list, 1): - f.write(f"Item {i}:\n") - f.write(f"Chunk ID: {item.get('chunk_id', '')}\n") - f.write(f"Triple: {item.get('triple', '')}\n") - f.write("-" * 30 + "\n\n") - - # Weak Relations for this dialog - weak_list = section.get("weak", []) - f.write(f"Weak Relations ({len(weak_list)} total)\n") - f.write("-" * 30 + "\n\n") - for i, item in enumerate(weak_list, 1): - f.write(f"Item {i}:\n") - f.write(f"Chunk ID: {item.get('chunk_id', '')}\n") - f.write(f"Entity: {item.get('entity', '')}\n") - f.write("-" * 30 + "\n\n") - - print(f"Saved relations to {output_path}") - return output_path - except Exception as e: - print(f"Failed to save relations to {output_path}: {e}") - return output_path diff --git a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py b/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py deleted file mode 100644 index 646ae914..00000000 --- a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py +++ /dev/null @@ -1,222 +0,0 @@ -import os -import asyncio -from datetime import datetime -from typing import Any, Optional -from pydantic import BaseModel, Field -from app.core.memory.src.llm_tools.openai_client import OpenAIClient -from app.core.memory.models.message_models import DialogData, Statement, TemporalValidityRange -from app.core.memory.utils.prompt.prompt_utils import render_temporal_extraction_prompt -from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, TemporalInfo -from app.core.memory.utils.log.logging_utils import prompt_logger - - -class RawTemporalRange(BaseModel): - """Schema for the raw temporal range extracted by the LLM.""" - - valid_at: Optional[str] = Field( - None, description="The start date and time of the validity range in ISO 8601 format." - ) - invalid_at: Optional[str] = Field( - None, description="The end date and time of the validity range in ISO 8601 format." - ) - - -class TemporalExtractor: - """ - Extracts temporal validity ranges from statements using an LLM. - """ - - def __init__(self, llm_client: OpenAIClient): - """ - Initializes the TemporalExtractor. - - Args: - llm_client (OpenAIClient): The OpenAI client to use for LLM calls. - """ - self.llm_client = llm_client - - async def _extract_temporal_ranges( - self, statement: Statement, ref_dates: dict[str, Any] - ) -> TemporalValidityRange: - """ - Extracts the temporal range for a single statement. - - Args: - statement (Statement): The statement to process. - ref_dates (dict[str, Any]): Reference dates for context. - - Returns: - TemporalValidityRange: The extracted temporal validity range. - """ - if not ref_dates: - ref_dates = {"today": datetime.now().strftime("%Y-%m-%d")} - - if statement.temporal_info == TemporalInfo.ATEMPORAL: - return TemporalValidityRange(valid_at=None, invalid_at=None) - - temporal_guide = LABEL_DEFINITIONS["temporal_labelling"] - statement_guide = LABEL_DEFINITIONS["statement_labelling"] - - # Log start and input context - try: - prompt_logger.info(f"[Temporal] Started - statement_id={statement.id}") - prompt_logger.debug( - f"[Temporal] Input statement=\"{statement.statement}\" ref_dates={ref_dates}" - ) - except Exception: - pass - - prompt_content = await render_temporal_extraction_prompt( - ref_dates=ref_dates, - statement=statement.model_dump(), - temporal_guide=temporal_guide, - statement_guide=statement_guide, - json_schema=RawTemporalRange.model_json_schema(), - ) - - messages = [ - { - "role": "system", - "content": "You are an expert at extracting temporal validity ranges from statements. Follow the provided instructions carefully and return valid JSON.", - }, - {"role": "user", "content": prompt_content}, - ] - - try: - response = await self.llm_client.response_structured( - messages, RawTemporalRange - ) - if response: - # Log raw structured response - try: - prompt_logger.debug( - f"[Temporal] Raw structured response - statement_id={statement.id}: valid_at={response.valid_at}, invalid_at={response.invalid_at}" - ) - except Exception: - pass - return TemporalValidityRange( - valid_at=response.valid_at, invalid_at=response.invalid_at - ) - except Exception as e: - try: - prompt_logger.warning( - f"[Temporal] Failed to process statement_id={statement.id}. Error: {e}" - ) - except Exception: - pass - - return TemporalValidityRange(valid_at=None, invalid_at=None) - - from typing import Dict, Tuple - - async def extract_temporal_ranges( - self, dialog_data: DialogData, ref_dates: Optional[dict[str, Any]] = None - ) -> Dict[str, TemporalValidityRange]: - """ - Extracts temporal ranges for statements in the dialog_data. - - Args: - dialog_data (DialogData): The dialog data containing chunks with statements to process. - ref_dates (Optional[dict[str, Any]]): Reference dates for context. - - Returns: - Dict[str, TemporalValidityRange]: A dictionary mapping statement IDs to their temporal ranges. - """ - if ref_dates is None: - ref_dates = {} - - statement_temporal_map = {} - - # Header (match legacy format) - try: - prompt_logger.info("") - prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===") - prompt_logger.info( - f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}" - ) - except Exception: - pass - - # Collect all statements with their IDs - all_tasks = [] - statement_ids = [] - - for chunk in dialog_data.chunks: - if not chunk.statements: - continue - - for statement in chunk.statements: - if statement.temporal_info == TemporalInfo.ATEMPORAL: - # Log skipped - try: - prompt_logger.info( - f"[Temporal] Skipped ATEMPORAL - statement_id={statement.id}" - ) - except Exception: - pass - statement_temporal_map[statement.id] = TemporalValidityRange( - valid_at=None, invalid_at=None - ) - continue - all_tasks.append(self._extract_temporal_ranges(statement, ref_dates)) - statement_ids.append(statement.id) - - # Process all statements concurrently - results = await asyncio.gather(*all_tasks, return_exceptions=True) - - # Map results back to statement IDs - for i, result in enumerate(results): - statement_id = statement_ids[i] - if isinstance(result, TemporalValidityRange): - statement_temporal_map[statement_id] = result - else: - try: - prompt_logger.warning( - f"[Temporal] Failed to process statement_id={statement_id}. Error: {result}" - ) - except Exception: - pass - statement_temporal_map[statement_id] = TemporalValidityRange( - valid_at=None, invalid_at=None - ) - - # Summary (match legacy completion line) - try: - extracted_count = sum( - 1 - for v in statement_temporal_map.values() - if (v.valid_at is not None or v.invalid_at is not None) - ) - prompt_logger.info( - f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)} completed, extracted_valid_ranges={extracted_count}" - ) - except Exception: - pass - - return statement_temporal_map - - def save_temporal_extractions_to_file( - self, dialog_data: DialogData, output_path: Optional[str] = None - ): - """ - Saves the extracted temporal data to a text file. - - Args: - dialog_data (DialogData): The dialog data containing the statements with temporal data. - output_path (str): The path to the output file. - """ - if not output_path: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_path = settings.get_memory_output_path("extracted_temporal_data.txt") - with open(output_path, "w") as f: - for chunk in dialog_data.chunks: - f.write(f"Chunk: {chunk.content}\n") - for statement in chunk.statements: - f.write(f" - Statement: {statement.statement}\n") - if statement.temporal_validity: - f.write(f" - Valid At: {statement.temporal_validity.valid_at}\n") - f.write(f" - Invalid At: {statement.temporal_validity.invalid_at}\n") - else: - f.write(f" - Temporal Validity: Not Extracted\n") - f.write("\n") diff --git a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py deleted file mode 100644 index c65d5b74..00000000 --- a/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ /dev/null @@ -1,223 +0,0 @@ -import os -import asyncio -from typing import List, Dict - -from app.core.logging_config import get_memory_logger -from app.core.memory.src.llm_tools.openai_client import OpenAIClient -from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt -from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 -from app.core.memory.models.triplet_models import TripletExtractionResponse -from app.core.memory.models.message_models import DialogData, Statement -from app.core.memory.utils.log.logging_utils import prompt_logger - -logger = get_memory_logger(__name__) - - - -class TripletExtractor: - """Extracts knowledge triplets and entities from statements using LLM""" - - def __init__(self, llm_client: OpenAIClient): - """Initialize the TripletExtractor with an LLM client - - Args: - llm_client: OpenAIClient instance for processing - """ - self.llm_client = llm_client - - async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse: - """Process a single statement and return extracted triplets and entities""" - # Render the prompt using helper function - # Log start and input context similar to legacy logs - try: - prompt_logger.info(f"[Triplet] Started - statement_id={statement.id}") - prompt_logger.debug(f"[Triplet] Input statement=\"{statement.statement}\"") - except Exception: - # Avoid breaking flow due to logging issues - pass - - prompt_content = await render_triplet_extraction_prompt( - statement=statement.statement, - chunk_content=chunk_content, - json_schema=TripletExtractionResponse.model_json_schema(), - predicate_instructions=PREDICATE_DEFINITIONS - ) - - # Create messages for LLM - messages = [ - {"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."}, - {"role": "user", "content": prompt_content} - ] - - try: - # Get structured response from LLM - response = await self.llm_client.response_structured(messages, TripletExtractionResponse) - # Filter triplets to only allowed predicates from ontology - # 这里过滤掉了不在 Predicate 枚举中的谓语 但是容易造成谓语太严格,有点语句的谓语没有在枚举中,就被判断为弱关系 - allowed_predicates = {p.value for p in Predicate} - filtered_triplets = [t for t in response.triplets if getattr(t, "predicate", "") in allowed_predicates] - # 仅保留predicate ∈ Predicate 的三元组,其余全部剔除 - - # Create new triplets with statement_id set during creation - updated_triplets = [] - for triplet in filtered_triplets: # 仅保留 predicate ∈ Predicate 的三元组 - updated_triplet = triplet.model_copy(update={"statement_id": statement.id}) - updated_triplets.append(updated_triplet) - - # Log completion and per-item details to match legacy format - try: - prompt_logger.info( - f"[Triplet] Completed - statement_id={statement.id}, triplets={len(updated_triplets)}, entities={len(response.entities)}" - ) - for i, t in enumerate(updated_triplets, 1): - prompt_logger.debug( - f"[Triplet] Triplet #{i}: ({t.subject_name}) - {t.predicate} - ({t.object_name}) value={t.value if t.value is not None else 'None'}" - ) - for i, e in enumerate(response.entities, 1): - prompt_logger.debug( - f"[Triplet] Entity #{i}: id={getattr(e, 'entity_idx', None)} name={getattr(e, 'name', None)} type={getattr(e, 'type', None)} desc={getattr(e, 'description', None)}" - ) - except Exception: - print(f"Error logging triplet details: {e}") - pass - - # Return new response with updated triplets - return TripletExtractionResponse( - triplets=updated_triplets, - entities=response.entities - ) - # # Set statement_id for each triplet to establish parent relationship - # for triplet in response.triplets: - # triplet.statement_id = statement.id - - # return response - - except Exception as e: - logger.error(f"Error processing statement: {e}", exc_info=True) - return TripletExtractionResponse(triplets=[], entities=[]) - - async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]: - """Extract triplets and entities from statements - - Args: - dialog_data: DialogData object to process - limit_chunks: Number of chunks to process - - Returns: - Dict[str, TripletExtractionResponse]: Dictionary mapping statement IDs to their triplet responses - """ - # Collect all statements from the specified chunks - all_statements = [] - chunks_to_process = dialog_data.chunks[:limit_chunks] if limit_chunks else dialog_data.chunks - - for chunk in chunks_to_process: - all_statements.extend(chunk.statements) - - logger.info(f"Processing {len(all_statements)} statements for triplet extraction...") - try: - prompt_logger.info( - f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}" - ) - except Exception: - pass - - # Prepare tasks and statement IDs - tasks = [] - statement_ids = [] - - for chunk in chunks_to_process: - for statement in chunk.statements: - tasks.append(self._extract_triplets(statement, chunk.content)) - statement_ids.append(statement.id) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # Map results to statement IDs - statement_triplet_map = {} - for i, result in enumerate(results): - statement_id = statement_ids[i] - if isinstance(result, TripletExtractionResponse): - statement_triplet_map[statement_id] = result - else: - logger.error(f"Error in triplet extraction for statement {statement_id}: {result}", exc_info=True) - statement_triplet_map[statement_id] = TripletExtractionResponse(triplets=[], entities=[]) - - # Dialog-level summary and details (match legacy format) - try: - # Flatten totals - all_triplets = [] - all_entities_with_stmt = [] - for sid, resp in statement_triplet_map.items(): - for t in resp.triplets: - all_triplets.append(t) - for e in resp.entities: - all_entities_with_stmt.append((sid, e)) - - prompt_logger.info( - f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)} completed, total_triplets={len(all_triplets)}, total_entities={len(all_entities_with_stmt)}" - ) - - # Triplets Detail section - prompt_logger.info("\n--- Triplets Detail ---") - for i, t in enumerate(all_triplets, 1): - prompt_logger.info( - f"[Triplet] #{i} statement_id={getattr(t, 'statement_id', None)} subject=({getattr(t, 'subject_name', None)}:{getattr(t, 'subject_id', None)}) predicate={getattr(t, 'predicate', None)} object=({getattr(t, 'object_name', None)}:{getattr(t, 'object_id', None)}) value={getattr(t, 'value', None) if getattr(t, 'value', None) is not None else 'None'}" - ) - - # Entities Detail section - prompt_logger.info("\n--- Entities Detail ---") - for i, (sid, e) in enumerate(all_entities_with_stmt, 1): - prompt_logger.info( - f"[Entity] #{i} statement_id={sid} id={getattr(e, 'entity_idx', None)} name={getattr(e, 'name', None)} type={getattr(e, 'type', None)} desc={getattr(e, 'description', None)}" - ) - except Exception: - pass - - return statement_triplet_map - - def save_triplets(self, triplet_responses: List[TripletExtractionResponse], output_path: str = None) -> str: - """Save extracted triplets and entities to a file - - Args: - triplet_responses: List of TripletExtractionResponse objects - output_path: Optional path to save the results - - Returns: - Path where the triplets were saved - """ - if output_path is None: - from app.core.config import settings - settings.ensure_memory_output_dir() - output_path = settings.get_memory_output_path("extracted_triplets.txt") - - # Flatten all triplets and entities - all_triplets = [] - all_entities = [] - - for response in triplet_responses: - all_triplets.extend(response.triplets) - all_entities.extend(response.entities) - - # Save to file - with open(output_path, "w", encoding="utf-8") as f: - f.write(f"=== EXTRACTED TRIPLETS ({len(all_triplets)} total) ===\n\n") - for i, triplet in enumerate(all_triplets, 1): - f.write(f"Triplet {i}:\n") - f.write(f" Subject: {triplet.subject_name} (ID: {triplet.subject_id})\n") - f.write(f" Predicate: {triplet.predicate}\n") - f.write(f" Object: {triplet.object_name} (ID: {triplet.object_id})\n") - if triplet.value: - f.write(f" Value: {triplet.value}\n") - f.write("\n") - - f.write(f"\n=== EXTRACTED ENTITIES ({len(all_entities)} total) ===\n\n") - for i, entity in enumerate(all_entities, 1): - f.write(f"Entity {i}:\n") - f.write(f" ID: {entity.entity_idx}\n") - f.write(f" Name: {entity.name}\n") - f.write(f" Type: {entity.type}\n") - f.write(f" Description: {entity.description}\n") - f.write("\n") - - logger.info(f"Saved {len(all_triplets)} triplets and {len(all_entities)} entities to: {output_path}") - return output_path diff --git a/app/core/memory/storage_services/extraction_engine/pipeline_help.py b/app/core/memory/storage_services/extraction_engine/pipeline_help.py deleted file mode 100644 index f6b709cd..00000000 --- a/app/core/memory/storage_services/extraction_engine/pipeline_help.py +++ /dev/null @@ -1,528 +0,0 @@ -""" -提取流水线工具函数 - -该模块提供知识提取流水线的辅助工具函数,包括: -1. 解析和格式化提取结果 -2. 生成提取结果汇总报告 -3. 导出测试输入文档 - -这些函数主要用于: -- 解析三元组和实体信息 -- 统计去重和消歧效果 -- 生成可读的结果报告 - -作者:Memory Refactoring Team -原路径:app/core/memory/src/pipeline_help.py(已迁移) -迁移日期:2025-11-22 -""" - -import os -import re -import json -from datetime import datetime -from collections import defaultdict - - -def _parse_triplets_from_file(filepath): - """解析三元组文件,返回三元组列表""" - triplets = [] - if not os.path.exists(filepath): - return triplets - - try: - with open(filepath, 'r', encoding='utf-8') as f: - content = f.read() - - lines = content.split('\n') - current_triplet = {} - - for line in lines: - line = line.strip() - if line.startswith('Triplet '): - if current_triplet: - triplets.append(current_triplet) - current_triplet = {} - elif line.startswith('Subject:'): - subject = line.replace('Subject:', '').strip() - subject = subject.split('(ID:')[0].strip() - current_triplet['subject'] = subject - elif line.startswith('Predicate:'): - predicate = line.replace('Predicate:', '').strip() - current_triplet['predicate'] = predicate - elif line.startswith('Object:'): - obj = line.replace('Object:', '').strip() - obj = obj.split('(ID:')[0].strip() - current_triplet['object'] = obj - - if current_triplet: - triplets.append(current_triplet) - except Exception as e: - print(f"解析三元组文件失败: {e}") - - return triplets - - -def _parse_entities_from_triplets(filepath): - """从三元组文件中解析实体信息,按类型分组""" - entities_by_type = defaultdict(list) - - if not os.path.exists(filepath): - return entities_by_type - - try: - with open(filepath, 'r', encoding='utf-8') as f: - content = f.read() - - if '=== EXTRACTED ENTITIES' in content: - entity_section = content.split('=== EXTRACTED ENTITIES')[1] - lines = entity_section.split('\n') - - current_entity = {} - for line in lines: - line = line.strip() - if line.startswith('Entity '): - if current_entity and 'name' in current_entity and 'type' in current_entity: - entities_by_type[current_entity['type']].append(current_entity['name']) - current_entity = {} - elif line.startswith('Name:'): - name = line.replace('Name:', '').strip() - current_entity['name'] = name - elif line.startswith('Type:'): - entity_type = line.replace('Type:', '').strip() - current_entity['type'] = entity_type - - if current_entity and 'name' in current_entity and 'type' in current_entity: - entities_by_type[current_entity['type']].append(current_entity['name']) - - # 去重 - for entity_type in entities_by_type: - entities_by_type[entity_type] = list(set(entities_by_type[entity_type])) - except Exception as e: - print(f"解析实体信息失败: {e}") - - return entities_by_type - - -def _format_predicate(predicate): - """格式化谓词为中文""" - predicate_map = { - 'COLLABORATES_WITH': '同事', - 'MENTIONS': '提到', - 'DEVELOPED': '开发', - 'PART_OF': '参与', - 'LOCATED_IN': '位于', - 'WORKS_AT': '工作于', - 'PURCHASED': '购买', - 'INTERESTED_IN': '感兴趣' - } - return predicate_map.get(predicate, predicate.lower().replace('_', ' ')) - - -def _write_extracted_result_summary( - chunk_nodes, - pipeline_output_dir: str, -): - """ - 汇总生成 logs/memory-output/extracted_result.json,包含: - - 提取实体数(从 extracted_entities_edges.txt 的 ENTITY 行计数) - - 去重后合并个数(统计 dedup_entity_output.txt 的精确/模糊/LLM合并记录) - - 实体消歧次数(统计阻断与合并应用,并输出同名实体“消歧成功”) - - 记忆片段数(chunk_nodes 的数量) - - 关系三元组数(从 extracted_triplets.txt 标题获取总数) - """ - os.makedirs(pipeline_output_dir, exist_ok=True) - result_path = os.path.join(pipeline_output_dir, "extracted_result.json") - entities_edges_path = os.path.join(pipeline_output_dir, "extracted_entities_edges.txt") - dedup_report_path = os.path.join(pipeline_output_dir, "dedup_entity_output.txt") - triplets_path = os.path.join(pipeline_output_dir, "extracted_triplets.txt") - - # 1) 提取实体数 - extracted_entity_count = 0 - # 初始提取的名称计数(用于“出现X次”的基础计数) - initial_name_counts: dict[str, int] = {} - try: - with open(entities_edges_path, "r", encoding="utf-8") as f: - for line in f: - if line.strip().startswith("ENTITY:"): - extracted_entity_count += 1 - # 解析 name 字段 - try: - m = re.search(r"\{\s*\"id\"\s*:\s*\"[^\"]*\"\s*,\s*\"name\"\s*:\s*\"([^\"]+)\"", line) - if m: - nm = m.group(1).strip() - if nm: - initial_name_counts[nm] = initial_name_counts.get(nm, 0) + 1 - except Exception: - pass - except Exception: - pass - - # 2) 去重后合并个数 & 3) 实体消歧次数(含成功名称) - exact_merge_total = 0 - fuzzy_merge_total = 0 - llm_merge_total = 0 - disamb_block_total = 0 - # 记录成功区分的消歧对(阻断的左右实体及类型) - disamb_success_pairs: list[tuple[str, str, str, str]] = [] - # 在外部定义这些字典,确保后续代码可以访问 - dedup_impact: dict[tuple[str, str], int] = {} - # 第二层精准合并新增:包含自合并(自合并视为"比较两个实体后合并为一") - second_layer_exact_additions: dict[tuple[str, str], int] = {} - # LLM 同名类型相似:按名称计一次出现(代表两个实体合并为一) - llm_same_name_additions: dict[str, int] = {} - - try: - with open(dedup_report_path, "r", encoding="utf-8") as f: - current_layer: str | None = None - for raw in f: - line = raw.strip() - if line.startswith("=== 第一层去重消歧 ==="): - current_layer = "第一层去重消歧" - continue - if line.startswith("=== 第二层去重消歧 ==="): - current_layer = "第二层去重消歧" - continue - # 精确合并:统计“合并实体IDs”数量 - if line.startswith("[精确] ") and "合并实体IDs" in line: - try: - # 先提取规范ID(用于第二层去重统计) - canonical_id = "" - id_match = re.search(r"规范实体\s+([0-9a-f]{40})", line) - if id_match: - canonical_id = id_match.group(1).strip() - - # 提取名称、类型和合并实体IDs - m = re.search(r"名称\s+'([^']+)'\s+类型\s+(\S+)\s+<-\s+合并实体IDs\s+(.+)$", line) - if m: - name = m.group(1).strip() - ent_type = m.group(2).strip() - ids_part = m.group(3).strip() - else: - # 退化解析:如果上式失败,回退到简单切分 - canonical_id = "" - name = "" - ent_type = "" - ids_part = line.split("合并实体IDs", 1)[1].lstrip("::").strip() - id_list = [i.strip() for i in ids_part.split(",") if i.strip()] - exact_merge_total += len(id_list) - if name and ent_type: - key = (name, ent_type) - dedup_impact[key] = dedup_impact.get(key, 0) + len(id_list) - # 在第二层:统计新增出现次数(包含自合并,视为两实体比较后合并为一,至少+1) - if current_layer == "第二层去重消歧": - try: - non_self = len([i for i in id_list if i != canonical_id]) if canonical_id else len(id_list) - except Exception: - non_self = len(id_list) - add_cnt = non_self if non_self > 0 else 1 - second_layer_exact_additions[key] = second_layer_exact_additions.get(key, 0) + add_cnt - except Exception: - pass - # 模糊合并:每条记录算一次合并 - elif line.startswith("[模糊] ") and "<- 合并实体" in line: - fuzzy_merge_total += 1 - # 解析括号中的三元组 (group|name|type) - try: - m = re.search(r"规范实体[^\(]*\(([^|]+)\|([^|]+)\|([^\)]+)\)", line) - if m: - name = m.group(2).strip() - ent_type = m.group(3).strip() - key = (name, ent_type) - dedup_impact[key] = dedup_impact.get(key, 0) + 1 - except Exception: - pass - # LLM 决策合并:每条记录算一次合并(包含 LLM融合/LLM合并 以及 “同名类型相似”的 LLM 去重) - elif (line.startswith("[LLM融合]") or line.startswith("[LLM合并]")) and "<- 合并实体" in line: - llm_merge_total += 1 - try: - m = re.search(r"规范实体[^\(]*\(([^|]+)\|([^|]+)\|([^\)]+)\)", line) - if m: - name = m.group(2).strip() - ent_type = m.group(3).strip() - key = (name, ent_type) - dedup_impact[key] = dedup_impact.get(key, 0) + 1 - except Exception: - pass - elif line.startswith("[LLM去重]"): - # 例如:[LLM去重] 同名类型相似 A(TypeA)|B(TypeB) | conf=... | reason=... - # 这类记录同样属于 LLM 决策的去重合并,计入 LLM 合并总数 - llm_merge_total += 1 - # 若同名类型相似(名称相同),按“名称”计一次出现(两实体合并为一) - try: - m = re.search(r"同名类型相似\s*([^((]+)[((][^))]+[))]\|([^((]+)[((][^))]+[))]", line) - if m: - left = m.group(1).strip() - right = m.group(2).strip() - if left and right and left == right: - llm_same_name_additions[left] = llm_same_name_additions.get(left, 0) + 1 - except Exception: - pass - # 可选:解析名称与类型,当前不用于后续统计输出,保持简单 - # 若未来需要统计影响,可以解析左右两侧名称/类型并分别+1 - # 消歧阻断计数:仅统计 [DISAMB阻断],忽略异常阻断与合并应用 - elif line.startswith("[DISAMB阻断]"): - disamb_block_total += 1 - # 解析形如: - # [DISAMB阻断] A(TypeA)|B(TypeB) | conf=... | reason=... || block_pair=True - try: - m = re.search(r"\[DISAMB阻断\]\s*([^((]+)[((]([^))]+)[))]\|([^((]+)[((]([^))]+)[))]", line) - if m: - left_name = m.group(1).strip() - left_type = m.group(2).strip() - right_name = m.group(3).strip() - right_type = m.group(4).strip() - disamb_success_pairs.append((left_name, left_type, right_name, right_type)) - except Exception: - pass - except Exception: - pass - - total_merged_count = exact_merge_total + fuzzy_merge_total + llm_merge_total - disamb_total = disamb_block_total - - # 4) 记忆片段数(分块器生成的 chunk 数量) - memory_chunk_count = 0 - try: - memory_chunk_count = len(chunk_nodes) if chunk_nodes is not None else 0 - except Exception: - pass - - # 5) 关系三元组数(从文件头部“EXTRACTED TRIPLETS (N total)”解析) - triplet_count = 0 - try: - with open(triplets_path, "r", encoding="utf-8") as f: - head = f.readline() - m = re.search(r"EXTRACTED\s+TRIPLETS\s*\((\d+)\s+total\)", head) - if m: - triplet_count = int(m.group(1)) - except Exception: - pass - - # 写入结果文件 - # 构建 JSON 结构(字段顺序按用户需求组织:先“实体去重的影响”,后“实体消歧的效果”) - readable_path = os.path.join(pipeline_output_dir, "extracted_result_readable.txt") - summary_json = { - "generated_at": datetime.now().isoformat(), - "entities": { - "extracted_count": extracted_entity_count, - }, - "dedup": { - "total_merged_count": total_merged_count, - "breakdown": { - "exact": exact_merge_total, - "fuzzy": fuzzy_merge_total, - "llm": llm_merge_total, - }, - "impact": [ - { - "name": nm, - "type": tp, - "appear_count": (initial_name_counts.get(nm, 0) - + second_layer_exact_additions.get((nm, tp), 0) - + llm_same_name_additions.get(nm, 0)) if (initial_name_counts.get(nm, 0) - + second_layer_exact_additions.get((nm, tp), 0) - + llm_same_name_additions.get(nm, 0)) > 0 else merge_cnt, - "merge_count": merge_cnt, - } - for (nm, tp), merge_cnt in (dedup_impact.items() if 'dedup_impact' in locals() else []) - ], - }, - "disambiguation": { - "block_count": disamb_block_total, - "effects": [ - { - "left": {"name": ln, "type": lt}, - "right": {"name": rn, "type": rt}, - "result": "成功区分" - } - for (ln, lt, rn, rt) in disamb_success_pairs - ], - }, - "memory": {"chunks": memory_chunk_count}, - "triplets": {"count": triplet_count}, - "core_entities": [], # 将在下面填充 - "triplet_samples": [], # 将在下面填充 - } - - # 解析实体和三元组数据(用于JSON和文本输出) - entities_by_type = _parse_entities_from_triplets(triplets_path) - triplets_list = _parse_triplets_from_file(triplets_path) - - # 类型翻译映射 - type_translation = { - 'Person': '人物', - 'Organization': '组织', - 'Location': '地点', - 'Product': '产品', - 'Event': '事件', - 'Technology': '技术', - 'Activity': '活动', - 'Exercise': '运动' - } - - # 构建核心实体数据(按类型分组) - core_entities_data = [] - for entity_type, entities in sorted(entities_by_type.items(), key=lambda x: -len(x[1])): - type_name_cn = type_translation.get(entity_type, entity_type) - core_entities_data.append({ - "type": entity_type, - "type_cn": type_name_cn, - "count": len(entities), - "entities": entities[:5] # 最多显示5个 - }) - summary_json["core_entities"] = core_entities_data - - # 构建三元组示例数据 - triplet_samples = [] - display_count = min(7, len(triplets_list)) - for i in range(display_count): - triplet = triplets_list[i] - predicate_cn = _format_predicate(triplet.get('predicate', '')) - triplet_samples.append({ - "subject": triplet.get('subject', ''), - "predicate": triplet.get('predicate', ''), - "predicate_cn": predicate_cn, - "object": triplet.get('object', '') - }) - summary_json["triplet_samples"] = triplet_samples - - # 写 JSON 到 extracted_result.json(满足"以 json 格式输出并为 .json 文件"的要求) - with open(result_path, "w", encoding="utf-8") as f: - json.dump(summary_json, f, ensure_ascii=False, indent=2) - - # 额外生成可读版文本,模块顺序调整 - lines: list[str] = [] - lines.append(f"结果汇总 - {datetime.now().isoformat()}") - lines.append("") - # 提取实体数模块 - lines.append("提取实体数:") - lines.append(f"总计 {extracted_entity_count} 个") - lines.append(f"去重后合并个数:{total_merged_count} (精确={exact_merge_total},模糊={fuzzy_merge_total},LLM={llm_merge_total})") - lines.append("") - # 实体消歧次数模块 - lines.append("实体消歧次数:") - lines.append(f"总计 {disamb_total} 次(阻断={disamb_block_total})") - lines.append("") - # 记忆片段数模块 - lines.append("记忆片段数:") - lines.append(f"总计 {memory_chunk_count} 条") - lines.append("") - # 关系三元组数模块 - lines.append("关系三元组数:") - lines.append(f"总计 {triplet_count} 条") - lines.append("") - - # 新增模块1:提取的核心实体(去重后) - lines.append("提取的核心实体(去重后):") - lines.append("") - # 从 extracted_triplets.txt 解析去重后的实体并按类型分组 - entities_by_type = _parse_entities_from_triplets(triplets_path) - type_translation = { - 'Person': '人物', - 'Organization': '组织', - 'Location': '地点', - 'Product': '产品', - 'Event': '事件', - 'Technology': '技术', - 'Activity': '活动', - 'Exercise': '运动' - } - for entity_type, entities in sorted(entities_by_type.items(), key=lambda x: -len(x[1])): - type_name = type_translation.get(entity_type, entity_type) - count = len(entities) - lines.append(f"{type_name}({count}):") - # 最多显示5个实体 - display_entities = entities[:5] - for entity in display_entities: - lines.append(f" • {entity}") - lines.append("") - - # 新增模块2:提取的关系三元组(部分) - lines.append("提取的关系三元组(部分):") - lines.append("") - # 从 extracted_triplets.txt 读取三元组 - triplets = _parse_triplets_from_file(triplets_path) - display_count = min(7, len(triplets)) - for i in range(display_count): - triplet = triplets[i] - predicate_cn = _format_predicate(triplet['predicate']) - lines.append(f" • ({triplet['subject']}, {predicate_cn}, {triplet['object']})") - lines.append("") - lines.append(f"... 共{triplet_count}条关系三元组") - lines.append("") - - # 实体去重的影响模块(先输出) - if dedup_impact: - lines.append("实体去重的影响:") - # 出现次数 = 初始提取次数 + 第二层精准合并新增次数(包含自合并至少+1) + LLM同名类型相似按名称的新增次数 - # 若某名称初始未出现但发生了合并(少见),退化为使用合并次数 - for (nm, tp), merge_cnt in dedup_impact.items(): - init_cnt = initial_name_counts.get(nm, 0) - add_cnt = second_layer_exact_additions.get((nm, tp), 0) - llm_add = llm_same_name_additions.get(nm, 0) - appear_cnt = init_cnt + add_cnt + llm_add - if appear_cnt <= 0: - appear_cnt = merge_cnt - lines.append(f"[{nm}]出现{appear_cnt}次 → 合并为1个类型是[{tp}]的实体") - lines.append("") - - # 新增模块:实体消歧的效果(后输出,来源于 dedup_entity_output.txt 的 DISAMB阻断 记录) - if disamb_success_pairs: - lines.append("实体消歧的效果:") - for left_name, left_type, right_name, right_type in disamb_success_pairs: - lines.append(f"{left_name}({left_type}) vs {right_name}({right_type}) → 成功区分。") - lines.append("") - - with open(readable_path, "w", encoding="utf-8") as f: - f.write("\n".join(lines)) - -def export_test_input_doc( - entity_nodes, - statement_entity_edges, - entity_entity_edges, -): - """将提取出的实体与两类边导出到 extracted_entities_edges.txt。 - - 保持与 extraction_pipeline.py 原本本地函数一致的行为与输出格式。 - """ - try: - from app.core.config import settings - settings.ensure_memory_output_dir() - out_path = settings.get_memory_output_path("extracted_entities_edges.txt") - - def _to_dict(m): - d = m.model_dump() - for k, v in list(d.items()): - if isinstance(v, datetime): - d[k] = v.isoformat() - return d - - def _entity_to_dict(e): - return { - "id": getattr(e, "id"), - "name": getattr(e, "name"), - "entity_type": getattr(e, "entity_type"), - "description": getattr(e, "description"), - } - - with open(out_path, "w", encoding="utf-8") as f: - header_time = entity_nodes[0].created_at.isoformat() - f.write( - f"=== TEST EXTRACTED ENTITIES === (created_at: {header_time})\n" - ) - for e in entity_nodes: - f.write( - "ENTITY: " + json.dumps(_entity_to_dict(e), ensure_ascii=False) + "\n" - ) - - f.write("\n=== TEST STATEMENT-ENTITY EDGES ===\n") - for se in statement_entity_edges: - f.write("SE_EDGE: " + json.dumps(_to_dict(se), ensure_ascii=False) + "\n") - - f.write("\n=== TEST ENTITY-ENTITY EDGES ===\n") - for ee in entity_entity_edges: - f.write("EE_EDGE: " + json.dumps(_to_dict(ee), ensure_ascii=False) + "\n") - - print(f"Exported extracted entities & edges to: {out_path}") - except Exception as e: - print(f"Failed to export test input doc: {e}") diff --git a/app/core/memory/storage_services/forgetting_engine/__init__.py b/app/core/memory/storage_services/forgetting_engine/__init__.py deleted file mode 100644 index db5c0769..00000000 --- a/app/core/memory/storage_services/forgetting_engine/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""遗忘引擎模块 - -该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。 -""" - -from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine - -__all__ = ["ForgettingEngine"] diff --git a/app/core/memory/storage_services/forgetting_engine/forgetting_engine.py b/app/core/memory/storage_services/forgetting_engine/forgetting_engine.py deleted file mode 100644 index 44ce50a8..00000000 --- a/app/core/memory/storage_services/forgetting_engine/forgetting_engine.py +++ /dev/null @@ -1,271 +0,0 @@ -"""遗忘引擎实现 - -该模块实现基于改进的艾宾浩斯遗忘曲线的记忆遗忘机制。 - -遗忘曲线公式: -R(t, S) = offset + (1 - offset) * exp(-λ_time * t / (λ_mem * S)) - -其中: -- R: 记忆保持率 (0 到 1) -- t: 自学习以来经过的时间 -- S: 记忆强度(值越高表示记忆越强) -- offset: 最小保持率(防止完全遗忘) -- λ_time: 控制时间效应的 Lambda 参数 -- λ_mem: 控制记忆强度效应的 Lambda 参数 -""" - -import math -from typing import List, Dict, Any, Optional -from datetime import datetime, timedelta -from app.core.memory.models.variate_config import ForgettingEngineConfig - - -class ForgettingEngine: - """遗忘引擎 - 实现记忆遗忘机制 - - 该引擎基于改进的艾宾浩斯遗忘曲线计算记忆保持率, - 结合时间衰减和记忆强度因素,支持可配置的遗忘行为。 - - Attributes: - config: 遗忘引擎配置 - offset: 最小保持率(防止完全遗忘) - lambda_time: 控制时间衰减效应的参数 - lambda_mem: 控制记忆强度效应的参数 - """ - - def __init__(self, config: Optional[ForgettingEngineConfig] = None): - """初始化遗忘引擎 - - Args: - config: ForgettingEngineConfig 实例,包含遗忘参数配置 - """ - if config is None: - config = ForgettingEngineConfig() - - self.config = config - self.offset = config.offset - self.lambda_time = config.lambda_time - self.lambda_mem = config.lambda_mem - - def forgetting_curve(self, t: float, S: float) -> float: - """使用改进的艾宾浩斯遗忘曲线计算记忆保持率 - - 公式: R = offset + (1-offset) * e^(-λ_time * t / (λ_mem * S)) - - Args: - t: 自学习以来经过的时间 - S: 记忆的相对强度 - - Returns: - 记忆保持率,值在 0 到 1 之间 - """ - if S <= 0: - return self.offset - - exponent = -self.lambda_time * t / (self.lambda_mem * S) - retention = self.offset + (1 - self.offset) * math.exp(exponent) - - # 确保保持率在 0 到 1 之间 - return max(0.0, min(1.0, retention)) - - def calculate_forgetting_score(self, time_elapsed: float, memory_strength: float) -> float: - """计算记忆项的遗忘分数 - - 遗忘分数 = 1 - 保持率,值越高表示越容易被遗忘 - - Args: - time_elapsed: 自记忆创建/最后访问以来的时间 - memory_strength: 记忆强度(值越高表示越难忘记) - - Returns: - 遗忘分数,值在 0 到 1 之间 - """ - retention = self.forgetting_curve(time_elapsed, memory_strength) - return 1.0 - retention - - def calculate_weight(self, time_elapsed: float, memory_strength: float) -> float: - """计算记忆项的权重(即保持率) - - Args: - time_elapsed: 自记忆创建/最后访问以来的时间 - memory_strength: 记忆强度(值越高表示越难忘记) - - Returns: - 权重值,值在 0 到 1 之间 - """ - return self.forgetting_curve(time_elapsed, memory_strength) - - def apply_forgetting_weights( - self, - items: List[dict], - time_key: str = 'time_elapsed', - strength_key: str = 'strength' - ) -> List[dict]: - """为记忆项列表应用遗忘权重 - - Args: - items: 包含记忆项的字典列表 - time_key: 每个项中时间经过的键名 - strength_key: 每个项中记忆强度的键名 - - Returns: - 添加了 'forgetting_weight' 字段的项列表 - """ - weighted_items = [] - - for item in items: - item_copy = item.copy() - time_elapsed = item.get(time_key, 0) - strength = item.get(strength_key, 1.0) - - weight = self.calculate_weight(time_elapsed, strength) - item_copy['forgetting_weight'] = weight - - weighted_items.append(item_copy) - - return weighted_items - - def mark_items_for_forgetting( - self, - items: List[dict], - forgetting_threshold: float = 0.5, - time_key: str = 'time_elapsed', - strength_key: str = 'strength' - ) -> tuple[List[dict], List[dict]]: - """标记应该被遗忘的记忆项 - - Args: - items: 包含记忆项的字典列表 - forgetting_threshold: 遗忘阈值,遗忘分数超过此值的项将被标记 - time_key: 每个项中时间经过的键名 - strength_key: 每个项中记忆强度的键名 - - Returns: - 元组 (应保留的项列表, 应遗忘的项列表) - """ - to_keep = [] - to_forget = [] - - for item in items: - time_elapsed = item.get(time_key, 0) - strength = item.get(strength_key, 1.0) - - forgetting_score = self.calculate_forgetting_score(time_elapsed, strength) - - item_copy = item.copy() - item_copy['forgetting_score'] = forgetting_score - - if forgetting_score > forgetting_threshold: - to_forget.append(item_copy) - else: - to_keep.append(item_copy) - - return to_keep, to_forget - - def get_forgetting_statistics( - self, - items: List[dict], - forgetting_threshold: float = 0.5, - time_key: str = 'time_elapsed', - strength_key: str = 'strength' - ) -> Dict[str, Any]: - """获取记忆项的遗忘统计信息 - - Args: - items: 包含记忆项的字典列表 - forgetting_threshold: 遗忘阈值 - time_key: 每个项中时间经过的键名 - strength_key: 每个项中记忆强度的键名 - - Returns: - 包含统计信息的字典: - - total_items: 总项数 - - items_to_keep: 应保留的项数 - - items_to_forget: 应遗忘的项数 - - forgetting_rate: 遗忘率 - - average_retention: 平均保持率 - - average_forgetting_score: 平均遗忘分数 - """ - if not items: - return { - "total_items": 0, - "items_to_keep": 0, - "items_to_forget": 0, - "forgetting_rate": 0.0, - "average_retention": 0.0, - "average_forgetting_score": 0.0 - } - - to_keep, to_forget = self.mark_items_for_forgetting( - items, forgetting_threshold, time_key, strength_key - ) - - total = len(items) - keep_count = len(to_keep) - forget_count = len(to_forget) - - # 计算平均保持率和遗忘分数 - total_retention = 0.0 - total_forgetting_score = 0.0 - - for item in items: - time_elapsed = item.get(time_key, 0) - strength = item.get(strength_key, 1.0) - - retention = self.calculate_weight(time_elapsed, strength) - forgetting_score = self.calculate_forgetting_score(time_elapsed, strength) - - total_retention += retention - total_forgetting_score += forgetting_score - - avg_retention = total_retention / total - avg_forgetting_score = total_forgetting_score / total - - return { - "total_items": total, - "items_to_keep": keep_count, - "items_to_forget": forget_count, - "forgetting_rate": forget_count / total, - "average_retention": avg_retention, - "average_forgetting_score": avg_forgetting_score - } - - def calculate_time_elapsed_days( - self, - created_at: datetime, - current_time: Optional[datetime] = None - ) -> float: - """计算经过的天数 - - Args: - created_at: 创建时间 - current_time: 当前时间,如果为 None 则使用当前系统时间 - - Returns: - 经过的天数(浮点数) - """ - if current_time is None: - current_time = datetime.now() - - time_diff = current_time - created_at - return time_diff.total_seconds() / (24 * 3600) - - def calculate_time_elapsed_hours( - self, - created_at: datetime, - current_time: Optional[datetime] = None - ) -> float: - """计算经过的小时数 - - Args: - created_at: 创建时间 - current_time: 当前时间,如果为 None 则使用当前系统时间 - - Returns: - 经过的小时数(浮点数) - """ - if current_time is None: - current_time = datetime.now() - - time_diff = current_time - created_at - return time_diff.total_seconds() / 3600 diff --git a/app/core/memory/storage_services/forgetting_engine/memory_strength.py b/app/core/memory/storage_services/forgetting_engine/memory_strength.py deleted file mode 100644 index 2bc819a0..00000000 --- a/app/core/memory/storage_services/forgetting_engine/memory_strength.py +++ /dev/null @@ -1,251 +0,0 @@ -""" -Memory Strength Calculator based on ACT-R Theory - -This module implements the Base-Level Activation equation from ACT-R -(Adaptive Control of Thought-Rational) cognitive architecture. - -Formula: B(i) = ln(Σ(t_k^(-d))) - -Where: -- B(i): Base-level activation score -- t_k: Time since the k-th access -- d: Decay parameter (typically 0.5) -- n: Number of accesses - -Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe? -""" - -import math -from typing import List, Optional -from datetime import datetime, timedelta - - -class MemoryStrengthCalculator: - """ - Calculate memory strength using ACT-R base-level activation formula. - """ - - def __init__(self, decay_parameter: float = 0.5, time_unit: str = "seconds"): - """ - Initialize the memory strength calculator. - - Args: - decay_parameter: The decay rate (d). Typically 0.5 for human memory. - Higher values = faster forgetting. - time_unit: Unit for time calculations. Options: 'seconds', 'minutes', - 'hours', 'days'. Default is 'seconds'. - """ - self.decay_parameter = decay_parameter - self.time_unit = time_unit - self._time_multipliers = { - "seconds": 1, - "minutes": 60, - "hours": 3600, - "days": 86400, - } - - def calculate_activation( - self, access_times: List[datetime], current_time: Optional[datetime] = None - ) -> float: - """ - Calculate the base-level activation B(i) for a memory item. - - Args: - access_times: List of datetime objects representing when the memory - was accessed (most recent first or in any order). - current_time: The current time for calculation. If None, uses datetime.now(). - - Returns: - float: The base-level activation score B(i). - Higher values indicate stronger, more retrievable memories. - - Raises: - ValueError: If access_times is empty or contains invalid data. - """ - if not access_times: - raise ValueError("access_times cannot be empty") - - if current_time is None: - current_time = datetime.now() - - # Calculate time differences in specified units - time_diffs = [] - for access_time in access_times: - diff_seconds = (current_time - access_time).total_seconds() - if diff_seconds < 0: - raise ValueError(f"Access time {access_time} is in the future") - - # Convert to specified time unit - diff = diff_seconds / self._time_multipliers[self.time_unit] - - # Avoid division by zero for very recent accesses - # Use a small epsilon (0.01 time units) - diff = max(diff, 0.01) - time_diffs.append(diff) - - # Calculate B(i) = ln(Σ(t_k^(-d))) - sum_power_law = sum(t ** (-self.decay_parameter) for t in time_diffs) - activation = math.log(sum_power_law) - - return activation - - def calculate_activation_from_intervals( - self, time_intervals: List[float] - ) -> float: - """ - Calculate activation directly from time intervals (in the configured time unit). - - Args: - time_intervals: List of time intervals since each access. - E.g., [1.0, 3.5, 7.2] means accessed 1, 3.5, and 7.2 time units ago. - - Returns: - float: The base-level activation score B(i). - """ - if not time_intervals: - raise ValueError("time_intervals cannot be empty") - - # Ensure no zero or negative intervals - safe_intervals = [max(t, 0.01) for t in time_intervals] - - sum_power_law = sum(t ** (-self.decay_parameter) for t in safe_intervals) - activation = math.log(sum_power_law) - - return activation - - def calculate_memory_strength(self, activation: float) -> float: - """ - Convert activation score to memory strength S(i) = e^(B(i)). - - This converts the log-space activation to linear space, - suitable for use in the Ebbinghaus forgetting curve. - - Args: - activation: The base-level activation B(i). - - Returns: - float: Memory strength S(i) in linear space. - """ - return math.exp(activation) - - def calculate_retention_probability( - self, - activation: float, - time_since_last_access: float, - decay_rate: float = 0.01, - offset: float = 0.1, - ) -> float: - """ - Calculate retention probability using the unified Ebbinghaus-ACT-R formula. - - Formula: R(i) = offset + (1-offset) * exp(-λ * t / Σ(t_k^(-d))) - - Args: - activation: The base-level activation B(i). - time_since_last_access: Time since last access (in configured time units). - decay_rate: Lambda (λ) parameter controlling forgetting speed. - offset: Baseline retention rate (minimum memory strength). - - Returns: - float: Retention probability between 0 and 1. - """ - memory_strength = self.calculate_memory_strength(activation) - - # Unified formula: R(i) = offset + (1-offset) * exp(-λ * t / S(i)) - retention = offset + (1 - offset) * math.exp( - -decay_rate * time_since_last_access / memory_strength - ) - - return retention - - def should_retain( - self, - access_times: List[datetime], - threshold: float = 0.5, - current_time: Optional[datetime] = None, - decay_rate: float = 0.01, - offset: float = 0.1, - ) -> tuple[bool, float, float]: - """ - Determine if a memory should be retained based on its strength. - - Args: - access_times: List of access timestamps. - threshold: Retention probability threshold (default 0.5 = 50%). - current_time: Current time for calculation. - decay_rate: Lambda parameter for forgetting curve. - offset: Baseline retention rate. - - Returns: - tuple: (should_retain: bool, retention_probability: float, activation: float) - """ - if current_time is None: - current_time = datetime.now() - - activation = self.calculate_activation(access_times, current_time) - - # Time since last access - last_access = max(access_times) - time_since_last = (current_time - last_access).total_seconds() / self._time_multipliers[self.time_unit] - time_since_last = max(time_since_last, 0.01) - - retention_prob = self.calculate_retention_probability( - activation, time_since_last, decay_rate, offset - ) - - return (retention_prob >= threshold, retention_prob, activation) - - -# Convenience functions for quick calculations -def calculate_activation( - access_times: List[datetime], - decay_parameter: float = 0.5, - current_time: Optional[datetime] = None, -) -> float: - """ - Quick function to calculate activation without creating a calculator instance. - - Args: - access_times: List of access timestamps. - decay_parameter: Decay rate (default 0.5). - current_time: Current time (default now). - - Returns: - float: Base-level activation B(i). - """ - calculator = MemoryStrengthCalculator(decay_parameter=decay_parameter) - return calculator.calculate_activation(access_times, current_time) - - -def calculate_retention( - access_times: List[datetime], - decay_parameter: float = 0.5, - decay_rate: float = 0.01, - offset: float = 0.1, - current_time: Optional[datetime] = None, -) -> float: - """ - Quick function to calculate retention probability. - - Args: - access_times: List of access timestamps. - decay_parameter: ACT-R decay parameter (default 0.5). - decay_rate: Ebbinghaus decay rate lambda (default 0.01). - offset: Baseline retention (default 0.1). - current_time: Current time (default now). - - Returns: - float: Retention probability between 0 and 1. - """ - calculator = MemoryStrengthCalculator(decay_parameter=decay_parameter) - activation = calculator.calculate_activation(access_times, current_time) - - if current_time is None: - current_time = datetime.now() - - last_access = max(access_times) - time_since_last = (current_time - last_access).total_seconds() - - return calculator.calculate_retention_probability( - activation, time_since_last, decay_rate, offset - ) diff --git a/app/core/memory/storage_services/reflection_engine/__init__.py b/app/core/memory/storage_services/reflection_engine/__init__.py deleted file mode 100644 index 0f3f1eb1..00000000 --- a/app/core/memory/storage_services/reflection_engine/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -自我反思引擎模块 - -该模块实现了记忆系统的自我反思功能,包括: -- 基于时间的反思 -- 基于事实的反思(冲突检测) -- 综合反思 -- 反思结果应用 -""" - -from app.core.memory.storage_services.reflection_engine.self_reflexion import ( - ReflectionEngine, - ReflectionConfig, - ReflectionResult, -) - -__all__ = [ - "ReflectionEngine", - "ReflectionConfig", - "ReflectionResult", -] diff --git a/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/app/core/memory/storage_services/reflection_engine/self_reflexion.py deleted file mode 100644 index b3e5813d..00000000 --- a/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ /dev/null @@ -1,585 +0,0 @@ -""" -自我反思引擎实现 - -该模块实现了记忆系统的自我反思功能,包括: -1. 基于时间的反思 - 根据时间周期触发反思 -2. 基于事实的反思 - 检测记忆冲突并解决 -3. 综合反思 - 整合多种反思策略 -4. 反思结果应用 - 更新记忆库 -""" - -import os -import json -import logging -import asyncio -from typing import List, Dict, Any, Optional -from datetime import datetime -from enum import Enum -import uuid - -from pydantic import BaseModel, Field - - -# 配置日志 -_root_logger = logging.getLogger() -if not _root_logger.handlers: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(message)s" - ) -else: - _root_logger.setLevel(logging.INFO) - - -class ReflectionRange(str, Enum): - """反思范围枚举""" - RETRIEVAL = "retrieval" # 从检索结果中反思 - DATABASE = "database" # 从整个数据库中反思 - - -class ReflectionBaseline(str, Enum): - """反思基线枚举""" - TIME = "TIME" # 基于时间的反思 - FACT = "FACT" # 基于事实的反思 - HYBRID = "HYBRID" # 混合反思 - - -class ReflectionConfig(BaseModel): - """反思引擎配置""" - enabled: bool = False - iteration_period: str = "3" # 反思周期 - reflexion_range: ReflectionRange = ReflectionRange.RETRIEVAL - baseline: ReflectionBaseline = ReflectionBaseline.TIME - concurrency: int = Field(default=5, description="并发数量") - - class Config: - use_enum_values = True - - -class ReflectionResult(BaseModel): - """反思结果""" - success: bool - message: str - conflicts_found: int = 0 - conflicts_resolved: int = 0 - memories_updated: int = 0 - execution_time: float = 0.0 - details: Optional[Dict[str, Any]] = None - - -class ReflectionEngine: - """ - 自我反思引擎 - - 负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。 - """ - - def __init__( - self, - config: ReflectionConfig, - neo4j_connector: Optional[Any] = None, - llm_client: Optional[Any] = None, - get_data_func: Optional[Any] = None, - render_evaluate_prompt_func: Optional[Any] = None, - render_reflexion_prompt_func: Optional[Any] = None, - conflict_schema: Optional[Any] = None, - reflexion_schema: Optional[Any] = None, - update_query: Optional[str] = None - ): - """ - 初始化反思引擎 - - Args: - config: 反思引擎配置 - neo4j_connector: Neo4j 连接器(可选) - llm_client: LLM 客户端(可选) - get_data_func: 获取数据的函数(可选) - render_evaluate_prompt_func: 渲染评估提示词的函数(可选) - render_reflexion_prompt_func: 渲染反思提示词的函数(可选) - conflict_schema: 冲突结果 Schema(可选) - reflexion_schema: 反思结果 Schema(可选) - update_query: 更新查询语句(可选) - """ - self.config = config - self.neo4j_connector = neo4j_connector - self.llm_client = llm_client - self.get_data_func = get_data_func - self.render_evaluate_prompt_func = render_evaluate_prompt_func - self.render_reflexion_prompt_func = render_reflexion_prompt_func - self.conflict_schema = conflict_schema - self.reflexion_schema = reflexion_schema - self.update_query = update_query - self._semaphore = asyncio.Semaphore(config.concurrency) - - # 延迟导入以避免循环依赖 - self._lazy_init_done = False - - def _lazy_init(self): - """延迟初始化,避免循环导入""" - if self._lazy_init_done: - return - - if self.neo4j_connector is None: - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - self.neo4j_connector = Neo4jConnector() - - if self.llm_client is None: - from app.core.memory.utils.llm.llm_utils import get_llm_client - from app.core.memory.utils.config import definitions as config_defs - self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) - - if self.get_data_func is None: - from app.core.memory.utils.config.get_data import get_data - self.get_data_func = get_data - - if self.render_evaluate_prompt_func is None: - from app.core.memory.utils.prompt.template_render import render_evaluate_prompt - self.render_evaluate_prompt_func = render_evaluate_prompt - - if self.render_reflexion_prompt_func is None: - from app.core.memory.utils.prompt.template_render import render_reflexion_prompt - self.render_reflexion_prompt_func = render_reflexion_prompt - - if self.conflict_schema is None: - from app.schemas.memory_storage_schema import ConflictResultSchema - self.conflict_schema = ConflictResultSchema - - if self.reflexion_schema is None: - from app.schemas.memory_storage_schema import ReflexionResultSchema - self.reflexion_schema = ReflexionResultSchema - - if self.update_query is None: - from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT - self.update_query = UPDATE_STATEMENT_INVALID_AT - - self._lazy_init_done = True - - async def execute_reflection(self, host_id: uuid.UUID) -> ReflectionResult: - """ - 执行完整的反思流程 - - Args: - host_id: 主机ID - - Returns: - ReflectionResult: 反思结果 - """ - # 延迟初始化 - self._lazy_init() - - if not self.config.enabled: - return ReflectionResult( - success=False, - message="反思引擎未启用" - ) - - start_time = asyncio.get_event_loop().time() - logging.info("====== 自我反思流程开始 ======") - - try: - # 1. 获取反思数据 - reflexion_data = await self._get_reflexion_data(host_id) - if not reflexion_data: - return ReflectionResult( - success=True, - message="无反思数据,结束反思", - execution_time=asyncio.get_event_loop().time() - start_time - ) - - # 2. 检测冲突(基于事实的反思) - conflict_data = await self._detect_conflicts(reflexion_data) - if not conflict_data: - return ReflectionResult( - success=True, - message="无冲突,无需反思", - execution_time=asyncio.get_event_loop().time() - start_time - ) - - conflicts_found = len(conflict_data) - logging.info(f"发现 {conflicts_found} 个冲突") - - # 记录冲突数据 - await self._log_data("conflict", conflict_data) - - # 3. 解决冲突 - solved_data = await self._resolve_conflicts(conflict_data) - if not solved_data: - return ReflectionResult( - success=False, - message="反思失败,未解决冲突", - conflicts_found=conflicts_found, - execution_time=asyncio.get_event_loop().time() - start_time - ) - - conflicts_resolved = len(solved_data) - logging.info(f"解决了 {conflicts_resolved} 个冲突") - - # 记录解决方案 - await self._log_data("solved_data", solved_data) - - # 4. 应用反思结果(更新记忆库) - memories_updated = await self._apply_reflection_results(solved_data) - - execution_time = asyncio.get_event_loop().time() - start_time - - logging.info("====== 自我反思流程结束 ======") - - return ReflectionResult( - success=True, - message="反思完成", - conflicts_found=conflicts_found, - conflicts_resolved=conflicts_resolved, - memories_updated=memories_updated, - execution_time=execution_time - ) - - except Exception as e: - logging.error(f"反思流程执行失败: {e}", exc_info=True) - return ReflectionResult( - success=False, - message=f"反思流程执行失败: {str(e)}", - execution_time=asyncio.get_event_loop().time() - start_time - ) - - async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]: - """ - 获取反思数据 - - 根据配置的反思范围获取需要反思的记忆数据。 - - Args: - host_id: 主机ID - - Returns: - List[Any]: 反思数据列表 - """ - if self.config.reflexion_range == ReflectionRange.RETRIEVAL: - # 从检索结果中获取数据 - return await self.get_data_func(host_id) - elif self.config.reflexion_range == ReflectionRange.DATABASE: - # 从整个数据库中获取数据(待实现) - logging.warning("从数据库获取反思数据功能尚未实现") - return [] - else: - raise ValueError(f"未知的反思范围: {self.config.reflexion_range}") - - async def _detect_conflicts(self, data: List[Any]) -> List[Any]: - """ - 检测冲突(基于事实的反思) - - 使用 LLM 分析记忆数据,检测其中的冲突。 - - Args: - data: 待检测的记忆数据 - - Returns: - List[Any]: 冲突记忆列表 - """ - if not data: - return [] - - logging.info("====== 冲突检测开始 ======") - start_time = asyncio.get_event_loop().time() - - try: - # 渲染冲突检测提示词 - rendered_prompt = await self.render_evaluate_prompt_func( - data, - self.conflict_schema - ) - - messages = [{"role": "user", "content": rendered_prompt}] - logging.info(f"提示词长度: {len(rendered_prompt)}") - - # 调用 LLM 进行冲突检测 - response = await self.llm_client.response_structured( - messages, - self.conflict_schema - ) - - execution_time = asyncio.get_event_loop().time() - start_time - logging.info(f"冲突检测耗时: {execution_time:.2f} 秒") - - if not response: - logging.error("LLM 冲突检测输出解析失败") - return [] - - # 标准化返回格式 - if isinstance(response, BaseModel): - return [response.model_dump()] - elif hasattr(response, 'dict'): - return [response.dict()] - else: - return [response] - - except Exception as e: - logging.error(f"冲突检测失败: {e}", exc_info=True) - return [] - - async def _resolve_conflicts(self, conflicts: List[Any]) -> List[Any]: - """ - 解决冲突 - - 使用 LLM 对检测到的冲突进行反思和解决。 - - Args: - conflicts: 冲突列表 - - Returns: - List[Any]: 解决方案列表 - """ - if not conflicts: - return [] - - logging.info("====== 冲突解决开始 ======") - - # 并行处理每个冲突 - async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]: - """解决单个冲突""" - async with self._semaphore: - try: - # 渲染反思提示词 - rendered_prompt = await self.render_reflexion_prompt_func( - [conflict], - self.reflexion_schema - ) - - messages = [{"role": "user", "content": rendered_prompt}] - - # 调用 LLM 进行反思 - response = await self.llm_client.response_structured( - messages, - self.reflexion_schema - ) - - if not response: - return None - - # 标准化返回格式 - if isinstance(response, BaseModel): - return response.model_dump() - elif hasattr(response, 'dict'): - return response.dict() - elif isinstance(response, dict): - return response - else: - return None - - except Exception as e: - logging.warning(f"解决单个冲突失败: {e}") - return None - - # 并发执行所有冲突解决任务 - tasks = [_resolve_one(conflict) for conflict in conflicts] - results = await asyncio.gather(*tasks, return_exceptions=False) - - # 过滤掉失败的结果 - solved = [r for r in results if r is not None] - - logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突") - - return solved - - async def _apply_reflection_results( - self, - solved_data: List[Dict[str, Any]] - ) -> int: - """ - 应用反思结果(更新记忆库) - - 将解决冲突后的记忆更新到 Neo4j 数据库中。 - - Args: - solved_data: 解决方案列表 - - Returns: - int: 成功更新的记忆数量 - """ - if not solved_data: - logging.warning("无解决方案数据,跳过更新") - return 0 - - logging.info("====== 记忆更新开始 ======") - - success_count = 0 - - async def _update_one(item: Dict[str, Any]) -> bool: - """更新单条记忆""" - async with self._semaphore: - try: - if not isinstance(item, dict): - return False - - # 提取更新参数 - resolved = item.get("resolved", {}) - resolved_mem = resolved.get("resolved_memory", {}) - group_id = resolved_mem.get("group_id") - memory_id = resolved_mem.get("id") - new_invalid_at = resolved_mem.get("invalid_at") - - if not all([group_id, memory_id, new_invalid_at]): - logging.warning(f"记忆更新参数缺失,跳过此项: {item}") - return False - - # 执行更新 - await self.neo4j_connector.execute_query( - self.update_query, - group_id=group_id, - id=memory_id, - new_invalid_at=new_invalid_at, - ) - - return True - - except Exception as e: - logging.error(f"更新单条记忆失败: {e}") - return False - - # 并发执行所有更新任务 - tasks = [ - _update_one(item) - for item in solved_data - if isinstance(item, dict) - ] - results = await asyncio.gather(*tasks, return_exceptions=False) - success_count = sum(1 for r in results if r) - - logging.info(f"成功更新 {success_count}/{len(solved_data)} 条记忆") - - return success_count - - async def _log_data(self, label: str, data: Any) -> None: - """ - 记录数据到文件 - - Args: - label: 数据标签 - data: 要记录的数据 - """ - def _write(): - try: - with open("reflexion_data.json", "a", encoding="utf-8") as f: - f.write(f"### {label} ###\n") - json.dump(data, f, ensure_ascii=False, indent=4) - f.write("\n\n") - except Exception as e: - logging.warning(f"记录数据失败: {e}") - - # 在后台线程中执行写入,避免阻塞事件循环 - await asyncio.to_thread(_write) - - # 基于时间的反思方法 - async def time_based_reflection( - self, - host_id: uuid.UUID, - time_period: Optional[str] = None - ) -> ReflectionResult: - """ - 基于时间的反思 - - 根据时间周期触发反思,检查在指定时间段内的记忆。 - - Args: - host_id: 主机ID - time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值 - - Returns: - ReflectionResult: 反思结果 - """ - period = time_period or self.config.iteration_period - logging.info(f"执行基于时间的反思,周期: {period}") - - # 使用标准反思流程 - return await self.execute_reflection(host_id) - - # 基于事实的反思方法 - async def fact_based_reflection( - self, - host_id: uuid.UUID - ) -> ReflectionResult: - """ - 基于事实的反思 - - 检测记忆中的事实冲突并解决。 - - Args: - host_id: 主机ID - - Returns: - ReflectionResult: 反思结果 - """ - logging.info("执行基于事实的反思") - - # 使用标准反思流程 - return await self.execute_reflection(host_id) - - # 综合反思方法 - async def comprehensive_reflection( - self, - host_id: uuid.UUID - ) -> ReflectionResult: - """ - 综合反思 - - 整合基于时间和基于事实的反思策略。 - - Args: - host_id: 主机ID - - Returns: - ReflectionResult: 反思结果 - """ - logging.info("执行综合反思") - - # 根据配置的基线选择反思策略 - if self.config.baseline == ReflectionBaseline.TIME: - return await self.time_based_reflection(host_id) - elif self.config.baseline == ReflectionBaseline.FACT: - return await self.fact_based_reflection(host_id) - elif self.config.baseline == ReflectionBaseline.HYBRID: - # 混合策略:先执行基于时间的反思,再执行基于事实的反思 - time_result = await self.time_based_reflection(host_id) - fact_result = await self.fact_based_reflection(host_id) - - # 合并结果 - return ReflectionResult( - success=time_result.success and fact_result.success, - message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}", - conflicts_found=time_result.conflicts_found + fact_result.conflicts_found, - conflicts_resolved=time_result.conflicts_resolved + fact_result.conflicts_resolved, - memories_updated=time_result.memories_updated + fact_result.memories_updated, - execution_time=time_result.execution_time + fact_result.execution_time - ) - else: - raise ValueError(f"未知的反思基线: {self.config.baseline}") - - -# 便捷函数:创建默认配置的反思引擎 -def create_reflection_engine( - enabled: bool = False, - iteration_period: str = "3", - reflexion_range: str = "retrieval", - baseline: str = "TIME", - concurrency: int = 5 -) -> ReflectionEngine: - """ - 创建反思引擎实例 - - Args: - enabled: 是否启用反思 - iteration_period: 反思周期 - reflexion_range: 反思范围 - baseline: 反思基线 - concurrency: 并发数量 - - Returns: - ReflectionEngine: 反思引擎实例 - """ - config = ReflectionConfig( - enabled=enabled, - iteration_period=iteration_period, - reflexion_range=reflexion_range, - baseline=baseline, - concurrency=concurrency - ) - return ReflectionEngine(config) diff --git a/app/core/memory/storage_services/search/__init__.py b/app/core/memory/storage_services/search/__init__.py deleted file mode 100644 index 1109ed3e..00000000 --- a/app/core/memory/storage_services/search/__init__.py +++ /dev/null @@ -1,131 +0,0 @@ -# -*- coding: utf-8 -*- -"""搜索服务模块 - -本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。 -""" - -from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult -from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy -from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy -from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy - -__all__ = [ - "SearchStrategy", - "SearchResult", - "KeywordSearchStrategy", - "SemanticSearchStrategy", - "HybridSearchStrategy", -] - - -# ============================================================================ -# 向后兼容的函数式API -# ============================================================================ -# 为了兼容旧代码,提供与 src/search.py 相同的函数式接口 - - -async def run_hybrid_search( - query_text: str, - search_type: str = "hybrid", - group_id: str | None = None, - apply_id: str | None = None, - user_id: str | None = None, - limit: int = 50, - include: list[str] | None = None, - alpha: float = 0.6, - use_forgetting_curve: bool = False, - embedding_id: str | None = None, - **kwargs -) -> dict: - """运行混合搜索(向后兼容的函数式API) - - 这是一个向后兼容的包装函数,将旧的函数式API转换为新的基于类的API。 - - Args: - query_text: 查询文本 - search_type: 搜索类型("hybrid", "keyword", "semantic") - group_id: 组ID过滤 - apply_id: 应用ID过滤 - user_id: 用户ID过滤 - limit: 每个类别的最大结果数 - include: 要包含的搜索类别列表 - alpha: BM25分数权重(0.0-1.0) - use_forgetting_curve: 是否使用遗忘曲线 - embedding_id: 嵌入模型ID - **kwargs: 其他参数 - - Returns: - dict: 搜索结果字典,格式与旧API兼容 - """ - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.core.memory.utils.config.config_utils import get_embedder_config - from app.core.memory.utils.config import definitions as config_defs - from app.core.models.base import RedBearModelConfig - - # 使用提供的embedding_id或默认值 - emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID - - # 初始化客户端 - connector = Neo4jConnector() - embedder_config_dict = get_embedder_config(emb_id) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) - - try: - # 根据搜索类型选择策略 - if search_type == "keyword": - strategy = KeywordSearchStrategy(connector=connector) - elif search_type == "semantic": - strategy = SemanticSearchStrategy( - connector=connector, - embedder_client=embedder_client - ) - else: # hybrid - strategy = HybridSearchStrategy( - connector=connector, - embedder_client=embedder_client, - alpha=alpha, - use_forgetting_curve=use_forgetting_curve - ) - - # 执行搜索 - result = await strategy.search( - query_text=query_text, - group_id=group_id, - limit=limit, - include=include, - alpha=alpha, - use_forgetting_curve=use_forgetting_curve, - **kwargs - ) - - # 转换为旧格式 - result_dict = result.to_dict() - - # 保存到文件(如果指定了output_path) - output_path = kwargs.get('output_path', 'search_results.json') - if output_path: - import json - import os - from datetime import datetime - - try: - # 确保目录存在 - out_dir = os.path.dirname(output_path) - if out_dir: - os.makedirs(out_dir, exist_ok=True) - - # 保存结果 - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str) - print(f"Search results saved to {output_path}") - except Exception as e: - print(f"Error saving search results: {e}") - return result_dict - - finally: - await connector.close() - - -__all__.append("run_hybrid_search") diff --git a/app/core/memory/storage_services/search/hybrid_chatbot.py b/app/core/memory/storage_services/search/hybrid_chatbot.py deleted file mode 100644 index 5b3e6827..00000000 --- a/app/core/memory/storage_services/search/hybrid_chatbot.py +++ /dev/null @@ -1,447 +0,0 @@ - -# TODO hybrid_chatbot.py 是一个独立的GUI演示应用,不是核心功能的一部分,可以考虑删除 -from app.core.memory.utils.llm.llm_utils import get_llm_client -import asyncio -import os -import time -import json -from datetime import datetime, timezone -import tkinter as tk -from tkinter import scrolledtext, messagebox -import threading -from typing import Any, Dict, Tuple, List - -# Import our hybrid search functionality -from app.core.memory.storage_services.search import run_hybrid_search -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.src.llm_tools.openai_client import OpenAIClient -from app.core.memory.models.config_models import LLMConfig -from dotenv import load_dotenv - -load_dotenv() - - -class HybridSearchChatbot: - def __init__(self): - - from app.core.memory.utils.config import definitions as config_defs - self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID) - - # Chat history - self.chat_history = [] - - # Search configuration - self.search_config = { - "group_id": "group_wyl_25", - "limit": 10, - "include": ["statements", "chunks", "entities","summaries"], - # "include": ["statements", "dialogues", "entities"], - "rerank_alpha": 0.6 - } - - # Setup GUI - self.setup_gui() - - def setup_gui(self): - """Setup the GUI interface""" - self.root = tk.Tk() - self.root.title("Hybrid Search Chatbot") - self.root.geometry("800x600") - - # Chat display area - self.chat_display = scrolledtext.ScrolledText( - self.root, - wrap=tk.WORD, - width=80, - height=25, - state=tk.DISABLED - ) - self.chat_display.pack(padx=10, pady=10, fill=tk.BOTH, expand=True) - - # Input frame - input_frame = tk.Frame(self.root) - input_frame.pack(padx=10, pady=5, fill=tk.X) - - # User input - self.user_input = tk.Entry(input_frame, font=("Arial", 12)) - self.user_input.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 5)) - self.user_input.bind("", self.on_send_message) - - # Send button - self.send_button = tk.Button( - input_frame, - text="发送", - command=self.on_send_message, - font=("Arial", 12) - ) - self.send_button.pack(side=tk.RIGHT) - - # Status frame - status_frame = tk.Frame(self.root) - status_frame.pack(padx=10, pady=5, fill=tk.X) - - # Status label - self.status_label = tk.Label( - status_frame, - text="就绪", - font=("Arial", 10), - anchor="w" - ) - self.status_label.pack(side=tk.LEFT, fill=tk.X, expand=True) - - # Search config button - config_button = tk.Button( - status_frame, - text="搜索配置", - command=self.show_config_dialog, - font=("Arial", 10) - ) - config_button.pack(side=tk.RIGHT) - - # Add welcome message - self.add_message("系统", "欢迎使用混合搜索聊天机器人!我可以基于知识图谱中的信息回答您的问题。") - - def add_message(self, sender: str, message: str, metadata: Dict = None): - """Add a message to the chat display""" - self.chat_display.config(state=tk.NORMAL) - - timestamp = datetime.now().strftime("%H:%M:%S") - - # Add sender and timestamp - self.chat_display.insert(tk.END, f"[{timestamp}] {sender}:\n", "sender") - - # Add message content - self.chat_display.insert(tk.END, f"{message}\n", "message") - - # Add metadata if available - if metadata: - self.chat_display.insert(tk.END, f" {metadata}\n", "metadata") - - self.chat_display.insert(tk.END, "\n") - self.chat_display.config(state=tk.DISABLED) - self.chat_display.see(tk.END) - - # Configure text tags for styling - self.chat_display.tag_config("sender", foreground="blue", font=("Arial", 10, "bold")) - self.chat_display.tag_config("message", foreground="black", font=("Arial", 10)) - self.chat_display.tag_config("metadata", foreground="gray", font=("Arial", 8)) - - def show_config_dialog(self): - """Show search configuration dialog""" - config_window = tk.Toplevel(self.root) - config_window.title("搜索配置") - config_window.geometry("400x600") - config_window.transient(self.root) - config_window.grab_set() - - # Current configuration display - current_config_frame = tk.Frame(config_window) - current_config_frame.pack(pady=10, padx=10, fill=tk.X) - tk.Label(current_config_frame, text="当前配置:", font=("Arial", 10, "bold")).pack(anchor="w") - current_text = f"Alpha: {self.search_config['rerank_alpha']}, 限制: {self.search_config['limit']}, 目标: {', '.join(self.search_config['include'])}" - tk.Label(current_config_frame, text=current_text, font=("Arial", 9), fg="blue").pack(anchor="w") - - # Alpha parameter - tk.Label(config_window, text="重排权重 (Alpha):").pack(pady=(10, 5)) - alpha_var = tk.DoubleVar(value=self.search_config["rerank_alpha"]) - alpha_scale = tk.Scale( - config_window, - from_=0.0, - to=1.0, - resolution=0.1, - orient=tk.HORIZONTAL, - variable=alpha_var - ) - alpha_scale.pack(pady=5, padx=20, fill=tk.X) - tk.Label(config_window, text="0.0=纯语义搜索, 1.0=纯关键词搜索", font=("Arial", 8)).pack() - - # Limit parameter - tk.Label(config_window, text="搜索结果数量:").pack(pady=(20, 5)) - limit_var = tk.IntVar(value=self.search_config["limit"]) - limit_spinbox = tk.Spinbox( - config_window, - from_=1, - to=50, - textvariable=limit_var, - width=10 - ) - limit_spinbox.pack(pady=5) - - # Include options - tk.Label(config_window, text="搜索目标:").pack(pady=(20, 5)) - include_frame = tk.Frame(config_window) - include_frame.pack(pady=5) - - include_vars = {} - for option in ["statements", "chunks", "entities","summaries"]: - var = tk.BooleanVar(value=option in self.search_config["include"]) - include_vars[option] = var - tk.Checkbutton( - include_frame, - text=option, - variable=var - ).pack(side=tk.LEFT, padx=10) - - # Buttons - button_frame = tk.Frame(config_window) - button_frame.pack(pady=20) - - def save_config(): - try: - # Validate inputs - alpha_value = alpha_var.get() - limit_value = limit_var.get() - include_list = [ - option for option, var in include_vars.items() if var.get() - ] - - # Check if at least one search target is selected - if not include_list: - messagebox.showerror("配置错误", "请至少选择一个搜索目标!") - return - - # Update configuration - self.search_config["rerank_alpha"] = alpha_value - self.search_config["limit"] = limit_value - self.search_config["include"] = include_list - - config_window.destroy() - self.add_message("系统", - f"配置已更新: Alpha={alpha_value:.1f}, 限制={limit_value}, 目标={', '.join(include_list)}") - - except Exception as e: - messagebox.showerror("配置错误", f"保存配置时出错: {str(e)}") - print(f"Config save error: {e}") # Debug output - - tk.Button(button_frame, text="保存", command=save_config).pack(side=tk.LEFT, padx=5) - tk.Button(button_frame, text="取消", command=config_window.destroy).pack(side=tk.LEFT, padx=5) - - def on_send_message(self, event=None): - """Handle sending a message""" - user_message = self.user_input.get().strip() - if not user_message: - return - - # Clear input - self.user_input.delete(0, tk.END) - - # Add user message to display - self.add_message("用户", user_message) - - # Disable send button and show processing status - self.send_button.config(state=tk.DISABLED) - self.status_label.config(text="正在搜索和生成回复...") - - # Process message in background thread - threading.Thread( - target=self.process_message_async, - args=(user_message,), - daemon=True - ).start() - - def process_message_async(self, user_message: str): - """Process message asynchronously""" - try: - # Run the async processing - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - response, metadata = loop.run_until_complete( - self.process_message(user_message) - ) - loop.close() - - # Update GUI in main thread - self.root.after(0, self.on_response_ready, response, metadata) - - except Exception as e: - error_msg = f"处理消息时出错: {str(e)}" - self.root.after(0, self.on_error, error_msg) - - async def process_message(self, user_message: str) -> Tuple[str, Dict[str, Any]]: - """Process user message with hybrid search""" - start_time = time.time() - - # Perform hybrid search - search_start = time.time() - search_results = await run_hybrid_search( - query_text=user_message, - search_type="hybrid", - group_id=self.search_config["group_id"], - limit=self.search_config["limit"], - include=self.search_config["include"], - output_path=None, - rerank_alpha=self.search_config["rerank_alpha"] - ) - search_time = time.time() - search_start - - # Extract relevant information from search results - context_info = self.extract_context_from_search(search_results) - - # Generate response using LLM - llm_start = time.time() - response = await self.generate_response(user_message, context_info) - llm_time = time.time() - llm_start - - total_time = time.time() - start_time - - # Prepare metadata - metadata = { - "搜索时间": f"{search_time:.2f}s", - "生成时间": f"{llm_time:.2f}s", - "总时间": f"{total_time:.2f}s", - "搜索结果": self.get_search_summary(search_results), - "重排权重": self.search_config["rerank_alpha"] - } - - return response, metadata - - def extract_context_from_search(self, search_results: Dict) -> str: - """Extract context information from search results""" - if not search_results: - return "未找到相关信息。" - - context_parts = [] - - # Get reranked results if available, otherwise use individual results - if "reranked_results" in search_results: - results = search_results["reranked_results"] - else: - results = {} - for key in ["keyword_search", "embedding_search"]: - if key in search_results: - for category, items in search_results[key].items(): - if category not in results: - results[category] = [] - results[category].extend(items) - - # Extract statements - if "statements" in results and results["statements"]: - statements = results["statements"][:5] # Top 5 - context_parts.append("相关陈述:") - for i, stmt in enumerate(statements, 1): - content = stmt.get("statement", "") - score = stmt.get("combined_score", stmt.get("score", 0)) - context_parts.append(f"{i}. {content} (相关度: {score:.3f})") - - # Extract chunks - if "chunks" in results and results["chunks"]: - chunks = results["chunks"][:3] # Top 3 - context_parts.append("\n相关对话:") - for i, chunk in enumerate(chunks, 1): - content = chunk.get("content", "") - score = chunk.get("combined_score", chunk.get("score", 0)) - context_parts.append(f"{i}. {content} (相关度: {score:.3f})") - - # Extract entities - if "entities" in results and results["entities"]: - entities = results["entities"][:5] # Top 5 - context_parts.append("\n相关实体:") - entity_names = [ent.get("name", "") for ent in entities] - context_parts.append(", ".join(entity_names)) - - return "\n".join(context_parts) if context_parts else "未找到相关信息。" - - def get_search_summary(self, search_results: Dict) -> str: - """Get a summary of search results""" - if not search_results: - return "无结果" - - summary_parts = [] - - if "combined_summary" in search_results: - summary = search_results["combined_summary"] - if "total_reranked_results" in summary: - summary_parts.append(f"重排结果: {summary['total_reranked_results']}") - if "total_keyword_results" in summary: - summary_parts.append(f"关键词: {summary['total_keyword_results']}") - if "total_embedding_results" in summary: - summary_parts.append(f"语义: {summary['total_embedding_results']}") - - return ", ".join(summary_parts) if summary_parts else "有结果" - - async def generate_response(self, user_message: str, context: str) -> str: - """Generate response using LLM""" - system_prompt = f"""你是一个智能助手,基于知识图谱中的信息回答用户问题。 - -以下是从知识图谱中检索到的相关信息: -{context} - -请基于这些信息回答用户的问题。如果信息不足,请诚实地说明。回答要自然、友好,并且准确。""" - - try: - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_message} - ] - - response = self.llm_client.chat( - messages=messages, - ) - print(response) - # Extract content from various possible response types - # 1) LangChain AIMessage or similar object with `.content` - if hasattr(response, 'content'): - return getattr(response, 'content') - - # 2) OpenAI-style response with `.choices` - if hasattr(response, 'choices') and response.choices: - first_choice = response.choices[0] - # Newer clients may have `.message.content`, some have `.content` directly - if hasattr(first_choice, 'message') and hasattr(first_choice.message, 'content'): - return first_choice.message.content - if hasattr(first_choice, 'content'): - return first_choice.content - - # 3) Dict-like responses - if isinstance(response, dict): - if 'content' in response: - return response['content'] - if 'choices' in response and response['choices']: - ch = response['choices'][0] - if isinstance(ch, dict): - if 'message' in ch and 'content' in ch['message']: - return ch['message']['content'] - if 'content' in ch: - return ch['content'] - - # 4) Fallback: if it's a plain string - if isinstance(response, str): - return response - - # Default fallback - return "抱歉,我无法生成回复。" - - except Exception as e: - return f"生成回复时出错: {str(e)}" - - def on_response_ready(self, response: str, metadata: Dict[str, Any]): - """Handle when response is ready""" - self.add_message("助手", response, metadata) - self.send_button.config(state=tk.NORMAL) - self.status_label.config(text="就绪") - self.user_input.focus() - - def on_error(self, error_message: str): - """Handle errors""" - self.add_message("系统", f" {error_message}") - self.send_button.config(state=tk.NORMAL) - self.status_label.config(text="就绪") - self.user_input.focus() - - def run(self): - """Start the chatbot""" - self.root.mainloop() - - -def main(): - """Main function to run the chatbot""" - try: - chatbot = HybridSearchChatbot() - chatbot.run() - except Exception as e: - print(f"启动聊天机器人时出错: {e}") - - -if __name__ == "__main__": - main() diff --git a/app/core/memory/storage_services/search/hybrid_search.py b/app/core/memory/storage_services/search/hybrid_search.py deleted file mode 100644 index 8203aacf..00000000 --- a/app/core/memory/storage_services/search/hybrid_search.py +++ /dev/null @@ -1,408 +0,0 @@ -# -*- coding: utf-8 -*- -"""混合搜索策略 - -结合关键词搜索和语义搜索的混合检索方法。 -支持结果重排序和遗忘曲线加权。 -""" - -from typing import List, Dict, Any, Optional -import math -from datetime import datetime -from app.core.logging_config import get_memory_logger -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult -from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy -from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy -from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.models.variate_config import ForgettingEngineConfig -from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine - -logger = get_memory_logger(__name__) - - -class HybridSearchStrategy(SearchStrategy): - """混合搜索策略 - - 结合关键词搜索和语义搜索的优势: - - 关键词搜索:精确匹配,适合已知术语 - - 语义搜索:语义理解,适合概念查询 - - 混合重排序:综合两种搜索的结果 - - 遗忘曲线:根据时间衰减调整相关性 - """ - - def __init__( - self, - connector: Optional[Neo4jConnector] = None, - embedder_client: Optional[OpenAIEmbedderClient] = None, - alpha: float = 0.6, - use_forgetting_curve: bool = False, - forgetting_config: Optional[ForgettingEngineConfig] = None - ): - """初始化混合搜索策略 - - Args: - connector: Neo4j连接器 - embedder_client: 嵌入模型客户端 - alpha: BM25分数权重(0.0-1.0),1-alpha为嵌入分数权重 - use_forgetting_curve: 是否使用遗忘曲线 - forgetting_config: 遗忘引擎配置 - """ - self.connector = connector - self.embedder_client = embedder_client - self.alpha = alpha - self.use_forgetting_curve = use_forgetting_curve - self.forgetting_config = forgetting_config or ForgettingEngineConfig() - self._owns_connector = connector is None - - # 创建子策略 - self.keyword_strategy = KeywordSearchStrategy(connector=connector) - self.semantic_strategy = SemanticSearchStrategy( - connector=connector, - embedder_client=embedder_client - ) - - async def __aenter__(self): - """异步上下文管理器入口""" - if self._owns_connector: - self.connector = Neo4jConnector() - self.keyword_strategy.connector = self.connector - self.semantic_strategy.connector = self.connector - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" - if self._owns_connector and self.connector: - await self.connector.close() - - async def search( - self, - query_text: str, - group_id: Optional[str] = None, - limit: int = 50, - include: Optional[List[str]] = None, - **kwargs - ) -> SearchResult: - """执行混合搜索 - - Args: - query_text: 查询文本 - group_id: 可选的组ID过滤 - limit: 每个类别的最大结果数 - include: 要包含的搜索类别列表 - **kwargs: 其他搜索参数(如alpha, use_forgetting_curve) - - Returns: - SearchResult: 搜索结果对象 - """ - logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}") - - # 从kwargs中获取参数 - alpha = kwargs.get("alpha", self.alpha) - use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve) - - # 获取有效的搜索类别 - include_list = self._get_include_list(include) - - try: - # 并行执行关键词搜索和语义搜索 - keyword_result = await self.keyword_strategy.search( - query_text=query_text, - group_id=group_id, - limit=limit, - include=include_list - ) - - semantic_result = await self.semantic_strategy.search( - query_text=query_text, - group_id=group_id, - limit=limit, - include=include_list - ) - - # 重排序结果 - if use_forgetting: - reranked_results = self._rerank_with_forgetting_curve( - keyword_result=keyword_result, - semantic_result=semantic_result, - alpha=alpha, - limit=limit - ) - else: - reranked_results = self._rerank_hybrid_results( - keyword_result=keyword_result, - semantic_result=semantic_result, - alpha=alpha, - limit=limit - ) - - # 创建元数据 - metadata = self._create_metadata( - query_text=query_text, - search_type="hybrid", - group_id=group_id, - limit=limit, - include=include_list, - alpha=alpha, - use_forgetting_curve=use_forgetting - ) - - # 添加结果统计 - metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {}) - metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {}) - metadata["total_keyword_results"] = keyword_result.total_results() - metadata["total_semantic_results"] = semantic_result.total_results() - metadata["total_reranked_results"] = reranked_results.total_results() - - reranked_results.metadata = metadata - - logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果") - return reranked_results - - except Exception as e: - logger.error(f"混合搜索失败: {e}", exc_info=True) - # 返回空结果但包含错误信息 - return SearchResult( - metadata=self._create_metadata( - query_text=query_text, - search_type="hybrid", - group_id=group_id, - limit=limit, - error=str(e) - ) - ) - - def _normalize_scores( - self, - results: List[Dict[str, Any]], - score_field: str = "score" - ) -> List[Dict[str, Any]]: - """使用z-score标准化和sigmoid转换归一化分数 - - Args: - results: 结果列表 - score_field: 分数字段名 - - Returns: - List[Dict[str, Any]]: 归一化后的结果列表 - """ - if not results: - return results - - # 提取分数 - scores = [] - for item in results: - if score_field in item: - score = item.get(score_field) - if score is not None and isinstance(score, (int, float)): - scores.append(float(score)) - else: - scores.append(0.0) - - if not scores or len(scores) == 1: - # 单个分数或无分数,设置为1.0 - for item in results: - if score_field in item: - item[f"normalized_{score_field}"] = 1.0 - return results - - # 计算均值和标准差 - mean_score = sum(scores) / len(scores) - variance = sum((score - mean_score) ** 2 for score in scores) / len(scores) - std_dev = math.sqrt(variance) - - if std_dev == 0: - # 所有分数相同,设置为1.0 - for item in results: - if score_field in item: - item[f"normalized_{score_field}"] = 1.0 - else: - # z-score标准化 + sigmoid转换 - for item in results: - if score_field in item: - score = item[score_field] - if score is None or not isinstance(score, (int, float)): - score = 0.0 - z_score = (score - mean_score) / std_dev - normalized = 1 / (1 + math.exp(-z_score)) - item[f"normalized_{score_field}"] = normalized - - return results - - def _rerank_hybrid_results( - self, - keyword_result: SearchResult, - semantic_result: SearchResult, - alpha: float, - limit: int - ) -> SearchResult: - """重排序混合搜索结果 - - Args: - keyword_result: 关键词搜索结果 - semantic_result: 语义搜索结果 - alpha: BM25分数权重 - limit: 结果限制 - - Returns: - SearchResult: 重排序后的结果 - """ - reranked_data = {} - - for category in ["statements", "chunks", "entities", "summaries"]: - keyword_items = getattr(keyword_result, category, []) - semantic_items = getattr(semantic_result, category, []) - - # 归一化分数 - keyword_items = self._normalize_scores(keyword_items, "score") - semantic_items = self._normalize_scores(semantic_items, "score") - - # 合并结果 - combined_items = {} - - # 添加关键词结果 - for item in keyword_items: - item_id = item.get("id") or item.get("uuid") - if item_id: - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) - combined_items[item_id]["embedding_score"] = 0 - - # 添加或更新语义结果 - for item in semantic_items: - item_id = item.get("id") or item.get("uuid") - if item_id: - if item_id in combined_items: - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - else: - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = 0 - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - - # 计算组合分数 - for item_id, item in combined_items.items(): - bm25_score = item.get("bm25_score", 0) - embedding_score = item.get("embedding_score", 0) - combined_score = alpha * bm25_score + (1 - alpha) * embedding_score - item["combined_score"] = combined_score - - # 排序并限制结果 - sorted_items = sorted( - combined_items.values(), - key=lambda x: x.get("combined_score", 0), - reverse=True - )[:limit] - - reranked_data[category] = sorted_items - - return SearchResult( - statements=reranked_data.get("statements", []), - chunks=reranked_data.get("chunks", []), - entities=reranked_data.get("entities", []), - summaries=reranked_data.get("summaries", []) - ) - - def _parse_datetime(self, value: Any) -> Optional[datetime]: - """解析日期时间字符串""" - if value is None: - return None - if isinstance(value, datetime): - return value - if isinstance(value, str): - s = value.strip() - if not s: - return None - try: - return datetime.fromisoformat(s) - except Exception: - return None - return None - - def _rerank_with_forgetting_curve( - self, - keyword_result: SearchResult, - semantic_result: SearchResult, - alpha: float, - limit: int - ) -> SearchResult: - """使用遗忘曲线重排序混合搜索结果 - - Args: - keyword_result: 关键词搜索结果 - semantic_result: 语义搜索结果 - alpha: BM25分数权重 - limit: 结果限制 - - Returns: - SearchResult: 重排序后的结果 - """ - engine = ForgettingEngine(self.forgetting_config) - now_dt = datetime.now() - - reranked_data = {} - - for category in ["statements", "chunks", "entities", "summaries"]: - keyword_items = getattr(keyword_result, category, []) - semantic_items = getattr(semantic_result, category, []) - - # 归一化分数 - keyword_items = self._normalize_scores(keyword_items, "score") - semantic_items = self._normalize_scores(semantic_items, "score") - - # 合并结果 - combined_items = {} - - for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]: - for item in src_items: - item_id = item.get("id") or item.get("uuid") - if not item_id: - continue - - if item_id not in combined_items: - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = 0 - combined_items[item_id]["embedding_score"] = 0 - - if is_embedding: - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - else: - combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) - - # 计算分数并应用遗忘权重 - for item_id, item in combined_items.items(): - bm25_score = float(item.get("bm25_score", 0) or 0) - embedding_score = float(item.get("embedding_score", 0) or 0) - combined_score = alpha * bm25_score + (1 - alpha) * embedding_score - - # 计算时间衰减 - dt = self._parse_datetime(item.get("created_at")) - if dt is None: - time_elapsed_days = 0.0 - else: - time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) - - memory_strength = 1.0 # 默认强度 - forgetting_weight = engine.calculate_weight( - time_elapsed=time_elapsed_days, - memory_strength=memory_strength - ) - - final_score = combined_score * forgetting_weight - item["combined_score"] = final_score - item["forgetting_weight"] = forgetting_weight - item["time_elapsed_days"] = time_elapsed_days - - # 排序并限制结果 - sorted_items = sorted( - combined_items.values(), - key=lambda x: x.get("combined_score", 0), - reverse=True - )[:limit] - - reranked_data[category] = sorted_items - - return SearchResult( - statements=reranked_data.get("statements", []), - chunks=reranked_data.get("chunks", []), - entities=reranked_data.get("entities", []), - summaries=reranked_data.get("summaries", []) - ) diff --git a/app/core/memory/storage_services/search/keyword_search.py b/app/core/memory/storage_services/search/keyword_search.py deleted file mode 100644 index 95dd0581..00000000 --- a/app/core/memory/storage_services/search/keyword_search.py +++ /dev/null @@ -1,122 +0,0 @@ -# -*- coding: utf-8 -*- -"""关键词搜索策略 - -实现基于关键词的全文搜索功能。 -使用Neo4j的全文索引进行高效的文本匹配。 -""" - -from typing import List, Dict, Any, Optional -from app.core.logging_config import get_memory_logger -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult -from app.repositories.neo4j.graph_search import search_graph - -logger = get_memory_logger(__name__) - - -class KeywordSearchStrategy(SearchStrategy): - """关键词搜索策略 - - 使用Neo4j全文索引进行关键词匹配搜索。 - 支持跨陈述句、实体、分块和摘要的搜索。 - """ - - def __init__(self, connector: Optional[Neo4jConnector] = None): - """初始化关键词搜索策略 - - Args: - connector: Neo4j连接器,如果为None则创建新连接 - """ - self.connector = connector - self._owns_connector = connector is None - - async def __aenter__(self): - """异步上下文管理器入口""" - if self._owns_connector: - self.connector = Neo4jConnector() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" - if self._owns_connector and self.connector: - await self.connector.close() - - async def search( - self, - query_text: str, - group_id: Optional[str] = None, - limit: int = 50, - include: Optional[List[str]] = None, - **kwargs - ) -> SearchResult: - """执行关键词搜索 - - Args: - query_text: 查询文本 - group_id: 可选的组ID过滤 - limit: 每个类别的最大结果数 - include: 要包含的搜索类别列表 - **kwargs: 其他搜索参数 - - Returns: - SearchResult: 搜索结果对象 - """ - logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}") - - # 获取有效的搜索类别 - include_list = self._get_include_list(include) - - # 确保连接器已初始化 - if not self.connector: - self.connector = Neo4jConnector() - - try: - # 调用底层的关键词搜索函数 - results_dict = await search_graph( - connector=self.connector, - q=query_text, - group_id=group_id, - limit=limit, - include=include_list - ) - - # 创建元数据 - metadata = self._create_metadata( - query_text=query_text, - search_type="keyword", - group_id=group_id, - limit=limit, - include=include_list - ) - - # 添加结果统计 - metadata["result_counts"] = { - category: len(results_dict.get(category, [])) - for category in include_list - } - metadata["total_results"] = sum(metadata["result_counts"].values()) - - # 构建SearchResult对象 - search_result = SearchResult( - statements=results_dict.get("statements", []), - chunks=results_dict.get("chunks", []), - entities=results_dict.get("entities", []), - summaries=results_dict.get("summaries", []), - metadata=metadata - ) - - logger.info(f"关键词搜索完成: 共找到 {search_result.total_results()} 条结果") - return search_result - - except Exception as e: - logger.error(f"关键词搜索失败: {e}", exc_info=True) - # 返回空结果但包含错误信息 - return SearchResult( - metadata=self._create_metadata( - query_text=query_text, - search_type="keyword", - group_id=group_id, - limit=limit, - error=str(e) - ) - ) diff --git a/app/core/memory/storage_services/search/search_strategy.py b/app/core/memory/storage_services/search/search_strategy.py deleted file mode 100644 index 27c02c89..00000000 --- a/app/core/memory/storage_services/search/search_strategy.py +++ /dev/null @@ -1,125 +0,0 @@ -# -*- coding: utf-8 -*- -"""搜索策略基类 - -定义搜索策略的抽象接口和统一的搜索结果数据结构。 -遵循策略模式(Strategy Pattern)和开放-关闭原则(OCP)。 -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional -from pydantic import BaseModel, Field -from datetime import datetime - - -class SearchResult(BaseModel): - """统一的搜索结果数据结构 - - Attributes: - statements: 陈述句搜索结果列表 - chunks: 分块搜索结果列表 - entities: 实体搜索结果列表 - summaries: 摘要搜索结果列表 - metadata: 搜索元数据(如查询时间、结果数量等) - """ - statements: List[Dict[str, Any]] = Field(default_factory=list, description="陈述句搜索结果") - chunks: List[Dict[str, Any]] = Field(default_factory=list, description="分块搜索结果") - entities: List[Dict[str, Any]] = Field(default_factory=list, description="实体搜索结果") - summaries: List[Dict[str, Any]] = Field(default_factory=list, description="摘要搜索结果") - metadata: Dict[str, Any] = Field(default_factory=dict, description="搜索元数据") - - def total_results(self) -> int: - """返回所有类别的结果总数""" - return ( - len(self.statements) + - len(self.chunks) + - len(self.entities) + - len(self.summaries) - ) - - def to_dict(self) -> Dict[str, Any]: - """转换为字典格式""" - return { - "statements": self.statements, - "chunks": self.chunks, - "entities": self.entities, - "summaries": self.summaries, - "metadata": self.metadata - } - - -class SearchStrategy(ABC): - """搜索策略抽象基类 - - 定义所有搜索策略必须实现的接口。 - 遵循依赖反转原则(DIP):高层模块依赖抽象而非具体实现。 - """ - - @abstractmethod - async def search( - self, - query_text: str, - group_id: Optional[str] = None, - limit: int = 50, - include: Optional[List[str]] = None, - **kwargs - ) -> SearchResult: - """执行搜索 - - Args: - query_text: 查询文本 - group_id: 可选的组ID过滤 - limit: 每个类别的最大结果数 - include: 要包含的搜索类别列表(statements, chunks, entities, summaries) - **kwargs: 其他搜索参数 - - Returns: - SearchResult: 统一的搜索结果对象 - """ - pass - - def _create_metadata( - self, - query_text: str, - search_type: str, - group_id: Optional[str] = None, - limit: int = 50, - **kwargs - ) -> Dict[str, Any]: - """创建搜索元数据 - - Args: - query_text: 查询文本 - search_type: 搜索类型 - group_id: 组ID - limit: 结果限制 - **kwargs: 其他元数据 - - Returns: - Dict[str, Any]: 元数据字典 - """ - metadata = { - "query": query_text, - "search_type": search_type, - "group_id": group_id, - "limit": limit, - "timestamp": datetime.now().isoformat() - } - metadata.update(kwargs) - return metadata - - def _get_include_list(self, include: Optional[List[str]] = None) -> List[str]: - """获取要包含的搜索类别列表 - - Args: - include: 用户指定的类别列表 - - Returns: - List[str]: 有效的类别列表 - """ - default_include = ["statements", "chunks", "entities", "summaries"] - if include is None: - return default_include - - # 验证并过滤有效的类别 - valid_categories = set(default_include) - return [cat for cat in include if cat in valid_categories] diff --git a/app/core/memory/storage_services/search/semantic_search.py b/app/core/memory/storage_services/search/semantic_search.py deleted file mode 100644 index 38c58cc1..00000000 --- a/app/core/memory/storage_services/search/semantic_search.py +++ /dev/null @@ -1,159 +0,0 @@ -# -*- coding: utf-8 -*- -"""语义搜索策略 - -实现基于向量嵌入的语义搜索功能。 -使用余弦相似度进行语义匹配。 -""" - -from typing import List, Dict, Any, Optional -from app.core.logging_config import get_memory_logger -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult -from app.repositories.neo4j.graph_search import search_graph_by_embedding -from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.utils.config import definitions as config_defs -from app.core.models.base import RedBearModelConfig - -logger = get_memory_logger(__name__) - - -class SemanticSearchStrategy(SearchStrategy): - """语义搜索策略 - - 使用向量嵌入和余弦相似度进行语义搜索。 - 支持跨陈述句、分块、实体和摘要的语义匹配。 - """ - - def __init__( - self, - connector: Optional[Neo4jConnector] = None, - embedder_client: Optional[OpenAIEmbedderClient] = None - ): - """初始化语义搜索策略 - - Args: - connector: Neo4j连接器,如果为None则创建新连接 - embedder_client: 嵌入模型客户端,如果为None则根据配置创建 - """ - self.connector = connector - self.embedder_client = embedder_client - self._owns_connector = connector is None - self._owns_embedder = embedder_client is None - - async def __aenter__(self): - """异步上下文管理器入口""" - if self._owns_connector: - self.connector = Neo4jConnector() - if self._owns_embedder: - self.embedder_client = self._create_embedder_client() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" - if self._owns_connector and self.connector: - await self.connector.close() - - def _create_embedder_client(self) -> OpenAIEmbedderClient: - """创建嵌入模型客户端 - - Returns: - OpenAIEmbedderClient: 嵌入模型客户端实例 - """ - try: - # 从数据库读取嵌入器配置 - embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID) - rb_config = RedBearModelConfig( - model_name=embedder_config_dict["model_name"], - provider=embedder_config_dict["provider"], - api_key=embedder_config_dict["api_key"], - base_url=embedder_config_dict["base_url"], - type="llm" - ) - return OpenAIEmbedderClient(model_config=rb_config) - except Exception as e: - logger.error(f"创建嵌入模型客户端失败: {e}", exc_info=True) - raise - - async def search( - self, - query_text: str, - group_id: Optional[str] = None, - limit: int = 50, - include: Optional[List[str]] = None, - **kwargs - ) -> SearchResult: - """执行语义搜索 - - Args: - query_text: 查询文本 - group_id: 可选的组ID过滤 - limit: 每个类别的最大结果数 - include: 要包含的搜索类别列表 - **kwargs: 其他搜索参数 - - Returns: - SearchResult: 搜索结果对象 - """ - logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}") - - # 获取有效的搜索类别 - include_list = self._get_include_list(include) - - # 确保连接器和嵌入器已初始化 - if not self.connector: - self.connector = Neo4jConnector() - if not self.embedder_client: - self.embedder_client = self._create_embedder_client() - - try: - # 调用底层的语义搜索函数 - results_dict = await search_graph_by_embedding( - connector=self.connector, - embedder_client=self.embedder_client, - query_text=query_text, - group_id=group_id, - limit=limit, - include=include_list - ) - - # 创建元数据 - metadata = self._create_metadata( - query_text=query_text, - search_type="semantic", - group_id=group_id, - limit=limit, - include=include_list - ) - - # 添加结果统计 - metadata["result_counts"] = { - category: len(results_dict.get(category, [])) - for category in include_list - } - metadata["total_results"] = sum(metadata["result_counts"].values()) - - # 构建SearchResult对象 - search_result = SearchResult( - statements=results_dict.get("statements", []), - chunks=results_dict.get("chunks", []), - entities=results_dict.get("entities", []), - summaries=results_dict.get("summaries", []), - metadata=metadata - ) - - logger.info(f"语义搜索完成: 共找到 {search_result.total_results()} 条结果") - return search_result - - except Exception as e: - logger.error(f"语义搜索失败: {e}", exc_info=True) - # 返回空结果但包含错误信息 - return SearchResult( - metadata=self._create_metadata( - query_text=query_text, - search_type="semantic", - group_id=group_id, - limit=limit, - error=str(e) - ) - ) diff --git a/app/core/memory/utils/README.md b/app/core/memory/utils/README.md deleted file mode 100644 index 32264569..00000000 --- a/app/core/memory/utils/README.md +++ /dev/null @@ -1,445 +0,0 @@ -# Memory 模块工具函数文档 - -本目录包含 Memory 模块使用的所有工具函数,统一管理以提高代码可维护性和可复用性。 - -## 目录结构 - -``` -app/core/memory/utils/ -├── __init__.py # 包初始化文件,导出所有公共接口 -├── README.md # 本文档 -├── config/ # 配置管理模块 -│ ├── __init__.py # 配置模块初始化 -│ ├── config_utils.py # 配置管理工具 -│ ├── definitions.py # 全局定义和常量 -│ ├── overrides.py # 运行时配置覆写 -│ ├── get_data.py # 数据获取工具 -│ ├── litellm_config.py # LiteLLM 配置和监控 -│ └── config_optimization.py # 配置优化工具 -├── log/ # 日志管理模块 -│ ├── __init__.py # 日志模块初始化 -│ ├── logging_utils.py # 日志工具 -│ └── audit_logger.py # 审计日志 -├── prompt/ # 提示词管理模块 -│ ├── __init__.py # 提示词模块初始化 -│ ├── prompt_utils.py # 提示词渲染工具 -│ ├── template_render.py # 模板渲染工具 -│ └── prompts/ # Jinja2 提示词模板目录 -│ ├── entity_dedup.jinja2 # 实体去重提示词 -│ ├── extract_statement.jinja2 # 陈述句提取提示词 -│ ├── extract_temporal.jinja2 # 时间信息提取提示词 -│ ├── extract_triplet.jinja2 # 三元组提取提示词 -│ ├── memory_summary.jinja2 # 记忆摘要提示词 -│ ├── evaluate.jinja2 # 评估提示词 -│ ├── reflexion.jinja2 # 反思提示词 -│ ├── system.jinja2 # 系统提示词 -│ └── user.jinja2 # 用户提示词 -├── llm/ # LLM 工具模块 -│ ├── __init__.py # LLM 模块初始化 -│ └── llm_utils.py # LLM 客户端工具 -├── data/ # 数据处理模块 -│ ├── __init__.py # 数据模块初始化 -│ ├── text_utils.py # 文本处理工具 -│ ├── time_utils.py # 时间处理工具 -│ └── ontology.py # 本体定义(谓语、标签等) -├── paths/ # 路径管理模块 -│ ├── __init__.py # 路径模块初始化 -│ └── output_paths.py # 输出路径管理 -├── visualization/ # 可视化模块 -│ ├── __init__.py # 可视化模块初始化 -│ └── forgetting_visualizer.py # 遗忘曲线可视化 -└── self_reflexion_utils/ # 自我反思工具模块 - ├── __init__.py # 反思模块初始化 - ├── evaluate.py # 冲突评估 - ├── reflexion.py # 反思处理 - └── self_reflexion.py # 自我反思主逻辑 -``` - -## 模块分类 - -### 1. 配置管理(config/) - -配置管理模块包含所有与配置相关的工具函数和定义。 - -#### config_utils.py -提供配置加载和管理功能: -- `get_model_config(model_id)` - 获取 LLM 模型配置 -- `get_embedder_config(embedding_id)` - 获取嵌入模型配置 -- `get_neo4j_config()` - 获取 Neo4j 数据库配置 -- `get_chunker_config(chunker_strategy)` - 获取分块策略配置 -- `get_pipeline_config()` - 获取流水线配置 -- `get_pruning_config()` - 获取语义剪枝配置 -- `get_picture_config()` - 获取图片模型配置 -- `get_voice_config()` - 获取语音模型配置 - -#### definitions.py -全局定义和常量: -- `CONFIG` - 基础配置(从 config.json 加载) -- `RUNTIME_CONFIG` - 运行时配置(从 runtime.json 或数据库加载) -- `PROJECT_ROOT` - 项目根目录路径 -- 各种选择配置常量(LLM、嵌入模型、分块策略等) -- `reload_configuration_from_database(config_id)` - 动态重新加载配置 - -#### overrides.py -运行时配置覆写: -- `load_unified_config(project_root)` - 加载统一配置 - -#### get_data.py -数据获取工具: -- `get_data(host_id)` - 从 SQL 数据库获取数据 - -#### litellm_config.py -LiteLLM 配置和监控: -- `LiteLLMConfig` - LiteLLM 配置类 -- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置 -- `get_usage_summary()` - 获取使用统计摘要 -- `print_usage_summary()` - 打印使用统计 -- `get_instant_qps(module)` - 获取即时 QPS 数据 -- `print_instant_qps(module)` - 打印即时 QPS 信息 - -#### config_optimization.py -配置优化工具: -- 配置参数优化相关功能 - -### 3. LLM 工具(llm/) - -LLM 工具模块包含所有与 LLM 客户端相关的工具函数。 - -#### llm_utils.py -LLM 客户端工具: -- `get_llm_client(llm_id)` - 获取 LLM 客户端实例 -- `get_reranker_client(rerank_id)` - 获取重排序客户端实例 -- `handle_response(response)` - 处理 LLM 响应 - -#### litellm_config.py -LiteLLM 配置和监控: -- `LiteLLMConfig` - LiteLLM 配置类 -- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置 -- `get_usage_summary()` - 获取使用统计摘要 -- `print_usage_summary()` - 打印使用统计 -- `get_instant_qps(module)` - 获取即时 QPS 数据 -- `print_instant_qps(module)` - 打印即时 QPS 信息 - -### 4. 提示词管理(prompt/) - -提示词管理模块包含所有提示词渲染和模板管理相关的工具函数。 - -#### prompt_utils.py -提示词渲染工具(使用 Jinja2 模板): -- `get_prompts(message)` - 获取系统和用户提示词 -- `render_statement_extraction_prompt(...)` - 渲染陈述句提取提示词 -- `render_temporal_extraction_prompt(...)` - 渲染时间信息提取提示词 -- `render_entity_dedup_prompt(...)` - 渲染实体去重提示词 -- `render_triplet_extraction_prompt(...)` - 渲染三元组提取提示词 -- `render_memory_summary_prompt(...)` - 渲染记忆摘要提示词 -- `prompt_env` - Jinja2 环境对象 - -#### template_render.py -模板渲染工具(用于评估和反思): -- `render_evaluate_prompt(evaluate_data, schema)` - 渲染评估提示词 -- `render_reflexion_prompt(data, schema)` - 渲染反思提示词 - -#### prompts/ -Jinja2 模板文件目录,包含所有提示词模板 - -### 5. 数据处理(data/) - -数据处理模块包含所有数据处理相关的工具函数。 - -#### text_utils.py -文本处理工具: -- `escape_lucene_query(query)` - 转义 Lucene 查询特殊字符 -- `extract_plain_query(query_input)` - 从各种输入格式提取纯文本查询 - -#### time_utils.py -时间处理工具: -- `validate_date_format(date_str)` - 验证日期格式(YYYY-MM-DD) -- `normalize_date(date_str)` - 标准化日期格式 -- `normalize_date_safe(date_str, default)` - 安全的日期标准化(带默认值) -- `preprocess_date_string(date_str)` - 预处理日期字符串 - -#### ontology.py -本体定义: -- `PREDICATE_DEFINITIONS` - 谓语定义字典 -- `LABEL_DEFINITIONS` - 标签定义字典 -- `Predicate` - 谓语枚举 -- `StatementType` - 陈述句类型枚举 -- `TemporalInfo` - 时间信息枚举 -- `RelevenceInfo` - 相关性信息枚举 - -### 2. 日志管理(log/) - -日志管理模块包含所有与日志记录相关的工具函数。 - -#### logging_utils.py -日志工具: -- `log_prompt_rendering(role, content)` - 记录提示词渲染 -- `log_template_rendering(template_name, context)` - 记录模板渲染 -- `log_time(operation, duration)` - 记录操作耗时 -- `prompt_logger` - 提示词日志记录器 - -#### audit_logger.py -审计日志: -- `audit_logger` - 审计日志记录器 -- 记录系统关键操作和安全事件 - -### 6. 自我反思工具(self_reflexion_utils/) - -自我反思工具模块包含记忆冲突检测和反思处理功能。 - -#### evaluate.py -冲突评估: -- `conflict(evaluate_data, schema)` - 评估记忆冲突 - -#### reflexion.py -反思处理: -- `reflexion(data, schema)` - 执行反思处理 - -#### self_reflexion.py -自我反思主逻辑: -- `self_reflexion(...)` - 自我反思主函数 - -### 7. 数据模型 - -#### json_schema.py -JSON Schema 数据模型: -- `BaseDataSchema` - 基础数据模型 -- `ConflictResultSchema` - 冲突结果模型 -- `ConflictSchema` - 冲突模型 -- `ReflexionSchema` - 反思模型 -- `ResolvedSchema` - 解决方案模型 -- `ReflexionResultSchema` - 反思结果模型 - -#### messages.py -API 消息模型: -- `ConfigKey` - 配置键模型 -- `ChunkerStrategy` - 分块策略枚举 -- `ConfigParams` - 配置参数模型 -- `ConfigParamsCreate` - 创建配置参数模型 -- `ConfigUpdate` - 更新配置模型 -- `ConfigUpdateExtracted` - 更新萃取引擎配置模型 -- `ConfigUpdateForget` - 更新遗忘引擎配置模型 -- `ConfigPilotRun` - 试运行配置模型 -- `ConfigFilter` - 配置过滤模型 -- `ApiResponse` - API 响应模型 -- `ok(msg, data)` - 成功响应构造函数 -- `fail(msg, error_code, data)` - 失败响应构造函数 - -### 8. 可视化(visualization/) - -可视化模块包含所有可视化相关的工具函数。 - -#### forgetting_visualizer.py -遗忘曲线可视化: -- `export_memory_curve_numpy(...)` - 导出记忆曲线为 NumPy 数组 -- `export_memory_curves_multiple_strengths(...)` - 导出多个强度的记忆曲线 -- `export_parameter_sweep_numpy(...)` - 导出参数扫描结果 -- `visualize_forgetting_curve(...)` - 可视化遗忘曲线 -- `plot_3d_forgetting_surface(...)` - 绘制 3D 遗忘曲线表面 -- `create_comparison_visualization(...)` - 创建对比可视化 -- `save_memory_curves_to_file(...)` - 保存记忆曲线到文件 - -### 9. 路径管理(paths/) - -路径管理模块包含所有路径管理相关的工具函数。 - -#### output_paths.py -输出路径管理: -- `get_output_dir()` - 获取输出目录 -- `get_output_path(filename)` - 获取输出文件路径 - -## 使用示例 - -### 配置管理 - -```python -from app.core.memory.utils.config import get_model_config, get_pipeline_config -from app.core.memory.utils.config.definitions import SELECTED_LLM_ID - -# 获取模型配置 -model_config = get_model_config("model_id_123") - -# 获取流水线配置 -pipeline_config = get_pipeline_config() - -# 使用全局常量 -llm_id = SELECTED_LLM_ID -``` - -### 日志管理 - -```python -from app.core.memory.utils.log import log_prompt_rendering, log_time, audit_logger - -# 记录提示词渲染 -log_prompt_rendering('user', 'Hello, world!') - -# 记录操作耗时 -log_time('extraction', 1.23) - -# 使用审计日志 -audit_logger.info('User action performed') -``` - -### LLM 工具 - -```python -from app.core.memory.utils.llm import get_llm_client - -# 获取 LLM 客户端 -llm_client = get_llm_client("llm_id_456") - -# 调用 LLM -response = await llm_client.chat([ - {"role": "user", "content": "Hello"} -]) -``` - -### 提示词渲染 - -```python -from app.core.memory.utils.prompt import render_statement_extraction_prompt -from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS - -# 渲染陈述句提取提示词 -prompt = await render_statement_extraction_prompt( - chunk_content="对话内容...", - definitions=LABEL_DEFINITIONS, - json_schema=schema, - granularity=2 -) -``` - -### 数据处理 - -```python -from app.core.memory.utils.data.time_utils import normalize_date -from app.core.memory.utils.data.text_utils import escape_lucene_query - -# 标准化日期 -normalized = normalize_date("2025/10/28") # 返回 "2025-10-28" - -# 转义 Lucene 查询 -escaped = escape_lucene_query("user:admin AND status:active") -``` - -### 运行时配置覆写 - -```python -from app.core.memory.utils import apply_runtime_overrides_with_config_id - -# 使用指定 config_id 覆写配置 -runtime_cfg = {"selections": {}} -updated_cfg = apply_runtime_overrides_with_config_id( - project_root="/path/to/project", - runtime_cfg=runtime_cfg, - config_id="config_123" -) -``` - -## 迁移说明 - -### 从旧路径迁移 - -如果你的代码使用了旧的导入路径,请按以下方式更新: - -**旧路径(2024年11月之前):** -```python -from app.core.memory.src.utils.config_utils import get_model_config -from app.core.memory.src.utils.prompt_utils import render_statement_extraction_prompt -from app.core.memory.src.data_config_api.utils.messages import ok, fail -``` - -**中间路径(2024年11月):** -```python -from app.core.memory.utils.config_utils import get_model_config -from app.core.memory.utils.logging_utils import log_prompt_rendering -from app.schemas.memory_storage_schema import ok, fail -``` - -**新路径(2024年11月27日之后):** -```python -# 配置相关 -from app.core.memory.utils.config.config_utils import get_model_config -from app.core.memory.utils.config import get_model_config # 简化导入 - -# 日志相关 -from app.core.memory.utils.log.logging_utils import log_prompt_rendering -from app.core.memory.utils.log import log_prompt_rendering # 简化导入 - -# 其他工具 -from app.core.memory.utils import prompt_utils -from app.schemas.memory_storage_schema import ok, fail -``` - -### 目录结构重组(2024年11月27日) - -utils 目录已按功能进行了完整的重组: - -**重组前的结构:** -- 所有文件都在 `app/core/memory/utils/` 根目录下 - -**重组后的结构:** -- `config/` - 配置管理相关文件 -- `log/` - 日志管理相关文件 -- `prompt/` - 提示词管理相关文件 -- `llm/` - LLM 工具相关文件 -- `data/` - 数据处理相关文件 -- `paths/` - 路径管理相关文件 -- `visualization/` - 可视化相关文件 -- `self_reflexion_utils/` - 自我反思工具(已存在) - -**导入路径变化:** -```python -# 旧导入方式 -from app.core.memory.utils.config_utils import get_model_config -from app.core.memory.utils.logging_utils import log_prompt_rendering -from app.core.memory.utils.prompt_utils import render_statement_extraction_prompt - -# 新导入方式 -from app.core.memory.utils.config.config_utils import get_model_config -from app.core.memory.utils.log.logging_utils import log_prompt_rendering -from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt - -# 或使用简化导入 -from app.core.memory.utils.config import get_model_config -from app.core.memory.utils.log import log_prompt_rendering -from app.core.memory.utils.prompt import render_statement_extraction_prompt -``` - -## 维护指南 - -### 添加新工具函数 - -1. 在相应的模块文件中添加函数 -2. 在 `__init__.py` 中导出函数 -3. 在本 README 中添加文档 -4. 编写单元测试 - -### 删除旧工具函数 - -1. 确认没有代码使用该函数 -2. 从模块文件中删除函数 -3. 从 `__init__.py` 中删除导出 -4. 更新本 README - -### 重构工具函数 - -1. 保持向后兼容性(使用别名或包装器) -2. 更新所有使用该函数的代码 -3. 更新文档和测试 -4. 在适当时机删除旧版本 - -## 注意事项 - -1. **向后兼容性**:所有工具函数应保持向后兼容,避免破坏现有代码 -2. **文档完整性**:每个函数都应有清晰的文档字符串 -3. **类型注解**:使用类型注解提高代码可读性 -4. **错误处理**:工具函数应有适当的错误处理 -5. **测试覆盖**:所有工具函数都应有单元测试 - -## 相关文档 - -- [Memory 模块架构设计](../.kiro/specs/memory-refactoring/design.md) -- [Memory 模块需求文档](../.kiro/specs/memory-refactoring/requirements.md) -- [Memory 模块任务列表](../.kiro/specs/memory-refactoring/tasks.md) diff --git a/app/core/memory/utils/__init__.py b/app/core/memory/utils/__init__.py deleted file mode 100644 index 8b91c46f..00000000 --- a/app/core/memory/utils/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -Memory 模块工具函数包 - -本包包含 Memory 模块使用的所有工具函数,按功能分类管理。 - -目录结构: -- config/: 配置管理模块(config_utils, definitions, overrides, get_data, litellm_config, config_optimization) -- log/: 日志管理模块(logging_utils, audit_logger) -- prompt/: 提示词管理模块(prompt_utils, template_render, prompts/) -- llm/: LLM 工具模块(llm_utils) -- data/: 数据处理模块(text_utils, time_utils, ontology) -- paths/: 路径管理模块(output_paths) -- visualization/: 可视化模块(forgetting_visualizer) -- self_reflexion_utils/: 自我反思工具(evaluate, reflexion, self_reflexion) - -注意: -- json_schema 和 messages 已迁移到 app.schemas.memory_storage_schema -- 所有工具函数已按功能分类到对应的子目录 - -使用示例: - # 配置管理 - from app.core.memory.utils.config import get_model_config - from app.core.memory.utils.config.definitions import SELECTED_LLM_ID - - # 日志管理 - from app.core.memory.utils.log import log_prompt_rendering, audit_logger - - # 提示词管理 - from app.core.memory.utils.prompt import render_statement_extraction_prompt - - # LLM 工具 - from app.core.memory.utils.llm import get_llm_client - - # 数据处理 - from app.core.memory.utils.data import text_utils, time_utils - from app.core.memory.utils.data.ontology import Predicate, StatementType - - # 路径管理 - from app.core.memory.utils.paths import get_output_dir - - # 可视化 - from app.core.memory.utils.visualization import visualize_forgetting_curve - - # 自我反思 - from app.core.memory.utils.self_reflexion_utils import self_reflexion -""" - -# 不在 __init__.py 中进行模块级别的导入,以避免循环导入 -# 用户应该直接导入需要的模块,例如: -# from app.core.memory.utils.config import config_utils -# from app.core.memory.utils.log import logging_utils -# from app.core.memory.utils.data import text_utils -# from app.core.memory.utils.prompt import prompt_utils - -__all__ = [ - # 子模块 - "config", - "log", - "prompt", - "llm", - "data", - "paths", - "visualization", - "self_reflexion_utils", -] diff --git a/app/core/memory/utils/config/__init__.py b/app/core/memory/utils/config/__init__.py deleted file mode 100644 index 2b41b522..00000000 --- a/app/core/memory/utils/config/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -配置管理模块 - -包含所有配置相关的工具函数和定义。 -""" - -# 从子模块导出常用函数和常量,保持向后兼容 -from .config_utils import ( - get_model_config, - get_embedder_config, - get_neo4j_config, - get_chunker_config, - get_pipeline_config, - get_pruning_config, - get_picture_config, - get_voice_config, -) -from .definitions import ( - CONFIG, - RUNTIME_CONFIG, - PROJECT_ROOT, - SELECTED_LLM_ID, - SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, - SELECTED_RERANK_ID, - SELECTED_LLM_PICTURE_NAME, - SELECTED_LLM_VOICE_NAME, - REFLEXION_ENABLED, - REFLEXION_ITERATION_PERIOD, - REFLEXION_RANGE, - REFLEXION_BASELINE, - reload_configuration_from_database, -) -from .overrides import load_unified_config -from .get_data import get_data -# litellm_config 需要时动态导入,避免循环依赖 -# from .litellm_config import ( -# LiteLLMConfig, -# setup_litellm_enhanced, -# get_usage_summary, -# print_usage_summary, -# get_instant_qps, -# print_instant_qps, -# ) - -__all__ = [ - # config_utils - "get_model_config", - "get_embedder_config", - "get_neo4j_config", - "get_chunker_config", - "get_pipeline_config", - "get_pruning_config", - "get_picture_config", - "get_voice_config", - # definitions - "CONFIG", - "RUNTIME_CONFIG", - "PROJECT_ROOT", - "SELECTED_LLM_ID", - "SELECTED_EMBEDDING_ID", - "SELECTED_GROUP_ID", - "SELECTED_RERANK_ID", - "SELECTED_LLM_PICTURE_NAME", - "SELECTED_LLM_VOICE_NAME", - "REFLEXION_ENABLED", - "REFLEXION_ITERATION_PERIOD", - "REFLEXION_RANGE", - "REFLEXION_BASELINE", - "reload_configuration_from_database", - # overrides - "load_unified_config", - # get_data - "get_data", - # litellm_config - 需要时从 .litellm_config 直接导入 - # "LiteLLMConfig", - # "setup_litellm_enhanced", - # "get_usage_summary", - # "print_usage_summary", - # "get_instant_qps", - # "print_instant_qps", -] diff --git a/app/core/memory/utils/config/config_optimization.py b/app/core/memory/utils/config/config_optimization.py deleted file mode 100644 index 41848a80..00000000 --- a/app/core/memory/utils/config/config_optimization.py +++ /dev/null @@ -1,398 +0,0 @@ -""" -配置管理优化模块 - -提供可选的配置管理优化功能,包括: -- LRU 缓存策略 -- 缓存预热 -- 缓存监控指标 -- 动态 TTL 策略 -- 配置版本控制 - -这些优化是可选的,当前的基础实现已经满足大多数需求。 -""" -import logging -import statistics -import threading -from collections import OrderedDict -from datetime import datetime, timedelta -from typing import Dict, Any, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -class LRUConfigCache: - """ - LRU(Least Recently Used)配置缓存 - - 当缓存达到最大容量时,自动淘汰最少使用的配置 - """ - - def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)): - """ - 初始化 LRU 缓存 - - Args: - max_size: 最大缓存容量 - ttl: 缓存过期时间 - """ - self.max_size = max_size - self.ttl = ttl - self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() - self._timestamps: Dict[str, datetime] = {} - self._lock = threading.RLock() - - # 统计信息 - self._stats = { - 'hits': 0, - 'misses': 0, - 'evictions': 0, - 'load_times': [] - } - - def get(self, config_id: str) -> Optional[Dict[str, Any]]: - """ - 获取配置(如果存在且未过期) - - Args: - config_id: 配置 ID - - Returns: - 配置字典,如果不存在或已过期则返回 None - """ - with self._lock: - if config_id not in self._cache: - self._stats['misses'] += 1 - return None - - # 检查是否过期 - timestamp = self._timestamps.get(config_id) - if timestamp and (datetime.now() - timestamp) >= self.ttl: - # 过期,移除 - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - self._stats['misses'] += 1 - return None - - # 命中,移动到末尾(标记为最近使用) - self._cache.move_to_end(config_id) - self._stats['hits'] += 1 - return self._cache[config_id] - - def put(self, config_id: str, config: Dict[str, Any]) -> None: - """ - 添加或更新配置 - - Args: - config_id: 配置 ID - config: 配置字典 - """ - with self._lock: - if config_id in self._cache: - # 更新现有配置 - self._cache.move_to_end(config_id) - else: - # 添加新配置 - if len(self._cache) >= self.max_size: - # 缓存已满,移除最旧的配置 - oldest_id, _ = self._cache.popitem(last=False) - self._timestamps.pop(oldest_id, None) - self._stats['evictions'] += 1 - logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}") - - self._cache[config_id] = config - self._timestamps[config_id] = datetime.now() - - def clear(self, config_id: Optional[str] = None) -> None: - """ - 清除缓存 - - Args: - config_id: 如果指定,只清除该配置;否则清除所有 - """ - with self._lock: - if config_id: - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - else: - self._cache.clear() - self._timestamps.clear() - - def get_stats(self) -> Dict[str, Any]: - """ - 获取缓存统计信息 - - Returns: - 统计信息字典 - """ - with self._lock: - total = self._stats['hits'] + self._stats['misses'] - hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0 - - return { - 'cache_size': len(self._cache), - 'max_size': self.max_size, - 'total_requests': total, - 'cache_hits': self._stats['hits'], - 'cache_misses': self._stats['misses'], - 'evictions': self._stats['evictions'], - 'hit_rate': hit_rate, - 'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0 - } - - def record_load_time(self, load_time_ms: float) -> None: - """ - 记录加载时间 - - Args: - load_time_ms: 加载时间(毫秒) - """ - with self._lock: - self._stats['load_times'].append(load_time_ms) - # 只保留最近 1000 次的记录 - if len(self._stats['load_times']) > 1000: - self._stats['load_times'] = self._stats['load_times'][-1000:] - - -class ConfigCacheWarmer: - """ - 配置缓存预热器 - - 在系统启动时预加载常用配置,减少首次请求延迟 - """ - - @staticmethod - def warmup(config_ids: List[str], load_func) -> Dict[str, bool]: - """ - 预热缓存 - - Args: - config_ids: 要预加载的配置 ID 列表 - load_func: 配置加载函数 - - Returns: - 每个配置的加载结果 - """ - results = {} - - logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置") - - for config_id in config_ids: - try: - result = load_func(config_id) - results[config_id] = result - if result: - logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}") - else: - logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}") - except Exception as e: - logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}") - results[config_id] = False - - success_count = sum(1 for r in results.values() if r) - logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功") - - return results - - -class DynamicTTLStrategy: - """ - 动态 TTL 策略 - - 根据配置类型和更新频率动态调整缓存过期时间 - """ - - # 预定义的 TTL 策略 - TTL_STRATEGIES = { - 'production': timedelta(minutes=30), # 生产配置较稳定 - 'staging': timedelta(minutes=15), # 预发布配置中等稳定 - 'development': timedelta(minutes=5), # 开发配置频繁变化 - 'testing': timedelta(minutes=1), # 测试配置快速过期 - 'default': timedelta(minutes=5) # 默认策略 - } - - @classmethod - def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta: - """ - 获取配置的 TTL - - Args: - config_id: 配置 ID - config_type: 配置类型(production/staging/development/testing) - - Returns: - TTL 时间间隔 - """ - if config_type and config_type in cls.TTL_STRATEGIES: - return cls.TTL_STRATEGIES[config_type] - - # 根据 config_id 推断类型 - if 'prod' in config_id.lower(): - return cls.TTL_STRATEGIES['production'] - elif 'stag' in config_id.lower(): - return cls.TTL_STRATEGIES['staging'] - elif 'dev' in config_id.lower(): - return cls.TTL_STRATEGIES['development'] - elif 'test' in config_id.lower(): - return cls.TTL_STRATEGIES['testing'] - - return cls.TTL_STRATEGIES['default'] - - -class ConfigVersionManager: - """ - 配置版本管理器 - - 跟踪配置版本,当配置更新时自动失效旧版本缓存 - """ - - def __init__(self): - self._versions: Dict[str, str] = {} - self._lock = threading.RLock() - - def get_version(self, config_id: str) -> Optional[str]: - """ - 获取配置版本 - - Args: - config_id: 配置 ID - - Returns: - 版本号,如果不存在则返回 None - """ - with self._lock: - return self._versions.get(config_id) - - def set_version(self, config_id: str, version: str) -> None: - """ - 设置配置版本 - - Args: - config_id: 配置 ID - version: 版本号 - """ - with self._lock: - old_version = self._versions.get(config_id) - self._versions[config_id] = version - - if old_version and old_version != version: - logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}") - - def check_version(self, config_id: str, cached_version: Optional[str]) -> bool: - """ - 检查缓存版本是否有效 - - Args: - config_id: 配置 ID - cached_version: 缓存的版本号 - - Returns: - True 如果版本匹配,False 如果版本不匹配或不存在 - """ - with self._lock: - current_version = self._versions.get(config_id) - - if not current_version or not cached_version: - return False - - return current_version == cached_version - - def invalidate(self, config_id: str) -> None: - """ - 使配置版本失效 - - Args: - config_id: 配置 ID - """ - with self._lock: - if config_id in self._versions: - # 生成新版本号 - import uuid - new_version = str(uuid.uuid4()) - self._versions[config_id] = new_version - logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}") - - -class CacheMonitor: - """ - 缓存监控器 - - 提供缓存性能监控和报告功能 - """ - - def __init__(self, cache: LRUConfigCache): - self.cache = cache - - def get_report(self) -> str: - """ - 生成缓存性能报告 - - Returns: - 格式化的报告字符串 - """ - stats = self.cache.get_stats() - - report = f""" -配置缓存性能报告 -================ -缓存容量: {stats['cache_size']}/{stats['max_size']} -总请求数: {stats['total_requests']} -缓存命中: {stats['cache_hits']} -缓存未命中: {stats['cache_misses']} -缓存命中率: {stats['hit_rate']:.2f}% -淘汰次数: {stats['evictions']} -平均加载时间: {stats['avg_load_time']:.2f}ms -""" - return report - - def log_stats(self) -> None: - """记录统计信息到日志""" - stats = self.cache.get_stats() - logger.info( - f"[CacheMonitor] 缓存统计 - " - f"容量: {stats['cache_size']}/{stats['max_size']}, " - f"命中率: {stats['hit_rate']:.2f}%, " - f"淘汰: {stats['evictions']}" - ) - - -# 使用示例 -def example_usage(): - """ - 优化功能使用示例 - """ - # 1. 使用 LRU 缓存 - lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5)) - - # 获取配置 - config = lru_cache.get("config_001") - if config is None: - # 缓存未命中,从数据库加载 - config = {"llm_name": "openai/gpt-4"} - lru_cache.put("config_001", config) - - # 2. 预热缓存 - def load_config(config_id): - # 实际的配置加载逻辑 - return True - - warmer = ConfigCacheWarmer() - results = warmer.warmup(["config_001", "config_002"], load_config) - - # 3. 动态 TTL - ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production") - print(f"TTL: {ttl}") - - # 4. 版本管理 - version_manager = ConfigVersionManager() - version_manager.set_version("config_001", "v1.0.0") - - # 检查版本 - is_valid = version_manager.check_version("config_001", "v1.0.0") - - # 5. 监控 - monitor = CacheMonitor(lru_cache) - print(monitor.get_report()) - - -if __name__ == "__main__": - example_usage() diff --git a/app/core/memory/utils/config/config_utils.py b/app/core/memory/utils/config/config_utils.py deleted file mode 100644 index 0f1934f0..00000000 --- a/app/core/memory/utils/config/config_utils.py +++ /dev/null @@ -1,267 +0,0 @@ -import uuid -import json -from typing import Optional - -from sqlalchemy.orm import Session -from fastapi.exceptions import HTTPException -from fastapi import status - -from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG -from app.core.memory.models.variate_config import ( - ExtractionPipelineConfig, - DedupConfig, - StatementExtractionConfig, - ForgettingEngineConfig, -) -from app.core.memory.models.config_models import PruningConfig -from app.db import get_db -from app.models.models_model import ModelConfig, ModelApiKey -from app.services.model_service import ModelConfigService -def get_model_config(model_id: str, db: Session | None = None) -> dict: - if db is None: - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) # 取到真正的 Session - - config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - if not config: - print(f"模型ID {model_id} 不存在") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在") - apiConfig: ModelApiKey = config.api_keys[0] - - # 从环境变量读取超时和重试配置 - from app.core.config import settings - - model_config = { - "model_name": apiConfig.model_name, - "provider": apiConfig.provider, - "api_key": apiConfig.api_key, - "base_url": apiConfig.api_base, - "model_config_id":apiConfig.model_config_id, - "type": config.type, - # 添加超时和重试配置,避免 LLM 请求超时 - "timeout": settings.LLM_TIMEOUT, # 从环境变量读取,默认120秒 - "max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取,默认2次 - } - # 写入model_config.log文件中 - with open("logs/model_config.log", "a", encoding="utf-8") as f: - f.write(f"模型ID: {model_id}\n") - f.write(f"模型配置信息:\n{model_config}\n") - f.write(f"=============================\n\n") - return model_config - -def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict: - if db is None: - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) # 取到真正的 Session - - config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id) - if not config: - print(f"嵌入模型ID {embedding_id} 不存在") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在") - apiConfig: ModelApiKey = config.api_keys[0] - model_config = { - "model_name": apiConfig.model_name, - "provider": apiConfig.provider, - "api_key": apiConfig.api_key, - "base_url": apiConfig.api_base, - "model_config_id":apiConfig.model_config_id, - # Ensure required field for RedBearModelConfig validation - "type": config.type, - # 添加超时和重试配置,避免嵌入服务请求超时 - "timeout": 120.0, # 嵌入服务超时时间(秒) - "max_retries": 5, # 最大重试次数 - } - # 写入embedder_config.log文件中 - with open("logs/embedder_config.log", "a", encoding="utf-8") as f: - f.write(f"嵌入模型ID: {embedding_id}\n") - f.write(f"嵌入模型配置信息:\n{model_config}\n") - f.write(f"=============================\n\n") - return model_config - -def get_neo4j_config() -> dict: - """Retrieves the Neo4j configuration from the config file.""" - return CONFIG.get("neo4j", {}) -def get_picture_config(llm_name: str) -> dict: - """Retrieves the configuration for a specific model from the config file.""" - for model_config in CONFIG.get("picture_recognition", []): - if model_config["llm_name"] == llm_name: - return model_config - raise ValueError(f"Model '{llm_name}' not found in config.json") -def get_voice_config(llm_name: str) -> dict: - """Retrieves the configuration for a specific model from the config file.""" - for model_config in CONFIG.get("voice_recognition", []): - if model_config["llm_name"] == llm_name: - return model_config - raise ValueError(f"Model '{llm_name}' not found in config.json") - - -def get_chunker_config(chunker_strategy: str) -> dict: - """Retrieves the configuration for a specific chunker strategy. - - Enhancements: - - Supports default configs for `LLMChunker` and `HybridChunker` if not present. - - Falls back to the first available chunker config when the requested one is missing. - """ - # 1) Try to find exact match in config - chunker_list = CONFIG.get("chunker_list", []) - for chunker_config in chunker_list: - if chunker_config.get("chunker_strategy") == chunker_strategy: - return chunker_config - - # 2) Provide sane defaults for newer strategies - default_configs = { - "LLMChunker": { - "chunker_strategy": "LLMChunker", - "embedding_model": "BAAI/bge-m3", - "chunk_size": 1000, - "threshold": 0.8, - "min_sentences": 2, - "language": "zh", - "skip_window": 1, - "min_characters_per_chunk": 100, - }, - "HybridChunker": { - "chunker_strategy": "HybridChunker", - "embedding_model": "BAAI/bge-m3", - "chunk_size": 512, - "threshold": 0.8, - "min_sentences": 2, - "language": "zh", - "skip_window": 1, - "min_characters_per_chunk": 100, - }, - } - if chunker_strategy in default_configs: - return default_configs[chunker_strategy] - - # 3) Fallback: use first available config but tag with requested strategy - if chunker_list: - fallback = chunker_list[0].copy() - fallback["chunker_strategy"] = chunker_strategy - # Non-fatal notice for visibility in logs if any - print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'") - return fallback - - # 4) If no configs available at all - raise ValueError( - f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available" - ) - - -def get_pipeline_config() -> ExtractionPipelineConfig: - """Build ExtractionPipelineConfig using only runtime.json values. - - Behavior: - - Read `deduplication` section from runtime.json if present. - - Read `statement_extraction` section from runtime.json if present. - - Read `forgetting_engine` section from runtime.json if present. - - If absent, check legacy top-level `enable_llm_dedup` key. - - Do NOT fall back to environment variables. - - Unspecified fields use model defaults defined in DedupConfig. - """ - dedup_rc = RUNTIME_CONFIG.get("deduplication", {}) or {} - stmt_rc = RUNTIME_CONFIG.get("statement_extraction", {}) or {} - forget_rc = RUNTIME_CONFIG.get("forgetting_engine", {}) or {} - - # Assemble kwargs from runtime.json only - kwargs = {} - # LLM switch: prefer new key, then legacy top-level, default False - if "enable_llm_dedup_blockwise" in dedup_rc: - kwargs["enable_llm_dedup_blockwise"] = bool(dedup_rc.get("enable_llm_dedup_blockwise")) - else: - # Legacy top-level fallback inside runtime.json only - legacy = RUNTIME_CONFIG.get("enable_llm_dedup") - if legacy is not None: - kwargs["enable_llm_dedup_blockwise"] = bool(legacy) - else: - kwargs["enable_llm_dedup_blockwise"] = False # default reserve - # Disambiguation switch: only from runtime.json deduplication section - if "enable_llm_disambiguation" in dedup_rc: - kwargs["enable_llm_disambiguation"] = bool(dedup_rc.get("enable_llm_disambiguation")) - - # Optional LLM fallback gating - if "enable_llm_fallback_only_on_borderline" in dedup_rc: - kwargs["enable_llm_fallback_only_on_borderline"] = bool(dedup_rc.get("enable_llm_fallback_only_on_borderline")) - - # Optional fuzzy thresholds: use values if provided; otherwise rely on DedupConfig defaults - for key in ( - "fuzzy_name_threshold_strict", - "fuzzy_type_threshold_strict", - "fuzzy_overall_threshold", - "fuzzy_unknown_type_name_threshold", - "fuzzy_unknown_type_type_threshold", - ): - if key in dedup_rc: - kwargs[key] = dedup_rc[key] - - # Optional weights and bonuses for overall scoring - for key in ( - "name_weight", - "desc_weight", - "type_weight", - "context_bonus", - "llm_fallback_floor", - "llm_fallback_ceiling", - ): - if key in dedup_rc: - kwargs[key] = dedup_rc[key] - - # Optional LLM iterative dedup parameters - for key in ( - "llm_block_size", - "llm_block_concurrency", - "llm_pair_concurrency", - "llm_max_rounds", - ): - if key in dedup_rc: - kwargs[key] = dedup_rc[key] - - dedup_config = DedupConfig(**kwargs) - - # Build StatementExtractionConfig from runtime.json - stmt_kwargs = {} - for key in ( - "statement_granularity", - "temperature", - "include_dialogue_context", - "max_dialogue_context_chars", - ): - if key in stmt_rc: - stmt_kwargs[key] = stmt_rc[key] - stmt_config = StatementExtractionConfig(**stmt_kwargs) - - # Build ForgettingEngineConfig from runtime.json - forget_kwargs = {} - for key in ("offset", "lambda_time", "lambda_mem"): - if key in forget_rc: - forget_kwargs[key] = forget_rc[key] - forget_config = ForgettingEngineConfig(**forget_kwargs) - - return ExtractionPipelineConfig( - statement_extraction=stmt_config, - deduplication=dedup_config, - forgetting_engine=forget_config, - ) - - -def get_pruning_config() -> dict: - """Retrieve semantic pruning config from runtime.json. - - Returns a dict suitable for PruningConfig.model_validate. - - Structure in runtime.json: - { - "pruning": { - "enabled": true, - "scene": "education" | "online_service" | "outbound", - "threshold": 0.5 - } - } - """ - pruning_rc = RUNTIME_CONFIG.get("pruning", {}) or {} - - return { - "pruning_switch": bool(pruning_rc.get("enabled", False)), - "pruning_scene": pruning_rc.get("scene", "education"), - "pruning_threshold": float(pruning_rc.get("threshold", 0.5)), - } diff --git a/app/core/memory/utils/config/definitions.py b/app/core/memory/utils/config/definitions.py deleted file mode 100644 index 316245c2..00000000 --- a/app/core/memory/utils/config/definitions.py +++ /dev/null @@ -1,360 +0,0 @@ -""" -配置加载模块 - 三阶段架构(已迁移到统一配置管理) - -本模块现在使用全局配置管理系统 (app/core/config.py) -来加载和管理配置,同时保持向后兼容性。 - -阶段 1: 从 runtime.json 加载配置(路径 A) -阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id) -阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点) -""" -import os -import json -import threading -from typing import Any, Dict, Optional -from datetime import datetime, timedelta - -try: - from dotenv import load_dotenv - load_dotenv() -except Exception: - pass - -# Import unified configuration system -try: - from app.core.config import settings - USE_UNIFIED_CONFIG = True -except ImportError: - USE_UNIFIED_CONFIG = False - settings = None - -# PROJECT_ROOT 应该指向 app/core/memory/ 目录 -# __file__ = app/core/memory/utils/config/definitions.py -# os.path.dirname(__file__) = app/core/memory/utils/config -# os.path.dirname(...) = app/core/memory/utils -# os.path.dirname(...) = app/core/memory -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# 全局配置锁 - 用于线程安全 -_config_lock = threading.RLock() - -# 加载基础配置(config.json)- 使用全局配置系统 -if USE_UNIFIED_CONFIG: - CONFIG = settings.load_memory_config() -else: - # Fallback to legacy loading - config_path = os.path.join(PROJECT_ROOT, "config.json") - try: - with open(config_path, "r") as f: - CONFIG = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - print("Warning: config.json not found or is malformed. Using default settings.") - CONFIG = {} - -DEFAULT_VALUES = { - "llm_name": "openai/qwen-plus", - "embedding_name": "openai/nomic-embed-text:v1.5", - "chunker_strategy": "RecursiveChunker", - "group_id": "group_123", - "user_id": "default_user", - "apply_id": "default_apply", - "llm_agent_name": "openai/qwen-plus", - "llm_verify_name": "openai/qwen-plus", - "llm_image_recognition": "openai/qwen-plus", - "llm_voice_recognition": "openai/qwen-plus", - "prompt_level": "DEBUG", - "reflexion_iteration_period": "3", - "reflexion_range": "retrieval", - "reflexion_baseline": "TIME", -} - - -# 阶段 1: 从 runtime.json 加载配置(路径 A) -def _load_from_runtime_json() -> Dict[str, Any]: - """ - 从 runtime.json 文件加载配置(通过统一配置加载器) - - 使用 overrides.py 的统一配置加载器,按优先级加载: - 1. 数据库配置(如果 dbrun.json 中有 config_id/group_id) - 2. 环境变量配置 - 3. runtime.json 默认配置 - - Returns: - Dict[str, Any]: 运行时配置字典 - """ - try: - # 使用 overrides.py 的统一配置加载器 - from app.core.memory.utils.config.overrides import load_unified_config - - runtime_cfg = load_unified_config(PROJECT_ROOT) - return runtime_cfg - except Exception as e: - # Fallback: 直接读取 runtime.json - runtime_config_path = os.path.join(PROJECT_ROOT, "runtime.json") - try: - with open(runtime_config_path, "r", encoding="utf-8") as f: - return json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e2: - pass # print(f"[definitions] ❌ 无法加载 runtime.json: {e2},使用空配置") - return {"selections": {}} - - -# 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器 -# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代 -# 保留此函数仅为向后兼容 -def _load_from_database() -> Optional[Dict[str, Any]]: - """ - 从数据库加载配置(基于 dbrun.json 中的 config_id) - - 注意:此函数已被统一配置加载器替代,现在直接调用 _load_from_runtime_json - 即可获得包含数据库配置的完整配置。 - - Returns: - Optional[Dict[str, Any]]: 配置字典 - """ - try: - # 直接使用统一配置加载器 - return _load_from_runtime_json() - except Exception: - return None - - -# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点) -def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None: - """ - 将运行时配置暴露为全局常量供项目使用 - - 这是路径 A(runtime.json)和路径 B(数据库)的汇合点, - 无论配置来自哪里,都通过这个函数统一暴露为常量。 - - Args: - runtime_cfg: 运行时配置字典 - """ - global RUNTIME_CONFIG, SELECTIONS, LOGGING_CONFIG - global LANGFUSE_ENABLED, AGENTA_ENABLED, PROMPT_LOG_LEVEL_NAME - global SELECTED_LLM_NAME, SELECTED_EMBEDDING_NAME, SELECTED_CHUNKER_STRATEGY - global SELECTED_GROUP_ID, SELECTED_USER_ID, SELECTED_APPLY_ID, SELECTED_TEST_DATA_INDICES - global SELECTED_LLM_AGENT_NAME, SELECTED_LLM_VERIFY_NAME, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME - global SELECTED_LLM_ID, SELECTED_EMBEDDING_ID, SELECTED_RERANK_ID - global REFLEXION_CONFIG, REFLEXION_ENABLED, REFLEXION_ITERATION_PERIOD, REFLEXION_RANGE, REFLEXION_BASELINE - - RUNTIME_CONFIG = runtime_cfg - - # 可观测性配置 - LANGFUSE_ENABLED = RUNTIME_CONFIG.get("langfuse", {}).get("enabled", False) - AGENTA_ENABLED = RUNTIME_CONFIG.get("agenta", {}).get("enabled", False) - - # 日志配置 - LOGGING_CONFIG = RUNTIME_CONFIG.get("logging", {}) - PROMPT_LOG_LEVEL_NAME = LOGGING_CONFIG.get("prompt_level", DEFAULT_VALUES["prompt_level"]) - - # 选择配置 - SELECTIONS = RUNTIME_CONFIG.get("selections", {}) - - # 基础模型选择 - SELECTED_LLM_NAME = SELECTIONS.get("llm_name", DEFAULT_VALUES["llm_name"]) - SELECTED_EMBEDDING_NAME = SELECTIONS.get("embedding_name", DEFAULT_VALUES["embedding_name"]) - SELECTED_CHUNKER_STRATEGY = SELECTIONS.get("chunker_strategy", DEFAULT_VALUES["chunker_strategy"]) - - # 分组和用户配置 - SELECTED_GROUP_ID = SELECTIONS.get("group_id", DEFAULT_VALUES["group_id"]) - SELECTED_USER_ID = SELECTIONS.get("user_id", DEFAULT_VALUES["user_id"]) - SELECTED_APPLY_ID = SELECTIONS.get("apply_id", DEFAULT_VALUES["apply_id"]) - SELECTED_TEST_DATA_INDICES = SELECTIONS.get("test_data_indices", None) - - # 专用 LLM 配置 - SELECTED_LLM_AGENT_NAME = SELECTIONS.get("llm_agent_name", DEFAULT_VALUES["llm_agent_name"]) - SELECTED_LLM_VERIFY_NAME = SELECTIONS.get("llm_verify_name", DEFAULT_VALUES["llm_verify_name"]) - SELECTED_LLM_PICTURE_NAME = SELECTIONS.get("llm_image_recognition", DEFAULT_VALUES["llm_image_recognition"]) - SELECTED_LLM_VOICE_NAME = SELECTIONS.get("llm_voice_recognition", DEFAULT_VALUES["llm_voice_recognition"]) - - # 模型 ID 配置 - SELECTED_LLM_ID = SELECTIONS.get("llm_id", None) - SELECTED_EMBEDDING_ID = SELECTIONS.get("embedding_id", None) - SELECTED_RERANK_ID = SELECTIONS.get("rerank_id", None) - - # 反思配置 - REFLEXION_CONFIG = RUNTIME_CONFIG.get("reflexion", {}) - REFLEXION_ENABLED = REFLEXION_CONFIG.get("enabled", False) - REFLEXION_ITERATION_PERIOD = REFLEXION_CONFIG.get("iteration_period", DEFAULT_VALUES["reflexion_iteration_period"]) - REFLEXION_RANGE = REFLEXION_CONFIG.get("reflexion_range", DEFAULT_VALUES["reflexion_range"]) - REFLEXION_BASELINE = REFLEXION_CONFIG.get("baseline", DEFAULT_VALUES["reflexion_baseline"]) - - -# 初始化:使用统一配置加载器 -def _initialize_configuration() -> None: - """ - 初始化配置:使用统一配置加载器 - - 配置加载优先级(由 overrides.py 统一处理): - 1. 数据库配置(如果 dbrun.json 中有 config_id/group_id) - 2. 环境变量配置(.env) - 3. runtime.json 默认配置 - """ - try: - - # 使用统一配置加载器(已包含所有优先级处理) - runtime_config = _load_from_runtime_json() - - # 暴露为全局常量 - _expose_runtime_constants(runtime_config) - - - except Exception as e: - pass # print(f"[definitions] × 配置初始化失败: {e}") - # 使用空配置 - _expose_runtime_constants({"selections": {}}) - - -# 模块加载时自动初始化配置 -_initialize_configuration() - - -# 公共 API:动态重新加载配置 -def reload_configuration_from_database(config_id: int | str, force_reload: bool = False) -> bool: - """ - 动态重新加载配置(从数据库)- 使用统一配置加载器 - 用于运行时切换配置,例如前端传入新的 config_id 时调用。 - - 注意:此函数仅在内存中覆写配置,不会修改 runtime.json 文件。 - - Args: - config_id: 配置 ID(整数或字符串,会自动转换) - force_reload: 保留参数以保持向后兼容(已移除缓存逻辑) - - Returns: - bool: 是否成功重新加载配置 - """ - import logging - logger = logging.getLogger(__name__) - - # 导入审计日志记录器 - try: - from app.core.memory.utils.log.audit_logger import audit_logger - except ImportError: - audit_logger = None - - with _config_lock: - try: - from app.core.memory.utils.config.overrides import load_unified_config - except Exception as e: - logger.error(f"[definitions] 导入统一配置加载器失败: {e}") - - # 记录配置加载失败 - if audit_logger: - audit_logger.log_config_load( - config_id=config_id, - success=False, - details={"error": f"Import failed: {str(e)}"} - ) - - return False - - try: - logger.info(f"[definitions] 开始重新加载配置,config_id={config_id}") - - # 使用统一配置加载器(指定 config_id) - updated_cfg = load_unified_config(PROJECT_ROOT, config_id=config_id) - - # 检查是否成功加载 - if not updated_cfg or not updated_cfg.get('selections'): - logger.error(f"[definitions] 配置加载失败:数据库中未找到 config_id={config_id} 的配置") - - # 记录配置加载失败 - if audit_logger: - audit_logger.log_config_load( - config_id=config_id, - success=False, - details={"reason": "config not found in database"} - ) - - return False - - # 重新暴露常量 - _expose_runtime_constants(updated_cfg) - - logger.info(f"[definitions] 配置重新加载成功,已暴露常量") - logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, " - f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}") - - # 记录成功的配置加载 - if audit_logger: - selections = updated_cfg.get('selections', {}) - audit_logger.log_config_load( - config_id=config_id, - user_id=selections.get('user_id', None), - group_id=selections.get('group_id', None), - success=True, - details={ - "llm_id": selections.get('llm_id'), - "embedding_id": selections.get('embedding_id'), - "chunker_strategy": selections.get('chunker_strategy') - } - ) - - return True - except Exception as e: - logger.error(f"[definitions] 重新加载配置时发生异常: {e}", exc_info=True) - - # 记录配置加载异常 - if audit_logger: - audit_logger.log_config_load( - config_id=config_id, - success=False, - details={"error": str(e)} - ) - - return False - - - - - -def get_current_config_id() -> Optional[str]: - """ - 获取当前使用的 config_id - - Returns: - Optional[str]: 当前的 config_id,如果未设置则返回 None - """ - return SELECTIONS.get("config_id", None) - - -def ensure_fresh_config(config_id: Optional[int | str] = None) -> bool: - """ - 确保使用最新的配置(每次写入操作前调用) - - 如果提供了 config_id,则加载该配置; - 否则从 dbrun.json 读取并加载最新配置。 - - Args: - config_id: 可选的配置ID(整数或字符串,会自动转换) - - Returns: - bool: 是否成功加载配置 - """ - import logging - logger = logging.getLogger(__name__) - - with _config_lock: - try: - if config_id: - # 使用指定的 config_id - logger.debug(f"[definitions] 加载指定配置,config_id={config_id}") - return reload_configuration_from_database(config_id) - else: - # 从数据库重新加载配置 - logger.debug("[definitions] 从数据库重新加载最新配置") - memory_config = _load_from_database() - - if not memory_config or not memory_config.get('selections'): - logger.warning("[definitions] 未能从数据库加载配置,使用当前配置") - return False - - _expose_memory_constants(memory_config) - return True - except Exception as e: - logger.error(f"[definitions] 加载配置失败: {e}", exc_info=True) - return False - - diff --git a/app/core/memory/utils/config/get_data.py b/app/core/memory/utils/config/get_data.py deleted file mode 100644 index f2f21198..00000000 --- a/app/core/memory/utils/config/get_data.py +++ /dev/null @@ -1,93 +0,0 @@ -import json -import os -import uuid -from typing import List, Dict, Any, Optional -from sqlalchemy.orm import Session -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo -from app.schemas.memory_storage_schema import BaseDataSchema - -import logging -logger = logging.getLogger(__name__) - -async def _load_(data: List[Any]) -> List[Dict]: - target_keys = [ - "id", - "statement", - "group_id", - "chunk_id", - "created_at", - "expired_at", - "valid_at", - "invalid_at", - ] - results = [] - for row in data or []: - s = None - if isinstance(row, (tuple, list)) and row: - s = row[0] - elif hasattr(row, "retrieve_info"): - s = getattr(row, "retrieve_info") - elif isinstance(row, dict) and "retrieve_info" in row: - s = row.get("retrieve_info") - elif hasattr(row, "_mapping") and "retrieve_info" in getattr(row, "_mapping"): - s = row._mapping["retrieve_info"] - else: - s = row - if s is None: - continue - if isinstance(s, bytes): - try: - s = s.decode("utf-8") - except Exception: - try: - s = s.decode() - except Exception: - continue - s = str(s).strip() - if not s or s == "[]": - continue - try: - parsed = json.loads(s) - except Exception: - continue - items = parsed if isinstance(parsed, list) else [parsed] - for item in items: - if "statement" not in item and "statements" in item: - item["statement"] = item.get("statements") or "" - normalized = {k: item.get(k, "") for k in target_keys} - results.append(normalized) - return results - - -async def get_data(host_id: uuid.UUID) -> List[Dict]: - """ - 从数据库中获取数据 - """ - # 从数据库会话中获取会话 - db: Session = next(get_db()) - try: - data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all() - - # print(f"data:\n{data}") - # 解析,提取为字典的列表 - results = await _load_(data) - return results - except Exception as e: - logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}") - raise e - finally: - try: - db.close() - except Exception: - pass - - -if __name__ == "__main__": - import asyncio - - # 从数据库中获取数据 - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - data = asyncio.run(get_data(host_id)) - print(type(data)) - print(data) diff --git a/app/core/memory/utils/config/get_example_data.py b/app/core/memory/utils/config/get_example_data.py deleted file mode 100644 index c466645b..00000000 --- a/app/core/memory/utils/config/get_example_data.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import re -import uuid -import random -import string -from typing import List, Dict, Optional - -# 生成包含字母(大小写)和数字的随机字符串 -def generate_random_string(length=16): - characters = string.ascii_letters + string.digits - return ''.join(random.choice(characters) for _ in range(length)) - -def get_example_data() -> List[Dict[str, Optional[str]]]: - """ - 从句子提取日志中获取数据 - Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。 - Created At: 2025-11-28 19:28:38.256421 - Expired At: None - Valid At: None - Invalid At: None - 将数据构造成如下形式: - [ - { - "id":id, - "group_id":group_id, - "statement": Content, - "created_at": Created At, - "expired_at": Expired At, - "valid_at": Valid At, - "invalid_at": Invalid At, - "chunk_id": "86da9022710c40eaa5f518a294c398d2", - "entity_ids": [] - }, - ... - ] - """ - # 获取日志文件路径 - log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt") - - # 检查文件是否存在 - if not os.path.exists(log_file_path): - return [] - - # 读取日志文件 - with open(log_file_path, "r", encoding="utf-8") as f: - content = f.read() - - # 解析数据 - results = [] - - # 使用正则表达式分割每个 Statement - statement_blocks = re.split(r"Statement \d+:", content) - - for block in statement_blocks[1:]: # 跳过第一个空块 - # 提取各个字段 - id_match = re.search(r"Id:\s*(.+?)(?=\n)", block) - group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block) - statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block) - created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block) - expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block) - valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block) - invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block) - chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block) - - # 构造字典 - if statement_match: - statement_data = { - "id": id_match.group(1).strip() if id_match else generate_random_string(), - "group_id": group_id_match.group(1).strip() if group_id_match else "group_example", - "statement": statement_match.group(1).strip(), - "created_at": created_at_match.group(1).strip() if created_at_match else None, - "expired_at": expired_at_match.group(1).strip() if expired_at_match else None, - "valid_at": valid_at_match.group(1).strip() if valid_at_match else None, - "invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None, - "chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example", - "entity_ids": [] - } - - # 将 "None" 字符串转换为 None - for key in ["created_at", "expired_at", "valid_at", "invalid_at"]: - if statement_data[key] == "None": - statement_data[key] = None - - results.append(statement_data) - - return results - - -if __name__ == "__main__": - print(f"获取数据如下:\n {get_example_data()}") \ No newline at end of file diff --git a/app/core/memory/utils/config/litellm_config.py b/app/core/memory/utils/config/litellm_config.py deleted file mode 100644 index f5a9667f..00000000 --- a/app/core/memory/utils/config/litellm_config.py +++ /dev/null @@ -1,516 +0,0 @@ -""" -LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring -""" - -import litellm -from typing import Dict, Any, List -import json -from datetime import datetime, timedelta -import os -import time -from collections import defaultdict -import threading -from queue import Queue - -class LiteLLMConfig: - """Configuration class for LiteLLM with enhanced retry and tracking capabilities""" - - def __init__(self): - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], # Store precise timestamps - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] # Store QPS measurements over time - }) - self.start_time = datetime.now() - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - - # Rate limiting for AWS Bedrock (conservative limits) - self.rate_limits = { - 'bedrock': { - 'requests_per_minute': 2, # AWS Bedrock default is very low - 'requests_per_second': 0.033, # 2/60 = 0.033 RPS - 'last_request_time': 0, - 'request_queue': Queue(), - 'lock': threading.Lock() - } - } - self.rate_limiting_enabled = True - - def setup_enhanced_config(self, max_retries: int = 3): - """Configure LiteLLM with retry logic and instant QPS tracking""" - - litellm.num_retries = max_retries - litellm.request_timeout = 300 - - litellm.retry_policy = { - "RateLimitError": { - "max_retries": 5, - "exponential_backoff": True, - "initial_delay": 1, - "max_delay": 60, - "jitter": True - }, - "APIConnectionError": { - "max_retries": 3, - "exponential_backoff": True, - "initial_delay": 2, - "max_delay": 30, - "jitter": True - }, - "InternalServerError": { - "max_retries": 2, - "exponential_backoff": True, - "initial_delay": 5, - "max_delay": 60, - "jitter": True - }, - "BadRequestError": { - "max_retries": 1, - "exponential_backoff": False, - "initial_delay": 1, - "max_delay": 5 - } - } - - litellm.success_callback = [self._success_callback] - litellm.failure_callback = [self._failure_callback] - litellm.completion_cost_tracking = True - litellm.set_verbose = False - litellm.modify_params = True - - print("✅ LiteLLM configured with instant QPS tracking and rate limiting") - - def _success_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for successful requests with module-specific QPS tracking""" - try: - # Extract usage information - usage = completion_response.get('usage', {}) - model = kwargs.get('model', 'unknown') - - # Extract module information from metadata or model name - module = self._extract_module_name(kwargs, model) - - # Calculate cost - cost = 0.0 - try: - cost = litellm.completion_cost(completion_response) - except: - pass - - # Calculate duration - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Record usage data - usage_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "input_tokens": usage.get('prompt_tokens', 0), - "output_tokens": usage.get('completion_tokens', 0), - "total_tokens": usage.get('total_tokens', 0), - "cost": cost, - "duration_seconds": duration_seconds, - "status": "success" - } - - self.usage_data.append(usage_record) - - # Update module-specific stats for QPS tracking - self._update_module_stats(module, usage_record, success=True) - - # Print real-time feedback - print(f"✓ {model}: {usage_record['input_tokens']}→{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s") - - except Exception as e: - print(f"Warning: Success callback failed: {e}") - - def _failure_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for failed requests with module-specific error tracking""" - try: - model = kwargs.get('model', 'unknown') - module = self._extract_module_name(kwargs, model) - - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Handle different error response formats - error_message = "Unknown error" - error_type = "UnknownError" - - # According to LiteLLM docs, completion_response contains the exception for failures - if completion_response is not None: - error_message = str(completion_response) - error_type = type(completion_response).__name__ - - # Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events) - elif 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - # Check for other error formats in kwargs - elif 'error' in kwargs: - error = kwargs['error'] - error_message = str(error) - error_type = type(error).__name__ - - # Check log_event_type to confirm this is a failure event - log_event_type = kwargs.get('log_event_type', '') - if log_event_type == 'failed_api_call' and 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - error_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "error": error_message, - "error_type": error_type, - "duration_seconds": duration_seconds, - "status": "failed" - } - - self.error_data.append(error_record) - - # Update module-specific stats for error tracking - self._update_module_stats(module, error_record, success=False) - - # Print error feedback - print(f"✗ {model}: {error_type} - {error_message[:100]}") - - except Exception as e: - print(f"Warning: Failure callback failed: {e}") - # Debug: print the actual parameters to understand the structure - print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}") - print(f"Debug - completion_response type: {type(completion_response)}") - print(f"Debug - completion_response: {completion_response}") - - def _should_rate_limit(self, model: str) -> bool: - """Check if the model should be rate limited""" - if not self.rate_limiting_enabled: - return False - return model.startswith('bedrock/') or 'bedrock' in model.lower() - - def _enforce_rate_limit(self, model: str): - """Enforce rate limiting for AWS Bedrock models""" - if not self._should_rate_limit(model): - return - - provider = 'bedrock' - if provider not in self.rate_limits: - return - - rate_config = self.rate_limits[provider] - - with rate_config['lock']: - current_time = time.time() - time_since_last = current_time - rate_config['last_request_time'] - min_interval = 1.0 / rate_config['requests_per_second'] - - if time_since_last < min_interval: - sleep_time = min_interval - time_since_last - print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}") - time.sleep(sleep_time) - - rate_config['last_request_time'] = time.time() - - def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str: - """Extract module name from request context""" - # Try to get module from metadata - metadata = kwargs.get('metadata', {}) - if 'module' in metadata: - return metadata['module'] - - # Try to infer from model name or other context - if 'claude' in model.lower(): - return 'bedrock_client' - elif 'gpt' in model.lower() or 'openai' in model.lower(): - return 'openai_client' - elif 'embed' in model.lower(): - return 'embedder' - else: - return 'unknown' - - def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool): - """Update module-specific statistics with instant QPS tracking""" - current_timestamp = time.time() - current_time = datetime.now() - - # Initialize module stats if first request - if self.module_stats[module]['start_time'] is None: - self.module_stats[module]['start_time'] = current_time - - # Update counters - self.module_stats[module]['requests'] += 1 - self.module_stats[module]['last_request_time'] = current_time - self.module_stats[module]['request_timestamps'].append(current_timestamp) - self.global_request_timestamps.append(current_timestamp) - - # Calculate instant QPS for this module - self._calculate_instant_qps(module, current_timestamp) - - # Calculate global instant QPS - self._calculate_global_instant_qps(current_timestamp) - - if success: - self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0) - self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0) - self.module_stats[module]['cost'] += record.get('cost', 0.0) - else: - self.module_stats[module]['errors'] += 1 - - def _calculate_instant_qps(self, module: str, current_timestamp: float): - """Calculate instant QPS for a specific module using sliding window""" - # Keep only timestamps from last 1 second for instant QPS - cutoff_time = current_timestamp - 1.0 - timestamps = self.module_stats[module]['request_timestamps'] - - # Remove old timestamps - self.module_stats[module]['request_timestamps'] = [ - ts for ts in timestamps if ts >= cutoff_time - ] - - # Calculate current QPS (requests in last second) - current_qps = len(self.module_stats[module]['request_timestamps']) - self.module_stats[module]['current_qps'] = current_qps - - # Update max QPS if current is higher - if current_qps > self.module_stats[module]['max_qps']: - self.module_stats[module]['max_qps'] = current_qps - - # Store QPS history (keep last 60 measurements) - self.module_stats[module]['qps_history'].append(current_qps) - if len(self.module_stats[module]['qps_history']) > 60: - self.module_stats[module]['qps_history'].pop(0) - - def _calculate_global_instant_qps(self, current_timestamp: float): - """Calculate global instant QPS across all modules""" - # Keep only timestamps from last 1 second - cutoff_time = current_timestamp - 1.0 - self.global_request_timestamps = [ - ts for ts in self.global_request_timestamps if ts >= cutoff_time - ] - - # Calculate current global QPS - current_global_qps = len(self.global_request_timestamps) - - # Update max global QPS - if current_global_qps > self.global_max_qps: - self.global_max_qps = current_global_qps - - def get_instant_qps(self, module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - if module: - if module in self.module_stats: - return { - 'module': module, - 'current_qps': self.module_stats[module]['current_qps'], - 'max_qps': self.module_stats[module]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0 - } - else: - return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0} - else: - # Return data for all modules plus global - result = { - 'global': { - 'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]), - 'max_qps': self.global_max_qps - }, - 'modules': {} - } - - for mod in self.module_stats.keys(): - result['modules'][mod] = { - 'current_qps': self.module_stats[mod]['current_qps'], - 'max_qps': self.module_stats[mod]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0 - } - - return result - - def get_usage_summary(self) -> Dict[str, Any]: - """Get essential usage statistics""" - if not self.usage_data: - return { - "total_requests": 0, - "total_cost": 0.0, - "error_rate": 0.0, - "message": "No usage data available" - } - - total_requests = len(self.usage_data) - total_errors = len(self.error_data) - total_cost = sum(record['cost'] for record in self.usage_data) - total_input_tokens = sum(record['input_tokens'] for record in self.usage_data) - total_output_tokens = sum(record['output_tokens'] for record in self.usage_data) - - # Calculate session duration - duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60 - - # Build module statistics - module_stats = {} - for module, stats in self.module_stats.items(): - if stats['requests'] > 0: - module_stats[module] = { - "requests": stats['requests'], - "errors": stats['errors'], - "success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0, - "tokens_in": stats['tokens_in'], - "tokens_out": stats['tokens_out'], - "cost": stats['cost'], - "current_qps": stats['current_qps'], - "max_qps": stats['max_qps'] - } - - return { - "session_duration_minutes": duration_minutes, - "total_requests": total_requests, - "total_errors": total_errors, - "error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0, - "total_input_tokens": total_input_tokens, - "total_output_tokens": total_output_tokens, - "total_cost": total_cost, - "module_stats": module_stats, - "global_max_qps": self.global_max_qps - } - - def print_usage_summary(self): - """Print essential usage summary""" - stats = self.get_usage_summary() - - if stats.get('message'): - print(f"📊 {stats['message']}") - return - - print(f"\n📊 USAGE SUMMARY") - print(f"{'='*50}") - print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min") - print(f"📈 Requests: {stats['total_requests']}") - print(f"❌ Errors: {stats['total_errors']}") - print(f"💰 Cost: ${stats['total_cost']:.4f}") - print(f"🏆 Global Max QPS: {stats['global_max_qps']}") - - # Module statistics - if stats.get('module_stats'): - print(f"\n📦 MODULES:") - for module, mod_stats in stats['module_stats'].items(): - print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}") - - print(f"{'='*50}") - - def save_usage_data(self, filename: str = "litellm_usage.json"): - """Save usage data to JSON file""" - data = { - "summary": self.get_usage_summary(), - "detailed_usage": self.usage_data, - "errors": self.error_data, - "export_timestamp": datetime.now().isoformat() - } - - with open(filename, 'w') as f: - json.dump(data, f, indent=2) - - print(f"📁 Usage data saved to {filename}") - - def reset_tracking(self): - """Reset all tracking data""" - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] - }) - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - self.start_time = datetime.now() - print("🔄 All tracking data reset") - -# Global instance for easy access -litellm_config = LiteLLMConfig() - -def setup_litellm_enhanced(max_retries: int = 3): - """ - Quick setup function for LiteLLM enhanced configuration - - Args: - max_retries: Maximum number of retries for failed requests - """ - litellm_config.setup_enhanced_config(max_retries) - return litellm_config - -def get_usage_summary(): - """Get current usage summary""" - return litellm_config.get_usage_summary() - -def print_usage_summary(): - """Print current usage summary""" - litellm_config.print_usage_summary() - -def save_usage_data(filename: str = "litellm_usage.json"): - """Save usage data to file""" - litellm_config.save_usage_data(filename) - -def get_instant_qps(module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - return litellm_config.get_instant_qps(module) - -def print_instant_qps(module: str = None): - """Print instant QPS information""" - qps_data = get_instant_qps(module) - - print(f"\n⚡ INSTANT QPS MONITOR") - print(f"{'='*60}") - - if module: - print(f"Module: {qps_data['module']}") - print(f" Current QPS: {qps_data['current_qps']}") - print(f" Max QPS: {qps_data['max_qps']}") - print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}") - else: - # Global stats - global_data = qps_data.get('global', {}) - print(f"🌍 GLOBAL:") - print(f" Current QPS: {global_data.get('current_qps', 0)}") - print(f" Max QPS: {global_data.get('max_qps', 0)}") - - # Module stats - modules = qps_data.get('modules', {}) - if modules: - print(f"\n📦 MODULES:") - for mod, data in modules.items(): - print(f" {mod}:") - print(f" Current: {data['current_qps']} QPS") - print(f" Max: {data['max_qps']} QPS") - print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS") - - print(f"{'='*60}") - -def reset_tracking(): - """Reset all tracking data""" - litellm_config.reset_tracking() - -def get_module_stats() -> Dict[str, Dict[str, Any]]: - """Get detailed module statistics""" - summary = get_usage_summary() - return summary.get('module_stats', {}) diff --git a/app/core/memory/utils/config/overrides.py b/app/core/memory/utils/config/overrides.py deleted file mode 100644 index e333bb29..00000000 --- a/app/core/memory/utils/config/overrides.py +++ /dev/null @@ -1,611 +0,0 @@ -""" -运行时配置覆写工具 - 统一配置加载器 - -本模块作为统一的配置加载器,负责从多个来源加载配置并按优先级覆写。 - -配置来源优先级(从高到低): -1. 数据库配置(PostgreSQL data_config 表) -2. 环境变量配置(.env 文件) -3. 默认配置(runtime.json 文件) - -支持的配置加载方式: -- 基于 config_id 的配置加载(从 dbrun.json 读取或前端传入) -- 基于 group_id 的配置加载(从 dbrun.json 读取) -- 环境变量覆写(支持 INTERNAL/EXTERNAL 网络模式) - -主要功能: -- 从 PostgreSQL 数据库读取配置 -- 从环境变量读取配置 -- 从 runtime.json 读取默认配置 -- 按优先级覆写配置项(仅在内存中,不修改文件) -- 支持多种配置字段:selections、statement_extraction、deduplication、forgetting_engine、pruning、reflexion - -使用场景: -- 应用启动时自动加载配置 -- 前端切换配置时动态重新加载 -- 多租户场景下的配置隔离 -- 内外网环境自动切换 -""" -import os -import json -import socket -from typing import Optional, Dict, Any, Literal - -NetworkMode = Literal['internal', 'external'] - - -def _set_if_present(target: Dict[str, Any], target_key: str, src: Dict[str, Any], src_key: str, caster): - """安全地设置目标字典的值(如果源字典中存在且不为 None) - - Args: - target: 目标字典 - target_key: 目标字典的键 - src: 源字典 - src_key: 源字典的键 - caster: 类型转换函数 - """ - try: - if src_key in src and src.get(src_key) is not None: - try: - target[target_key] = caster(src.get(src_key)) - except Exception: - pass - except Exception: - pass - - -def _to_bool(val: Any) -> bool: - """将各种类型的值转换为布尔值 - - 支持的输入: - - bool: 直接返回 - - int/float: 非零为 True - - str: "true", "1", "on", "yes" 为 True;"false", "0", "off", "no" 为 False - - Args: - val: 要转换的值 - - Returns: - bool: 转换后的布尔值 - """ - try: - if isinstance(val, bool): - return val - if isinstance(val, (int, float)): - return bool(val) - if isinstance(val, str): - m = val.strip().lower() - if m in {"true", "1", "on", "yes"}: - return True - if m in {"false", "0", "off", "no"}: - return False - return bool(val) - except Exception: - return False - - -def _make_pgsql_conn() -> Optional[object]: - """创建 PostgreSQL 数据库连接 - - 使用环境变量配置连接参数: - - DB_HOST: 数据库主机地址(默认 localhost) - - DB_PORT: 数据库端口(默认 5432) - - DB_USER: 数据库用户名 - - DB_PASSWORD: 数据库密码 - - DB_NAME: 数据库名称 - - Returns: - Optional[object]: 数据库连接对象,失败时返回 None - """ - host = os.getenv("DB_HOST", "localhost") - user = os.getenv("DB_USER") - password = os.getenv("DB_PASSWORD") - dbname = os.getenv("DB_NAME") - port_str = os.getenv("DB_PORT") - - try: - import psycopg2 # type: ignore - from psycopg2.extras import RealDictCursor # type: ignore - - port = int(port_str) if port_str else 5432 - conn = psycopg2.connect( - host=host, - port=port, - user=user, - password=password, - dbname=dbname, - ) - conn.autocommit = True - return conn - except Exception: - return None - - -def _fetch_db_config_by_group_id(group_id: str) -> Optional[Dict[str, Any]]: - """根据 group_id 从数据库查询配置 - - Args: - group_id: 组标识符 - - Returns: - Optional[Dict[str, Any]]: 配置字典,未找到时返回 None - """ - conn = _make_pgsql_conn() - if conn is None: - return None - - try: - from psycopg2.extras import RealDictCursor # type: ignore - cur = conn.cursor(cursor_factory=RealDictCursor) - - try: - cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",)) - except Exception: - pass - - sql = ( - "SELECT group_id, user_id, apply_id, chunker_strategy, " - " enable_llm_dedup_blockwise, enable_llm_disambiguation " - "FROM data_config WHERE group_id = %s ORDER BY updated_at DESC LIMIT 1" - ) - cur.execute(sql, (group_id,)) - row = cur.fetchone() - return row if row else None - except Exception: - return None - finally: - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass - - -def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, Any]]: - """根据 config_id 从数据库查询配置 - - Args: - config_id: 配置标识符(整数或字符串,会自动转换为整数) - - Returns: - Optional[Dict[str, Any]]: 配置字典,未找到时返回 None - """ - conn = _make_pgsql_conn() - if conn is None: - try: - pass - except Exception: - pass - return None - - try: - from psycopg2.extras import RealDictCursor # type: ignore - cur = conn.cursor(cursor_factory=RealDictCursor) - - try: - cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",)) - except Exception: - pass - - # config_id 在数据库中是 Integer 类型,需要转换 - try: - config_id_int = int(config_id) - except (ValueError, TypeError) as e: - try: - pass - except Exception: - pass - return None - - sql = ( - "SELECT config_id, group_id, user_id, apply_id, chunker_strategy, " - " enable_llm_dedup_blockwise, enable_llm_disambiguation, " - " deep_retrieval, t_type_strict, t_name_strict, t_overall, state, " - " statement_granularity, include_dialogue_context, max_context, " - " \"offset\" AS offset, lambda_time, lambda_mem, " - " pruning_enabled, pruning_scene, pruning_threshold, " - " llm_id, embedding_id " - "FROM data_config WHERE config_id = %s LIMIT 1" - ) - cur.execute(sql, (config_id_int,)) - row = cur.fetchone() - - if row: - try: - pass - except Exception: - pass - else: - pass - - return row if row else None - except Exception as e: - pass - return None - finally: - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass - - -def _load_dbrun_group_id(project_root: str) -> Optional[str]: - """从 dbrun.json 读取 group_id - - Args: - project_root: 项目根目录路径 - - Returns: - Optional[str]: group_id,未找到时返回 None - """ - try: - path = os.path.join(project_root, "dbrun.json") - if not os.path.isfile(path): - return None - - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - - if isinstance(data, dict): - if "group_id" in data: - return str(data.get("group_id")) - sel = data.get("selections", {}) - if isinstance(sel, dict) and "group_id" in sel: - return str(sel.get("group_id")) - - return None - except Exception: - return None - - -def _load_dbrun_config_id(project_root: str) -> Optional[str]: - """从 dbrun.json 读取 config_id - - Args: - project_root: 项目根目录路径 - - Returns: - Optional[str]: config_id,未找到时返回 None - """ - try: - path = os.path.join(project_root, "dbrun.json") - if not os.path.isfile(path): - return None - - with open(path, "r", encoding="utf-8") as f: - data = json.load(f) - - if isinstance(data, dict): - if "config_id" in data: - return str(data.get("config_id")) - sel = data.get("selections", {}) - if isinstance(sel, dict) and "config_id" in sel: - return str(sel.get("config_id")) - - return None - except Exception: - return None - - -def _apply_overrides_from_db_row( - runtime_cfg: Dict[str, Any], - db_row: Optional[Dict[str, Any]], - identifier: str, - identifier_type: str = "config_id" -) -> Dict[str, Any]: - """从数据库行数据覆写运行时配置(统一处理函数) - - Args: - runtime_cfg: 运行时配置字典 - db_row: 数据库查询结果行 - identifier: 标识符值(group_id 或 config_id) - identifier_type: 标识符类型("group_id" 或 "config_id") - - Returns: - Dict[str, Any]: 覆写后的运行时配置 - """ - try: - selections = runtime_cfg.setdefault("selections", {}) - selections[identifier_type] = identifier - - if not db_row: - return runtime_cfg - - # 覆写 selections 字段 - for tk in ("group_id", "user_id", "apply_id", "chunker_strategy", "state", - "t_type_strict", "t_name_strict", "t_overall", - "statement_granularity", "include_dialogue_context"): - _set_if_present(selections, tk, db_row, tk, str) - - # 特殊处理 UUID 字段,确保转换为字符串格式 - for uuid_field in ("llm_id", "embedding_id"): - if uuid_field in db_row and db_row.get(uuid_field) is not None: - try: - value = db_row.get(uuid_field) - # 如果是 UUID 对象,转换为字符串(带连字符的标准格式) - if hasattr(value, 'hex'): - selections[uuid_field] = str(value) - else: - selections[uuid_field] = str(value) - except Exception: - pass - - # 覆写 statement_extraction 字段 - stmt = runtime_cfg.setdefault("statement_extraction", {}) - _set_if_present(stmt, "statement_granularity", db_row, "statement_granularity", int) - _set_if_present(stmt, "include_dialogue_context", db_row, "include_dialogue_context", _to_bool) - _set_if_present(stmt, "max_dialogue_context_chars", db_row, "max_context", int) - - # 覆写 deduplication 字段 - dedup = runtime_cfg.setdefault("deduplication", {}) - for tk in ("enable_llm_dedup_blockwise", "enable_llm_disambiguation"): - _set_if_present(dedup, tk, db_row, tk, _to_bool) - _set_if_present(dedup, "deep_retrieval", db_row, "deep_retrieval", _to_bool) - - # 覆写 forgetting_engine 字段 - forgetting = runtime_cfg.setdefault("forgetting_engine", {}) - _set_if_present(forgetting, "offset", db_row, "offset", float) - _set_if_present(forgetting, "lambda_time", db_row, "lambda_time", float) - _set_if_present(forgetting, "lambda_mem", db_row, "lambda_mem", float) - - # 覆写 pruning 字段 - pruning = runtime_cfg.setdefault("pruning", {}) - _set_if_present(pruning, "enabled", db_row, "pruning_enabled", _to_bool) - _set_if_present(pruning, "scene", db_row, "pruning_scene", str) - - # 阈值需要转为 float,且限制在 [0.0, 0.9] - try: - if "pruning_threshold" in db_row and db_row.get("pruning_threshold") is not None: - thr = float(db_row.get("pruning_threshold")) - thr = max(0.0, min(0.9, thr)) # 限制在 [0.0, 0.9] - pruning["threshold"] = thr - except Exception: - pass - - return runtime_cfg - except Exception as e: - pass - return runtime_cfg - - -def apply_runtime_overrides_by_group(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]: - """基于 group_id 从数据库覆写运行时配置 - - 工作流程: - 1. 从 dbrun.json 读取 group_id - 2. 根据 group_id 查询数据库配置 - 3. 覆写运行时配置(仅在内存中) - - Args: - project_root: 项目根目录路径 - runtime_cfg: 运行时配置字典 - - Returns: - Dict[str, Any]: 覆写后的运行时配置 - """ - try: - selected_gid = _load_dbrun_group_id(project_root) - if not selected_gid: - return runtime_cfg - - db_row = _fetch_db_config_by_group_id(selected_gid) - if not db_row: - # 如果数据库中没有配置,仍然设置 group_id - runtime_cfg.setdefault("selections", {})["group_id"] = selected_gid - return runtime_cfg - - return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_gid, "group_id") - except Exception: - return runtime_cfg - - -def apply_runtime_overrides_by_config(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]: - """基于 config_id 从数据库覆写运行时配置(从 dbrun.json 读取) - - 工作流程: - 1. 从 dbrun.json 读取 config_id - 2. 根据 config_id 查询数据库配置 - 3. 覆写运行时配置(仅在内存中) - - Args: - project_root: 项目根目录路径 - runtime_cfg: 运行时配置字典 - - Returns: - Dict[str, Any]: 覆写后的运行时配置 - """ - try: - selected_cid = _load_dbrun_config_id(project_root) - if not selected_cid: - return runtime_cfg - - db_row = _fetch_db_config_by_config_id(selected_cid) - return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id") - except Exception: - return runtime_cfg - - -def apply_runtime_overrides_with_config_id( - project_root: str, - runtime_cfg: Dict[str, Any], - config_id: str -) -> tuple[Dict[str, Any], bool]: - """使用指定的 config_id 从数据库覆写运行时配置(不读 dbrun.json) - - 用于前端动态切换配置的场景。 - - Args: - project_root: 项目根目录路径 - runtime_cfg: 运行时配置字典 - config_id: 配置标识符 - - Returns: - tuple[Dict[str, Any], bool]: (覆写后的运行时配置, 是否成功从数据库加载) - """ - try: - selected_cid = str(config_id).strip() - if not selected_cid: - return runtime_cfg, False - - db_row = _fetch_db_config_by_config_id(selected_cid) - if db_row is None: - return runtime_cfg, False - - updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id") - return updated_cfg, True - except Exception as e: - pass - return runtime_cfg, False - - -# ============================================================================ -# 以下函数已注释:不再需要网络模式自动检测功能 -# ============================================================================ - -# def get_server_ip() -> str: -# """ -# 获取当前服务器的IP地址 -# -# Returns: -# 服务器IP地址字符串 -# """ -# try: -# # 方式1:从环境变量获取(优先) -# server_ip = os.getenv('SERVER_IP') -# if server_ip and server_ip not in ['127.0.0.1', 'localhost', '0.0.0.0']: -# return server_ip -# -# # 方式2:通过socket获取 -# hostname = socket.gethostname() -# ip_address = socket.gethostbyname(hostname) -# -# # 如果是本地回环地址,尝试获取真实IP -# if ip_address.startswith('127.'): -# # 尝试连接外部地址来获取本机IP -# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) -# try: -# s.connect(('8.8.8.8', 80)) -# ip_address = s.getsockname()[0] -# finally: -# s.close() -# -# return ip_address -# except Exception as e: -# print(f"[overrides] 获取服务器IP失败: {e},使用默认值 127.0.0.1") -# return '127.0.0.1' - - -# def auto_detect_network_mode() -> NetworkMode: -# """ -# 自动检测网络模式(基于服务器IP) -# -# 规则: -# - 如果服务器IP在内网IP列表中 → internal(内网) -# - 其他IP → external(外网) -# -# 可以通过环境变量 INTERNAL_SERVER_IPS 自定义内网IP列表(逗号分隔) -# -# Returns: -# 'internal' 或 'external' -# """ -# server_ip = get_server_ip() -# -# # 从环境变量获取内网IP列表(支持多个IP,逗号分隔) -# internal_ips_str = os.getenv('INTERNAL_SERVER_IPS', '119.45.181.55') -# internal_ips = [ip.strip() for ip in internal_ips_str.split(',')] -# -# # 判断当前IP是否在内网IP列表中 -# if server_ip in internal_ips: -# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于内网,使用 INTERNAL 配置") -# return 'internal' -# else: -# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于外网,使用 EXTERNAL 配置") -# return 'external' - - -# ============================================================================ -# 环境变量覆写功能已废弃 - 不再使用 -# ============================================================================ -# def _apply_env_var_overrides(runtime_cfg: Dict[str, Any], network_mode: NetworkMode = None, force_override: bool = False) -> Dict[str, Any]: -# """ -# 从环境变量覆写配置(已废弃) -# """ -# return runtime_cfg - - -def load_unified_config( - project_root: str, - config_id: Optional[int | str] = None, - group_id: Optional[str] = None, - network_mode: NetworkMode = None, - env_override_models: bool = True -) -> Dict[str, Any]: - """ - 统一配置加载器 - 按优先级加载配置 - - 配置加载优先级: - 1. PG数据库配置(最高优先级,通过 dbrun.json 中的 config_id 读取) - 2. runtime.json 默认配置(最低优先级) - - Args: - project_root: 项目根目录路径 - config_id: 配置ID(整数或字符串,可选,优先从 dbrun.json 读取) - group_id: 组ID(可选) - network_mode: 已废弃,保留参数仅为向后兼容 - env_override_models: 已废弃,保留参数仅为向后兼容 - - Returns: - Dict[str, Any]: 最终的运行时配置 - """ - try: - # 步骤 1: 加载 runtime.json 作为基础配置 - runtime_config_path = os.path.join(project_root, "runtime.json") - try: - with open(runtime_config_path, "r", encoding="utf-8") as f: - runtime_cfg = json.load(f) - except (FileNotFoundError, json.JSONDecodeError) as e: - runtime_cfg = {"selections": {}} - - # 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级) - if config_id: - # 优先使用传入的 config_id - db_row = _fetch_db_config_by_config_id(config_id) - if db_row: - runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, config_id, "config_id") - pass - elif group_id: - # 其次使用 group_id - db_row = _fetch_db_config_by_group_id(group_id) - if db_row: - runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, group_id, "group_id") - pass - else: - # 尝试从 dbrun.json 读取 - dbrun_config_id = _load_dbrun_config_id(project_root) - if dbrun_config_id: - db_row = _fetch_db_config_by_config_id(dbrun_config_id) - if db_row: - runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_config_id, "config_id") - pass - else: - dbrun_group_id = _load_dbrun_group_id(project_root) - if dbrun_group_id: - db_row = _fetch_db_config_by_group_id(dbrun_group_id) - if db_row: - runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_group_id, "group_id") - pass - return runtime_cfg - - except Exception as e: - return {"selections": {}} - - -# 向后兼容的别名 -apply_runtime_overrides = apply_runtime_overrides_by_config diff --git a/app/core/memory/utils/data/__init__.py b/app/core/memory/utils/data/__init__.py deleted file mode 100644 index 706053b9..00000000 --- a/app/core/memory/utils/data/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -数据处理模块 - -包含所有数据处理相关的工具函数,包括文本处理、时间处理和本体定义。 -""" - -# 从子模块导出常用函数和类,保持向后兼容 -from .text_utils import ( - escape_lucene_query, - extract_plain_query, -) -from .time_utils import ( - validate_date_format, - normalize_date, - normalize_date_safe, - preprocess_date_string, -) -from .ontology import ( - PREDICATE_DEFINITIONS, - LABEL_DEFINITIONS, - Predicate, - StatementType, - TemporalInfo, - RelevenceInfo, -) - -__all__ = [ - # text_utils - "escape_lucene_query", - "extract_plain_query", - # time_utils - "validate_date_format", - "normalize_date", - "normalize_date_safe", - "preprocess_date_string", - # ontology - "PREDICATE_DEFINITIONS", - "LABEL_DEFINITIONS", - "Predicate", - "StatementType", - "TemporalInfo", - "RelevenceInfo", -] diff --git a/app/core/memory/utils/data/ontology.py b/app/core/memory/utils/data/ontology.py deleted file mode 100644 index 19bddaa7..00000000 --- a/app/core/memory/utils/data/ontology.py +++ /dev/null @@ -1,199 +0,0 @@ -from enum import StrEnum - - -# Use jinja template.render -PREDICATE_DEFINITIONS = { - "IS_A": "Denotes a class-or-type relationship between two entities (e.g., 'Model Y IS_A electric-SUV'). Includes 'is' and 'was'.", - "HAS_A": "Denotes a part-whole relationship between two entities (e.g., 'Model Y HAS_A electric-engine'). Includes 'has' and 'had'.", - "LOCATED_IN": "Specifies geographic or organisational containment or proximity (e.g., headquarters LOCATED_IN Berlin).", - "HOLDS_ROLE": "Connects a person to a formal office or title within an organisation (CEO, Chair, Director, etc.).", - "PRODUCES": "Indicates that an entity manufactures, builds, or creates a product, service, or infrastructure (includes scale-ups and component inclusion).", - "SELLS": "Marks a commercial seller-to-customer relationship for a product or service (markets, distributes, sells).", - "LAUNCHED": "Captures the official first release, shipment, or public start of a product, service, or initiative.", - "DEVELOPED": "Shows design, R&D, or innovation origin of a technology, product, or capability. Includes 'researched' or 'created'.", - "ADOPTED_BY": "Indicates that a technology or product has been taken up, deployed, or implemented by another entity.", - "INVESTS_IN": "Represents the flow of capital or resources from one entity into another (equity, funding rounds, strategic investment).", - "COLLABORATES_WITH": "Generic partnership, alliance, joint venture, or licensing relationship between entities.", - "SUPPLIES": "Captures vendor–client supply-chain links or dependencies (provides to, sources from).", - "HAS_REVENUE": "Associates an entity with a revenue amount or metric—actual, reported, or projected.", - "INCREASED": "Expresses an upward change in a metric (revenue, market share, output) relative to a prior period or baseline.", - "DECREASED": "Expresses a downward change in a metric relative to a prior period or baseline.", - "RESULTED_IN": "Captures a causal relationship where one event or factor leads to a specific outcome (positive or negative).", - "TARGETS": "Denotes a strategic objective, market segment, or customer group that an entity seeks to reach.", - "PART_OF": "Expresses hierarchical membership or subset relationships (division, subsidiary, managed by, belongs to).", - "DISCONTINUED": "Indicates official end-of-life, shutdown, or termination of a product, service, or relationship.", - "SECURED": "Marks the successful acquisition of funding, contracts, assets, or rights by an entity.", - "MENTIONS": "Denotes a reference or mention of an entity in a text or document.", - - # 移除了过于宽泛的谓语集合 - # "MENTIONS": "Denotes a reference or mention of an entity in a text or document." , - # "FEELS" : "Denotes a subjective opinion or feeling about an entity (e.g., 'I feel like X').Includes 'THINKS'.", - # "HELPS" :"Express a action that make it easier or possible for (someone) to do something by offering one's services or resources. Includes 'assist', 'aid' and 'support' " , - # "IS_DOING" : "Denotes a subjective action or activity about an entity (e.g., 'I am doing X').Includes 'DOES'.", - # "LIKES": "Express enjoy or approve of something or someone (e.g., 'I like roses').Includes 'LIKES'.", - # "DISLIKES": "Express dislike or disapprove of something or someone (e.g., 'I dislike roses').Includes 'DISLIKES'.", - # "HAS_ATTRIBUTE": "Express that an entity has a certain attribute (e.g., 'X has a red car').Includes 'HAS'.", - -} - -LABEL_DEFINITIONS: dict[str, dict[str, dict[str, str]]] = { - "statement_labelling": { - "FACT": dict( - definition=( - "Statements that are objective and can be independently " - "verified or falsified through evidence." - ), - date_handling_guidance=( - "These statements can be made up of multiple static and " - "dynamic temporal events marking for example the start, end, " - "and duration of the fact described statement." - ), - date_handling_example=( - "'Company A owns Company B in 2022', 'X caused Y to happen', " - "or 'John said X at Event' are verifiable facts which currently " - "hold true unless we have a contradictory fact." - ), - ), - "OPINION": dict( - definition=( - "Statements that contain personal opinions, feelings, values, " - "or judgments that are not independently verifiable. It also " - "includes hypothetical and speculative statements." - ), - date_handling_guidance=( - "This statement is always static. It is a record of the date the " - "opinion was made." - ), - date_handling_example=( - "'I like Company A's strategy', 'X may have caused Y to happen', " - "or 'The event felt like X' are opinions and down to the reporters " - "interpretation." - ), - ), - "PREDICTION": dict( - definition=( - "Uncertain statements about the future on something that might happen, " - "a hypothetical outcome, unverified claims. " - "If the tense of the statement changed, the statement " - "would then become a fact." - ), - date_handling_guidance=( - "This statement is always static. It is a record of the date the " - "prediction was made." - ), - date_handling_example=( - "'It is rumoured that Dave will resign next month', 'Company A expects " - "X to happen', or 'X suggests Y' are all predictions." - ), - ), - "SUGGESTION": dict( - definition=( - "A proposal or recommendation for action, often implying a future course of conduct. " - " It's not a statement of fact or a prediction, but rather an advised path. " - "It's a suggestion for action that is not yet implemented." - ), - date_handling_guidance=( - "This statement is always static." - ), - date_handling_example=( - "'They should launch the new product next quarter', 'You could try a different approach', " - "or 'I would recommend moving the headquarters to Berlin' are all suggestions." - ), - ), - }, - "temporal_labelling": { - "STATIC": dict( - definition=( - "Often past tense, think -ed verbs, describing single points-in-time. " - "These statements are valid from the day they occurred and are never " - "invalid. Refer to single points in time at which an event occurred, " - "the fact X occurred on that date will always hold true." - ), - date_handling_guidance=( - "The valid_at date is the date the event occurred. The invalid_at date " - "is None." - ), - date_handling_example=( - "'John was appointed CEO on 4th Jan 2024', 'Company A reported X percent " - "growth from last FY', or 'X resulted in Y to happen' are valid the day " - "they occurred and are never invalid." - ), - ), - "DYNAMIC": dict( - definition=( - "Often present tense, think -ing verbs, describing a period of time. " - "These statements are valid for a specific period of time and are usually " - "invalidated by a Static fact marking the end of the event or start of a " - "contradictory new one. The statement could already be referring to a " - "discrete time period (invalid) or may be an ongoing relationship (not yet " - "invalid)." - ), - date_handling_guidance=( - "The valid_at date is the date the event started. The invalid_at date is " - "the date the event or relationship ended, for ongoing events this is None." - ), - date_handling_example=( - "'John is the CEO', 'Company A remains a market leader', or 'X is continuously " - "causing Y to decrease' are valid from when the event started and are invalidated " - "by a new event." - ), - ), - "ATEMPORAL": dict( - definition=( - "Statements that will always hold true regardless of time therefore have no " - "temporal bounds." - ), - date_handling_guidance=( - "These statements are assumed to be atemporal and have no temporal bounds. Both " - "their valid_at and invalid_at are None." - ), - date_handling_example=( - "'A stock represents a unit of ownership in a company', 'The earth is round', or " - "'Europe is a continent'. These statements are true regardless of time." - ), - ), - }, -} - -class Predicate(StrEnum): - """Enumeration of normalised predicates.""" - - IS_A = "IS_A" - HAS_A = "HAS_A" - LOCATED_IN = "LOCATED_IN" - HOLDS_ROLE = "HOLDS_ROLE" - PRODUCES = "PRODUCES" - SELLS = "SELLS" - LAUNCHED = "LAUNCHED" - DEVELOPED = "DEVELOPED" - ADOPTED_BY = "ADOPTED_BY" - INVESTS_IN = "INVESTS_IN" - COLLABORATES_WITH = "COLLABORATES_WITH" - SUPPLIES = "SUPPLIES" - HAS_REVENUE = "HAS_REVENUE" - INCREASED = "INCREASED" - DECREASED = "DECREASED" - RESULTED_IN = "RESULTED_IN" - TARGETS = "TARGETS" - PART_OF = "PART_OF" - DISCONTINUED = "DISCONTINUED" - SECURED = "SECURED" - MENTIONS = "MENTIONS" - - -class StatementType(StrEnum): - FACT = "FACT" - OPINION = "OPINION" - PREDICTION = "PREDICTION" - SUGGESTION = "SUGGESTION" - -class TemporalInfo(StrEnum): - ATEMPORAL = "ATEMPORAL" - STATIC = "STATIC" - DYNAMIC = "DYNAMIC" - -# Relevance labelling for statements -class RelevenceInfo(StrEnum): - RELEVANT = "RELEVANT" - IRRELEVANT = "IRRELEVANT" - diff --git a/app/core/memory/utils/data/text_utils.py b/app/core/memory/utils/data/text_utils.py deleted file mode 100644 index 133990f7..00000000 --- a/app/core/memory/utils/data/text_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import json - - -def escape_lucene_query(query: str) -> str: - """Escape Lucene special characters in a free-text query. - - This prevents ParseException when using Neo4j full-text procedures. - """ - if query is None: - return "" - - s = str(query) - # Normalize whitespace - s = s.replace("\r", " ").replace("\n", " ").strip() - - # Lucene reserved tokens/special characters - specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':'] - # Replace longer tokens first to avoid partial double-escaping - for token in sorted(specials, key=len, reverse=True): - s = s.replace(token, f"\\{token}") - - return s - -def extract_plain_query(query_input: str) -> str: - """Extract clean, plain-text query from various input forms. - - - Strips surrounding quotes and whitespace - - If input looks like JSON, prefers the 'original' field - - Fallbacks to the raw string when parsing fails - """ - if query_input is None: - return "" - - # Directly handle dict-like input - if isinstance(query_input, dict): - original = query_input.get("original") - if isinstance(original, str) and original.strip(): - return original.strip() - context = query_input.get("context") - if isinstance(context, dict): - for key, val in context.items(): - if isinstance(key, str) and key.strip(): - return key.strip() - if isinstance(val, list) and val: - first = val[0] - if isinstance(first, str) and first.strip(): - return first.strip() - # Fallback to string conversion below - - s = str(query_input).strip() - - # Remove surrounding single/double quotes if present - if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')): - s = s[1:-1].strip() - - # Attempt to parse JSON and extract the 'original' field - if s.startswith("{") and s.endswith("}"): - try: - data = json.loads(s) - # Prefer 'original' field if available - original = data.get("original") - if isinstance(original, str) and original.strip(): - return original.strip() - # Fallbacks: try common nested structures - context = data.get("context") - if isinstance(context, dict): - # Take the first key or first string value in context - for key, val in context.items(): - if isinstance(key, str) and key.strip(): - return key.strip() - if isinstance(val, list) and val: - first = val[0] - if isinstance(first, str) and first.strip(): - return first.strip() - except Exception: - # Not valid JSON; keep as-is after best-effort unescape below - pass - - # Best-effort unescape common escaped newlines/tabs without altering unicode - s = s.replace("\\n", " ").replace("\\t", " ") - return s diff --git a/app/core/memory/utils/data/time_utils.py b/app/core/memory/utils/data/time_utils.py deleted file mode 100644 index c6791dfc..00000000 --- a/app/core/memory/utils/data/time_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -import re -from dateutil import parser -from datetime import datetime - -def validate_date_format(date_str: str) -> bool: - """ - Validate if the date string is in the format YYYY-MM-DD. - """ - pattern = r"^\d{4}-\d{1,2}-\d{1,2}$" - return bool(re.match(pattern, date_str)) - - -def normalize_date(date_str: str) -> str: - """ - 更强大的日期标准化函数,支持多种日期格式转换为 Y-M-D 格式 - - Args: - date_str: 各种格式的日期字符串 - - Returns: - Y-M-D 格式的标准化日期字符串 - """ - if not date_str or not isinstance(date_str, str): - return date_str - - # 移除首尾空格 - date_str = date_str.strip().replace(' ', '').replace('/', '').replace('.', '').replace('_', '').replace('-', '') - - try: - # 预处理:识别并规范化特殊格式 - preprocessed_str = preprocess_date_string(date_str) - - # 使用 dateutil.parser 进行解析[citation:1][citation:7] - dt = parser.parse(preprocessed_str, dayfirst=False, yearfirst=True) - - return dt.strftime('%Y-%m-%d') - - except (ValueError, TypeError, OverflowError): - # 如果智能解析失败,尝试格式匹配 - return fallback_parse(date_str) - - -def preprocess_date_string(date_str: str) -> str: - """预处理日期字符串,处理特殊格式""" - - # 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔) - match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str) - if match: - year, month, day = match.groups() - # 如果年份超过4位,可能是年份和月份连在一起 - if len(year) > 4: - # 取前4位作为年份,剩余作为月份 - actual_year = year[:4] - actual_month = year[4:] + (month if month else '') - # 重新组合 - if day: - return f"{actual_year}-{actual_month.zfill(2)}-{day.zfill(2)}" - else: - return f"{actual_year}-{actual_month.zfill(2)}" - else: - return f"{year}-{month.zfill(2)}-{day.zfill(2)}" if day else f"{year}-{month.zfill(2)}" - - # 处理无分隔符的纯数字格式[citation:4] - if re.match(r'^\d{6,8}$', date_str): - if len(date_str) == 8: # YYYYMMDD - return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}" - elif len(date_str) == 6: # YYMMDD 或 MMDDYY - # 尝试不同解释 - if 1 <= int(date_str[:2]) <= 12: # 可能是 MMDDYY - return f"20{date_str[4:6]}-{date_str[:2]}-{date_str[2:4]}" - else: # 可能是 YYMMDD - return f"20{date_str[:2]}-{date_str[2:4]}-{date_str[4:6]}" - - # 处理混合分隔符,统一为 - - date_str = re.sub(r'[/\._]', '-', date_str) - - return date_str - - -def fallback_parse(date_str: str) -> str: - """备选解析方案""" - - # 尝试常见的日期格式[citation:4][citation:5] - formats_to_try = [ - '%Y-%m-%d', '%Y/%m/%d', '%Y.%m.%d', - '%Y%m%d', '%y%m%d', - '%m-%d-%Y', '%m/%d/%Y', '%m.%d.%Y', - '%d-%m-%Y', '%d/%m/%Y', '%d.%m.%Y', - '%Y-%m', '%Y/%m', '%Y.%m' - ] - - for fmt in formats_to_try: - try: - dt = datetime.strptime(date_str, fmt) - return dt.strftime('%Y-%m-%d') - except ValueError: - continue - - # 所有方法都失败时,返回原字符串或抛出异常 - return date_str - - -def normalize_date_safe(date_str: str, default: str = None) -> str: - """ - 安全的日期标准化函数,提供默认值处理 - - Args: - date_str: 日期字符串 - default: 解析失败时的默认返回值 - - Returns: - 标准化日期字符串或默认值 - """ - try: - result = normalize_date(date_str) - # 检查结果是否是有效的日期格式 - if validate_date_format(result): - return result - else: - return default if default is not None else date_str - except: - return default if default is not None else date_str - -if __name__ == "__main__": - start_dates = ["2025/10/28", "2025.10.28", "2025_10_28", "20251028"] - for date in start_dates: - print(normalize_date_safe(date)) diff --git a/app/core/memory/utils/llm/__init__.py b/app/core/memory/utils/llm/__init__.py deleted file mode 100644 index 321aee97..00000000 --- a/app/core/memory/utils/llm/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -""" -LLM 工具模块 - -包含所有 LLM 客户端相关的工具函数。 -""" - -# 从子模块导出常用函数,保持向后兼容 -from .llm_utils import ( - get_llm_client, - get_reranker_client, - handle_response, -) - -__all__ = [ - "get_llm_client", - "get_reranker_client", - "handle_response", -] diff --git a/app/core/memory/utils/llm/llm_utils.py b/app/core/memory/utils/llm/llm_utils.py deleted file mode 100644 index dc80d0a5..00000000 --- a/app/core/memory/utils/llm/llm_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -from pydantic import BaseModel - -from app.core.memory.src.llm_tools.openai_client import OpenAIClient -from app.core.memory.utils.config.config_utils import get_model_config -from app.core.memory.utils.config import definitions as config_defs -from app.core.models.base import RedBearModelConfig - -async def handle_response(response: type[BaseModel]) -> dict: - return response.model_dump() - - -def get_llm_client(llm_id: str | None = None): - llm_id = llm_id or config_defs.SELECTED_LLM_ID - - # Validate LLM ID exists before attempting to get config - if not llm_id: - raise ValueError("LLM ID is required but was not provided") - - try: - model_config = get_model_config(llm_id) - except Exception as e: - # Re-raise with clear error message about invalid LLM ID - raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e - - try: - # 移除调试打印,避免污染终端输出 - # print(model_config) - llm_client = OpenAIClient(RedBearModelConfig( - model_name=model_config.get("model_name"), - provider=model_config.get("provider"), - api_key=model_config.get("api_key"), - base_url=model_config.get("base_url") - ),type_=model_config.get("type")) - # print(llm.dict()) - return llm_client - except Exception as e: - model_name = model_config.get('model_name', 'unknown') - raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e - - -def get_reranker_client(rerank_id: str | None = None): - """ - Get an LLM client configured for reranking. - - Args: - rerank_id: Optional reranker model ID. If None, uses SELECTED_RERANK_ID. - - Returns: - OpenAIClient: Initialized client for the reranker model - - Raises: - ValueError: If rerank_id is invalid or client initialization fails - """ - rerank_id = rerank_id or config_defs.SELECTED_RERANK_ID - - # Validate rerank ID exists before attempting to get config - if not rerank_id: - raise ValueError("Rerank ID is required but was not provided") - - try: - model_config = get_model_config(rerank_id) - except Exception as e: - # Re-raise with clear error message about invalid rerank ID - raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e - - try: - reranker_client = OpenAIClient(RedBearModelConfig( - model_name=model_config.get("model_name"), - provider=model_config.get("provider"), - api_key=model_config.get("api_key"), - base_url=model_config.get("base_url") - ),type_=model_config.get("type")) - return reranker_client - except Exception as e: - model_name = model_config.get('model_name', 'unknown') - raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e \ No newline at end of file diff --git a/app/core/memory/utils/log/__init__.py b/app/core/memory/utils/log/__init__.py deleted file mode 100644 index 7386f911..00000000 --- a/app/core/memory/utils/log/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -日志管理模块 - -包含所有日志相关的工具函数。 -""" - -# 从子模块导出常用函数,保持向后兼容 -from .logging_utils import ( - log_prompt_rendering, - log_template_rendering, - log_time, - prompt_logger, -) -from .audit_logger import audit_logger - -__all__ = [ - # logging_utils - "log_prompt_rendering", - "log_template_rendering", - "log_time", - "prompt_logger", - # audit_logger - "audit_logger", -] diff --git a/app/core/memory/utils/log/audit_logger.py b/app/core/memory/utils/log/audit_logger.py deleted file mode 100644 index 9010aad5..00000000 --- a/app/core/memory/utils/log/audit_logger.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -配置审计日志记录器 - -提供专门的审计日志功能,用于追踪配置变更和操作记录。 -""" -import logging -import os -from datetime import datetime -from typing import Optional, Dict, Any - - -def _format_value(value: Any) -> str: - """ - 格式化值为字符串,特殊处理 UUID 等对象 - - Args: - value: 要格式化的值 - - Returns: - str: 格式化后的字符串 - """ - if value is None: - return "None" - elif isinstance(value, bool): - return str(value) - elif hasattr(value, 'hex'): # UUID 对象有 hex 属性 - return str(value) # 使用标准的 UUID 字符串格式(带连字符) - else: - return str(value) - - -class ConfigAuditLogger: - """配置审计日志记录器""" - - def __init__(self, log_file: str = "logs/config_audit.log"): - """ - 初始化审计日志记录器 - - Args: - log_file: 日志文件路径 - """ - self.logger = logging.getLogger("config_audit") - self.logger.setLevel(logging.INFO) - - # 避免重复添加处理器 - if not self.logger.handlers: - # 确保日志目录存在 - log_dir = os.path.dirname(log_file) - if log_dir and not os.path.exists(log_dir): - os.makedirs(log_dir, exist_ok=True) - - # 创建文件处理器 - handler = logging.FileHandler(log_file, encoding='utf-8') - formatter = logging.Formatter( - '%(asctime)s [AUDIT] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - handler.setFormatter(formatter) - self.logger.addHandler(handler) - - def log_config_load( - self, - config_id: str, - user_id: Optional[str] = None, - group_id: Optional[str] = None, - success: bool = True, - details: Optional[Dict[str, Any]] = None - ): - """ - 记录配置加载事件 - - Args: - config_id: 配置 ID - user_id: 用户 ID(可选) - group_id: 组 ID(可选) - success: 是否成功 - details: 详细信息(可选) - """ - result = "SUCCESS" if success else "FAILED" - msg = ( - f"CONFIG_LOAD config_id={config_id} " - f"user={user_id or 'N/A'} group={group_id or 'N/A'} " - f"result={result}" - ) - if details: - # 格式化详细信息,确保所有值都正确转换为字符串 - details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items()) - msg += f" details=[{details_str}]" - self.logger.info(msg) - - def log_config_change( - self, - config_id: str, - old_values: Dict[str, Any], - new_values: Dict[str, Any], - user_id: Optional[str] = None - ): - """ - 记录配置变更事件 - - Args: - config_id: 配置 ID - old_values: 旧配置值 - new_values: 新配置值 - user_id: 用户 ID(可选) - """ - changes = [] - for key in new_values: - if key in old_values and old_values[key] != new_values[key]: - changes.append(f"{key}: {old_values[key]} -> {new_values[key]}") - - if changes: - msg = ( - f"CONFIG_CHANGE config_id={config_id} " - f"user={user_id or 'N/A'} " - f"changes=[{', '.join(changes)}]" - ) - self.logger.info(msg) - - def log_operation( - self, - operation: str, - config_id: str, - group_id: str, - success: bool = True, - duration: Optional[float] = None, - error: Optional[str] = None, - details: Optional[Dict[str, Any]] = None - ): - """ - 记录操作事件 - - Args: - operation: 操作类型(WRITE, READ 等) - config_id: 配置 ID - group_id: 组 ID - success: 是否成功 - duration: 操作耗时(秒) - error: 错误信息(可选) - details: 详细信息(可选) - """ - result = "SUCCESS" if success else "FAILED" - msg = ( - f"{operation.upper()} config_id={config_id} " - f"group={group_id} result={result}" - ) - if duration is not None: - msg += f" duration={duration:.2f}s" - if error: - msg += f" error={error}" - if details: - # 格式化详细信息,确保所有值都正确转换为字符串 - details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items()) - msg += f" details=[{details_str}]" - self.logger.info(msg) - - def log_cache_event( - self, - event_type: str, - config_id: Optional[str] = None, - details: Optional[Dict[str, Any]] = None - ): - """ - 记录缓存事件 - - Args: - event_type: 事件类型(HIT, MISS, CLEAR, EXPIRE) - config_id: 配置 ID(可选) - details: 详细信息(可选) - """ - msg = f"CACHE_{event_type.upper()}" - if config_id: - msg += f" config_id={config_id}" - if details: - # 格式化详细信息,确保所有值都正确转换为字符串 - details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items()) - msg += f" details=[{details_str}]" - self.logger.info(msg) - - -# 全局审计日志记录器实例 -audit_logger = ConfigAuditLogger() diff --git a/app/core/memory/utils/log/logging_utils.py b/app/core/memory/utils/log/logging_utils.py deleted file mode 100644 index ca32f201..00000000 --- a/app/core/memory/utils/log/logging_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Logging utilities for prompt rendering and timing. - -This module provides backward-compatible access to memory module logging utilities -that have been unified into the centralized logging system (app.core.logging_config). - -All logging functions are now imported from the centralized configuration to ensure -consistent behavior, formatting, and configuration across the entire application. - -For new code, consider importing directly from app.core.logging_config: - from app.core.logging_config import log_prompt_rendering, log_template_rendering, log_time - -This module maintains backward compatibility for existing code that imports from here. -""" - -# Import from centralized logging configuration -from app.core.logging_config import ( - log_prompt_rendering as _log_prompt_rendering, - log_template_rendering as _log_template_rendering, - log_time as _log_time, - get_prompt_logger as _get_prompt_logger, -) - -# Re-export functions to maintain backward compatibility -log_prompt_rendering = _log_prompt_rendering -log_template_rendering = _log_template_rendering -log_time = _log_time - -# Re-export prompt_logger for backward compatibility with code that uses it directly -# This provides the same logger instance that was previously created in this module -prompt_logger = _get_prompt_logger() - -# Expose functions in __all__ for explicit exports -__all__ = [ - 'log_prompt_rendering', - 'log_template_rendering', - 'log_time', - 'prompt_logger', -] diff --git a/app/core/memory/utils/paths/__init__.py b/app/core/memory/utils/paths/__init__.py deleted file mode 100644 index 27c31f91..00000000 --- a/app/core/memory/utils/paths/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -路径管理模块 - -包含所有路径管理相关的工具函数。 -""" - -# 从子模块导出常用函数,保持向后兼容 -from .output_paths import ( - get_output_dir, - get_output_path, -) - -__all__ = [ - "get_output_dir", - "get_output_path", -] diff --git a/app/core/memory/utils/paths/output_paths.py b/app/core/memory/utils/paths/output_paths.py deleted file mode 100644 index d6df6ad4..00000000 --- a/app/core/memory/utils/paths/output_paths.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Output Path Management for Memory Module - -This module provides utilities for managing output file paths in the memory module. -All output files are now centralized in the logs/memory-output directory. - -Migration from: app/core/memory/src/pipeline_output/ -Migration to: logs/memory-output/ -""" - -import os -from pathlib import Path -from typing import Optional - -try: - from app.core.config import settings - USE_UNIFIED_CONFIG = True -except ImportError: - USE_UNIFIED_CONFIG = False - settings = None - - -def get_output_dir() -> str: - """ - Get the base output directory for memory module files. - - Returns: - str: Path to the output directory - """ - if USE_UNIFIED_CONFIG: - return settings.MEMORY_OUTPUT_DIR - else: - # Fallback to default path - return "logs/memory-output" - - -def get_output_path(filename: str) -> str: - """ - Get the full path for a memory module output file. - - Args: - filename: Name of the output file - - Returns: - str: Full path to the output file - """ - if USE_UNIFIED_CONFIG: - return settings.get_memory_output_path(filename) - else: - # Fallback to default path - return os.path.join("logs/memory-output", filename) - - -def ensure_output_dir() -> None: - """ - Ensure the output directory exists. - Creates the directory if it doesn't exist. - """ - if USE_UNIFIED_CONFIG: - settings.ensure_memory_output_dir() - else: - # Fallback: create directory manually - output_dir = Path("logs/memory-output") - output_dir.mkdir(parents=True, exist_ok=True) - - -# Standard output file names (for consistency across the module) -class OutputFiles: - """Standard output file names for the memory module.""" - - # Chunker output - CHUNKER_TEST_OUTPUT = "chunker_test_output.txt" - - # Preprocessing output - PREPROCESSED_DATA = "preprocessed_data.json" - PRUNED_DATA = "pruned_data.json" - PRUNED_TERMINAL = "pruned_terminal.json" - - # Extraction output - STATEMENT_EXTRACTION = "statement_extraction.txt" - RELATIONS_OUTPUT = "relations_output.txt" - EXTRACTED_TRIPLETS = "extracted_triplets.txt" - EXTRACTED_ENTITIES_EDGES = "extracted_entities_edges.txt" - EXTRACTED_TEMPORAL_DATA = "extracted_temporal_data.txt" - - # Deduplication output - DEDUP_ENTITY_OUTPUT = "dedup_entity_output.txt" - - # Summary output - EXTRACTED_RESULT = "extracted_result.json" - EXTRACTED_RESULT_READABLE = "extracted_result_readable.txt" - - # Analytics output - USER_DASHBOARD = "User-Dashboard.json" - SIGNBOARD = "Signboard.json" - - -def get_standard_output_path(file_constant: str) -> str: - """ - Get the full path for a standard output file. - - Args: - file_constant: One of the OutputFiles constants - - Returns: - str: Full path to the output file - """ - return get_output_path(file_constant) - - -# Backward compatibility: Legacy path resolution -def resolve_legacy_path(legacy_path: str) -> str: - """ - Resolve a legacy pipeline_output path to the new unified output path. - - This function helps migrate code that uses hardcoded pipeline_output paths. - - Args: - legacy_path: Path containing 'pipeline_output' - - Returns: - str: New path using unified output directory - """ - if "pipeline_output" in legacy_path: - # Extract filename from legacy path - filename = os.path.basename(legacy_path) - return get_output_path(filename) - return legacy_path - - -# Aliases for backward compatibility with test code -get_memory_output_dir = get_output_dir -get_memory_output_path = get_output_path diff --git a/app/core/memory/utils/prompt/__init__.py b/app/core/memory/utils/prompt/__init__.py deleted file mode 100644 index 012bb311..00000000 --- a/app/core/memory/utils/prompt/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -提示词管理模块 - -包含所有提示词渲染和模板管理相关的工具函数。 -""" - -# 从子模块导出常用函数,保持向后兼容 -from .prompt_utils import ( - get_prompts, - render_statement_extraction_prompt, - render_temporal_extraction_prompt, - render_entity_dedup_prompt, - render_triplet_extraction_prompt, - render_memory_summary_prompt, - prompt_env, -) -from .template_render import ( - render_evaluate_prompt, - render_reflexion_prompt, -) - -__all__ = [ - # prompt_utils - "get_prompts", - "render_statement_extraction_prompt", - "render_temporal_extraction_prompt", - "render_entity_dedup_prompt", - "render_triplet_extraction_prompt", - "render_memory_summary_prompt", - "prompt_env", - # template_render - "render_evaluate_prompt", - "render_reflexion_prompt", -] diff --git a/app/core/memory/utils/prompt/prompt_utils.py b/app/core/memory/utils/prompt/prompt_utils.py deleted file mode 100644 index 77a23e0f..00000000 --- a/app/core/memory/utils/prompt/prompt_utils.py +++ /dev/null @@ -1,240 +0,0 @@ -import os -from jinja2 import Environment, FileSystemLoader - -from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering - -# Setup Jinja2 environment -# Get the directory of this file (app/core/memory/utils/prompt/) -current_dir = os.path.dirname(os.path.abspath(__file__)) -prompt_dir = os.path.join(current_dir, "prompts") -prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) - -async def get_prompts(message: str) -> list[dict]: - """ - Renders system and user prompts using Jinja2 templates. - """ - system_template = prompt_env.get_template("system.jinja2") - user_template = prompt_env.get_template("user.jinja2") - - system_prompt = system_template.render() - user_prompt = user_template.render(message=message) - - # 记录渲染结果到提示日志(与示例日志结构一致) - log_prompt_rendering('system', system_prompt) - log_prompt_rendering('user', user_prompt) - # 可选:记录模板渲染信息(仅当 prompt_templates.log 存在时生效) - log_template_rendering('system.jinja2', {}) - log_template_rendering('user.jinja2', {'message': message}) - return [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - -async def render_statement_extraction_prompt( - chunk_content: str, - definitions: dict, - json_schema: dict, - granularity: int | None = None, - include_dialogue_context: bool = False, - dialogue_content: str | None = None, - max_dialogue_chars: int | None = None, -) -> str: - """ - Renders the statement extraction prompt using the extract_statement.jinja2 template. - - Args: - chunk_content: The content of the chunk to process - definitions: Label definitions for statement classification - json_schema: JSON schema for the expected output format - - Returns: - Rendered prompt content as string - """ - template = prompt_env.get_template("extract_statement.jinja2") - # Optional clipping of dialogue context - ctx = None - if include_dialogue_context and dialogue_content: - try: - if isinstance(max_dialogue_chars, int) and max_dialogue_chars > 0: - ctx = dialogue_content[:max_dialogue_chars] - else: - ctx = dialogue_content - except Exception: - ctx = dialogue_content - - rendered_prompt = template.render( - inputs={"chunk": chunk_content}, - definitions=definitions, - json_schema=json_schema, - granularity=granularity, - include_dialogue_context=include_dialogue_context, - dialogue_context=ctx, - ) - # 记录渲染结果到提示日志(与示例日志结构一致) - log_prompt_rendering('statement extraction', rendered_prompt) - # 可选:记录模板渲染信息 - log_template_rendering('extract_statement.jinja2', { - 'inputs': 'chunk', - 'definitions': 'LABEL_DEFINITIONS', - 'json_schema': 'StatementExtractionResponse.schema', - 'granularity': 'int|None', - 'include_dialogue_context': include_dialogue_context, - 'dialogue_context_len': (len(ctx) if isinstance(ctx, str) else 0), - }) - - return rendered_prompt - -async def render_temporal_extraction_prompt( - ref_dates: dict, - statement: dict, - temporal_guide: dict, - statement_guide: dict, - json_schema: dict, -) -> str: - """ - Renders the temporal extraction prompt using the extract_temporal.jinja2 template. - - Args: - ref_dates: Reference dates for context. - statement: The statement to process. - temporal_guide: Guidance on temporal types. - statement_guide: Guidance on statement types. - json_schema: JSON schema for the expected output format. - - Returns: - Rendered prompt content as a string. - """ - template = prompt_env.get_template("extract_temporal.jinja2") - inputs = ref_dates | statement - rendered_prompt = template.render( - inputs=inputs, - temporal_guide=temporal_guide, - statement_guide=statement_guide, - json_schema=json_schema, - ) - # 记录渲染结果到提示日志(与示例日志结构一致) - log_prompt_rendering('temporal extraction', rendered_prompt) - # 可选:记录模板渲染信息 - log_template_rendering('extract_temporal.jinja2', { - 'inputs': 'ref_dates|statement', - 'temporal_guide': 'dict', - 'statement_guide': 'dict', - 'json_schema': 'Temporal.schema' - }) - - return rendered_prompt - -def render_entity_dedup_prompt( - entity_a: dict, - entity_b: dict, - context: dict, - json_schema: dict, - disambiguation_mode: bool = False, -) -> str: - """ - Render the entity deduplication prompt using the entity_dedup.jinja2 template. - - Args: - entity_a: Dict of entity A attributes - entity_b: Dict of entity B attributes - context: Dict of computed signals (group/type gate, similarities, co-occurrence, relation statements) - json_schema: JSON schema for the structured output (EntityDedupDecision) - - Returns: - Rendered prompt content as string - """ - template = prompt_env.get_template("entity_dedup.jinja2") - rendered_prompt = template.render( - entity_a=entity_a, - entity_b=entity_b, - same_group=context.get("same_group", False), - type_ok=context.get("type_ok", False), - type_similarity=context.get("type_similarity", 0.0), - name_text_sim=context.get("name_text_sim", 0.0), - name_embed_sim=context.get("name_embed_sim", 0.0), - name_contains=context.get("name_contains", False), - co_occurrence=context.get("co_occurrence", False), - relation_statements=context.get("relation_statements", []), - json_schema=json_schema, - disambiguation_mode=disambiguation_mode, - ) - - # prompt_logger.info("\n=== RENDERED ENTITY DEDUP PROMPT ===") - # prompt_logger.info(rendered_prompt) - # prompt_logger.info("\n" + "="*50 + "\n") - - return rendered_prompt - - -# async def render_entity_dedup_prompt( -# entity_a: dict, -# entity_b: dict, -# context: dict, -# json_schema: dict, -# ) -> str: -# """ -# Render the entity deduplication prompt using the entity_dedup.jinja2 template. - -# Args: -# entity_a: Dict of entity A attributes -async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str: - """ - Renders the triplet extraction prompt using the extract_triplet.jinja2 template. - - Args: - statement: Statement text to process - chunk_content: The content of the chunk to process - json_schema: JSON schema for the expected output format - predicate_instructions: Optional predicate instructions - - Returns: - Rendered prompt content as string - """ - template = prompt_env.get_template("extract_triplet.jinja2") - rendered_prompt = template.render( - statement=statement, - chunk_content=chunk_content, - json_schema=json_schema, - predicate_instructions=predicate_instructions - ) - # 记录渲染结果到提示日志(与示例日志结构一致) - log_prompt_rendering('triplet extraction', rendered_prompt) - # 可选:记录模板渲染信息 - log_template_rendering('extract_triplet.jinja2', { - 'statement': 'str', - 'chunk_content': 'str', - 'json_schema': 'TripletExtractionResponse.schema', - 'predicate_instructions': 'PREDICATE_DEFINITIONS' - }) - - return rendered_prompt - -async def render_memory_summary_prompt( - chunk_texts: str, - json_schema: dict, - max_words: int = 200, -) -> str: - """ - Renders the memory summary prompt using the memory_summary.jinja2 template. - - Args: - chunk_texts: Concatenated text of conversation chunks - json_schema: JSON schema for the expected output format - max_words: Maximum words for the summary - - Returns: - Rendered prompt content as string. - """ - template = prompt_env.get_template("memory_summary.jinja2") - rendered_prompt = template.render( - chunk_texts=chunk_texts, - json_schema=json_schema, - max_words=max_words, - ) - log_prompt_rendering('memory summary', rendered_prompt) - log_template_rendering('memory_summary.jinja2', { - 'chunk_texts_len': len(chunk_texts or ""), - 'max_words': max_words, - 'json_schema': 'MemorySummaryResponse.schema' - }) - return rendered_prompt diff --git a/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 b/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 deleted file mode 100644 index b83e7b92..00000000 --- a/app/core/memory/utils/prompt/prompts/entity_dedup.jinja2 +++ /dev/null @@ -1,60 +0,0 @@ -===任务=== -你是一个实体去重/消歧判断助手。你将被提供两个实体的详细信息和上下文,请严格根据指引判断它们是否是同一真实世界实体,并在需要时进行类型消歧。 - -模式: {{ '消歧模式' if disambiguation_mode else '去重模式' }} - -===输入=== -实体A: -- 名称: "{{ entity_a.name | default('') }}" -- 类型: "{{ entity_a.entity_type | default('') }}" -- 描述: "{{ entity_a.description | default('') }}" -- 别名: {{ entity_a.aliases | default([]) }} -- 摘要: "{{ entity_a.fact_summary | default('') }}" -- 连接强弱: "{{ entity_a.connect_strength | default('') }}" - -实体B: -- 名称: "{{ entity_b.name | default('') }}" -- 类型: "{{ entity_b.entity_type | default('') }}" -- 描述: "{{ entity_b.description | default('') }}" -- 别名: {{ entity_b.aliases | default([]) }} -- 摘要: "{{ entity_b.fact_summary | default('') }}" -- 连接强弱: "{{ entity_b.connect_strength | default('') }}" - -上下文: -- 同组: {{ same_group | default(false) }} -- 类型一致或未知类型: {{ type_ok | default(false) }} -- 类型相似度(0-1): {{ type_similarity | default(0.0) }} -- 名称文本相似度(0-1): {{ name_text_sim | default(0.0) }} -- 名称向量相似度(0-1): {{ name_embed_sim | default(0.0) }} -- 名称包含关系: {{ name_contains | default(false) }} -- 上下文同源(同一语句指向两者): {{ co_occurrence | default(false) }} -- 两者相关的关系陈述(来自实体-实体边): -{% for s in relation_statements %} - - {{ s }} -{% endfor %} - -===判定指引=== -{% if disambiguation_mode %} -- 这是“同名但类型不同”的消歧场景。请判断两者是否指向同一真实世界实体。 -- 综合名称文本/向量相似度、别名、描述、摘要与上下文关系(同源与关系陈述)进行判断。 -- 若无法充分确定,应保守处理:不合并,并建议阻断该对在其他模糊/启发式合并中出现(block_pair=true)。 -- 若需要合并(should_merge=true),请选择“规范实体”(canonical_idx)并在可能的情况下给出建议统一类型(suggested_type),建议类型需与上下文一致。 -- 规范实体优先级:连接强度(strong/both)更高者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。 -{% else %} -- 若实体类型相同或任一为UNKNOWN/空,可放行作为候选;若类型明显冲突(如人 vs 物品),除非别名与描述高度一致,否则判定不同实体。 -- 综合名称文本/向量相似度、别名、描述、摘要以及上下文关系判断是否为同一实体。 -- 当上下文同源或存在明确的关系陈述支持同一性(例如同一对象反复被提及或别名对应),可以适度降低判定阈值。 -- 保守决策:当无法充分确定,不要合并(same_entity=false)。 -- 若需要合并,选择“保留的规范实体”(canonical_idx)为更合适的一个: - - 优先保留连接强度更强(strong/both)者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。 -{% endif %} - -**Output format** -**CRITICAL JSON FORMATTING REQUIREMENTS:** -1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes -2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") -3. Ensure all JSON strings are properly closed and comma-separated -4. Do not include line breaks within JSON string values - -The output language should always be the same as the input language. -{{ json_schema }} \ No newline at end of file diff --git a/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/app/core/memory/utils/prompt/prompts/evaluate.jinja2 deleted file mode 100644 index cb5b917d..00000000 --- a/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ /dev/null @@ -1,19 +0,0 @@ -你将收到一组记忆对象:{{ evaluate_data }}。 -任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突) - -仅输出一个合法 JSON 对象,严格遵循下述结构: -{ - "data": [ ...与输入同结构的记忆对象数组... ], - "conflict": true 或 false, - "conflict_memory": 若冲突为 true,则填写与其冲突的记忆对象;否则为 null -} - -必须遵守: -- 只输出 JSON,不要添加解释或多余文本。 -- 使用标准双引号,必要时对内部引号进行转义。 -- 字段名与结构必须与给定模式一致。 - -模式参考: -[ - {{ json_schema }} -] \ No newline at end of file diff --git a/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 b/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 deleted file mode 100644 index 8cd47c6e..00000000 --- a/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 +++ /dev/null @@ -1,49 +0,0 @@ -{# - 对话级抽取与相关性判定模板(用于剪枝加速) - 输入:pruning_scene, dialog_text - 输出:严格 JSON(不要包含任何多余文本),字段: - - is_related: bool,是否与所选场景相关 - - times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等) - - ids: [string],编号/ID/订单号/申请号/账号等 - - amounts: [string],金额/费用/价格相关(带单位或货币符号) - - contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等) - - addresses: [string],地址/地点相关文本 - - keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语) - - 要求: - - 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。 - - times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。 - - 仅输出上述键;避免多余解释或字段。 -#} - -{% set scene_instructions = { - 'education': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。', - 'online_service': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。', - 'outbound': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。' -} %} - -{% set scene_key = pruning_scene %} -{% if scene_key not in scene_instructions %} -{% set scene_key = 'education' %} -{% endif %} - -{% set instruction = scene_instructions[scene_key] %} - -请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性: -场景说明:{{ instruction }} - -对话全文: -""" -{{ dialog_text }} -""" - -只输出严格 JSON(键固定、顺序不限): -{ - "is_related": , - "times": [...], - "ids": [...], - "amounts": [...], - "contacts": [...], - "addresses": [...], - "keywords": [...] -} \ No newline at end of file diff --git a/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 b/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 deleted file mode 100644 index be1e1917..00000000 --- a/app/core/memory/utils/prompt/prompts/extract_statement.jinja2 +++ /dev/null @@ -1,207 +0,0 @@ -{% macro tidy(name) -%} - {{ name.replace('_', ' ')}} -{%- endmacro %} - - -===Tasks=== - -Your task is to identify and extract declarative statements from the provided conversational chunk based on the detailed extraction guidelines. -Each statement must be labeled as per the criteria mentioned below. - -===Inputs=== -{% if inputs %} -{% for key, val in inputs.items() %} -- {{ key }}: {{val}} -{% endfor %} -{% endif %} - - -===Extraction Instructions=== -{% if granularity %} -{% if granularity == 3 %} -Atomic & Clear: Structure statements to clearly show a single subject-predicate-object relationship. It is better to have multiple smaller statements than one complex one. -Context-Independent: Statements must be understandable without needing to read the entire conversation. -{% elif granularity == 2 %} -Extract statements at the sentence level. Each statement should correspond to a single, complete thought (typically a full sentence from the source) but be rephrased for maximum clarity, removing conversational filler (e.g., 'um,' 'like,' interjections). -{% elif granularity == 1 %} -Extract only essence sentences and summarize the chunk into multiple, standalone statements, each focusing on factual statements, user preferences, relationships, and salient temporal context. -{% endif %} -{% endif %} - -Context Resolution Requirements: -- Resolve demonstrative pronouns ("that," "this," "those","这个", "那个") to their specific referents -- If a statement contains vague references that cannot be resolved from the conversation context, either: - a) Expand the statement to include the missing context from earlier in the conversation - b) Mark the statement as requiring additional context - c) Skip extraction if the statement becomes meaningless without context - -Conversational Context & Co-reference Resolution: -- Attribute every statement to the participant who uttered it. -- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户"). -- Resolve all pronouns to the specific person or entity from the conversation's context. -- Identify and resolve abstract references to their specific names if mentioned. -- Expand abbreviations and acronyms to their full form. - -{% if include_dialogue_context %} -===Full Dialogue Context=== -The following is the complete dialogue context to help you understand references, pronouns, and conversational flow: - -{{ dialogue_context }} - -===End of Dialogue Context=== -{% endif %} - -Filtering and Formatting: - -- Extract only declarative statements. - DO NOT extract questions, commands, greetings, or conversational filler. -Temporal Precision: - -Include any explicit dates, times, or quantitative qualifiers. -If a sentence describes both the start of an event (static) and its ongoing nature (dynamic), extract both as separate statements. - -{%- if definitions %} - {%- for section_key, section_dict in definitions.items() %} -==== {{ tidy(section_key) | upper }} DEFINITIONS & GUIDANCE ==== - {%- for category, details in section_dict.items() %} -{{ loop.index }}. {{ category }} -- Definition: {{ details.get("definition", "") }} - {% endfor -%} - {% endfor -%} -{% endif -%} - -===Examples=== -Example 1: English Conversation -Example Chunk: """ -Date: March 15, 2024 -Participants: -- Sarah Chen (User) -- Assistant (AI) - -User: "I've been trying watercolor painting recently and painted some flowers." -AI: "Watercolor painting is very interesting! Watercolor paints are typically made from pigments mixed with binders like gum arabic. How do you like it?" -User: "I think the color combinations could use some improvement, but I really like roses and lilies." -""" - -Example Output: { - "statements": [ - { - "statement": "Sarah Chen has been trying watercolor painting recently.", - "statement_type": "FACT", - "temporal_type": "DYNAMIC", - "relevance": "RELEVANT" - }, - { - "statement": "Sarah Chen painted some flowers.", - "statement_type": "FACT", - "temporal_type": "DYNAMIC", - "relevance": "RELEVANT" - }, - { - "statement": "Watercolor paints are typically made from pigments mixed with binders like gum arabic.", - "statement_type": "FACT", - "temporal_type": "ATEMPORAL", - "relevance": "IRRELEVANT" - }, - { - "statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.", - "statement_type": "OPINION", - "temporal_type": "STATIC", - "relevance": "RELEVANT" - }, - { - "statement": "Sarah Chen really likes roses and lilies.", - "statement_type": "FACT", - "temporal_type": "STATIC", - "relevance": "RELEVANT" - } - ] -} - -Example 2: Chinese Conversation (中文对话示例) -Example Chunk: """ -日期: 2024年3月15日 -参与者: -- 张曼婷 (用户) -- 小助手 (AI助手) - -用户: "我最近在尝试水彩画,画了一些花朵。" -AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。你觉得怎么样?" -用户: "我觉得色彩搭配还有提升的空间,不过我很喜欢玫瑰和百合这两种花。" -""" - -Example Output: { - "statements": [ - { - "statement": "张曼婷最近在尝试水彩画。", - "statement_type": "FACT", - "temporal_type": "DYNAMIC", - "relevance": "RELEVANT" - }, - { - "statement": "张曼婷画了一些花朵。", - "statement_type": "FACT", - "temporal_type": "DYNAMIC", - "relevance": "RELEVANT" - }, - { - "statement": "水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。", - "statement_type": "FACT", - "temporal_type": "ATEMPORAL", - "relevance": "IRRELEVANT" - }, - { - "statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。", - "statement_type": "OPINION", - "temporal_type": "STATIC", - "relevance": "RELEVANT" - }, - { - "statement": "张曼婷很喜欢玫瑰和百合。", - "statement_type": "FACT", - "temporal_type": "STATIC", - "relevance": "RELEVANT" - } - ] -} -===End of Examples=== - -===Reflection Process=== - -After extracting statements, perform the following self-review steps: - -**Step 1: Attribution Check** -- Confirm every statement is properly attributed to the correct speaker -- Verify speaker names are used consistently throughout -- Check that AI assistant statements are properly attributed - -**Step 2: Completeness Review** -- Ensure no important declarative statements were missed -- Check that temporal information is preserved - -**Step 3: Classification Validation** -- Review statement_type classifications (FACT/OPINION/PREDICTION/SUGGESTION) -- Verify temporal_type assignments (STATIC/DYNAMIC/ATEMPORAL) -- Ensure classifications align with the provided definitions - -**Step 4: Final Quality Check** -- Remove any questions, commands, or conversational filler -- Verify JSON format compliance -- Confirm output language matches input language - -**Output format** -**CRITICAL JSON FORMATTING REQUIREMENTS:** -1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes -2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") -3. Ensure all JSON strings are properly closed and comma-separated -4. Do not include line breaks within JSON string values -5. Example of proper escaping: "statement": "John said: \"I really like this book.\"" - -**LANGUAGE REQUIREMENT:** -- The output language should ALWAYS match the input language -- If input is in English, extract statements in English -- If input is in Chinese, extract statements in Chinese -- Preserve the original language and do not translate - -Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below: -{{ json_schema }} \ No newline at end of file diff --git a/app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 b/app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 deleted file mode 100644 index a77d6093..00000000 --- a/app/core/memory/utils/prompt/prompts/extract_temporal.jinja2 +++ /dev/null @@ -1,81 +0,0 @@ - -{% macro tidy(name) -%} - {{ name.replace('_', ' ')}} -{%- endmacro %} -{# - This prompt (template) is adapted from [getzep/graphiti] - Licensed under the Apache License, Version 2.0 - - Original work: - https://github.com/getzep/graphiti/blob/main/graphiti_core/prompts/extract_edge_dates.py - - Modifications made by Ke Sun on 2025-09-01 - See the LICENSE file for the full Apache 2.0 license text. -#} -# Task - -Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable). - -# Input Data -{% if inputs %} -{% for key, val in inputs.items() %} -- {{ key }}: {{val}} -{% endfor %} -{% endif %} - -# Temporal Fields - -- **valid_at**: When the relationship/event started or became true (ISO 8601 format) -- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing) - -# Extraction Rules - -## Core Principles -1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge -2. **Use the reference/publication date as "now"** when interpreting relative times -3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions -4. **For point-in-time events**, set only `valid_at` - -## Date Format Requirements -- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ` -- If no time specified, use `00:00:00` (midnight) -- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate -- If only month mentioned, use first or last day of month -- Always include timezone (use `Z` for UTC if unspecified) -- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date - -## Statement Type Rules - -{{ inputs.get("statement_type") | upper }} Statement Guidance: -{%for key, guide in statement_guide.items() %} -- {{ tidy(key) | capitalize }}: {{ guide }} -{% endfor %} - -**Special Cases:** -- **Opinion statements**: Set only `valid_at` (when opinion was expressed) -- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned - -## Temporal Type Rules - -{{ inputs.get("temporal_type") | upper }} Temporal Type Guidance: -{% for key, guide in temporal_guide.items() %} -- {{ tidy(key) | capitalize }}: {{ guide }} -{% endfor %} - -{% if inputs.get('quarter') and inputs.get('publication_date') %} -## Quarter Reference -Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline. -{% endif %} - -# Output Requirements - -## JSON Formatting (CRITICAL) -1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants -2. Escape internal quotes with backslash: `\"` -3. No line breaks within JSON string values -4. Properly close and comma-separate all fields - -## Language -Output language must match input language. - -{{ json_schema }} diff --git a/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 deleted file mode 100644 index 0bfc5eb7..00000000 --- a/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ /dev/null @@ -1,248 +0,0 @@ -{% macro tidy(name) -%} - {{ name.replace('_', ' ')}} -{%- endmacro %} - -===Task=== -Extract entities and knowledge triplets from the given statement. - -===Inputs=== -**Chunk Content:** "{{ chunk_content }}" -**Statement:** "{{ statement }}" - -===Guidelines=== - -**Entity Extraction:** -- Extract entities with their types and context-independent descriptions -- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions -- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value) - Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric' - -**Triplet Extraction:** -- Extract (subject, predicate, object) triplets where: - - Subject: main entity performing the action or being described - - Predicate: relationship between entities (e.g., 'is', 'works at', 'believes') - - Object: entity, value, or concept affected by the predicate -- Exclude all temporal expressions from every field -- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens) -- Do NOT translate predicate tokens -- Do NOT include `statement_id` field (assigned automatically) - -**When NOT to extract triplets:** -- Non-propositional utterances (emotions, fillers, onomatopoeia) -- No clear predicate from the given definitions applies -- Standalone noun phrases or checklist items (e.g., "三脚架", "备用电池") → extract as entities only -- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS") - -**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty. -{%- if predicate_instructions -%} - -**Predicate Instructions:** -Use ONLY these predicates. If none fits, set triplets to []. -{%- for pred, instruction in predicate_instructions.items() %} -- {{ pred }}: {{ instruction }} -{%- endfor -%} -{%- endif -%} - - -===Examples=== - -**Example 1 (English):** "I plan to travel to Paris next week and visit the Louvre." -Output: -{ - "triplets": [ - { - "subject_name": "I", - "subject_id": 0, - "predicate": "PLANS_TO_VISIT", - "object_name": "Paris", - "object_id": 1, - "value": null - }, - { - "subject_name": "I", - "subject_id": 0, - "predicate": "PLANS_TO_VISIT", - "object_name": "Louvre", - "object_id": 2, - "value": null - } - ], - "entities": [ - { - "entity_idx": 0, - "name": "I", - "type": "Person", - "description": "The user" - }, - { - "entity_idx": 1, - "name": "Paris", - "type": "Location", - "description": "Capital city of France" - }, - { - "entity_idx": 2, - "name": "Louvre", - "type": "Location", - "description": "World-famous museum located in Paris" - } - ] -} - -**Example 2 (English):** "John Smith works at Google and is responsible for AI product development." -Output: -{ - "triplets": [ - { - "subject_name": "John Smith", - "subject_id": 0, - "predicate": "WORKS_AT", - "object_name": "Google", - "object_id": 1, - "value": null - }, - { - "subject_name": "John Smith", - "subject_id": 0, - "predicate": "RESPONSIBLE_FOR", - "object_name": "AI product development", - "object_id": 2, - "value": null - } - ], - "entities": [ - { - "entity_idx": 0, - "name": "John Smith", - "type": "Person", - "description": "Individual person name" - }, - { - "entity_idx": 1, - "name": "Google", - "type": "Organization", - "description": "American technology company" - }, - { - "entity_idx": 2, - "name": "AI product development", - "type": "WorkRole", - "description": "Artificial intelligence product development work" - } - ] -} - -**Example 3 (Chinese):** "我计划下周去巴黎旅行,参观卢浮宫。" -Output: -{ - "triplets": [ - { - "subject_name": "我", - "subject_id": 0, - "predicate": "PLANS_TO_VISIT", - "object_name": "巴黎", - "object_id": 1, - "value": null - }, - { - "subject_name": "我", - "subject_id": 0, - "predicate": "PLANS_TO_VISIT", - "object_name": "卢浮宫", - "object_id": 2, - "value": null - } - ], - "entities": [ - { - "entity_idx": 0, - "name": "我", - "type": "Person", - "description": "用户本人" - }, - { - "entity_idx": 1, - "name": "巴黎", - "type": "Location", - "description": "法国首都城市" - }, - { - "entity_idx": 2, - "name": "卢浮宫", - "type": "Location", - "description": "位于巴黎的世界著名博物馆" - } - ] -} - -**Example 4 (Chinese):** "张明在腾讯工作,负责AI产品开发。" -Output: -{ - "triplets": [ - { - "subject_name": "张明", - "subject_id": 0, - "predicate": "WORKS_AT", - "object_name": "腾讯", - "object_id": 1, - "value": null - }, - { - "subject_name": "张明", - "subject_id": 0, - "predicate": "RESPONSIBLE_FOR", - "object_name": "AI产品开发", - "object_id": 2, - "value": null - } - ], - "entities": [ - { - "entity_idx": 0, - "name": "张明", - "type": "Person", - "description": "个人姓名" - }, - { - "entity_idx": 1, - "name": "腾讯", - "type": "Organization", - "description": "中国科技公司" - }, - { - "entity_idx": 2, - "name": "AI产品开发", - "type": "WorkRole", - "description": "人工智能产品研发工作" - } - ] -} - -**Example 5 (Entity Only):** "Tripod" or "三脚架" -Output: -{ - "triplets": [], - "entities": [ - { - "entity_idx": 0, - "name": "Tripod", - "type": "Equipment", - "description": "Photography equipment accessory" - } - ] -} - -===Output Format=== - -**JSON Requirements:** -- Use only ASCII double quotes (") for JSON structure -- Never use Chinese quotation marks ("") or Unicode quotes -- Escape quotation marks in text with backslashes (\") -- Ensure proper string closure and comma separation -- No line breaks within JSON string values -- The output language should ALWAYS match the input language -- If input is in English, extract statements in English -- If input is in Chinese, extract statements in Chinese -- Preserve the original language and do not translate - -{{ json_schema }} \ No newline at end of file diff --git a/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 b/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 deleted file mode 100644 index 1dd86ca3..00000000 --- a/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 +++ /dev/null @@ -1,29 +0,0 @@ -{% macro tidy(name) -%} - {{ name.replace('_', ' ') }} -{%- endmacro %} - -=== Task === -Summarize the provided conversation chunks into a concise Memory summary. - -=== Requirements === -- Focus on factual statements, user preferences, relationships, and salient temporal context. -- Avoid repetition and filler; be specific. -- Keep it under {{ max_words or 200 }} words. -- Output must be valid JSON conforming to the schema below. - -=== Input === -{% if chunk_texts %} -{{ chunk_texts }} -{% endif %} - -=== Output Schema === -**CRITICAL JSON FORMATTING REQUIREMENTS:** -1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes -2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") -3. Ensure all JSON strings are properly closed and comma-separated -4. Do not include line breaks within JSON string values -5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\"" - -The output language should always be the same as the input language. -Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below: -{{ json_schema }} \ No newline at end of file diff --git a/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/app/core/memory/utils/prompt/prompts/reflexion.jinja2 deleted file mode 100644 index 3f78b137..00000000 --- a/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ /dev/null @@ -1,23 +0,0 @@ -你将收到一条冲突判定对象:{{ data }}。 -任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。 - -仅输出一个合法 JSON 对象,严格遵循下述结构: -{ - "conflict": 与输入同结构,包含 data 与 conflict_memory, - "reflexion": { "reason": string, "solution": string }, - "resolved": { - "original_memory_id": 被设为失效的记忆 id, - "resolved_memory": 完整的设为失效后的记忆对象 - } -} - -必须遵守: -- 只输出 JSON,不要添加解释或多余文本。 -- 使用标准双引号,必要时对内部引号进行转义。 -- 字段名与结构必须与给定模式一致。 -- 当 conflict 为 false 时,resolved 必须为 null。 - - 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。 -模式参考: -[ - {{ json_schema }} -] diff --git a/app/core/memory/utils/prompt/prompts/system.jinja2 b/app/core/memory/utils/prompt/prompts/system.jinja2 deleted file mode 100644 index 4975876b..00000000 --- a/app/core/memory/utils/prompt/prompts/system.jinja2 +++ /dev/null @@ -1,2 +0,0 @@ -You are an AI assistant that extracts entity nodes from conversational messages. -Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation. \ No newline at end of file diff --git a/app/core/memory/utils/prompt/prompts/user.jinja2 b/app/core/memory/utils/prompt/prompts/user.jinja2 deleted file mode 100644 index f20146c7..00000000 --- a/app/core/memory/utils/prompt/prompts/user.jinja2 +++ /dev/null @@ -1,5 +0,0 @@ -You are given a conversation context and a CURRENT MESSAGE. -Your task is to extract user name and age mentioned **explicitly or implicitly** in the CURRENT MESSAGE. -Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the reference entities. - -{{ message }} \ No newline at end of file diff --git a/app/core/memory/utils/prompt/template_render.py b/app/core/memory/utils/prompt/template_render.py deleted file mode 100644 index c783e095..00000000 --- a/app/core/memory/utils/prompt/template_render.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -from jinja2 import Environment, FileSystemLoader -from typing import List, Dict, Any - - -# Setup Jinja2 environment -prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") -prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) - -async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str: - """ - Renders the evaluate prompt using the evaluate.jinja2 template. - - Args: - evaluate_data: The data to evaluate - schema: The JSON schema to use for the output. - - Returns: - Rendered prompt content as string - """ - template = prompt_env.get_template("evaluate.jinja2") - - rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema) - - return rendered_prompt - -async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str: - """ - Renders the reflexion prompt using the extract_temporal.jinja2 template. - - Args: - data: The data to reflex on. - schema: The JSON schema to use for the output. - - Returns: - Rendered prompt content as a string. - """ - template = prompt_env.get_template("reflexion.jinja2") - - rendered_prompt = template.render(data=data, json_schema=schema) - - return rendered_prompt diff --git a/app/core/memory/utils/self_reflexion_utils/__init__.py b/app/core/memory/utils/self_reflexion_utils/__init__.py deleted file mode 100644 index 422a83e3..00000000 --- a/app/core/memory/utils/self_reflexion_utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思工具模块 - -本模块提供自我反思引擎的核心功能,包括: -- 记忆冲突判定 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api 迁移而来。 -""" - -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - -__all__ = ["conflict", "reflexion", "self_reflexion"] diff --git a/app/core/memory/utils/self_reflexion_utils/evaluate.py b/app/core/memory/utils/self_reflexion_utils/evaluate.py deleted file mode 100644 index 0ea68461..00000000 --- a/app/core/memory/utils/self_reflexion_utils/evaluate.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -"""记忆冲突判定模块 - -本模块提供记忆冲突判定功能,使用LLM判断记忆数据中是否存在冲突。 -从 app.core.memory.src.data_config_api.evaluate 迁移而来。 -""" - -import logging -from typing import List, Any -import time - -from app.core.memory.utils.prompt.template_render import render_evaluate_prompt -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.schemas.memory_storage_schema import ConflictResultSchema -from pydantic import BaseModel - - -async def conflict(evaluate_data: List[Any]) -> List[Any]: - """ - Evaluates memory conflict using the evaluate.jinja2 template. - - Args: - evaluate_data: 反思数据列表。 - Returns: - 冲突记忆列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - client = get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - print(f"====== 冲突判定开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ConflictResultSchema) - end_time = time.time() - print(f"冲突判定耗时: {end_time - start_time} 秒") - print(f"冲突判定原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。") - return [] - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。") - return [response] diff --git a/app/core/memory/utils/self_reflexion_utils/reflexion.py b/app/core/memory/utils/self_reflexion_utils/reflexion.py deleted file mode 100644 index 6835b868..00000000 --- a/app/core/memory/utils/self_reflexion_utils/reflexion.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -"""反思执行模块 - -本模块提供反思执行功能,使用LLM对冲突记忆进行反思和解决。 -从 app.core.memory.src.data_config_api.reflexion 迁移而来。 -""" - -import logging -from typing import List, Any -import time - -from app.core.memory.utils.prompt.template_render import render_reflexion_prompt -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.schemas.memory_storage_schema import ReflexionResultSchema -from pydantic import BaseModel - - -async def reflexion(ref_data: List[Any]) -> List[Any]: - """ - Reflexes on the given reference data using the reflexion.jinja2 template. - - Args: - ref_data: 反思数据列表。 - Returns: - 反思结果列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - client = get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - - print(f"====== 反思开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ReflexionResultSchema) - end_time = time.time() - print(f"反思耗时: {end_time - start_time} 秒") - print(f"反思原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 反思输出解析失败,返回空列表以继续流程。") - return [] - # 统一返回为列表[dict],便于自我反思主流程更新数据库 - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化反思返回类型,尝试直接封装为列表。") - return [response] diff --git a/app/core/memory/utils/self_reflexion_utils/self_reflexion.py b/app/core/memory/utils/self_reflexion_utils/self_reflexion.py deleted file mode 100644 index 5687223d..00000000 --- a/app/core/memory/utils/self_reflexion_utils/self_reflexion.py +++ /dev/null @@ -1,250 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思主执行模块 - -本模块提供自我反思引擎的主流程,包括: -- 获取反思数据 -- 冲突判断 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。 -""" - -import os -import json -import logging -import asyncio -from typing import List, Dict, Any -import uuid - -from app.core.memory.utils.config.definitions import ( - REFLEXION_ENABLED, - REFLEXION_ITERATION_PERIOD, - REFLEXION_RANGE, - REFLEXION_BASELINE, -) -from app.db import get_db -from sqlalchemy.orm import Session -from app.models.retrieval_info import RetrievalInfo -from app.core.memory.utils.config.get_data import get_data -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - -# 并发限制(可通过环境变量覆盖) -CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5")) - -# 确保 INFO 级别日志输出到终端 -_root_logger = logging.getLogger() -if not _root_logger.handlers: - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") -else: - _root_logger.setLevel(logging.INFO) - - -async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]: - """ - 根据反思范围获取判断的记忆数据。 - - Args: - host_id: 主机ID - Returns: - 符合反思范围的记忆数据列表。 - """ - if REFLEXION_RANGE == "retrieval": - return await get_data(host_id) - elif REFLEXION_RANGE == "database": - return [] - else: - raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}") - - -async def run_conflict(conflict_data: List[Any]) -> List[Any]: - """ - 判断反思数据中是否存在冲突。 - - Args: - conflict_data: 冲突数据列表。 - Returns: - 如果存在冲突则返回冲突记忆列表,否则返回空列表。 - """ - if not conflict_data: - return [] - - conflict_data = await conflict(conflict_data) - # 仅保留存在冲突的条目(conflict == True) - try: - return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True] - except Exception: - return [] - - -async def run_reflexion(reflexion_data: List[Any]) -> Any: - """ - 执行反思,解决冲突。 - - Args: - reflexion_data: 反思数据列表。 - Returns: - 解决冲突后的反思结果(由 LLM 返回)。 - """ - if not reflexion_data: - return [] - # 并行对每个冲突进行反思,整体缩短等待时间 - sem = asyncio.Semaphore(CONCURRENCY) - - async def _reflex_one(item: Any) -> Dict[str, Any] | None: - async with sem: - try: - result_list = await reflexion([item]) - if not result_list: - return None - obj = result_list[0] - if hasattr(obj, "model_dump"): - return obj.model_dump() - elif hasattr(obj, "dict"): - return obj.dict() - elif isinstance(obj, dict): - return obj - except Exception as e: - logging.warning(f"反思失败,跳过一项: {e}") - return None - - tasks = [_reflex_one(item) for item in reflexion_data] - results = await asyncio.gather(*tasks, return_exceptions=False) - return [r for r in results if r] - - -async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str: - """ - 更新记忆库,将解决冲突后的记忆更新到记忆库中。 - - Args: - solved_data: 解决冲突后的记忆(由 LLM 返回)。 - host_id: 主机ID - Returns: - 更新结果(成功或失败)。 - """ - flag = False - if not solved_data: - return "数据缺失,更新失败" - if not isinstance(solved_data, list): - return "数据格式错误,更新失败" - neo4j_connector = Neo4jConnector() - try: - print(f"====== 更新记忆开始 ======\n") - - sem = asyncio.Semaphore(CONCURRENCY) - success_count = 0 - - async def _update_one(item: Dict[str, Any]) -> bool: - async with sem: - try: - if not isinstance(item, dict): - return False - if not item: - return False - resolved = item.get("resolved") - if not isinstance(resolved, dict) or not resolved: - logging.warning(f"反思结果无可更新内容,跳过此项: {item}") - return False - resolved_mem = resolved.get("resolved_memory") - if not isinstance(resolved_mem, dict) or not resolved_mem: - logging.warning(f"反思结果缺少 resolved_memory,跳过此项: {item}") - return False - group_id = resolved_mem.get("group_id") - id = resolved_mem.get("id") - # 使用 invalid_at 字段作为新的失效时间 - new_invalid_at = resolved_mem.get("invalid_at") - if not all([group_id, id, new_invalid_at]): - logging.warning(f"记忆更新参数缺失,跳过此项: {item}") - return False - await neo4j_connector.execute_query( - UPDATE_STATEMENT_INVALID_AT, - group_id=group_id, - id=id, - new_invalid_at=new_invalid_at, - ) - return True - except Exception as e: - logging.error(f"更新单条记忆失败: {e}") - return False - - tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)] - results = await asyncio.gather(*tasks, return_exceptions=False) - success_count = sum(1 for r in results if r) - - logging.info(f"成功更新 {success_count} 条记忆") - flag = success_count > 0 - return "更新成功" if flag else "更新失败" - except Exception as e: - logging.error(f"更新记忆库失败: {e}") - return "更新失败" - finally: - if flag: # 删除数据库中的检索数据 - db: Session = next(get_db()) - try: - db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete() - db.commit() - logging.info(f"成功删除 {success_count} 条检索数据") - except Exception as e: - logging.error(f"删除数据库中的检索数据失败: {e}") - - -async def _append_json(label: str, data: Any) -> None: - """记录冲突记忆(后台线程写入,避免阻塞事件循环)""" - def _write(): - with open("reflexion_data.json", "a", encoding="utf-8") as f: - f.write(f"### {label} ###\n") - json.dump(data, f, ensure_ascii=False, indent=4) - f.write("\n\n") - # 正确地在协程内等待后台线程执行,避免未等待的协程警告 - await asyncio.to_thread(_write) - - -async def self_reflexion(host_id: uuid.UUID) -> str: - """ - 自我反思引擎,执行反思流程。 - - Args: - host_id: 主机ID - - Returns: - 反思结果描述字符串 - """ - if not REFLEXION_ENABLED: - return "未开启反思..." - print(f"====== 自我反思流程开始 ======\n") - reflexion_data = await get_reflexion_data(host_id) - if not reflexion_data: - print(f"====== 自我反思流程结束 ======\n") - return "无反思数据,结束反思" - print(f"反思数据获取成功,共 {len(reflexion_data)} 条") - - conflict_data = await run_conflict(reflexion_data) - if not conflict_data: - print(f"====== 自我反思流程结束 ======\n") - return "无冲突,无需反思" - print(f"冲突记忆类型: {type(conflict_data)}") - await _append_json("conflict", conflict_data) - - solved_data = await run_reflexion(conflict_data) - if not solved_data: - print(f"====== 自我反思流程结束 ======\n") - return "反思失败,未解决冲突" - print(f"解决冲突后的记忆类型: {type(solved_data)}") - await _append_json("solved_data", solved_data) - - result = await update_memory(solved_data, host_id) - print(f"更新记忆库结果: {result}") - print(f"====== 自我反思流程结束 ======\n") - return result - - -if __name__ == "__main__": - import asyncio - # host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333") - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) diff --git a/app/core/memory/utils/visualization/__init__.py b/app/core/memory/utils/visualization/__init__.py deleted file mode 100644 index 8e3541e9..00000000 --- a/app/core/memory/utils/visualization/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -可视化模块 - -包含所有可视化相关的工具函数,主要用于遗忘曲线的可视化。 -""" - -# 从子模块导出常用函数,保持向后兼容 -from .forgetting_visualizer import ( - export_memory_curve_numpy, - export_memory_curves_multiple_strengths, - export_parameter_sweep_numpy, - visualize_forgetting_curve, - plot_3d_forgetting_surface, - create_comparison_visualization, - save_memory_curves_to_file, -) - -__all__ = [ - "export_memory_curve_numpy", - "export_memory_curves_multiple_strengths", - "export_parameter_sweep_numpy", - "visualize_forgetting_curve", - "plot_3d_forgetting_surface", - "create_comparison_visualization", - "save_memory_curves_to_file", -] diff --git a/app/core/memory/utils/visualization/forgetting_visualizer.py b/app/core/memory/utils/visualization/forgetting_visualizer.py deleted file mode 100644 index fe82302b..00000000 --- a/app/core/memory/utils/visualization/forgetting_visualizer.py +++ /dev/null @@ -1,386 +0,0 @@ -""" -Memory Visualization Utilities - -This module provides visualization functions for the modified Ebbinghaus forgetting curve -and utilities to export memory curves as numpy arrays. -""" - -import numpy as np -import matplotlib.pyplot as plt -from typing import Optional, Tuple, List, Dict, Any -import math - - -def export_memory_curve_numpy(forgetting_engine, - time_range: Tuple[float, float] = (0, 10), - memory_strength: float = 1.0, - num_points: int = 1000) -> Tuple[np.ndarray, np.ndarray]: - """ - Export memory curve as numpy arrays for time and retention values. - - Args: - forgetting_engine: Instance of ForgettingEngine - time_range: Tuple of (start_time, end_time) - memory_strength: Memory strength value to use - num_points: Number of points to generate - - Returns: - Tuple of (time_array, retention_array) - """ - start_time, end_time = time_range - time_array = np.linspace(start_time, end_time, num_points) - retention_array = np.array([ - forgetting_engine.forgetting_curve(t, memory_strength) - for t in time_array - ]) - - return time_array, retention_array - - -def export_memory_curves_multiple_strengths(forgetting_engine, - time_range: Tuple[float, float] = (0, 10), - memory_strengths: List[float] = None, - num_points: int = 1000) -> Dict[str, np.ndarray]: - """ - Export memory curves for multiple memory strengths as numpy arrays. - - Args: - forgetting_engine: Instance of ForgettingEngine - time_range: Tuple of (start_time, end_time) - memory_strengths: List of memory strength values - num_points: Number of points to generate - - Returns: - Dictionary with 'time' and retention arrays for each strength - """ - if memory_strengths is None: - memory_strengths = [0.5, 1.0, 2.0, 5.0] - - start_time, end_time = time_range - time_array = np.linspace(start_time, end_time, num_points) - - result = {'time': time_array} - - for strength in memory_strengths: - retention_array = np.array([ - forgetting_engine.forgetting_curve(t, strength) - for t in time_array - ]) - result[f'strength_{strength}'] = retention_array - - return result - - -def export_parameter_sweep_numpy(base_engine, - parameter_name: str, - parameter_values: List[float], - time_range: Tuple[float, float] = (0, 10), - memory_strength: float = 1.0, - num_points: int = 1000) -> Dict[str, np.ndarray]: - """ - Export memory curves for parameter sweep as numpy arrays. - - Args: - base_engine: Base ForgettingEngine instance - parameter_name: Name of parameter to sweep ('offset', 'lambda_time', 'lambda_mem') - parameter_values: List of parameter values to test - time_range: Tuple of (start_time, end_time) - memory_strength: Memory strength value to use - num_points: Number of points to generate - - Returns: - Dictionary with 'time' and retention arrays for each parameter value - """ - from app.core.memory.storage_services.forgetting_engine import ForgettingEngine - from app.core.memory.models.variate_config import ForgettingEngineConfig - - start_time, end_time = time_range - time_array = np.linspace(start_time, end_time, num_points) - - result = {'time': time_array} - - for param_value in parameter_values: - # Create new engine with modified parameter - if parameter_name == 'offset': - config = ForgettingEngineConfig(offset=param_value, lambda_time=base_engine.lambda_time, lambda_mem=base_engine.lambda_mem) - elif parameter_name == 'lambda_time': - config = ForgettingEngineConfig(offset=base_engine.offset, lambda_time=param_value, lambda_mem=base_engine.lambda_mem) - elif parameter_name == 'lambda_mem': - config = ForgettingEngineConfig(offset=base_engine.offset, lambda_time=base_engine.lambda_time, lambda_mem=param_value) - else: - raise ValueError(f"Unknown parameter: {parameter_name}") - - engine = ForgettingEngine(config) - - retention_array = np.array([ - engine.forgetting_curve(t, memory_strength) - for t in time_array - ]) - result[f'{parameter_name}_{param_value}'] = retention_array - - return result - - -def visualize_forgetting_curve(forgetting_engine, - max_time: float = 10.0, - memory_strengths: Optional[List[float]] = None, - figsize: Tuple[int, int] = (12, 8)) -> None: - """ - Visualize the modified Ebbinghaus forgetting curve. - - Args: - forgetting_engine: Instance of ForgettingEngine - max_time: Maximum time to plot - memory_strengths: List of memory strength values to plot - figsize: Figure size for the plot - """ - if memory_strengths is None: - memory_strengths = [0.5, 1.0, 2.0, 5.0] - - # Create time array - t = np.linspace(0, max_time, 1000) - - # Create subplots - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize) - fig.suptitle('Modified Ebbinghaus Forgetting Curve Analysis', fontsize=16, fontweight='bold') - - # Plot 1: Different memory strengths - ax1.set_title('Effect of Memory Strength (S)') - for S in memory_strengths: - retention = [forgetting_engine.forgetting_curve(time, S) for time in t] - ax1.plot(t, retention, label=f'S = {S}', linewidth=2) - ax1.set_xlabel('Time') - ax1.set_ylabel('Memory Retention') - ax1.legend() - ax1.grid(True, alpha=0.3) - ax1.set_ylim(0, 1) - - # Plot 2: Different lambda_time values - ax2.set_title('Effect of λ_time') - lambda_times = [0.5, 1.0, 0.3] - lambda_mem = [0.5,0.3,1.0] - offset_mem = [0.1,0.05,0.2] - for i in range(len(lambda_times)): - lt = lambda_times[i] - lm = lambda_mem[i] - off = offset_mem[i] - from app.core.memory.storage_services.forgetting_engine import ForgettingEngine - from app.core.memory.models.variate_config import ForgettingEngineConfig - config = ForgettingEngineConfig(offset=off, lambda_time=lt, lambda_mem=lm) - temp_engine = ForgettingEngine(config) - retention = [temp_engine.forgetting_curve(time, 1.0) for time in t] - ax2.plot(t, retention, label=f'λ_time = {lt}', linewidth=2) - ax2.set_xlabel('Time') - ax2.set_ylabel('Memory Retention') - ax2.legend() - ax2.grid(True, alpha=0.3) - ax2.set_ylim(0, 1) - - plt.tight_layout() - plt.show() - - -def plot_3d_forgetting_surface(forgetting_engine, - max_time: float = 10.0, - max_strength: float = 5.0, - figsize: Tuple[int, int] = (12, 9)) -> None: - """ - Create a 3D surface plot of the forgetting curve. - - Args: - forgetting_engine: Instance of ForgettingEngine - max_time: Maximum time to plot - max_strength: Maximum memory strength to plot - figsize: Figure size for the plot - """ - # Create meshgrid - t = np.linspace(0.1, max_time, 50) - S = np.linspace(0.1, max_strength, 50) - T, S_mesh = np.meshgrid(t, S) - - # Calculate retention for each point - R = np.zeros_like(T) - for i in range(T.shape[0]): - for j in range(T.shape[1]): - R[i, j] = forgetting_engine.forgetting_curve(T[i, j], S_mesh[i, j]) - - # Create 3D plot - fig = plt.figure(figsize=figsize) - ax = fig.add_subplot(111, projection='3d') - - surface = ax.plot_surface(T, S_mesh, R, cmap='viridis', alpha=0.8) - - ax.set_xlabel('Time (t)') - ax.set_ylabel('Memory Strength (S)') - ax.set_zlabel('Memory Retention (R)') - ax.set_title(f'3D Forgetting Curve Surface\n(offset={forgetting_engine.offset}, λ_time={forgetting_engine.lambda_time}, λ_mem={forgetting_engine.lambda_mem})') - - # Add colorbar - fig.colorbar(surface, shrink=0.5, aspect=5) - - plt.show() - - -def create_comparison_visualization(forgetting_engine, figsize: Tuple[int, int] = (15, 10)) -> None: - """ - Create a comparison visualization of different curve configurations. - - Args: - forgetting_engine: Instance of ForgettingEngine - figsize: Figure size for the plot - """ - # Create figure with multiple subplots - fig, axes = plt.subplots(2, 2, figsize=figsize) - fig.suptitle('Modified Ebbinghaus Forgetting Curve - Parameter Comparison', fontsize=16, fontweight='bold') - - t = np.linspace(0, 10, 100) - - # Plot 1: Original vs Modified curve - ax1 = axes[0, 0] - ax1.set_title('Original vs Modified Ebbinghaus Curve') - - # Original Ebbinghaus: R = e^(-t/S) - S = 2.0 - original = np.exp(-t / S) - ax1.plot(t, original, 'r--', label='Original: R = e^(-t/S)', linewidth=2) - - # Modified with offset - modified = [forgetting_engine.forgetting_curve(time, S) for time in t] - ax1.plot(t, modified, 'b-', label='Modified: offset + (1-offset)*e^(-λ_time*t/λ_mem*S)', linewidth=2) - - ax1.set_xlabel('Time') - ax1.set_ylabel('Memory Retention') - ax1.legend() - ax1.grid(True, alpha=0.3) - ax1.set_ylim(0, 1) - - # Plot 2: Different offset values - ax2 = axes[0, 1] - ax2.set_title('Effect of Offset Parameter') - - for offset in [0.0, 0.1, 0.2, 0.3]: - from forgetting.forgetting_engine import ForgettingEngine - from app.core.memory.models.variate_config import ForgettingEngineConfig - config = ForgettingEngineConfig(offset=offset, lambda_time=1.0, lambda_mem=1.0) - engine = ForgettingEngine(config) - retention = [engine.forgetting_curve(time, 1.0) for time in t] - ax2.plot(t, retention, label=f'offset = {offset}', linewidth=2) - - ax2.set_xlabel('Time') - ax2.set_ylabel('Memory Retention') - ax2.legend() - ax2.grid(True, alpha=0.3) - ax2.set_ylim(0, 1) - - # Plot 3: Lambda time effect - ax3 = axes[1, 0] - ax3.set_title('Effect of λ_time (Time Sensitivity)') - - for lambda_time in [0.5, 1.0, 2.0, 3.0]: - from forgetting.forgetting_engine import ForgettingEngine - from app.core.memory.models.config_models import ForgettingEngineConfig - config = ForgettingEngineConfig(offset=0.1, lambda_time=lambda_time, lambda_mem=1.0) - engine = ForgettingEngine(config) - retention = [engine.forgetting_curve(time, 1.0) for time in t] - ax3.plot(t, retention, label=f'λ_time = {lambda_time}', linewidth=2) - - ax3.set_xlabel('Time') - ax3.set_ylabel('Memory Retention') - ax3.legend() - ax3.grid(True, alpha=0.3) - ax3.set_ylim(0, 1) - - # Plot 4: Memory strength effect - ax4 = axes[1, 1] - ax4.set_title('Effect of Memory Strength (S)') - - for strength in [0.5, 1.0, 2.0, 4.0]: - retention = [forgetting_engine.forgetting_curve(time, strength) for time in t] - ax4.plot(t, retention, label=f'S = {strength}', linewidth=2) - - ax4.set_xlabel('Time') - ax4.set_ylabel('Memory Retention') - ax4.legend() - ax4.grid(True, alpha=0.3) - ax4.set_ylim(0, 1) - - plt.tight_layout() - plt.show() - - -def save_memory_curves_to_file(forgetting_engine, - filename: str, - time_range: Tuple[float, float] = (0, 10), - memory_strengths: List[float] = None, - num_points: int = 1000, - format: str = 'npz') -> None: - """ - Save memory curves to file in various formats. - - Args: - forgetting_engine: Instance of ForgettingEngine - filename: Output filename (without extension) - time_range: Tuple of (start_time, end_time) - memory_strengths: List of memory strength values - num_points: Number of points to generate - format: Output format ('npz', 'csv', 'json') - """ - if memory_strengths is None: - memory_strengths = [0.5, 1.0, 2.0, 5.0] - - curves_data = export_memory_curves_multiple_strengths( - forgetting_engine, time_range, memory_strengths, num_points - ) - - if format == 'npz': - np.savez(f"{filename}.npz", **curves_data) - elif format == 'csv': - import pandas as pd - df = pd.DataFrame(curves_data) - df.to_csv(f"{filename}.csv", index=False) - elif format == 'json': - import json - # Convert numpy arrays to lists for JSON serialization - json_data = {k: v.tolist() if isinstance(v, np.ndarray) else v - for k, v in curves_data.items()} - with open(f"{filename}.json", 'w') as f: - json.dump(json_data, f, indent=2) - else: - raise ValueError(f"Unsupported format: {format}") - - -if __name__ == "__main__": - # Example usage - from app.core.memory.storage_services.forgetting_engine import ForgettingEngine - - print("Memory Visualization Utilities Demo") - print("=" * 40) - - # Create engine - from app.core.memory.models.variate_config import ForgettingEngineConfig - config = ForgettingEngineConfig(offset=0.1, lambda_time=0.5, lambda_mem=0.5) - engine = ForgettingEngine(config) - - # # Export single curve as numpy - # time_arr, retention_arr = export_memory_curve_numpy(engine, (0, 10), 1.0, 100) - # print(f"Exported single curve: {len(time_arr)} points") - # print(f"Time range: {time_arr[0]:.2f} to {time_arr[-1]:.2f}") - # print(f"Retention range: {retention_arr.min():.4f} to {retention_arr.max():.4f}") - - # # Export multiple curves - # curves = export_memory_curves_multiple_strengths(engine, (0, 10), [0.5, 1.0, 2.0]) - # print(f"\nExported multiple curves: {list(curves.keys())}") - - # # Parameter sweep - # param_sweep = export_parameter_sweep_numpy(engine, 'offset', [0.0, 0.1, 0.2, 0.3]) - # print(f"Parameter sweep results: {list(param_sweep.keys())}") - - # print("\nVisualization functions are ready to use!") - visualize_forgetting_curve(engine) - create_comparison_visualization(engine) - - - - - - diff --git a/app/core/models/__init__.py b/app/core/models/__init__.py deleted file mode 100644 index f54afc08..00000000 --- a/app/core/models/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFactory -from .llm import RedBearLLM -from .embedding import RedBearEmbeddings -from .rerank import RedBearRerank - -__all__ = [ - "RedBearModelConfig", - "RedBearLLM", - "RedBearEmbeddings", - "RedBearRerank", - "RedBearModelFactory", - "get_provider_llm_class" -] \ No newline at end of file diff --git a/app/core/models/base.py b/app/core/models/base.py deleted file mode 100644 index e33fd102..00000000 --- a/app/core/models/base.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations -import asyncio, httpx, time, os -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, TypeVar, Callable -from langchain_community.document_compressors import JinaRerank -from pydantic import BaseModel, Field -from langchain_core.runnables import RunnableSerializable -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models import BaseLLM, BaseLanguageModel -from langchain_core.outputs import LLMResult, Generation -from langchain_core.embeddings import Embeddings -from langchain_core.retrievers import BaseRetriever - -from app.models.models_model import ModelProvider, ModelType -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - -T = TypeVar("T") - -class RedBearModelConfig(BaseModel): - """模型配置基类""" - model_name: str - provider: str - api_key: str - base_url: Optional[str] = None - # 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置 - timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) - # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 - max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2"))) - concurrency: int = 5 # 并发限流 - extra_params: Dict[str, Any] = {} - -class RedBearModelFactory: - """模型工厂类""" - - @classmethod - def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: - """根据提供商获取模型参数""" - provider = config.provider.lower() - - # 打印供应商信息用于调试 - from app.core.logging_config import get_business_logger - logger = get_business_logger() - logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}") - - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: - # 使用 httpx.Timeout 对象来设置详细的超时配置 - # 这样可以分别控制连接超时和读取超时 - import httpx - timeout_config = httpx.Timeout( - timeout=config.timeout, # 总超时时间 - connect=60.0, # 连接超时:60秒(足够建立 TCP 连接) - read=config.timeout, # 读取超时:使用配置的超时时间 - write=60.0, # 写入超时:60秒 - pool=10.0, # 连接池超时:10秒 - ) - return { - "model": config.model_name, - "base_url": config.base_url, - "api_key": config.api_key, - "timeout": timeout_config, - "max_retries": config.max_retries, - **config.extra_params - } - elif provider == ModelProvider.DASHSCOPE: - # DashScope (通义千问) 使用自己的参数格式 - # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 - # 只支持: model, dashscope_api_key, max_retries, client - return { - "model": config.model_name, - "dashscope_api_key": config.api_key, - "max_retries": config.max_retries, - **config.extra_params - } - elif provider == ModelProvider.BEDROCK: - # Bedrock 使用 AWS 凭证 - # api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id - # region 从 base_url 或 extra_params 获取 - params = { - "model_id": config.model_name, - **config.extra_params - } - - # 解析 API key (格式: access_key_id:secret_access_key) - if config.api_key and ":" in config.api_key: - access_key_id, secret_access_key = config.api_key.split(":", 1) - params["aws_access_key_id"] = access_key_id - params["aws_secret_access_key"] = secret_access_key - elif config.api_key: - params["aws_access_key_id"] = config.api_key - - # 设置 region - if config.base_url: - params["region_name"] = config.base_url - elif "region_name" not in params: - params["region_name"] = "us-east-1" # 默认区域 - - return params - else: - raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - - @classmethod - def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: - """根据提供商获取模型参数""" - provider = config.provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: - return { - "model": config.model_name, - # "base_url": config.base_url, - "jina_api_key": config.api_key, - **config.extra_params - } - else: - raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - -def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]: - """根据模型提供商获取对应的模型类""" - provider = config.provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : - if type == ModelType.LLM: - from langchain_openai import OpenAI - return OpenAI - elif type == ModelType.CHAT: - from langchain_openai import ChatOpenAI - return ChatOpenAI - elif provider == ModelProvider.DASHSCOPE: - from langchain_community.chat_models import ChatTongyi - return ChatTongyi - elif provider == ModelProvider.OLLAMA: - from langchain_ollama import OllamaLLM - return OllamaLLM - elif provider == ModelProvider.BEDROCK: - from langchain_aws import ChatBedrock, ChatBedrockConverse - - return ChatBedrock - else: - raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - -def get_provider_embedding_class(provider: str) -> type[Embeddings]: - """根据模型提供商获取对应的模型类""" - provider = provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : - from langchain_openai import OpenAIEmbeddings - return OpenAIEmbeddings - elif provider == ModelProvider.DASHSCOPE: - from langchain_community.embeddings import DashScopeEmbeddings - return DashScopeEmbeddings - elif provider == ModelProvider.OLLAMA: - from langchain_ollama import OllamaEmbeddings - return OllamaEmbeddings - elif provider == ModelProvider.BEDROCK: - from langchain_aws import BedrockEmbeddings - return BedrockEmbeddings - else: - raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - -def get_provider_rerank_class(provider: str): - """根据模型提供商获取对应的模型类""" - provider = provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : - from langchain_community.document_compressors import JinaRerank - return JinaRerank - # elif provider == ModelProvider.OLLAMA: - # from langchain_ollama import OllamaEmbeddings - # return OllamaEmbeddings - else: - raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) \ No newline at end of file diff --git a/app/core/models/embedding.py b/app/core/models/embedding.py deleted file mode 100644 index 16af2567..00000000 --- a/app/core/models/embedding.py +++ /dev/null @@ -1,23 +0,0 @@ - -from typing import Any, Dict, List, Optional, TypeVar, Callable -from langchain_core.embeddings import Embeddings - -from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory - -class RedBearEmbeddings(Embeddings): - """Embedding → 完全符合 LangChain Embeddings""" - def __init__(self, config: RedBearModelConfig): - self._model = self._create_model(config) - self._config = config - - def _create_model(self, config: RedBearModelConfig) -> Embeddings: - """根据配置创建模型""" - embedding_class = get_provider_embedding_class(config.provider) - model_params = RedBearModelFactory.get_model_params(config) - return embedding_class(**model_params) - - def embed_documents(self, texts: list[str]) -> list[list[float]]: - return self._model.embed_documents(texts) - - def embed_query(self, text: str) -> List[float]: - return self._model.embed_query(text) diff --git a/app/core/models/factory.py b/app/core/models/factory.py deleted file mode 100644 index 7cd858bf..00000000 --- a/app/core/models/factory.py +++ /dev/null @@ -1,16 +0,0 @@ -# from typing import Optional -# from app.core.model_client import RedBearEmbeddings, RedBearLLM, RedBearRerank, ModelConfig - - -# class RedBearModelFactory: -# @staticmethod -# def llm(model: str, api_key: str, base_url: Optional[str] = None) -> RedBearLLM: -# return RedBearLLM(ModelConfig(model_name=model, api_key=api_key, base_url=base_url)) - -# @staticmethod -# def embeddings(model: str, api_key: str, base_url: Optional[str] = None) -> RedBearEmbeddings: -# return RedBearEmbeddings(ModelConfig(model_name=model, api_key=api_key, base_url=base_url)) - -# @staticmethod -# def reranker(model: str, api_key: str, base_url: Optional[str] = None) -> RedBearRerank: -# return RedBearRerank(ModelConfig(model_name=model, api_key=api_key, base_url=base_url)) diff --git a/app/core/models/llm.py b/app/core/models/llm.py deleted file mode 100644 index 5808d31a..00000000 --- a/app/core/models/llm.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import annotations -from typing import Any, Dict, List, Optional -from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun -from langchain_core.language_models import BaseLLM -from langchain_core.outputs import LLMResult - -from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class -from app.models.models_model import ModelType - - -class RedBearLLM(BaseLLM): - """ - RedBear LLM 模型包装器 - 完全动态代理实现 - - 这个包装器自动将所有方法调用委托给内部模型, - 同时提供优雅的回退机制和错误处理。 - """ - - def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM): - self._model = self._create_model(config, type) - self._config = config - - @property - def _llm_type(self) -> str: - """返回LLM类型标识符""" - return self._model._llm_type - - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any - ) -> LLMResult: - """同步生成文本""" - return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs) - - async def _agenerate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any - ) -> LLMResult: - """异步生成文本""" - return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs) - - # 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为 - def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any: - """直接调用底层模型以支持 ChatPrompt 和消息列表。""" - try: - return self._model.invoke(input, config=config, **kwargs) - except AttributeError as e: - # 只在属性错误时回退(说明底层模型不支持该方法) - if 'invoke' in str(e): - return super().invoke(input, config=config, **kwargs) - # 其他 AttributeError 直接抛出 - raise - except Exception: - # 其他所有异常(包括 ValidationException)直接抛出,不回退 - raise - - async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any: - """异步直接调用底层模型以支持 ChatPrompt 和消息列表。""" - try: - return await self._model.ainvoke(input, config=config, **kwargs) - except AttributeError as e: - # 只在属性错误时回退(说明底层模型不支持该方法) - if 'ainvoke' in str(e): - return await super().ainvoke(input, config=config, **kwargs) - # 其他 AttributeError 直接抛出 - raise - except Exception: - # 其他所有异常(包括 ValidationException)直接抛出,不回退 - raise - - def __getattr__(self, name): - """ - 动态代理:将所有未定义的属性和方法调用委托给内部模型 - - 这是最优雅的包装器实现方式,完全避免了方法重复定义 - """ - # 处理特殊属性以避免递归 - if name in ('__isabstractmethod__', '__dict__', '__class__'): - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - # 检查内部模型是否有该属性(使用安全的方式避免递归) - try: - # 使用 object.__getattribute__ 来安全地检查内部模型的属性 - attr = object.__getattribute__(self._model, name) - - # 如果是方法,返回一个包装器来处理调用 - if callable(attr): - # 流式方法直接返回,不包装(保持生成器特性) - if name in ('_stream', '_astream', 'stream', 'astream'): - return attr - - # 非流式方法使用包装器处理异常 - def method_wrapper(*args, **kwargs): - return attr(*args, **kwargs) - - # 保持方法的元信息 - method_wrapper.__name__ = name - method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}") - return method_wrapper - - # 如果是普通属性,直接返回 - return attr - - except AttributeError: - # 内部模型没有该属性,尝试回退实现 - pass - - # 检查是否有回退方法(使用安全的方式避免递归) - fallback_name = f'_fallback_{name}' - try: - fallback_method = object.__getattribute__(self, fallback_name) - return fallback_method - except AttributeError: - # 没有回退方法,抛出适当的错误 - pass - - # 如果都没有,抛出适当的错误 - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM: - """创建内部模型实例""" - llm_class = get_provider_llm_class(config, type) - model_params = RedBearModelFactory.get_model_params(config) - return llm_class(**model_params) - - - \ No newline at end of file diff --git a/app/core/models/rerank copy.py b/app/core/models/rerank copy.py deleted file mode 100644 index 927c94fe..00000000 --- a/app/core/models/rerank copy.py +++ /dev/null @@ -1,35 +0,0 @@ - -# from typing import Any, Dict, List, Optional -# from langchain_core.runnables import RunnableSerializable - -# from app.core.models.base import RedBearModelConfig - -# class RedBearRerank(RunnableSerializable[str, List[float]]): -# """ Rerank → 作为 Runnable 插入任意 LCEL 链""" -# def __init__(self, config: RedBearModelConfig): -# super().__init__(self, config) - -# def invoke(self, input: Dict[str, Any], config: Optional[Dict] = None) -> List[float]: -# query, docs = input["query"], input["documents"] -# url = (self.config.base_url or "https://api.cohere.ai/v1") + "/rerank" -# body = { -# "query": query, -# "documents": docs, -# "model": self.config.model_name, -# "top_n": len(docs), -# } -# js = self._sync_post(url, body) -# scores = [0.0] * len(docs) -# for item in js["results"]: -# scores[item["index"]] = item["relevance_score"] -# return scores - -# async def ainvoke(self, input: Dict[str, Any], config: Optional[Dict] = None) -> List[float]: -# query, docs = input["query"], input["documents"] -# url = (self.config.base_url or "https://api.cohere.ai/v1") + "/rerank" -# body = {"query": query, "documents": docs, "model": self.config.model_name, "top_n": len(docs)} -# js = await self._async_post(url, body) -# scores = [0.0] * len(docs) -# for item in js["results"]: -# scores[item["index"]] = item["relevance_score"] -# return scores \ No newline at end of file diff --git a/app/core/models/rerank.py b/app/core/models/rerank.py deleted file mode 100644 index 64b3b566..00000000 --- a/app/core/models/rerank.py +++ /dev/null @@ -1,80 +0,0 @@ - -from typing import Any, Dict, List, Optional, Sequence, Type, Union -from copy import deepcopy -from urllib.parse import urlparse -from langchain_core.documents import BaseDocumentCompressor, Document -from langchain_core.runnables import RunnableSerializable -from langchain_core.callbacks import Callbacks -from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory -from app.models import ModelProvider - -class RedBearRerank(BaseDocumentCompressor): - """ Rerank → 作为 Runnable 插入任意 LCEL 链""" - def __init__(self, config: RedBearModelConfig): - self._model = self._create_model(config) - self._config = config - - def _create_model(self, config: RedBearModelConfig): - """创建内部模型实例""" - model_class = get_provider_rerank_class(config.provider) - model_params = RedBearModelFactory.get_rerank_model_params(config) - print(model_params) - return model_class(**model_params) - - def compress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, - ) -> Sequence[Document]: - """ - Compress documents using Jina's Rerank API. - - Args: - documents: A sequence of documents to compress. - query: The query to use for compressing the documents. - callbacks: Callbacks to run during the compression process. - - Returns: - A sequence of compressed documents. - """ - compressed = [] - for res in self.rerank(documents, query): - doc = documents[res["index"]] - doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) - doc_copy.metadata["relevance_score"] = res["relevance_score"] - compressed.append(doc_copy) - return compressed - - - def rerank( - self, - documents: Sequence[Union[str, Document, dict]], - query: str, - *, - top_n: Optional[int] = -1, - ) -> List[Dict[str, Any]]: - provider = self._config.provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : - import langchain_community.document_compressors.jina_rerank as jina_mod - # 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank - def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]: - if not base_url: - return None - url = base_url.rstrip('/') - if url.endswith("/v1/rerank"): - return url - if url.endswith("/v1"): - return url + "/rerank" - return url + "/v1/rerank" - - jina_base = _normalize_jina_base(self._config.base_url) - if jina_base: - # 设置完整的 rerank 端点,例如 http://host:port/v1/rerank - jina_mod.JINA_API_URL = jina_base - from langchain_community.document_compressors import JinaRerank - model_instance : JinaRerank = self._model - return model_instance.rerank(documents = documents, query = query, top_n=top_n) - else: - raise ValueError(f"不支持的模型提供商: {provider}") - \ No newline at end of file diff --git a/app/core/permissions/__init__.py b/app/core/permissions/__init__.py deleted file mode 100644 index 5294e37c..00000000 --- a/app/core/permissions/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Permission management module. - -This module provides a unified permission service for managing access control -across the application. -""" - -from app.core.permissions.models import Action, ResourceType, Resource, Subject -from app.core.permissions.service import permission_service - -__all__ = [ - "Action", - "ResourceType", - "Resource", - "Subject", - "permission_service", -] diff --git a/app/core/permissions/models.py b/app/core/permissions/models.py deleted file mode 100644 index f26a49e3..00000000 --- a/app/core/permissions/models.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Permission models for access control. - -Defines the core models used in the permission system: -- Action: Types of operations that can be performed -- ResourceType: Types of resources in the system -- Resource: Represents a resource with ownership and tenant information -- Subject: Represents a user/actor performing an action -""" - -from enum import Enum -from typing import Set, Optional -from dataclasses import dataclass, field -from uuid import UUID - - -class Action(Enum): - """Operation types that can be performed on resources.""" - CREATE = "create" - READ = "read" - UPDATE = "update" - DELETE = "delete" - SHARE = "share" - MANAGE = "manage" - ACTIVATE = "activate" - DEACTIVATE = "deactivate" - - -class ResourceType(Enum): - """Types of resources in the system.""" - FILE = "file" - WORKSPACE = "workspace" - KNOWLEDGE = "knowledge" - APP = "app" - USER = "user" - DOCUMENT = "document" - MODEL = "model" - CHUNK = "chunk" - - -@dataclass -class Resource: - """ - Represents a resource in the system. - - Attributes: - type: The type of resource - id: Unique identifier of the resource - owner_id: ID of the user who owns the resource - tenant_id: ID of the tenant the resource belongs to - is_public: Whether the resource is publicly accessible within the tenant - metadata: Additional resource-specific metadata - """ - type: ResourceType - id: UUID - owner_id: UUID - tenant_id: UUID - is_public: bool = False - metadata: dict = field(default_factory=dict) - - @classmethod - def from_file(cls, file_obj) -> "Resource": - """Create a Resource from a GenericFile model instance.""" - return cls( - type=ResourceType.FILE, - id=file_obj.id, - owner_id=file_obj.created_by, - tenant_id=file_obj.tenant_id, - is_public=getattr(file_obj, 'is_public', False), - metadata={ - "file_name": file_obj.file_name, - "context": file_obj.context, - } - ) - - @classmethod - def from_workspace(cls, workspace_obj) -> "Resource": - """Create a Resource from a Workspace model instance.""" - return cls( - type=ResourceType.WORKSPACE, - id=workspace_obj.id, - owner_id=workspace_obj.tenant_id, - tenant_id=workspace_obj.tenant_id, - is_public=False, - metadata={ - "name": workspace_obj.name, - } - ) - - @classmethod - def from_user(cls, user_obj) -> "Resource": - """Create a Resource from a User model instance.""" - return cls( - type=ResourceType.USER, - id=user_obj.id, - owner_id=user_obj.id, # User owns themselves - tenant_id=user_obj.tenant_id, - is_public=False, - metadata={ - "username": user_obj.username, - "is_superuser": user_obj.is_superuser, - } - ) - - -@dataclass -class Subject: - """ - Represents a user/actor performing an action. - - Attributes: - id: User ID - tenant_id: Tenant ID the user belongs to - is_superuser: Whether the user is a superuser - roles: Set of role names the user has - workspace_memberships: Set of workspace IDs the user is a member of - """ - id: UUID - tenant_id: UUID - is_superuser: bool = False - roles: Set[str] = field(default_factory=set) - workspace_memberships: Set[UUID] = field(default_factory=set) - - @classmethod - def from_user(cls, user_obj, workspace_memberships: Optional[Set[UUID]] = None) -> "Subject": - """Create a Subject from a User model instance.""" - return cls( - id=user_obj.id, - tenant_id=user_obj.tenant_id, - is_superuser=user_obj.is_superuser, - roles=set(getattr(user_obj, 'roles', [])), - workspace_memberships=workspace_memberships or set() - ) diff --git a/app/core/permissions/policies.py b/app/core/permissions/policies.py deleted file mode 100644 index 10be9b5e..00000000 --- a/app/core/permissions/policies.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -Permission policies for access control. - -Defines various policy classes that implement different permission rules: -- SuperuserPolicy: Superusers can perform any action -- OwnerPolicy: Resource owners can perform any action on their resources -- TenantPolicy: Users in the same tenant can access public resources -- RoleBasedPolicy: Permission based on user roles -- WorkspaceMemberPolicy: Workspace members can access workspace resources -""" - -from abc import ABC, abstractmethod -from typing import Set -from app.core.permissions.models import Subject, Resource, Action, ResourceType - - -class PermissionPolicy(ABC): - """Base class for permission policies.""" - - @abstractmethod - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - """ - Determine if a subject can perform an action on a resource. - - Args: - subject: The user/actor attempting the action - action: The action being attempted - resource: The resource being acted upon - - Returns: - True if the action is allowed, False otherwise - """ - pass - - -class SuperuserPolicy(PermissionPolicy): - """Superusers can perform any action on any resource.""" - - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - return subject.is_superuser - - -class OwnerPolicy(PermissionPolicy): - """Resource owners can perform any action on their own resources.""" - - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - return subject.id == resource.owner_id - - -class TenantPolicy(PermissionPolicy): - """ - Users in the same tenant can access public resources. - - Args: - allowed_actions: Set of actions allowed on public resources (default: READ only) - """ - - def __init__(self, allowed_actions: Set[Action] = None): - self.allowed_actions = allowed_actions or {Action.READ} - - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - return ( - subject.tenant_id == resource.tenant_id and - resource.is_public and - action in self.allowed_actions - ) - - -class RoleBasedPolicy(PermissionPolicy): - """ - Permission based on user roles. - - Args: - required_roles: Set of roles that grant permission - allowed_actions: Set of actions these roles can perform - """ - - def __init__(self, required_roles: Set[str], allowed_actions: Set[Action]): - self.required_roles = required_roles - self.allowed_actions = allowed_actions - - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - has_role = bool(subject.roles & self.required_roles) - return has_role and action in self.allowed_actions - - -class WorkspaceMemberPolicy(PermissionPolicy): - """ - Workspace members can access workspace resources. - - Args: - allowed_actions: Set of actions workspace members can perform - """ - - def __init__(self, allowed_actions: Set[Action] = None): - self.allowed_actions = allowed_actions or {Action.READ, Action.UPDATE} - - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - if resource.type != ResourceType.WORKSPACE: - return False - - return ( - resource.id in subject.workspace_memberships and - action in self.allowed_actions - ) - - -class SameTenantSuperuserPolicy(PermissionPolicy): - """ - Superusers in the same tenant can perform specific actions. - - This is useful for tenant-scoped admin operations where even superusers - should be limited to their own tenant. - - Args: - allowed_actions: Set of actions allowed (default: all actions) - """ - - def __init__(self, allowed_actions: Set[Action] = None): - self.allowed_actions = allowed_actions or set(Action) - - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - return ( - subject.is_superuser and - subject.tenant_id == resource.tenant_id and - action in self.allowed_actions - ) - - -class SelfAccessPolicy(PermissionPolicy): - """ - Users can access their own user resource. - - This is specifically for user resources where users should be able - to read/update their own profile. - - Args: - allowed_actions: Set of actions users can perform on themselves - """ - - def __init__(self, allowed_actions: Set[Action] = None): - self.allowed_actions = allowed_actions or {Action.READ, Action.UPDATE} - - def can_perform(self, subject: Subject, action: Action, resource: Resource) -> bool: - if resource.type != ResourceType.USER: - return False - - return ( - subject.id == resource.id and - action in self.allowed_actions - ) diff --git a/app/core/permissions/service.py b/app/core/permissions/service.py deleted file mode 100644 index 7e032866..00000000 --- a/app/core/permissions/service.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -Unified permission service for centralized access control. - -This service provides a single point for all permission checks in the application, -replacing scattered inline permission logic. -""" - -from typing import List, Optional -from app.core.permissions.models import Subject, Resource, Action -from app.core.permissions.policies import ( - PermissionPolicy, - SuperuserPolicy, - OwnerPolicy, - TenantPolicy, - SelfAccessPolicy, -) -from app.core.exceptions import PermissionDeniedException -from app.core.logging_config import get_logger - -logger = get_logger(__name__) - - -class PermissionService: - """ - Centralized permission service. - - Uses a chain of permission policies to determine if an action is allowed. - Any policy in the chain can grant permission (OR logic). - """ - - def __init__(self): - # Default policy chain - order matters for performance - # Most common/permissive policies first - self.policies: List[PermissionPolicy] = [ - SuperuserPolicy(), # Check superuser first (most common bypass) - OwnerPolicy(), # Then check ownership - SelfAccessPolicy(), # Then self-access for user resources - TenantPolicy(), # Finally tenant-level access - ] - - def add_policy(self, policy: PermissionPolicy, position: Optional[int] = None): - """ - Add a permission policy to the chain. - - Args: - policy: The policy to add - position: Optional position in the chain (default: append to end) - """ - if position is not None: - self.policies.insert(position, policy) - else: - self.policies.append(policy) - - def remove_policy(self, policy_class: type): - """ - Remove all policies of a specific class from the chain. - - Args: - policy_class: The class of policies to remove - """ - self.policies = [p for p in self.policies if not isinstance(p, policy_class)] - - def can_perform( - self, - subject: Subject, - action: Action, - resource: Resource - ) -> bool: - """ - Check if a subject can perform an action on a resource. - - Args: - subject: The user/actor attempting the action - action: The action being attempted - resource: The resource being acted upon - - Returns: - True if any policy grants permission, False otherwise - """ - # Policy chain: any policy can grant permission (OR logic) - for policy in self.policies: - try: - if policy.can_perform(subject, action, resource): - logger.debug( - f"permission_granted: policy={policy.__class__.__name__}, " - f"subject_id={subject.id}, action={action.value}, " - f"resource_type={resource.type.value}, resource_id={resource.id}" - ) - return True - except Exception as e: - # Log policy errors but continue checking other policies - logger.error( - f"permission_policy_error: policy={policy.__class__.__name__}, " - f"error={str(e)}, subject_id={subject.id}, action={action.value}, " - f"resource_type={resource.type.value}" - ) - - logger.warning( - f"permission_denied: subject_id={subject.id}, action={action.value}, " - f"resource_type={resource.type.value}, resource_id={resource.id}, " - f"subject_tenant={subject.tenant_id}, resource_tenant={resource.tenant_id}, " - f"is_superuser={subject.is_superuser}" - ) - return False - - def require_permission( - self, - subject: Subject, - action: Action, - resource: Resource, - error_message: Optional[str] = None - ): - """ - Require permission, raising an exception if not granted. - - Args: - subject: The user/actor attempting the action - action: The action being attempted - resource: The resource being acted upon - error_message: Custom error message (optional) - - Raises: - PermissionDeniedException: If permission is not granted - """ - if not self.can_perform(subject, action, resource): - message = error_message or ( - f"无权对 {resource.type.value} 执行 {action.value} 操作" - ) - raise PermissionDeniedException(message) - - def check_superuser(self, subject: Subject, error_message: Optional[str] = None): - """ - Require that the subject is a superuser. - - Args: - subject: The user/actor to check - error_message: Custom error message (optional) - - Raises: - PermissionDeniedException: If subject is not a superuser - """ - if not subject.is_superuser: - message = error_message or "需要超级管理员权限" - logger.warning( - f"superuser_required: subject_id={subject.id}, is_superuser={subject.is_superuser}" - ) - raise PermissionDeniedException(message) - - def check_same_tenant( - self, - subject: Subject, - resource: Resource, - error_message: Optional[str] = None - ): - """ - Require that the subject and resource are in the same tenant. - - Args: - subject: The user/actor to check - resource: The resource to check - error_message: Custom error message (optional) - - Raises: - PermissionDeniedException: If not in the same tenant - """ - if subject.tenant_id != resource.tenant_id: - message = error_message or "无权访问其他租户的资源" - logger.warning( - f"tenant_mismatch: subject_id={subject.id}, " - f"subject_tenant={subject.tenant_id}, resource_tenant={resource.tenant_id}" - ) - raise PermissionDeniedException(message) - - -# Global permission service instance -permission_service = PermissionService() diff --git a/app/core/rag/__init__.py b/app/core/rag/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/app/__init__.py b/app/core/rag/app/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/app/audio.py b/app/core/rag/app/audio.py deleted file mode 100644 index 1bddc048..00000000 --- a/app/core/rag/app/audio.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import re -import tempfile - -from app.core.rag.nlp import rag_tokenizer, tokenize - - -def chunk(filename, binary, lang, callback=None, seq2txt_mdl=None, **kwargs): - doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - - # is it English - eng = lang.lower() == "english" # is_english(sections) - try: - _, ext = os.path.splitext(filename) - if not ext: - raise RuntimeError("No extension detected.") - - if ext not in [".da", ".wave", ".wav", ".mp3", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".ape"]: - raise RuntimeError(f"Extension {ext} is not supported yet.") - - tmp_path = "" - with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmpf: - tmpf.write(binary) - tmpf.flush() - tmp_path = os.path.abspath(tmpf.name) - - callback(0.1, "USE Sequence2Txt LLM to transcription the audio") - ans = seq2txt_mdl.transcription(tmp_path) - callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32]) - - tokenize(doc, ans, eng) - return [doc] - except Exception as e: - callback(prog=-1, msg=str(e)) - finally: - if tmp_path and os.path.exists(tmp_path): - try: - os.unlink(tmp_path) - except Exception: - pass - return [] diff --git a/app/core/rag/app/book.py b/app/core/rag/app/book.py deleted file mode 100644 index baf5af5b..00000000 --- a/app/core/rag/app/book.py +++ /dev/null @@ -1,170 +0,0 @@ -import logging -import re -from io import BytesIO - -from app.core.rag.deepdoc.parser.utils import get_text -from . import naive -from .naive import by_plaintext, PARSERS -from app.core.rag.nlp import bullets_category, is_english,remove_contents_table, \ - hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \ - tokenize_chunks -from app.core.rag.nlp import rag_tokenizer -from app.core.rag.deepdoc.parser import PdfParser, HtmlParser -from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper -from PIL import Image - - -class Pdf(PdfParser): - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None): - from timeit import default_timer as timer - start = timer() - callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback) - callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) - - start = timer() - self._layouts_rec(zoomin) - callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start)) - logging.debug("layouts: {}".format(timer() - start)) - - start = timer() - self._table_transformer_job(zoomin) - callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._text_merge() - tbls = self._extract_table_figure(True, zoomin, True, True) - self._naive_vertical_merge() - self._filter_forpages() - self._merge_with_same_bullet() - callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start)) - - return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) - for b in self.boxes], tbls - - -def chunk(filename, binary=None, from_page=0, to_page=100000, - lang="Chinese", callback=None, **kwargs): - """ - Supported file formats are docx, pdf, txt. - Since a book is long and not all the parts are useful, if it's a PDF, - please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time. - """ - parser_config = kwargs.get( - "parser_config", { - "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - pdf_parser = None - sections, tbls = [], [] - if re.search(r"\.docx$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - doc_parser = naive.Docx() - # TODO: table of contents need to be removed - sections, tbls = doc_parser( - filename, binary=binary, from_page=from_page, to_page=to_page) - remove_contents_table(sections, eng=is_english( - random_choices([t for t, _ in sections], k=200))) - tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs) - # tbls = [((None, lns), None) for lns in tbls] - sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)] - callback(0.8, "Finish parsing.") - - elif re.search(r"\.pdf$", filename, re.IGNORECASE): - layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") - - if isinstance(layout_recognizer, bool): - layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" - - name = layout_recognizer.strip().lower() - parser = PARSERS.get(name, by_plaintext) - callback(0.1, "Start to parse.") - - sections, tables, pdf_parser = parser( - filename = filename, - binary = binary, - from_page = from_page, - to_page = to_page, - lang = lang, - callback = callback, - pdf_cls = Pdf, - **kwargs - ) - - if not sections and not tables: - return [] - - if name in ["tcadp", "docling", "mineru"]: - parser_config["chunk_token_num"] = 0 - - callback(0.8, "Finish parsing.") - elif re.search(r"\.txt$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - txt = get_text(filename, binary) - sections = txt.split("\n") - sections = [(line, "") for line in sections if line] - remove_contents_table(sections, eng=is_english( - random_choices([t for t, _ in sections], k=200))) - callback(0.8, "Finish parsing.") - - elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - sections = HtmlParser()(filename, binary) - sections = [(line, "") for line in sections if line] - remove_contents_table(sections, eng=is_english( - random_choices([t for t, _ in sections], k=200))) - callback(0.8, "Finish parsing.") - - elif re.search(r"\.doc$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - binary = BytesIO(binary) - doc_parsed = parser.from_buffer(binary) - sections = doc_parsed['content'].split('\n') - sections = [(line, "") for line in sections if line] - remove_contents_table(sections, eng=is_english( - random_choices([t for t, _ in sections], k=200))) - callback(0.8, "Finish parsing.") - - else: - raise NotImplementedError( - "file type not supported yet(doc, docx, pdf, txt supported)") - - make_colon_as_title(sections) - bull = bullets_category( - [t for t in random_choices([t for t, _ in sections], k=100)]) - if bull >= 0: - chunks = ["\n".join(ck) - for ck in hierarchical_merge(bull, sections, 5)] - else: - sections = [s.split("@") for s, _ in sections] - sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ] - chunks = naive_merge( - sections, kwargs.get( - "chunk_token_num", 256), kwargs.get( - "delimer", "\n。;!?")) - - # is it English - # is_english(random_choices([t for t, _ in sections], k=218)) - eng = lang.lower() == "english" - - res = tokenize_table(tbls, doc, eng) - res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) - - return res - - -if __name__ == "__main__": - import sys - - def dummy(prog=None, msg=""): - pass - chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy) diff --git a/app/core/rag/app/laws.py b/app/core/rag/app/laws.py deleted file mode 100644 index 60f3c5b0..00000000 --- a/app/core/rag/app/laws.py +++ /dev/null @@ -1,219 +0,0 @@ -import logging -import re -from io import BytesIO -from docx import Document - -from app.core.rag.common.constants import ParserType -from app.core.rag.deepdoc.parser.utils import get_text -from app.core.rag.nlp import bullets_category, remove_contents_table, \ - make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge -from app.core.rag.nlp import rag_tokenizer, Node -from app.core.rag.deepdoc.parser import PdfParser, DocxParser, HtmlParser -from app.core.rag.app.naive import by_plaintext, PARSERS - - - - -class Docx(DocxParser): - def __init__(self): - pass - - def __clean(self, line): - line = re.sub(r"\u3000", " ", line).strip() - return line - - def old_call(self, filename, binary=None, from_page=0, to_page=100000): - self.doc = Document( - filename) if not binary else Document(BytesIO(binary)) - pn = 0 - lines = [] - for p in self.doc.paragraphs: - if pn > to_page: - break - if from_page <= pn < to_page and p.text.strip(): - lines.append(self.__clean(p.text)) - for run in p.runs: - if 'lastRenderedPageBreak' in run._element.xml: - pn += 1 - continue - if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: - pn += 1 - return [line for line in lines if line] - - def __call__(self, filename, binary=None, from_page=0, to_page=100000): - self.doc = Document( - filename) if not binary else Document(BytesIO(binary)) - pn = 0 - lines = [] - level_set = set() - bull = bullets_category([p.text for p in self.doc.paragraphs]) - for p in self.doc.paragraphs: - if pn > to_page: - break - question_level, p_text = docx_question_level(p, bull) - if not p_text.strip("\n"): - continue - lines.append((question_level, p_text)) - level_set.add(question_level) - for run in p.runs: - if 'lastRenderedPageBreak' in run._element.xml: - pn += 1 - continue - if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: - pn += 1 - - sorted_levels = sorted(level_set) - - h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1 - h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level - - root = Node(level=0, depth=h2_level, texts=[]) - root.build_tree(lines) - - return [element for element in root.get_tree() if element] - - - def __str__(self) -> str: - return f''' - question:{self.question}, - answer:{self.answer}, - level:{self.level}, - childs:{self.childs} - ''' - - -class Pdf(PdfParser): - def __init__(self): - self.model_speciess = ParserType.LAWS.value - super().__init__() - - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None): - from timeit import default_timer as timer - start = timer() - callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) - callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) - - start = timer() - self._layouts_rec(zoomin) - callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start)) - logging.debug("layouts:".format( - )) - self._naive_vertical_merge() - - callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start)) - - return [(b["text"], self._line_tag(b, zoomin)) - for b in self.boxes], None - - -def chunk(filename, binary=None, from_page=0, to_page=100000, - lang="Chinese", callback=None, **kwargs): - """ - Supported file formats are docx, pdf, txt. - """ - parser_config = kwargs.get( - "parser_config", { - "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - pdf_parser = None - sections = [] - # is it English - eng = lang.lower() == "english" # is_english(sections) - - if re.search(r"\.docx$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - chunks = Docx()(filename, binary) - callback(0.7, "Finish parsing.") - return tokenize_chunks(chunks, doc, eng, None) - - elif re.search(r"\.pdf$", filename, re.IGNORECASE): - layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") - - if isinstance(layout_recognizer, bool): - layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" - - name = layout_recognizer.strip().lower() - parser = PARSERS.get(name, by_plaintext) - callback(0.1, "Start to parse.") - - raw_sections, tables, pdf_parser = parser( - filename = filename, - binary = binary, - from_page = from_page, - to_page = to_page, - lang = lang, - callback = callback, - pdf_cls = Pdf, - **kwargs - ) - - if not raw_sections and not tables: - return [] - - if name in ["tcadp", "docling", "mineru"]: - parser_config["chunk_token_num"] = 0 - - for txt, poss in raw_sections: - sections.append(txt + poss) - - callback(0.8, "Finish parsing.") - elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - txt = get_text(filename, binary) - sections = txt.split("\n") - sections = [s for s in sections if s] - callback(0.8, "Finish parsing.") - - elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - sections = HtmlParser()(filename, binary) - sections = [s for s in sections if s] - callback(0.8, "Finish parsing.") - - elif re.search(r"\.doc$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - binary = BytesIO(binary) - doc_parsed = parser.from_buffer(binary) - sections = doc_parsed['content'].split('\n') - sections = [s for s in sections if s] - callback(0.8, "Finish parsing.") - - else: - raise NotImplementedError( - "file type not supported yet(doc, docx, pdf, txt supported)") - - - # Remove 'Contents' part - remove_contents_table(sections, eng) - - make_colon_as_title(sections) - bull = bullets_category(sections) - res = tree_merge(bull, sections, 2) - - - if not res: - callback(0.99, "No chunk parsed out.") - - return tokenize_chunks(res, doc, eng, pdf_parser) - - # chunks = hierarchical_merge(bull, sections, 5) - # return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser) - -if __name__ == "__main__": - import sys - - def dummy(prog=None, msg=""): - pass - chunk(sys.argv[1], callback=dummy) diff --git a/app/core/rag/app/mail.py b/app/core/rag/app/mail.py deleted file mode 100644 index c4318402..00000000 --- a/app/core/rag/app/mail.py +++ /dev/null @@ -1,114 +0,0 @@ -import logging -from email import policy -from email.parser import BytesParser -from .naive import chunk as naive_chunk -import re -from app.core.rag.nlp import rag_tokenizer, naive_merge, tokenize_chunks -from app.core.rag.deepdoc.parser import HtmlParser, TxtParser -from timeit import default_timer as timer -import io - - -def chunk( - filename, - binary=None, - from_page=0, - to_page=100000, - lang="Chinese", - callback=None, - **kwargs, -): - """ - Only eml is supported - """ - eng = lang.lower() == "english" # is_english(cks) - parser_config = kwargs.get( - "parser_config", - {"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}, - ) - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)), - } - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - main_res = [] - attachment_res = [] - - if binary: - msg = BytesParser(policy=policy.default).parse(io.BytesIO(binary)) - else: - msg = BytesParser(policy=policy.default).parse(open(filename, "rb")) - - text_txt, html_txt = [], [] - # get the email header info - for header, value in msg.items(): - text_txt.append(f"{header}: {value}") - - # get the email main info - def _add_content(msg, content_type): - def _decode_payload(payload, charset, target_list): - try: - target_list.append(payload.decode(charset)) - except (UnicodeDecodeError, LookupError): - for enc in ["utf-8", "gb2312", "gbk", "gb18030", "latin1"]: - try: - target_list.append(payload.decode(enc)) - break - except UnicodeDecodeError: - continue - else: - target_list.append(payload.decode("utf-8", errors="ignore")) - - if content_type == "text/plain": - payload = msg.get_payload(decode=True) - charset = msg.get_content_charset() or "utf-8" - _decode_payload(payload, charset, text_txt) - elif content_type == "text/html": - payload = msg.get_payload(decode=True) - charset = msg.get_content_charset() or "utf-8" - _decode_payload(payload, charset, html_txt) - elif "multipart" in content_type: - if msg.is_multipart(): - for part in msg.iter_parts(): - _add_content(part, part.get_content_type()) - - _add_content(msg, msg.get_content_type()) - - sections = TxtParser.parser_txt("\n".join(text_txt)) + [ - (line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line - ] - - st = timer() - chunks = naive_merge( - sections, - int(parser_config.get("chunk_token_num", 128)), - parser_config.get("delimiter", "\n!?。;!?"), - ) - - main_res.extend(tokenize_chunks(chunks, doc, eng, None)) - logging.debug("naive_merge({}): {}".format(filename, timer() - st)) - # get the attachment info - for part in msg.iter_attachments(): - content_disposition = part.get("Content-Disposition") - if content_disposition: - dispositions = content_disposition.strip().split(";") - if dispositions[0].lower() == "attachment": - filename = part.get_filename() - payload = part.get_payload(decode=True) - try: - attachment_res.extend( - naive_chunk(filename, payload, callback=callback, **kwargs) - ) - except Exception: - pass - - return main_res + attachment_res - - -if __name__ == "__main__": - import sys - - def dummy(prog=None, msg=""): - pass - - chunk(sys.argv[1], callback=dummy) diff --git a/app/core/rag/app/manual.py b/app/core/rag/app/manual.py deleted file mode 100644 index 050a550c..00000000 --- a/app/core/rag/app/manual.py +++ /dev/null @@ -1,299 +0,0 @@ -import logging -import copy -import re - -from app.core.rag.common.constants import ParserType -from io import BytesIO -from app.core.rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level -from app.core.rag.common.token_utils import num_tokens_from_string -from app.core.rag.deepdoc.parser import PdfParser, DocxParser -from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper -from docx import Document -from PIL import Image -from .naive import by_plaintext, PARSERS - -class Pdf(PdfParser): - def __init__(self): - self.model_speciess = ParserType.MANUAL.value - super().__init__() - - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None): - from timeit import default_timer as timer - start = timer() - callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) - callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) - logging.debug("OCR: {}".format(timer() - start)) - - start = timer() - self._layouts_rec(zoomin) - callback(0.65, "Layout analysis ({:.2f}s)".format(timer() - start)) - logging.debug("layouts: {}".format(timer() - start)) - - start = timer() - self._table_transformer_job(zoomin) - callback(0.67, "Table analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._text_merge() - tbls = self._extract_table_figure(True, zoomin, True, True) - self._concat_downward() - self._filter_forpages() - callback(0.68, "Text merged ({:.2f}s)".format(timer() - start)) - - # clean mess - for b in self.boxes: - b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) - - return [(b["text"], b.get("layoutno", ""), self.get_position(b, zoomin)) - for i, b in enumerate(self.boxes)], tbls - - -class Docx(DocxParser): - def __init__(self): - pass - - def get_picture(self, document, paragraph): - img = paragraph._element.xpath('.//pic:pic') - if not img: - return None - try: - img = img[0] - embed = img.xpath('.//a:blip/@r:embed')[0] - related_part = document.part.related_parts[embed] - image = related_part.image - if image is not None: - image = Image.open(BytesIO(image.blob)) - return image - elif related_part.blob is not None: - image = Image.open(BytesIO(related_part.blob)) - return image - else: - return None - except Exception: - return None - - def concat_img(self, img1, img2): - if img1 and not img2: - return img1 - if not img1 and img2: - return img2 - if not img1 and not img2: - return None - width1, height1 = img1.size - width2, height2 = img2.size - - new_width = max(width1, width2) - new_height = height1 + height2 - new_image = Image.new('RGB', (new_width, new_height)) - - new_image.paste(img1, (0, 0)) - new_image.paste(img2, (0, height1)) - - return new_image - - def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None): - self.doc = Document( - filename) if not binary else Document(BytesIO(binary)) - pn = 0 - last_answer, last_image = "", None - question_stack, level_stack = [], [] - ti_list = [] - for p in self.doc.paragraphs: - if pn > to_page: - break - question_level, p_text = 0, '' - if from_page <= pn < to_page and p.text.strip(): - question_level, p_text = docx_question_level(p) - if not question_level or question_level > 6: # not a question - last_answer = f'{last_answer}\n{p_text}' - current_image = self.get_picture(self.doc, p) - last_image = self.concat_img(last_image, current_image) - else: # is a question - if last_answer or last_image: - sum_question = '\n'.join(question_stack) - if sum_question: - ti_list.append((f'{sum_question}\n{last_answer}', last_image)) - last_answer, last_image = '', None - - i = question_level - while question_stack and i <= level_stack[-1]: - question_stack.pop() - level_stack.pop() - question_stack.append(p_text) - level_stack.append(question_level) - for run in p.runs: - if 'lastRenderedPageBreak' in run._element.xml: - pn += 1 - continue - if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: - pn += 1 - if last_answer: - sum_question = '\n'.join(question_stack) - if sum_question: - ti_list.append((f'{sum_question}\n{last_answer}', last_image)) - - tbls = [] - for tb in self.doc.tables: - html= "" - for r in tb.rows: - html += "" - i = 0 - while i < len(r.cells): - span = 1 - c = r.cells[i] - for j in range(i+1, len(r.cells)): - if c.text == r.cells[j].text: - span += 1 - i = j - else: - break - i += 1 - html += f"" if span == 1 else f"" - html += "" - html += "
{c.text}{c.text}
" - tbls.append(((None, html), "")) - return ti_list, tbls - - -def chunk(filename, binary=None, from_page=0, to_page=100000, - lang="Chinese", callback=None, **kwargs): - """ - Only pdf is supported. - """ - parser_config = kwargs.get( - "parser_config", { - "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) - pdf_parser = None - doc = { - "docnm_kwd": filename - } - doc["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"])) - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - # is it English - eng = lang.lower() == "english" # pdf_parser.is_english - if re.search(r"\.pdf$", filename, re.IGNORECASE): - layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") - - if isinstance(layout_recognizer, bool): - layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" - - name = layout_recognizer.strip().lower() - pdf_parser = PARSERS.get(name, by_plaintext) - callback(0.1, "Start to parse.") - - sections, tbls, pdf_parser = pdf_parser( - filename = filename, - binary = binary, - from_page = from_page, - to_page = to_page, - lang = lang, - callback = callback, - pdf_cls = Pdf, - **kwargs - ) - - if not sections and not tbls: - return [] - - if name in ["tcadp", "docling", "mineru"]: - parser_config["chunk_token_num"] = 0 - - callback(0.8, "Finish parsing.") - - if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03: - max_lvl = max([lvl for _, lvl in pdf_parser.outlines]) - most_level = max(0, max_lvl - 1) - levels = [] - for txt, _, _ in sections: - for t, lvl in pdf_parser.outlines: - tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)]) - tks_ = set([txt[i] + txt[i + 1] - for i in range(min(len(t), len(txt) - 1))]) - if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8: - levels.append(lvl) - break - else: - levels.append(max_lvl + 1) - - else: - bull = bullets_category([txt for txt, _, _ in sections]) - most_level, levels = title_frequency( - bull, [(txt, lvl) for txt, lvl, _ in sections]) - - assert len(sections) == len(levels) - sec_ids = [] - sid = 0 - for i, lvl in enumerate(levels): - if lvl <= most_level and i > 0 and lvl != levels[i - 1]: - sid += 1 - sec_ids.append(sid) - - sections = [(txt, sec_ids[i], poss) - for i, (txt, _, poss) in enumerate(sections)] - for (img, rows), poss in tbls: - if not rows: - continue - sections.append((rows if isinstance(rows, str) else rows[0], -1, - [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) - - def tag(pn, left, right, top, bottom): - if pn + left + right + top + bottom == 0: - return "" - return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ - .format(pn, left, right, top, bottom) - - chunks = [] - last_sid = -2 - tk_cnt = 0 - for txt, sec_id, poss in sorted(sections, key=lambda x: ( - x[-1][0][0], x[-1][0][3], x[-1][0][1])): - poss = "\t".join([tag(*pos) for pos in poss]) - if tk_cnt < 32 or (tk_cnt < 1024 and (sec_id == last_sid or sec_id == -1)): - if chunks: - chunks[-1] += "\n" + txt + poss - tk_cnt += num_tokens_from_string(txt) - continue - chunks.append(txt + poss) - tk_cnt = num_tokens_from_string(txt) - if sec_id > -1: - last_sid = sec_id - tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs) - res = tokenize_table(tbls, doc, eng) - res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) - return res - - elif re.search(r"\.docx?$", filename, re.IGNORECASE): - docx_parser = Docx() - ti_list, tbls = docx_parser(filename, binary, - from_page=0, to_page=10000, callback=callback) - tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs) - res = tokenize_table(tbls, doc, eng) - for text, image in ti_list: - d = copy.deepcopy(doc) - if image: - d['image'] = image - d["doc_type_kwd"] = "image" - tokenize(d, text, eng) - res.append(d) - return res - else: - raise NotImplementedError("file type not supported yet(pdf and docx supported)") - - -if __name__ == "__main__": - import sys - - - def dummy(prog=None, msg=""): - pass - - - chunk(sys.argv[1], callback=dummy) diff --git a/app/core/rag/app/naive.py b/app/core/rag/app/naive.py deleted file mode 100644 index 95aad2d2..00000000 --- a/app/core/rag/app/naive.py +++ /dev/null @@ -1,849 +0,0 @@ -import logging -import re -import os -from functools import reduce -from io import BytesIO -from timeit import default_timer as timer -from docx import Document -from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError -from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship -from docx.opc.oxml import parse_xml -from markdown import markdown -from PIL import Image -import copy - -from app.core.rag.llm.cv_model import AzureGptV4, QWenCV -from app.core.rag.common.file_utils import get_project_base_directory -from app.core.rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html -from app.core.rag.deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser -from app.core.rag.deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper -from app.core.rag.deepdoc.parser.pdf_parser import PlainParser, VisionParser -from app.core.rag.deepdoc.parser.mineru_parser import MinerUParser -from app.core.rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, tokenize, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table - -def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs): - callback = callback - binary = binary - pdf_parser = pdf_cls() if pdf_cls else Pdf() - sections, tables = pdf_parser( - filename if not binary else binary, - from_page=from_page, - to_page=to_page, - callback=callback - ) - - tables = vision_figure_parser_pdf_wrapper(tbls=tables, - callback=callback, - vision_model=vision_model, - **kwargs) - return sections, tables, pdf_parser - - -def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs): - mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru") - mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987") - pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api) - - if not pdf_parser.check_installation(): - callback(-1, "MinerU not found.") - return None, None, pdf_parser - - sections, tables = pdf_parser.parse_pdf( - filepath=filename, - binary=binary, - callback=callback, - output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""), - backend=os.environ.get("MINERU_BACKEND", "pipeline"), - delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))), - ) - return sections, tables, pdf_parser - - -def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs): - textln_app_id = os.environ.get("TEXTLN_APP_ID", "") - textln_secret_code = os.environ.get("TEXTLN_SECRET_CODE", "") - textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown") - pdf_parser = MinerUParser(mineru_path=textln_app_id, mineru_api=textln_api) - - if not pdf_parser.check_installation(): - callback(-1, "MinerU not found.") - return None, None, pdf_parser - - sections, tables = pdf_parser.parse_pdf( - filepath=filename, - binary=binary, - callback=callback, - output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""), - backend=os.environ.get("MINERU_BACKEND", "pipeline"), - delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))), - ) - return sections, tables, pdf_parser - - -def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=None, vision_model=None, **kwargs): - if kwargs.get("layout_recognizer", "") == "Plain Text": - pdf_parser = PlainParser() - else: - pdf_parser = VisionParser(vision_model=vision_model, **kwargs) - - sections, tables = pdf_parser( - filename if not binary else binary, - from_page=from_page, - to_page=to_page, - callback=callback - ) - return sections, tables, pdf_parser - - -PARSERS = { - "deepdoc": by_deepdoc, - "mineru": by_mineru, - "textln": by_textln, - "plaintext": by_plaintext, # default -} - - -class Docx(DocxParser): - def __init__(self): - pass - - def get_picture(self, document, paragraph): - imgs = paragraph._element.xpath('.//pic:pic') - if not imgs: - return None - res_img = None - for img in imgs: - embed = img.xpath('.//a:blip/@r:embed') - if not embed: - continue - embed = embed[0] - try: - related_part = document.part.related_parts[embed] - image_blob = related_part.image.blob - except UnrecognizedImageError: - logging.info("Unrecognized image format. Skipping image.") - continue - except UnexpectedEndOfFileError: - logging.info("EOF was unexpectedly encountered while reading an image stream. Skipping image.") - continue - except InvalidImageStreamError: - logging.info("The recognized image stream appears to be corrupted. Skipping image.") - continue - except UnicodeDecodeError: - logging.info("The recognized image stream appears to be corrupted. Skipping image.") - continue - except Exception: - logging.info("The recognized image stream appears to be corrupted. Skipping image.") - continue - try: - image = Image.open(BytesIO(image_blob)).convert('RGB') - if res_img is None: - res_img = image - else: - res_img = concat_img(res_img, image) - except Exception: - continue - - return res_img - - def __clean(self, line): - line = re.sub(r"\u3000", " ", line).strip() - return line - - def __get_nearest_title(self, table_index, filename): - """Get the hierarchical title structure before the table""" - import re - from docx.text.paragraph import Paragraph - - titles = [] - blocks = [] - - # Get document name from filename parameter - doc_name = re.sub(r"\.[a-zA-Z]+$", "", filename) - if not doc_name: - doc_name = "Untitled Document" - - # Collect all document blocks while maintaining document order - try: - # Iterate through all paragraphs and tables in document order - for i, block in enumerate(self.doc._element.body): - if block.tag.endswith('p'): # Paragraph - p = Paragraph(block, self.doc) - blocks.append(('p', i, p)) - elif block.tag.endswith('tbl'): # Table - blocks.append(('t', i, None)) # Table object will be retrieved later - except Exception as e: - logging.error(f"Error collecting blocks: {e}") - return "" - - # Find the target table position - target_table_pos = -1 - table_count = 0 - for i, (block_type, pos, _) in enumerate(blocks): - if block_type == 't': - if table_count == table_index: - target_table_pos = pos - break - table_count += 1 - - if target_table_pos == -1: - return "" # Target table not found - - # Find the nearest heading paragraph in reverse order - nearest_title = None - for i in range(len(blocks)-1, -1, -1): - block_type, pos, block = blocks[i] - if pos >= target_table_pos: # Skip blocks after the table - continue - - if block_type != 'p': - continue - - if block.style and block.style.name and re.search(r"Heading\s*(\d+)", block.style.name, re.I): - try: - level_match = re.search(r"(\d+)", block.style.name) - if level_match: - level = int(level_match.group(1)) - if level <= 7: # Support up to 7 heading levels - title_text = block.text.strip() - if title_text: # Avoid empty titles - nearest_title = (level, title_text) - break - except Exception as e: - logging.error(f"Error parsing heading level: {e}") - - if nearest_title: - # Add current title - titles.append(nearest_title) - current_level = nearest_title[0] - - # Find all parent headings, allowing cross-level search - while current_level > 1: - found = False - for i in range(len(blocks)-1, -1, -1): - block_type, pos, block = blocks[i] - if pos >= target_table_pos: # Skip blocks after the table - continue - - if block_type != 'p': - continue - - if block.style and re.search(r"Heading\s*(\d+)", block.style.name, re.I): - try: - level_match = re.search(r"(\d+)", block.style.name) - if level_match: - level = int(level_match.group(1)) - # Find any heading with a higher level - if level < current_level: - title_text = block.text.strip() - if title_text: # Avoid empty titles - titles.append((level, title_text)) - current_level = level - found = True - break - except Exception as e: - logging.error(f"Error parsing parent heading: {e}") - - if not found: # Break if no parent heading is found - break - - # Sort by level (ascending, from highest to lowest) - titles.sort(key=lambda x: x[0]) - # Organize titles (from highest to lowest) - hierarchy = [doc_name] + [t[1] for t in titles] - return " > ".join(hierarchy) - - return "" - - def __call__(self, filename, binary=None, from_page=0, to_page=100000): - self.doc = Document( - filename) if not binary else Document(BytesIO(binary)) - pn = 0 - lines = [] - last_image = None - for p in self.doc.paragraphs: - if pn > to_page: - break - if from_page <= pn < to_page: - if p.text.strip(): - if p.style and p.style.name == 'Caption': - former_image = None - if lines and lines[-1][1] and lines[-1][2] != 'Caption': - former_image = lines[-1][1].pop() - elif last_image: - former_image = last_image - last_image = None - lines.append((self.__clean(p.text), [former_image], p.style.name)) - else: - current_image = self.get_picture(self.doc, p) - image_list = [current_image] - if last_image: - image_list.insert(0, last_image) - last_image = None - lines.append((self.__clean(p.text), image_list, p.style.name if p.style else "")) - else: - if current_image := self.get_picture(self.doc, p): - if lines: - lines[-1][1].append(current_image) - else: - last_image = current_image - for run in p.runs: - if 'lastRenderedPageBreak' in run._element.xml: - pn += 1 - continue - if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: - pn += 1 - new_line = [(line[0], reduce(concat_img, line[1]) if line[1] else None) for line in lines] - - tbls = [] - for i, tb in enumerate(self.doc.tables): - title = self.__get_nearest_title(i, filename) - html = "" - if title: - html += f"" - for r in tb.rows: - html += "" - i = 0 - try: - while i < len(r.cells): - span = 1 - c = r.cells[i] - for j in range(i + 1, len(r.cells)): - if c.text == r.cells[j].text: - span += 1 - i = j - else: - break - i += 1 - html += f"" if span == 1 else f"" - except Exception as e: - logging.warning(f"Error parsing table, ignore: {e}") - html += "" - html += "
Table Location: {title}
{c.text}{c.text}
" - tbls.append(((None, html), "")) - return new_line, tbls - - def to_markdown(self, filename=None, binary=None, inline_images: bool = True): - """ - This function uses mammoth, licensed under the BSD 2-Clause License. - """ - - import base64 - import uuid - - import mammoth - from markdownify import markdownify - - docx_file = BytesIO(binary) if binary else open(filename, "rb") - - def _convert_image_to_base64(image): - try: - with image.open() as image_file: - image_bytes = image_file.read() - encoded = base64.b64encode(image_bytes).decode("utf-8") - base64_url = f"data:{image.content_type};base64,{encoded}" - - alt_name = "image" - alt_name = f"img_{uuid.uuid4().hex[:8]}" - - return {"src": base64_url, "alt": alt_name} - except Exception as e: - logging.warning(f"Failed to convert image to base64: {e}") - return {"src": "", "alt": "image"} - - try: - if inline_images: - result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64)) - else: - result = mammoth.convert_to_html(docx_file) - - html = result.value - - markdown_text = markdownify(html) - return markdown_text - - finally: - if not binary: - docx_file.close() - - -class Pdf(PdfParser): - def __init__(self): - super().__init__() - - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None, separate_tables_figures=False): - start = timer() - first_start = start - callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) - callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) - logging.info("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start)) - - start = timer() - self._layouts_rec(zoomin) - callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._table_transformer_job(zoomin) - callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._text_merge(zoomin=zoomin) - callback(0.67, "Text merged ({:.2f}s)".format(timer() - start)) - - if separate_tables_figures: - tbls, figures = self._extract_table_figure(True, zoomin, True, True, True) - self._concat_downward() - logging.info("layouts cost: {}s".format(timer() - first_start)) - return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls, figures - else: - tbls = self._extract_table_figure(True, zoomin, True, True) - self._naive_vertical_merge() - self._concat_downward() - self._final_reading_order_merge() - # self._filter_forpages() - logging.info("layouts cost: {}s".format(timer() - first_start)) - return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls - - -class Markdown(MarkdownParser): - def md_to_html(self, sections): - if not sections: - return [] - if isinstance(sections, type("")): - text = sections - elif isinstance(sections[0], type("")): - text = sections[0] - else: - return [] - - from bs4 import BeautifulSoup - html_content = markdown(text) - soup = BeautifulSoup(html_content, 'html.parser') - return soup - - def get_picture_urls(self, soup): - if soup: - return [img.get('src') for img in soup.find_all('img') if img.get('src')] - return [] - - def get_hyperlink_urls(self, soup): - if soup: - return set([a.get('href') for a in soup.find_all('a') if a.get('href')]) - return [] - - def get_pictures(self, text): - """Download and open all images from markdown text.""" - import requests - soup = self.md_to_html(text) - image_urls = self.get_picture_urls(soup) - images = [] - # Find all image URLs in text - for url in image_urls: - if not url: - continue - try: - # check if the url is a local file or a remote URL - if url.startswith(('http://', 'https://')): - # For remote URLs, download the image - response = requests.get(url, stream=True, timeout=30) - if response.status_code == 200 and response.headers['Content-Type'] and response.headers['Content-Type'].startswith('image/'): - img = Image.open(BytesIO(response.content)).convert('RGB') - images.append(img) - else: - # For local file paths, open the image directly - from pathlib import Path - local_path = Path(url) - if not local_path.exists(): - logging.warning(f"Local image file not found: {url}") - continue - img = Image.open(url).convert('RGB') - images.append(img) - except Exception as e: - logging.error(f"Failed to download/open image from {url}: {e}") - continue - - return images if images else None - - def __call__(self, filename, binary=None, separate_tables=True,delimiter=None): - if binary: - encoding = find_codec(binary) - txt = binary.decode(encoding, errors="ignore") - else: - with open(filename, "r") as f: - txt = f.read() - - remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables) - # To eliminate duplicate tables in chunking result, uncomment code below and set separate_tables to True in line 410. - # extractor = MarkdownElementExtractor(remainder) - extractor = MarkdownElementExtractor(txt) - element_sections = extractor.extract_elements(delimiter) - sections = [(element, "") for element in element_sections] - tbls = [] - for table in tables: - tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), "")) - return sections, tbls - -def load_from_xml_v2(baseURI, rels_item_xml): - """ - Return |_SerializedRelationships| instance loaded with the - relationships contained in *rels_item_xml*. Returns an empty - collection if *rels_item_xml* is |None|. - """ - srels = _SerializedRelationships() - if rels_item_xml is not None: - rels_elm = parse_xml(rels_item_xml) - for rel_elm in rels_elm.Relationship_lst: - if rel_elm.target_ref in ('../NULL', 'NULL'): - continue - srels._srels.append(_SerializedRelationship(baseURI, rel_elm)) - return srels - -def chunk(filename, binary=None, from_page=0, to_page=100000, - lang="Chinese", callback=None, vision_model=None, **kwargs): - """ - Supported file formats are docx, doc, pdf, excel, txt, markdown, html, json. - This method apply the naive ways to chunk files. - Successive text will be sliced into pieces using 'delimiter'. - Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. - """ - urls = set() - url_res = [] - - - is_english = lang.lower() == "english" # is_english(cks) - parser_config = kwargs.get( - "parser_config", { - "layout_recognize": "DeepDOC", "chunk_token_num": 512, "delimiter": "\n!?。;!?", "analyze_hyperlink": True}) - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - res = [] - pdf_parser = None - section_images = None - - is_root = kwargs.get("is_root", True) - embed_res = [] - if is_root: - # Only extract embedded files at the root call - embeds = [] - if binary is not None: - embeds = extract_embed_file(binary) - else: - raise Exception("Embedding extraction from file path is not supported.") - - # Recursively chunk each embedded file and collect results - for embed_filename, embed_bytes in embeds: - try: - sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs) or [] - embed_res.extend(sub_res) - except Exception as e: - if callback: - callback(0.05, f"Failed to chunk embed {embed_filename}: {e}") - continue - - if re.search(r"\.docx$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - if parser_config.get("analyze_hyperlink", False) and is_root: - urls = extract_links_from_docx(binary) - for index, url in enumerate(urls): - html_bytes, metadata = extract_html(url) - if not html_bytes: - continue - try: - sub_url_res = chunk(url, html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs) - except Exception as e: - logging.info(f"Failed to chunk url in registered file type {url}: {e}") - sub_url_res = chunk(f"{index}.html", html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs) - url_res.extend(sub_url_res) - - # fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246 - _SerializedRelationships.load_from_xml = load_from_xml_v2 - sections, tables = Docx()(filename, binary) - - tables=vision_figure_parser_docx_wrapper(sections=sections,tbls=tables,callback=callback, vision_model=vision_model, **kwargs) - - res = tokenize_table(tables, doc, is_english) - callback(0.8, "Finish parsing.") - - st = timer() - - chunks, images = naive_merge_docx( - sections, int(parser_config.get( - "chunk_token_num", 128)), parser_config.get( - "delimiter", "\n!?。;!?")) - - if kwargs.get("section_only", False): - chunks.extend(embed_res) - chunks.extend(url_res) - return chunks - - res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images)) - logging.info("naive_merge({}): {}".format(filename, timer() - st)) - res.extend(embed_res) - res.extend(url_res) - return res - - elif re.search(r"\.pdf$", filename, re.IGNORECASE): - layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") - if parser_config.get("analyze_hyperlink", False) and is_root: - urls = extract_links_from_pdf(binary) - - if isinstance(layout_recognizer, bool): - layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" - - name = layout_recognizer.strip().lower() - parser = PARSERS.get(name, by_plaintext) - callback(0.1, "Start to parse.") - - sections, tables, pdf_parser = parser( - filename=filename, - binary=binary, - from_page=from_page, - to_page=to_page, - lang=lang, - callback=callback, - vision_model=vision_model, - layout_recognizer=layout_recognizer, - **kwargs - ) - - if not sections and not tables: - return [] - - if name in ["mineru", "textln"]: - parser_config["chunk_token_num"] = 0 - - res = tokenize_table(tables, doc, is_english) - callback(0.8, "Finish parsing.") - - elif re.search(r"\.pptx?$", filename, re.IGNORECASE): - if not binary: - with open(filename, "rb") as f: - binary = f.read() - from app.core.rag.app.presentation import Ppt - ppt_parser = Ppt() - for pn, (txt, img) in enumerate(ppt_parser( - filename if not binary else binary, from_page, to_page, callback)): - d = copy.deepcopy(doc) - pn += from_page - d["image"] = img - d["doc_type_kwd"] = "image" - d["page_num_int"] = [pn + 1] - d["top_int"] = [0] - d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])] - tokenize(d, txt, is_english) - res.append(d) - return res - - elif re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", filename, re.IGNORECASE): - if not binary: - with open(filename, "rb") as f: - binary = f.read() - from app.core.rag.app.audio import chunk as parser - return parser(filename, binary, lang=lang, callback=callback, seq2txt_mdl=vision_model, **kwargs) - - elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", filename, re.IGNORECASE): - if not binary: - with open(filename, "rb") as f: - binary = f.read() - from app.core.rag.app.picture import chunk as parser - return parser(filename, binary, lang=lang, callback=callback, vision_model=vision_model, **kwargs) - - elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - if not binary: - with open(filename, "rb") as f: - binary = f.read() - excel_parser = ExcelParser() - if parser_config.get("html4excel"): - sections = [(_, "") for _ in excel_parser.html(binary, 12) if _] - else: - sections = [(_, "") for _ in excel_parser(binary) if _] - parser_config["chunk_token_num"] = 12800 - - elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - sections = TxtParser()(filename, binary, - parser_config.get("chunk_token_num", 128), - parser_config.get("delimiter", "\n!?;。;!?")) - callback(0.8, "Finish parsing.") - - elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - markdown_parser = Markdown(int(parser_config.get("chunk_token_num", 128))) - sections, tables = markdown_parser(filename, binary, separate_tables=False,delimiter=parser_config.get("delimiter", "\n!?;。;!?")) - - if vision_model: - # Process images for each section - section_images = [] - for idx, (section_text, _) in enumerate(sections): - images = markdown_parser.get_pictures(section_text) if section_text else None - - if images: - # If multiple images found, combine them using concat_img - combined_image = reduce(concat_img, images) if len(images) > 1 else images[0] - section_images.append(combined_image) - markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs) - boosted_figures = markdown_vision_parser(callback=callback) - sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1][0] for fig in boosted_figures]), sections[idx][1]) - else: - section_images.append(None) - - else: - logging.warning("No visual model detected. Skipping figure parsing enhancement.") - - if parser_config.get("hyperlink_urls", False) and is_root: - for idx, (section_text, _) in enumerate(sections): - soup = markdown_parser.md_to_html(section_text) - hyperlink_urls = markdown_parser.get_hyperlink_urls(soup) - urls.update(hyperlink_urls) - res = tokenize_table(tables, doc, is_english) - callback(0.8, "Finish parsing.") - - elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - chunk_token_num = int(parser_config.get("chunk_token_num", 128)) - sections = HtmlParser()(filename, binary, chunk_token_num) - sections = [(_, "") for _ in sections if _] - callback(0.8, "Finish parsing.") - - elif re.search(r"\.(json|jsonl|ldjson)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - chunk_token_num = int(parser_config.get("chunk_token_num", 128)) - sections = JsonParser(chunk_token_num)(filename) - sections = [(_, "") for _ in sections if _] - callback(0.8, "Finish parsing.") - - elif re.search(r"\.doc$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - - try: - import tika - os.environ['TIKA_SERVER_JAR'] = "/tmp/tika-server.jar" - os.environ['TIKA_SERVER_PORT'] = '9998' - # java11 Initialize Tika 3.1.0.jar service url:http://localhost:9998 view process:lsof -i :9998 - tika.initVM() - from tika import parser as tika_parser - except Exception as e: - callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.") - logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.") - return [] - - doc_parsed = tika_parser.from_file(filename) - if doc_parsed.get('content', None) is not None: - sections = doc_parsed['content'].split('\n') - sections = [(_, "") for _ in sections if _] - callback(0.8, "Finish parsing.") - else: - callback(0.8, f"tika.parser got empty content from {filename}.") - logging.warning(f"tika.parser got empty content from {filename}.") - return [] - else: - raise NotImplementedError( - "file type not supported yet(pdf, xlsx, doc, docx, txt supported)") - - st = timer() - if section_images: - # if all images are None, set section_images to None - if all(image is None for image in section_images): - section_images = None - - if section_images: - chunks, images = naive_merge_with_images(sections, section_images, - int(parser_config.get( - "chunk_token_num", 128)), parser_config.get( - "delimiter", "\n!?。;!?")) - if kwargs.get("section_only", False): - chunks.extend(embed_res) - return chunks - - res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images)) - else: - chunks = naive_merge( - sections, int(parser_config.get( - "chunk_token_num", 128)), parser_config.get( - "delimiter", "\n!?。;!?")) - if kwargs.get("section_only", False): - chunks.extend(embed_res) - return chunks - - res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser)) - - if urls and parser_config.get("analyze_hyperlink", False) and is_root: - for index, url in enumerate(urls): - html_bytes, metadata = extract_html(url) - if not html_bytes: - continue - try: - sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs) - except Exception as e: - logging.info(f"Failed to chunk url in registered file type {url}: {e}") - sub_url_res = chunk(f"{index}.html", html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs) - url_res.extend(sub_url_res) - - logging.info("naive_merge({}): {}".format(filename, timer() - st)) - - if embed_res: - res.extend(embed_res) - if url_res: - res.extend(url_res) - return res - - -if __name__ == "__main__": - # import sys - # chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) - - # Prepare to configure vision_model information - vision_model = QWenCV( - key="sk-8e9e40cd171749858ce2d3722ea75669", - model_name="qwen-vl-max", - lang="chinese", # 默认使用中文 - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" - ) - - def progress_callback(prog=None, msg=None): - print(f"prog: {prog} msg: {msg}\n") - - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/1.txt" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/2.md" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/3.md" # 带图url - file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/义务教育教科书·中国历史七年级上册 (2)_Compressed.md" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/4.doc" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/5.json" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/6.html" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/7.xlsx" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/8.pdf" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/9.pptx" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/11.mp4" - # file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/12.mp3" - res = chunk(filename=file_path, - from_page=0, - to_page=10, - callback=progress_callback, - vision_model=vision_model, - parser_config={ - "layout_recognize": "DeepDOC", - "chunk_token_num": 128, - "delimiter": "\n", - "analyze_hyperlink": True, - "auto_keywords": 0, - "auto_questions": 0, - "html4excel": "false" - }, - is_root=False) - for index, item in enumerate(res): - print(f"Index: {index}\n----") - print(item) - print("----") diff --git a/app/core/rag/app/one.py b/app/core/rag/app/one.py deleted file mode 100644 index a3a2a685..00000000 --- a/app/core/rag/app/one.py +++ /dev/null @@ -1,149 +0,0 @@ -import logging -from io import BytesIO -import re - -from app.core.rag.deepdoc.parser.utils import get_text -from . import naive -from app.core.rag.nlp import rag_tokenizer, tokenize -from app.core.rag.deepdoc.parser import PdfParser, ExcelParser, HtmlParser -from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper -from app.core.rag.app.naive import by_plaintext, PARSERS - -class Pdf(PdfParser): - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None): - from timeit import default_timer as timer - start = timer() - callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) - callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) - - start = timer() - self._layouts_rec(zoomin, drop=False) - callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start)) - logging.debug("layouts cost: {}s".format(timer() - start)) - - start = timer() - self._table_transformer_job(zoomin) - callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._text_merge() - callback(0.67, "Text merged ({:.2f}s)".format(timer() - start)) - tbls = self._extract_table_figure(True, zoomin, True, True) - self._concat_downward() - - sections = [(b["text"], self.get_position(b, zoomin)) - for i, b in enumerate(self.boxes)] - return [(txt, "") for txt, _ in sorted(sections, key=lambda x: ( - x[-1][0][0], x[-1][0][3], x[-1][0][1]))], tbls - - -def chunk(filename, binary=None, from_page=0, to_page=100000, - lang="Chinese", callback=None, **kwargs): - """ - Supported file formats are docx, pdf, excel, txt. - One file forms a chunk which maintains original text order. - """ - parser_config = kwargs.get( - "parser_config", { - "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) - eng = lang.lower() == "english" # is_english(cks) - - if re.search(r"\.docx$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - sections, tbls = naive.Docx()(filename, binary) - tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs) - sections = [s for s, _ in sections if s] - for (_, html), _ in tbls: - sections.append(html) - callback(0.8, "Finish parsing.") - - elif re.search(r"\.pdf$", filename, re.IGNORECASE): - layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") - - if isinstance(layout_recognizer, bool): - layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" - - name = layout_recognizer.strip().lower() - parser = PARSERS.get(name, by_plaintext) - callback(0.1, "Start to parse.") - - sections, tbls, pdf_parser = parser( - filename = filename, - binary = binary, - from_page = from_page, - to_page = to_page, - lang = lang, - callback = callback, - pdf_cls = Pdf, - **kwargs - ) - - if not sections and not tbls: - return [] - - if name in ["tcadp", "docling", "mineru"]: - parser_config["chunk_token_num"] = 0 - - callback(0.8, "Finish parsing.") - - for (img, rows), poss in tbls: - if not rows: - continue - sections.append((rows if isinstance(rows, str) else rows[0], - [(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss])) - sections = [s for s, _ in sections if s] - - elif re.search(r"\.xlsx?$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - excel_parser = ExcelParser() - sections = excel_parser.html(binary, 1000000000) - - elif re.search(r"\.(txt|md|markdown)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - txt = get_text(filename, binary) - sections = txt.split("\n") - sections = [s for s in sections if s] - callback(0.8, "Finish parsing.") - - elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - sections = HtmlParser()(filename, binary) - sections = [s for s in sections if s] - callback(0.8, "Finish parsing.") - - elif re.search(r"\.doc$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - binary = BytesIO(binary) - doc_parsed = parser.from_buffer(binary) - sections = doc_parsed['content'].split('\n') - sections = [s for s in sections if s] - callback(0.8, "Finish parsing.") - - else: - raise NotImplementedError( - "file type not supported yet(doc, docx, pdf, txt supported)") - - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - tokenize(doc, "\n".join(sections), eng) - return [doc] - - -if __name__ == "__main__": - import sys - - def dummy(prog=None, msg=""): - pass - - chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) diff --git a/app/core/rag/app/paper.py b/app/core/rag/app/paper.py deleted file mode 100644 index c2c63824..00000000 --- a/app/core/rag/app/paper.py +++ /dev/null @@ -1,284 +0,0 @@ -import logging -import copy -import re - -from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper -from app.core.rag.common.constants import ParserType -from app.core.rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks -from app.core.rag.deepdoc.parser import PdfParser, PlainParser -import numpy as np - -class Pdf(PdfParser): - def __init__(self): - self.model_speciess = ParserType.PAPER.value - super().__init__() - - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None): - from timeit import default_timer as timer - start = timer() - callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) - callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) - - start = timer() - self._layouts_rec(zoomin) - callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start)) - logging.debug(f"layouts cost: {timer() - start}s") - - start = timer() - self._table_transformer_job(zoomin) - callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._text_merge() - tbls = self._extract_table_figure(True, zoomin, True, True) - column_width = np.median([b["x1"] - b["x0"] for b in self.boxes]) - self._concat_downward() - self._filter_forpages() - callback(0.75, "Text merged ({:.2f}s)".format(timer() - start)) - - # clean mess - if column_width < self.page_images[0].size[0] / zoomin / 2: - logging.debug("two_column................... {} {}".format(column_width, - self.page_images[0].size[0] / zoomin / 2)) - self.boxes = self.sort_X_by_page(self.boxes, column_width / 2) - for b in self.boxes: - b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) - - def _begin(txt): - return re.match( - "[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)", - txt.lower().strip()) - - if from_page > 0: - return { - "title": "", - "authors": "", - "abstract": "", - "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if - re.match(r"(text|title)", b.get("layoutno", "text"))], - "tables": tbls - } - # get title and authors - title = "" - authors = [] - i = 0 - while i < min(32, len(self.boxes)-1): - b = self.boxes[i] - i += 1 - if b.get("layoutno", "").find("title") >= 0: - title = b["text"] - if _begin(title): - title = "" - break - for j in range(3): - if _begin(self.boxes[i + j]["text"]): - break - authors.append(self.boxes[i + j]["text"]) - break - break - # get abstract - abstr = "" - i = 0 - while i + 1 < min(32, len(self.boxes)): - b = self.boxes[i] - i += 1 - txt = b["text"].lower().strip() - if re.match("(abstract|摘要)", txt): - if len(txt.split()) > 32 or len(txt) > 64: - abstr = txt + self._line_tag(b, zoomin) - break - txt = self.boxes[i]["text"].lower().strip() - if len(txt.split()) > 32 or len(txt) > 64: - abstr = txt + self._line_tag(self.boxes[i], zoomin) - i += 1 - break - if not abstr: - i = 0 - - callback( - 0.8, "Page {}~{}: Text merging finished".format( - from_page, min( - to_page, self.total_page))) - for b in self.boxes: - logging.debug("{} {}".format(b["text"], b.get("layoutno"))) - logging.debug("{}".format(tbls)) - - return { - "title": title, - "authors": " ".join(authors), - "abstract": abstr, - "sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if - re.match(r"(text|title)", b.get("layoutno", "text"))], - "tables": tbls - } - - -def chunk(filename, binary=None, from_page=0, to_page=100000, - lang="Chinese", callback=None, **kwargs): - """ - Only pdf is supported. - The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly. - """ - parser_config = kwargs.get( - "parser_config", { - "chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"}) - if re.search(r"\.pdf$", filename, re.IGNORECASE): - if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": - pdf_parser = PlainParser() - paper = { - "title": filename, - "authors": " ", - "abstract": "", - "sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page)[0], - "tables": [] - } - else: - pdf_parser = Pdf() - paper = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) - tbls=paper["tables"] - tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs) - paper["tables"] = tbls - else: - raise NotImplementedError("file type not supported yet(pdf supported)") - - doc = {"docnm_kwd": filename, "authors_tks": rag_tokenizer.tokenize(paper["authors"]), - "title_tks": rag_tokenizer.tokenize(paper["title"] if paper["title"] else filename)} - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"]) - # is it English - eng = lang.lower() == "english" # pdf_parser.is_english - logging.debug("It's English.....{}".format(eng)) - - res = tokenize_table(paper["tables"], doc, eng) - - if paper["abstract"]: - d = copy.deepcopy(doc) - txt = pdf_parser.remove_tag(paper["abstract"]) - d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"] - d["important_tks"] = " ".join(d["important_kwd"]) - d["image"], poss = pdf_parser.crop( - paper["abstract"], need_position=True) - add_positions(d, poss) - tokenize(d, txt, eng) - res.append(d) - - sorted_sections = paper["sections"] - # set pivot using the most frequent type of title, - # then merge between 2 pivot - bull = bullets_category([txt for txt, _ in sorted_sections]) - most_level, levels = title_frequency(bull, sorted_sections) - assert len(sorted_sections) == len(levels) - sec_ids = [] - sid = 0 - for i, lvl in enumerate(levels): - if lvl <= most_level and i > 0 and lvl != levels[i - 1]: - sid += 1 - sec_ids.append(sid) - logging.debug("{} {} {} {}".format(lvl, sorted_sections[i][0], most_level, sid)) - - chunks = [] - last_sid = -2 - for (txt, _), sec_id in zip(sorted_sections, sec_ids): - if sec_id == last_sid: - if chunks: - chunks[-1] += "\n" + txt - continue - chunks.append(txt) - last_sid = sec_id - res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) - return res - - -""" - readed = [0] * len(paper["lines"]) - # find colon firstly - i = 0 - while i + 1 < len(paper["lines"]): - txt = pdf_parser.remove_tag(paper["lines"][i][0]) - j = i - if txt.strip("\n").strip()[-1] not in "::": - i += 1 - continue - i += 1 - while i < len(paper["lines"]) and not paper["lines"][i][0]: - i += 1 - if i >= len(paper["lines"]): break - proj = [paper["lines"][i][0].strip()] - i += 1 - while i < len(paper["lines"]) and paper["lines"][i][0].strip()[0] == proj[-1][0]: - proj.append(paper["lines"][i]) - i += 1 - for k in range(j, i): readed[k] = True - txt = txt[::-1] - if eng: - r = re.search(r"(.*?) ([\\.;?!]|$)", txt) - txt = r.group(1)[::-1] if r else txt[::-1] - else: - r = re.search(r"(.*?) ([。?;!]|$)", txt) - txt = r.group(1)[::-1] if r else txt[::-1] - for p in proj: - d = copy.deepcopy(doc) - txt += "\n" + pdf_parser.remove_tag(p) - d["image"], poss = pdf_parser.crop(p, need_position=True) - add_positions(d, poss) - tokenize(d, txt, eng) - res.append(d) - - i = 0 - chunk = [] - tk_cnt = 0 - def add_chunk(): - nonlocal chunk, res, doc, pdf_parser, tk_cnt - d = copy.deepcopy(doc) - ck = "\n".join(chunk) - tokenize(d, pdf_parser.remove_tag(ck), pdf_parser.is_english) - d["image"], poss = pdf_parser.crop(ck, need_position=True) - add_positions(d, poss) - res.append(d) - chunk = [] - tk_cnt = 0 - - while i < len(paper["lines"]): - if tk_cnt > 128: - add_chunk() - if readed[i]: - i += 1 - continue - readed[i] = True - txt, layouts = paper["lines"][i] - txt_ = pdf_parser.remove_tag(txt) - i += 1 - cnt = num_tokens_from_string(txt_) - if any([ - layouts.find("title") >= 0 and chunk, - cnt + tk_cnt > 128 and tk_cnt > 32, - ]): - add_chunk() - chunk = [txt] - tk_cnt = cnt - else: - chunk.append(txt) - tk_cnt += cnt - - if chunk: add_chunk() - for i, d in enumerate(res): - print(d) - # d["image"].save(f"./logs/{i}.jpg") - return res -""" - -if __name__ == "__main__": - import sys - - def dummy(prog=None, msg=""): - pass - chunk(sys.argv[1], callback=dummy) diff --git a/app/core/rag/app/picture.py b/app/core/rag/app/picture.py deleted file mode 100644 index addc7d9b..00000000 --- a/app/core/rag/app/picture.py +++ /dev/null @@ -1,96 +0,0 @@ -import io -import re - -import numpy as np -from PIL import Image - -from app.core.rag.deepdoc.vision import OCR -from app.core.rag.nlp import rag_tokenizer, tokenize -from app.core.rag.common.string_utils import clean_markdown_block - -ocr = OCR() - -# Gemini supported MIME types -VIDEO_EXTS = [".mp4", ".mov", ".avi", ".flv", ".mpeg", ".mpg", ".webm", ".wmv", ".3gp", ".3gpp", ".mkv"] - - -def chunk(filename, binary, lang, callback=None, vision_model=None, **kwargs): - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)), - } - eng = lang.lower() == "english" - - if any(filename.lower().endswith(ext) for ext in VIDEO_EXTS): - try: - doc.update({"doc_type_kwd": "video"}) - ans = vision_model.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename) - callback(0.8, "CV LLM respond: %s ..." % ans[:32]) - ans += "\n" + ans - tokenize(doc, ans, eng) - return [doc] - except Exception as e: - callback(prog=-1, msg=str(e)) - else: - img = Image.open(io.BytesIO(binary)).convert("RGB") - doc.update( - { - "image": img, - "doc_type_kwd": "image", - } - ) - bxs = ocr(np.array(img)) - txt = "\n".join([t[0] for _, t in bxs if t[0]]) - callback(0.4, "Finish OCR: (%s ...)" % txt[:12]) - if (eng and len(txt.split()) > 32) or len(txt) > 32: - tokenize(doc, txt, eng) - callback(0.8, "OCR results is too long to use CV LLM.") - return [doc] - - try: - callback(0.4, "Use CV LLM to describe the picture.") - img_binary = io.BytesIO() - img.save(img_binary, format="JPEG") - img_binary.seek(0) - ans = vision_model.describe(img_binary.read()) - callback(0.8, "CV LLM respond: %s ..." % ans[:32]) - txt += "\n" + ans - tokenize(doc, txt, eng) - return [doc] - except Exception as e: - callback(prog=-1, msg=str(e)) - - return [] - - -def vision_llm_chunk(binary, vision_model, prompt=None, callback=None): - """ - A simple wrapper to process image to markdown texts via VLM. - - Returns: - Simple markdown texts generated by VLM. - """ - callback = callback or (lambda prog, msg: None) - - img = binary - txt = "" - - try: - with io.BytesIO() as img_binary: - try: - img.save(img_binary, format="JPEG") - except Exception: - img_binary.seek(0) - img_binary.truncate() - img.save(img_binary, format="PNG") - - img_binary.seek(0) - description, token_count = vision_model.describe_with_prompt(img_binary.read(), prompt) - ans = clean_markdown_block(description) - txt += "\n" + ans - return txt - - except Exception as e: - callback(-1, str(e)) - - return "" diff --git a/app/core/rag/app/presentation.py b/app/core/rag/app/presentation.py deleted file mode 100644 index d7b23d66..00000000 --- a/app/core/rag/app/presentation.py +++ /dev/null @@ -1,164 +0,0 @@ -import copy -import re -from io import BytesIO -from PIL import Image - -from app.core.rag.nlp import tokenize, is_english -from app.core.rag.nlp import rag_tokenizer -from app.core.rag.deepdoc.parser import PdfParser, PptParser, PlainParser -from PyPDF2 import PdfReader as pdf2_read -from app.core.rag.app.naive import by_plaintext, PARSERS - -class Ppt(PptParser): - def __call__(self, fnm, from_page, to_page, callback=None): - txts = super().__call__(fnm, from_page, to_page) - - callback(0.5, "Text extraction finished.") - import aspose.slides as slides - import aspose.pydrawing as drawing - imgs = [] - with slides.Presentation(BytesIO(fnm)) as presentation: - for i, slide in enumerate(presentation.slides[from_page: to_page]): - try: - with BytesIO() as buffered: - slide.get_thumbnail( - 0.1, 0.1).save( - buffered, drawing.imaging.ImageFormat.jpeg) - buffered.seek(0) - imgs.append(Image.open(buffered).copy()) - except RuntimeError as e: - raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e - assert len(imgs) == len( - txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts)) - callback(0.9, "Image extraction finished") - self.is_english = is_english(txts) - return [(txts[i], imgs[i]) for i in range(len(txts))] - -class Pdf(PdfParser): - def __init__(self): - super().__init__() - - def __garbage(self, txt): - txt = txt.lower().strip() - if re.match(r"[0-9\.,%/-]+$", txt): - return True - if len(txt) < 3: - return True - return False - - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None): - from timeit import default_timer as timer - start = timer() - callback(msg="OCR started") - self.__images__(filename if not binary else binary, - zoomin, from_page, to_page, callback) - callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start)) - assert len(self.boxes) == len(self.page_images), "{} vs. {}".format( - len(self.boxes), len(self.page_images)) - res = [] - for i in range(len(self.boxes)): - lines = "\n".join([b["text"] for b in self.boxes[i] - if not self.__garbage(b["text"])]) - res.append((lines, self.page_images[i])) - callback(0.9, "Page {}~{}: Parsing finished".format( - from_page, min(to_page, self.total_page))) - return res, [] - - -class PlainPdf(PlainParser): - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, callback=None, **kwargs): - self.pdf = pdf2_read(filename if not binary else BytesIO(binary)) - page_txt = [] - for page in self.pdf.pages[from_page: to_page]: - page_txt.append(page.extract_text()) - callback(0.9, "Parsing finished") - return [(txt, None) for txt in page_txt], [] - - -def chunk(filename, binary=None, from_page=0, to_page=100000, - lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs): - """ - The supported file formats are pdf, pptx. - Every page will be treated as a chunk. And the thumbnail of every page will be stored. - PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary. - """ - if parser_config is None: - parser_config = {} - eng = lang.lower() == "english" - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } - doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"]) - res = [] - if re.search(r"\.pptx?$", filename, re.IGNORECASE): - if not binary: - with open(filename, "rb") as f: - binary = f.read() - ppt_parser = Ppt() - for pn, (txt, img) in enumerate(ppt_parser( - filename if not binary else binary, from_page, 1000000, callback)): - d = copy.deepcopy(doc) - pn += from_page - d["image"] = img - d["doc_type_kwd"] = "image" - d["page_num_int"] = [pn + 1] - d["top_int"] = [0] - d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])] - tokenize(d, txt, eng) - res.append(d) - return res - elif re.search(r"\.pdf$", filename, re.IGNORECASE): - layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") - - if isinstance(layout_recognizer, bool): - layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text" - - name = layout_recognizer.strip().lower() - parser = PARSERS.get(name, by_plaintext) - callback(0.1, "Start to parse.") - - sections, _, _ = parser( - filename=filename, - binary=binary, - from_page=from_page, - to_page=to_page, - lang=lang, - callback=callback, - vision_model=vision_model, - pdf_cls=Pdf, - **kwargs - ) - - if not sections: - return [] - - if name in ["tcadp", "docling", "mineru"]: - parser_config["chunk_token_num"] = 0 - - callback(0.8, "Finish parsing.") - - for pn, (txt, img) in enumerate(sections): - d = copy.deepcopy(doc) - pn += from_page - if img: - d["image"] = img - d["page_num_int"] = [pn + 1] - d["top_int"] = [0] - d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)] - tokenize(d, txt, eng) - res.append(d) - return res - - raise NotImplementedError( - "file type not supported yet(pptx, pdf supported)") - - -if __name__ == "__main__": - import sys - - def dummy(a, b): - pass - chunk(sys.argv[1], callback=dummy) diff --git a/app/core/rag/app/qa.py b/app/core/rag/app/qa.py deleted file mode 100644 index 248dbe0a..00000000 --- a/app/core/rag/app/qa.py +++ /dev/null @@ -1,455 +0,0 @@ -import logging -import re -import csv -from copy import deepcopy -from io import BytesIO -from timeit import default_timer as timer -from openpyxl import load_workbook - -from app.core.rag.deepdoc.parser.utils import get_text -from app.core.rag.nlp import is_english, random_choices, qbullets_category, add_positions, has_qbullet, docx_question_level -from app.core.rag.nlp import rag_tokenizer, tokenize_table, concat_img -from app.core.rag.deepdoc.parser import PdfParser, ExcelParser, DocxParser -from docx import Document -from PIL import Image -from markdown import markdown - -from app.core.rag.common.float_utils import get_float - - -class Excel(ExcelParser): - def __call__(self, fnm, binary=None, callback=None): - if not binary: - wb = load_workbook(fnm) - else: - wb = load_workbook(BytesIO(binary)) - total = 0 - for sheetname in wb.sheetnames: - total += len(list(wb[sheetname].rows)) - - res, fails = [], [] - for sheetname in wb.sheetnames: - ws = wb[sheetname] - rows = list(ws.rows) - for i, r in enumerate(rows): - q, a = "", "" - for cell in r: - if not cell.value: - continue - if not q: - q = str(cell.value) - elif not a: - a = str(cell.value) - else: - break - if q and a: - res.append((q, a)) - else: - fails.append(str(i + 1)) - if len(res) % 999 == 0: - callback(len(res) * - 0.6 / - total, ("Extract pairs: {}".format(len(res)) + - (f"{len(fails)} failure, line: %s..." % - (",".join(fails[:3])) if fails else ""))) - - callback(0.6, ("Extract pairs: {}. ".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - self.is_english = is_english( - [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1]) - return res - - -class Pdf(PdfParser): - def __call__(self, filename, binary=None, from_page=0, - to_page=100000, zoomin=3, callback=None): - start = timer() - callback(msg="OCR started") - self.__images__( - filename if not binary else binary, - zoomin, - from_page, - to_page, - callback - ) - callback(msg="OCR finished ({:.2f}s)".format(timer() - start)) - logging.debug("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start)) - start = timer() - self._layouts_rec(zoomin, drop=False) - callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._table_transformer_job(zoomin) - callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start)) - - start = timer() - self._text_merge() - callback(0.67, "Text merged ({:.2f}s)".format(timer() - start)) - tbls = self._extract_table_figure(True, zoomin, True, True) - #self._naive_vertical_merge() - # self._concat_downward() - #self._filter_forpages() - logging.debug("layouts: {}".format(timer() - start)) - sections = [b["text"] for b in self.boxes] - bull_x0_list = [] - q_bull, reg = qbullets_category(sections) - if q_bull == -1: - raise ValueError("Unable to recognize Q&A structure.") - qai_list = [] - last_q, last_a, last_tag = '', '', '' - last_index = -1 - last_box = {'text':''} - last_bull = None - def sort_key(element): - tbls_pn = element[1][0][0] - tbls_top = element[1][0][3] - return tbls_pn, tbls_top - tbls.sort(key=sort_key) - tbl_index = 0 - last_pn, last_bottom = 0, 0 - tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', '' - for box in self.boxes: - section, line_tag = box['text'], self._line_tag(box, zoomin) - has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list) - last_box, last_index, last_bull = box, index, has_bull - line_pn = get_float(line_tag.lstrip('@@').split('\t')[0]) - line_top = get_float(line_tag.rstrip('##').split('\t')[3]) - tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) - if not has_bull: # No question bullet - if not last_q: - if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed - tbl_index += 1 - continue - else: - sum_tag = line_tag - sum_section = section - while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \ - and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer - sum_tag = f'{tbl_tag}{sum_tag}' - sum_section = f'{tbl_text}{sum_section}' - tbl_index += 1 - tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) - last_a = f'{last_a}{sum_section}' - last_tag = f'{last_tag}{sum_tag}' - else: - if last_q: - while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \ - and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer - last_tag = f'{last_tag}{tbl_tag}' - last_a = f'{last_a}{tbl_text}' - tbl_index += 1 - tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index) - image, poss = self.crop(last_tag, need_position=True) - qai_list.append((last_q, last_a, image, poss)) - last_q, last_a, last_tag = '', '', '' - last_q = has_bull.group() - _, end = has_bull.span() - last_a = section[end:] - last_tag = line_tag - last_bottom = float(line_tag.rstrip('##').split('\t')[4]) - last_pn = line_pn - if last_q: - qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True))) - return qai_list, tbls - - def get_tbls_info(self, tbls, tbl_index): - if tbl_index >= len(tbls): - return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', '' - tbl_pn = tbls[tbl_index][1][0][0]+1 - tbl_left = tbls[tbl_index][1][0][1] - tbl_right = tbls[tbl_index][1][0][2] - tbl_top = tbls[tbl_index][1][0][3] - tbl_bottom = tbls[tbl_index][1][0][4] - tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \ - .format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom) - _tbl_text = ''.join(tbls[tbl_index][0][1]) - return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, _tbl_text - - -class Docx(DocxParser): - def __init__(self): - pass - - def get_picture(self, document, paragraph): - img = paragraph._element.xpath('.//pic:pic') - if not img: - return None - img = img[0] - embed = img.xpath('.//a:blip/@r:embed')[0] - related_part = document.part.related_parts[embed] - image = related_part.image - image = Image.open(BytesIO(image.blob)).convert('RGB') - return image - - def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None): - self.doc = Document( - filename) if not binary else Document(BytesIO(binary)) - pn = 0 - last_answer, last_image = "", None - question_stack, level_stack = [], [] - qai_list = [] - for p in self.doc.paragraphs: - if pn > to_page: - break - question_level, p_text = 0, '' - if from_page <= pn < to_page and p.text.strip(): - question_level, p_text = docx_question_level(p) - if not question_level or question_level > 6: # not a question - last_answer = f'{last_answer}\n{p_text}' - current_image = self.get_picture(self.doc, p) - last_image = concat_img(last_image, current_image) - else: # is a question - if last_answer or last_image: - sum_question = '\n'.join(question_stack) - if sum_question: - qai_list.append((sum_question, last_answer, last_image)) - last_answer, last_image = '', None - - i = question_level - while question_stack and i <= level_stack[-1]: - question_stack.pop() - level_stack.pop() - question_stack.append(p_text) - level_stack.append(question_level) - for run in p.runs: - if 'lastRenderedPageBreak' in run._element.xml: - pn += 1 - continue - if 'w:br' in run._element.xml and 'type="page"' in run._element.xml: - pn += 1 - if last_answer: - sum_question = '\n'.join(question_stack) - if sum_question: - qai_list.append((sum_question, last_answer, last_image)) - - tbls = [] - for tb in self.doc.tables: - html= "" - for r in tb.rows: - html += "" - i = 0 - while i < len(r.cells): - span = 1 - c = r.cells[i] - for j in range(i+1, len(r.cells)): - if c.text == r.cells[j].text: - span += 1 - i = j - i += 1 - html += f"" if span == 1 else f"" - html += "" - html += "
{c.text}{c.text}
" - tbls.append(((None, html), "")) - return qai_list, tbls - - -def rmPrefix(txt): - return re.sub( - r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE) - - -def beAdocPdf(d, q, a, eng, image, poss): - qprefix = "Question: " if eng else "问题:" - aprefix = "Answer: " if eng else "回答:" - d["content_with_weight"] = "\t".join( - [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) - d["content_ltks"] = rag_tokenizer.tokenize(q) - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - if image: - d["image"] = image - d["doc_type_kwd"] = "image" - add_positions(d, poss) - return d - - -def beAdocDocx(d, q, a, eng, image, row_num=-1): - qprefix = "Question: " if eng else "问题:" - aprefix = "Answer: " if eng else "回答:" - d["content_with_weight"] = "\t".join( - [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) - d["content_ltks"] = rag_tokenizer.tokenize(q) - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - if image: - d["image"] = image - d["doc_type_kwd"] = "image" - if row_num >= 0: - d["top_int"] = [row_num] - return d - - -def beAdoc(d, q, a, eng, row_num=-1): - qprefix = "Question: " if eng else "问题:" - aprefix = "Answer: " if eng else "回答:" - d["content_with_weight"] = "\t".join( - [qprefix + rmPrefix(q), aprefix + rmPrefix(a)]) - d["content_ltks"] = rag_tokenizer.tokenize(q) - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - if row_num >= 0: - d["top_int"] = [row_num] - return d - - -def mdQuestionLevel(s): - match = re.match(r'#*', s) - return (len(match.group(0)), s.lstrip('#').lstrip()) if match else (0, s) - - -def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): - """ - Excel and csv(txt) format files are supported. - If the file is in excel format, there should be 2 column question and answer without header. - And question column is ahead of answer column. - And it's O.K if it has multiple sheets as long as the columns are rightly composed. - - If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer. - - All the deformed lines will be ignored. - Every pair of Q&A will be treated as a chunk. - """ - eng = lang.lower() == "english" - res = [] - doc = { - "docnm_kwd": filename, - "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)) - } - if re.search(r"\.xlsx?$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - excel_parser = Excel() - for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)): - res.append(beAdoc(deepcopy(doc), q, a, eng, ii)) - return res - - elif re.search(r"\.(txt)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - txt = get_text(filename, binary) - lines = txt.split("\n") - comma, tab = 0, 0 - for line in lines: - if len(line.split(",")) == 2: - comma += 1 - if len(line.split("\t")) == 2: - tab += 1 - delimiter = "\t" if tab >= comma else "," - - fails = [] - question, answer = "", "" - i = 0 - while i < len(lines): - arr = lines[i].split(delimiter) - if len(arr) != 2: - if question: - answer += "\n" + lines[i] - else: - fails.append(str(i+1)) - elif len(arr) == 2: - if question and answer: - res.append(beAdoc(deepcopy(doc), question, answer, eng, i)) - question, answer = arr - i += 1 - if len(res) % 999 == 0: - callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - - if question: - res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines))) - - callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - - return res - - elif re.search(r"\.(csv)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - txt = get_text(filename, binary) - lines = txt.split("\n") - delimiter = "\t" if any("\t" in line for line in lines) else "," - - fails = [] - question, answer = "", "" - res = [] - reader = csv.reader(lines, delimiter=delimiter) - - for i, row in enumerate(reader): - if len(row) != 2: - if question: - answer += "\n" + lines[i] - else: - fails.append(str(i + 1)) - elif len(row) == 2: - if question and answer: - res.append(beAdoc(deepcopy(doc), question, answer, eng, i)) - question, answer = row - if len(res) % 999 == 0: - callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - - if question: - res.append(beAdoc(deepcopy(doc), question, answer, eng, len(list(reader)))) - - callback(0.6, ("Extract Q&A: {}".format(len(res)) + ( - f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else ""))) - return res - - elif re.search(r"\.pdf$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - pdf_parser = Pdf() - qai_list, tbls = pdf_parser(filename if not binary else binary, - from_page=from_page, to_page=to_page, callback=callback) - for q, a, image, poss in qai_list: - res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss)) - return res - - elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE): - callback(0.1, "Start to parse.") - txt = get_text(filename, binary) - lines = txt.split("\n") - _last_question, last_answer = "", "" - question_stack, level_stack = [], [] - code_block = False - for index, line in enumerate(lines): - if line.strip().startswith('```'): - code_block = not code_block - question_level, question = 0, '' - if not code_block: - question_level, question = mdQuestionLevel(line) - - if not question_level or question_level > 6: # not a question - last_answer = f'{last_answer}\n{line}' - else: # is a question - if last_answer.strip(): - sum_question = '\n'.join(question_stack) - if sum_question: - res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) - last_answer = '' - - i = question_level - while question_stack and i <= level_stack[-1]: - question_stack.pop() - level_stack.pop() - question_stack.append(question) - level_stack.append(question_level) - if last_answer.strip(): - sum_question = '\n'.join(question_stack) - if sum_question: - res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index)) - return res - - elif re.search(r"\.docx$", filename, re.IGNORECASE): - docx_parser = Docx() - qai_list, tbls = docx_parser(filename, binary, - from_page=0, to_page=10000, callback=callback) - res = tokenize_table(tbls, doc, eng) - for i, (q, a, image) in enumerate(qai_list): - res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i)) - return res - - raise NotImplementedError( - "Excel, csv(txt), pdf, markdown and docx format files are supported.") - - -if __name__ == "__main__": - import sys - - def dummy(prog=None, msg=""): - pass - chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) diff --git a/app/core/rag/common/__init__.py b/app/core/rag/common/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/common/connection_utils.py b/app/core/rag/common/connection_utils.py deleted file mode 100644 index 349caa27..00000000 --- a/app/core/rag/common/connection_utils.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import queue -import threading -from typing import Any, Callable, Coroutine, Optional, Type, Union -import asyncio -import trio -from functools import wraps -from flask import make_response, jsonify -from .constants import RetCode - -TimeoutException = Union[Type[BaseException], BaseException] -OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]] - - -def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, - on_timeout: Optional[OnTimeoutCallback] = None): - if isinstance(seconds, str): - seconds = float(seconds) - - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - result_queue = queue.Queue(maxsize=1) - - def target(): - try: - result = func(*args, **kwargs) - result_queue.put(result) - except Exception as e: - result_queue.put(e) - - thread = threading.Thread(target=target) - thread.daemon = True - thread.start() - - for a in range(attempts): - try: - if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - result = result_queue.get(timeout=seconds) - else: - result = result_queue.get() - if isinstance(result, Exception): - raise result - return result - except queue.Empty: - pass - raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.") - - @wraps(func) - async def async_wrapper(*args, **kwargs) -> Any: - if seconds is None: - return await func(*args, **kwargs) - - for a in range(attempts): - try: - if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - with trio.fail_after(seconds): - return await func(*args, **kwargs) - else: - return await func(*args, **kwargs) - except trio.TooSlowError: - if a < attempts - 1: - continue - if on_timeout is not None: - if callable(on_timeout): - result = on_timeout() - if isinstance(result, Coroutine): - return await result - return result - return on_timeout - - if exception is None: - raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.") - - if isinstance(exception, BaseException): - raise exception - - if isinstance(exception, type) and issubclass(exception, BaseException): - raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.") - - raise RuntimeError("Invalid exception type provided") - - if asyncio.iscoroutinefunction(func): - return async_wrapper - return wrapper - - return decorator - - -def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): - result_dict = {"code": code, "message": message, "data": data} - response_dict = {} - for key, value in result_dict.items(): - if value is None and key != "code": - continue - else: - response_dict[key] = value - response = make_response(jsonify(response_dict)) - if auth: - response.headers["Authorization"] = auth - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Method"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" - response.headers["Access-Control-Expose-Headers"] = "Authorization" - return response diff --git a/app/core/rag/common/constants.py b/app/core/rag/common/constants.py deleted file mode 100644 index 12d5d6d2..00000000 --- a/app/core/rag/common/constants.py +++ /dev/null @@ -1,180 +0,0 @@ -from enum import Enum, IntEnum -from strenum import StrEnum - -SERVICE_CONF = "service_conf.yaml" -RAG_SERVICE_NAME = "rag" - -class CustomEnum(Enum): - @classmethod - def valid(cls, value): - try: - cls(value) - return True - except BaseException: - return False - - @classmethod - def values(cls): - return [member.value for member in cls.__members__.values()] - - @classmethod - def names(cls): - return [member.name for member in cls.__members__.values()] - - -class RetCode(IntEnum, CustomEnum): - SUCCESS = 0 - NOT_EFFECTIVE = 10 - EXCEPTION_ERROR = 100 - ARGUMENT_ERROR = 101 - DATA_ERROR = 102 - OPERATING_ERROR = 103 - CONNECTION_ERROR = 105 - RUNNING = 106 - PERMISSION_ERROR = 108 - AUTHENTICATION_ERROR = 109 - UNAUTHORIZED = 401 - SERVER_ERROR = 500 - FORBIDDEN = 403 - NOT_FOUND = 404 - - -class StatusEnum(Enum): - VALID = "1" - INVALID = "0" - - -class ActiveEnum(Enum): - ACTIVE = "1" - INACTIVE = "0" - - -class LLMType(StrEnum): - CHAT = 'chat' - EMBEDDING = 'embedding' - SPEECH2TEXT = 'speech2text' - IMAGE2TEXT = 'image2text' - RERANK = 'rerank' - TTS = 'tts' - - -class TaskStatus(StrEnum): - UNSTART = "0" - RUNNING = "1" - CANCEL = "2" - DONE = "3" - FAIL = "4" - SCHEDULE = "5" - - -VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, - TaskStatus.SCHEDULE} - - -class ParserType(StrEnum): - PRESENTATION = "presentation" - LAWS = "laws" - MANUAL = "manual" - PAPER = "paper" - RESUME = "resume" - BOOK = "book" - QA = "qa" - TABLE = "table" - NAIVE = "naive" - PICTURE = "picture" - ONE = "one" - AUDIO = "audio" - EMAIL = "email" - KG = "knowledge_graph" - TAG = "tag" - - -class FileSource(StrEnum): - LOCAL = "" - KNOWLEDGEBASE = "knowledgebase" - S3 = "s3" - NOTION = "notion" - DISCORD = "discord" - CONFLUENCE = "confluence" - GMAIL = "gmail" - GOOGLE_DRIVE = "google_drive" - JIRA = "jira" - SHAREPOINT = "sharepoint" - SLACK = "slack" - TEAMS = "teams" - - -class PipelineTaskType(StrEnum): - PARSE = "Parse" - DOWNLOAD = "Download" - RAPTOR = "RAPTOR" - GRAPH_RAG = "GraphRAG" - MINDMAP = "Mindmap" - - -VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, - PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP} - -class MCPServerType(StrEnum): - SSE = "sse" - STREAMABLE_HTTP = "streamable-http" - -VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} - -class Storage(Enum): - MINIO = 1 - AZURE_SPN = 2 - AZURE_SAS = 3 - AWS_S3 = 4 - OSS = 5 - OPENDAL = 6 - -# environment -# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT" -# ENV_RAG_SECRET_KEY = "RAG_SECRET_KEY" -# ENV_REGISTER_ENABLED = "REGISTER_ENABLED" -# ENV_DOC_ENGINE = "DOC_ENGINE" -# ENV_SANDBOX_ENABLED = "SANDBOX_ENABLED" -# ENV_SANDBOX_HOST = "SANDBOX_HOST" -# ENV_MAX_CONTENT_LENGTH = "MAX_CONTENT_LENGTH" -# ENV_COMPONENT_EXEC_TIMEOUT = "COMPONENT_EXEC_TIMEOUT" -# ENV_TRINO_USE_TLS = "TRINO_USE_TLS" -# ENV_MAX_FILE_NUM_PER_USER = "MAX_FILE_NUM_PER_USER" -# ENV_MACOS = "MACOS" -# ENV_RAG_DEBUGPY_LISTEN = "RAG_DEBUGPY_LISTEN" -# ENV_WERKZEUG_RUN_MAIN = "WERKZEUG_RUN_MAIN" -# ENV_DISABLE_SDK = "DISABLE_SDK" -# ENV_ENABLE_TIMEOUT_ASSERTION = "ENABLE_TIMEOUT_ASSERTION" -# ENV_LOG_LEVELS = "LOG_LEVELS" -# ENV_TENSORRT_DLA_SVR = "TENSORRT_DLA_SVR" -# ENV_OCR_GPU_MEM_LIMIT_MB = "OCR_GPU_MEM_LIMIT_MB" -# ENV_OCR_ARENA_EXTEND_STRATEGY = "OCR_ARENA_EXTEND_STRATEGY" -# ENV_MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = "MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK" -# ENV_MAX_MAX_CONCURRENT_CHATS = "MAX_CONCURRENT_CHATS" -# ENV_RAG_MCP_BASE_URL = "RAG_MCP_BASE_URL" -# ENV_RAG_MCP_HOST = "RAG_MCP_HOST" -# ENV_RAG_MCP_PORT = "RAG_MCP_PORT" -# ENV_RAG_MCP_LAUNCH_MODE = "RAG_MCP_LAUNCH_MODE" -# ENV_RAG_MCP_HOST_API_KEY = "RAG_MCP_HOST_API_KEY" -# ENV_MINERU_EXECUTABLE = "MINERU_EXECUTABLE" -# ENV_MINERU_APISERVER = "MINERU_APISERVER" -# ENV_MINERU_OUTPUT_DIR = "MINERU_OUTPUT_DIR" -# ENV_MINERU_BACKEND = "MINERU_BACKEND" -# ENV_MINERU_DELETE_OUTPUT = "MINERU_DELETE_OUTPUT" -# ENV_TCADP_OUTPUT_DIR = "TCADP_OUTPUT_DIR" -# ENV_LM_TIMEOUT_SECONDS = "LM_TIMEOUT_SECONDS" -# ENV_LLM_MAX_RETRIES = "LLM_MAX_RETRIES" -# ENV_LLM_BASE_DELAY = "LLM_BASE_DELAY" -# ENV_OLLAMA_KEEP_ALIVE = "OLLAMA_KEEP_ALIVE" -# ENV_DOC_BULK_SIZE = "DOC_BULK_SIZE" -# ENV_EMBEDDING_BATCH_SIZE = "EMBEDDING_BATCH_SIZE" -# ENV_MAX_CONCURRENT_TASKS = "MAX_CONCURRENT_TASKS" -# ENV_MAX_CONCURRENT_CHUNK_BUILDERS = "MAX_CONCURRENT_CHUNK_BUILDERS" -# ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO" -# ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT" -# ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED" - -PAGERANK_FLD = "pagerank_fea" -SVR_QUEUE_NAME = "rag_svr_queue" -SVR_CONSUMER_GROUP_NAME = "rag_svr_task_broker" -TAG_FLD = "tag_feas" diff --git a/app/core/rag/common/file_utils.py b/app/core/rag/common/file_utils.py deleted file mode 100644 index 36a033c7..00000000 --- a/app/core/rag/common/file_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import os - -PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE") - - -def get_project_base_directory(*args): - global PROJECT_BASE - if PROJECT_BASE is None: - PROJECT_BASE = os.path.abspath( - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.pardir, - os.pardir, - os.pardir, - os.pardir, - ) - ) - - if args: - return os.path.join(PROJECT_BASE, *args) - return PROJECT_BASE - - -def traversal_files(base): - for root, ds, fs in os.walk(base): - for f in fs: - fullname = os.path.join(root, f) - yield fullname diff --git a/app/core/rag/common/float_utils.py b/app/core/rag/common/float_utils.py deleted file mode 100644 index 583d3b99..00000000 --- a/app/core/rag/common/float_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -def get_float(v): - """ - Convert a value to float, handling None and exceptions gracefully. - - Attempts to convert the input value to a float. If the value is None or - cannot be converted to float, returns negative infinity as a default value. - - Args: - v: The value to convert to float. Can be any type that float() accepts, - or None. - - Returns: - float: The converted float value if successful, otherwise float('-inf'). - - Examples: - >>> get_float("3.14") - 3.14 - >>> get_float(None) - -inf - >>> get_float("invalid") - -inf - >>> get_float(42) - 42.0 - """ - if v is None: - return float('-inf') - try: - return float(v) - except Exception: - return float('-inf') \ No newline at end of file diff --git a/app/core/rag/common/misc_utils.py b/app/core/rag/common/misc_utils.py deleted file mode 100644 index 758c3aaf..00000000 --- a/app/core/rag/common/misc_utils.py +++ /dev/null @@ -1,92 +0,0 @@ -import base64 -import hashlib -import uuid -import requests -import threading -import subprocess -import sys -import os -import logging - -def get_uuid(): - return uuid.uuid1().hex - - -def download_img(url): - if not url: - return "" - response = requests.get(url) - return "data:" + \ - response.headers.get('Content-Type', 'image/jpg') + ";" + \ - "base64," + base64.b64encode(response.content).decode("utf-8") - - -def hash_str2int(line: str, mod: int = 10 ** 8) -> int: - return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod - -def convert_bytes(size_in_bytes: int) -> str: - """ - Format size in bytes. - """ - if size_in_bytes == 0: - return "0 B" - - units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] - i = 0 - size = float(size_in_bytes) - - while size >= 1024 and i < len(units) - 1: - size /= 1024 - i += 1 - - if i == 0 or size >= 100: - return f"{size:.0f} {units[i]}" - elif size >= 10: - return f"{size:.1f} {units[i]}" - else: - return f"{size:.2f} {units[i]}" - - -def once(func): - """ - A thread-safe decorator that ensures the decorated function runs exactly once, - caching and returning its result for all subsequent calls. This prevents - race conditions in multi-thread environments by using a lock to protect - the execution state. - - Args: - func (callable): The function to be executed only once. - - Returns: - callable: A wrapper function that executes `func` on the first call - and returns the cached result thereafter. - - Example: - @once - def compute_expensive_value(): - print("Computing...") - return 42 - - # First call: executes and prints - # Subsequent calls: return 42 without executing - """ - executed = False - result = None - lock = threading.Lock() - def wrapper(*args, **kwargs): - nonlocal executed, result - with lock: - if not executed: - executed = True - result = func(*args, **kwargs) - return result - return wrapper - -@once -def pip_install_torch(): - device = os.getenv("DEVICE", "cpu") - if device=="cpu": - return - logging.info("Installing pytorch") - pkg_names = ["torch>=2.5.0,<3.0.0"] - subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names]) diff --git a/app/core/rag/common/settings.py b/app/core/rag/common/settings.py deleted file mode 100644 index e04e7aea..00000000 --- a/app/core/rag/common/settings.py +++ /dev/null @@ -1,2 +0,0 @@ -PARALLEL_DEVICES: int = 0 - diff --git a/app/core/rag/common/string_utils.py b/app/core/rag/common/string_utils.py deleted file mode 100644 index 43152f56..00000000 --- a/app/core/rag/common/string_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -import re - - -def remove_redundant_spaces(txt: str): - """ - Remove redundant spaces around punctuation marks while preserving meaningful spaces. - - This function performs two main operations: - 1. Remove spaces after left-boundary characters (opening brackets, etc.) - 2. Remove spaces before right-boundary characters (closing brackets, punctuation, etc.) - - Args: - txt (str): Input text to process - - Returns: - str: Text with redundant spaces removed - """ - # First pass: Remove spaces after left-boundary characters - # Matches: [non-alphanumeric-and-specific-right-punctuation] + [non-space] - # Removes spaces after characters like '(', '<', and other non-alphanumeric chars - # Examples: - # "( test" → "(test" - txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE) - - # Second pass: Remove spaces before right-boundary characters - # Matches: [non-space] + [non-alphanumeric-and-specific-left-punctuation] - # Removes spaces before characters like non-')', non-',', non-'.', and non-alphanumeric chars - # Examples: - # "world !" → "world!" - return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE) - - -def clean_markdown_block(text): - """ - Remove Markdown code block syntax from the beginning and end of text. - - This function cleans Markdown code blocks by removing: - - Opening ```Markdown tags (with optional whitespace and newlines) - - Closing ``` tags (with optional whitespace and newlines) - - Args: - text (str): Input text that may be wrapped in Markdown code blocks - - Returns: - str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace - - """ - # Remove opening ```markdown tag with optional whitespace and newlines - # Matches: optional whitespace + ```markdown + optional whitespace + optional newline - text = re.sub(r'^\s*```markdown\s*\n?', '', text) - - # Remove closing ``` tag with optional whitespace and newlines - # Matches: optional newline + optional whitespace + ``` + optional whitespace at end - text = re.sub(r'\n?\s*```\s*$', '', text) - - # Return text with surrounding whitespace removed - return text.strip() diff --git a/app/core/rag/common/token_utils.py b/app/core/rag/common/token_utils.py deleted file mode 100644 index e8390300..00000000 --- a/app/core/rag/common/token_utils.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import tiktoken - -from .file_utils import get_project_base_directory - -tiktoken_cache_dir = os.path.join(get_project_base_directory(), "res") -os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir -# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo") -encoder = tiktoken.get_encoding("cl100k_base") - - -def num_tokens_from_string(string: str) -> int: - """Returns the number of tokens in a text string.""" - try: - code_list = encoder.encode(string) - return len(code_list) - except Exception: - return 0 - -def total_token_count_from_response(resp): - if resp is None: - return 0 - - if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"): - try: - return resp.usage.total_tokens - except Exception: - pass - - if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"): - try: - return resp.usage_metadata.total_tokens - except Exception: - pass - - if 'usage' in resp and 'total_tokens' in resp['usage']: - try: - return resp["usage"]["total_tokens"] - except Exception: - pass - - if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']: - try: - return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"] - except Exception: - pass - - if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']: - try: - return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"] - except Exception: - pass - return 0 - - -def truncate(string: str, max_len: int) -> str: - """Returns truncated text if the length of text exceed max_len.""" - return encoder.decode(encoder.encode(string)[:max_len]) - diff --git a/app/core/rag/deepdoc/README.md b/app/core/rag/deepdoc/README.md deleted file mode 100644 index 14c7947b..00000000 --- a/app/core/rag/deepdoc/README.md +++ /dev/null @@ -1,122 +0,0 @@ -English | [简体中文](./README_zh.md) - -# *Deep*Doc - -- [1. Introduction](#1) -- [2. Vision](#2) -- [3. Parser](#3) - - -## 1. Introduction - -With a bunch of documents from various domains with various formats and along with diverse retrieval requirements, -an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose. -There are 2 parts in *Deep*Doc so far: vision and parser. -You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR. -```bash -python deepdoc/vision/t_ocr.py -h -usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] - -options: - -h, --help show this help message and exit - --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF - --output_dir OUTPUT_DIR - Directory where to store the output images. Default: './ocr_outputs' -``` -```bash -python deepdoc/vision/t_recognizer.py -h -usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}] - -options: - -h, --help show this help message and exit - --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF - --output_dir OUTPUT_DIR - Directory where to store the output images. Default: './layouts_outputs' - --threshold THRESHOLD - A threshold to filter out detections. Default: 0.5 - --mode {layout,tsr} Task mode: layout recognition or table structure recognition -``` - -Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!! -```bash -export HF_ENDPOINT=https://hf-mirror.com -``` - - -## 2. Vision - -We use vision information to resolve problems as human being. - - OCR. Since a lot of documents presented as images or at least be able to transform to image, - OCR is a very essential and fundamental or even universal solution for text extraction. - ```bash - python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result - ``` - The inputs could be directory to images or PDF, or a image or PDF. - You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results, - txt files which contain the OCR text. -
- -
- - - Layout recognition. Documents from different domain may have various layouts, - like, newspaper, magazine, book and résumé are distinct in terms of layout. - Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not, - or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption. - We have 10 basic layout components which covers most cases: - - Text - - Title - - Figure - - Figure caption - - Table - - Table caption - - Header - - Footer - - Reference - - Equation - - Have a try on the following command to see the layout detection results. - ```bash - python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result - ``` - The inputs could be directory to images or PDF, or a image or PDF. - You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following: -
- -
- - - Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text. - And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers. - Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM. - We have five labels for TSR task: - - Column - - Row - - Column header - - Projected row header - - Spanning cell - - Have a try on the following command to see the layout detection results. - ```bash - python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result - ``` - The inputs could be directory to images or PDF, or a image or PDF. - You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following: -
- -
- - -## 3. Parser - -Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser. -The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes: - - Text chunks with their own positions in PDF(page number and rectangular positions). - - Tables with cropped image from the PDF, and contents which has already translated into natural language sentences. - - Figures with caption and text in the figures. - -### Résumé - -The résumé is a very complicated kind of document. A résumé which is composed of unstructured text -with various layouts could be resolved into structured data composed of nearly a hundred of fields. -We haven't opened the parser yet, as we open the processing method after parsing procedure. - - \ No newline at end of file diff --git a/app/core/rag/deepdoc/README_zh.md b/app/core/rag/deepdoc/README_zh.md deleted file mode 100644 index 4ada7edb..00000000 --- a/app/core/rag/deepdoc/README_zh.md +++ /dev/null @@ -1,116 +0,0 @@ -[English](./README.md) | 简体中文 - -# *Deep*Doc - -- [*Deep*Doc](#deepdoc) - - [1. 介绍](#1-介绍) - - [2. 视觉处理](#2-视觉处理) - - [3. 解析器](#3-解析器) - - [简历](#简历) - - -## 1. 介绍 - -对于来自不同领域、具有不同格式和不同检索要求的大量文档,准确的分析成为一项极具挑战性的任务。*Deep*Doc 就是为了这个目的而诞生的。到目前为止,*Deep*Doc 中有两个组成部分:视觉处理和解析器。如果您对我们的OCR、布局识别和TSR结果感兴趣,您可以运行下面的测试程序。 - -```bash -python deepdoc/vision/t_ocr.py -h -usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] - -options: - -h, --help show this help message and exit - --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF - --output_dir OUTPUT_DIR - Directory where to store the output images. Default: './ocr_outputs' -``` - -```bash -python deepdoc/vision/t_recognizer.py -h -usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}] - -options: - -h, --help show this help message and exit - --inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF - --output_dir OUTPUT_DIR - Directory where to store the output images. Default: './layouts_outputs' - --threshold THRESHOLD - A threshold to filter out detections. Default: 0.5 - --mode {layout,tsr} Task mode: layout recognition or table structure recognition -``` - -HuggingFace为我们的模型提供服务。如果你在下载HuggingFace模型时遇到问题,这可能会有所帮助!! - -```bash -export HF_ENDPOINT=https://hf-mirror.com -``` - - -## 2. 视觉处理 - -作为人类,我们使用视觉信息来解决问题。 - - - **OCR(Optical Character Recognition,光学字符识别)**。由于许多文档都是以图像形式呈现的,或者至少能够转换为图像,因此OCR是文本提取的一个非常重要、基本,甚至通用的解决方案。 - - ```bash - python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result - ``` - - 输入可以是图像或PDF的目录,或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中有演示结果位置的图像,以及包含OCR文本的txt文件。 - -
- -
- - - 布局识别(Layout recognition)。来自不同领域的文件可能有不同的布局,如报纸、杂志、书籍和简历在布局方面是不同的。只有当机器有准确的布局分析时,它才能决定这些文本部分是连续的还是不连续的,或者这个部分需要表结构识别(Table Structure Recognition,TSR)来处理,或者这个部件是一个图形并用这个标题来描述。我们有10个基本布局组件,涵盖了大多数情况: - - 文本 - - 标题 - - 配图 - - 配图标题 - - 表格 - - 表格标题 - - 页头 - - 页尾 - - 参考引用 - - 公式 - - 请尝试以下命令以查看布局检测结果。 - - ```bash - python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result - ``` - - 输入可以是图像或PDF的目录,或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中有显示检测结果的图像,如下所示: -
- -
- - - **TSR(Table Structure Recognition,表结构识别)**。数据表是一种常用的结构,用于表示包括数字或文本在内的数据。表的结构可能非常复杂,比如层次结构标题、跨单元格和投影行标题。除了TSR,我们还将内容重新组合成LLM可以很好理解的句子。TSR任务有五个标签: - - 列 - - 行 - - 列标题 - - 行标题 - - 合并单元格 - - 请尝试以下命令以查看布局检测结果。 - - ```bash - python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result - ``` - - 输入可以是图像或PDF的目录,或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中包含图像和html页面,这些页面展示了以下检测结果: - -
- -
- - -## 3. 解析器 - -PDF、DOCX、EXCEL和PPT四种文档格式都有相应的解析器。最复杂的是PDF解析器,因为PDF具有灵活性。PDF解析器的输出包括: - - 在PDF中有自己位置的文本块(页码和矩形位置)。 - - 带有PDF裁剪图像的表格,以及已经翻译成自然语言句子的内容。 - - 图中带标题和文字的图。 - -### 简历 - -简历是一种非常复杂的文档。由各种格式的非结构化文本构成的简历可以被解析为包含近百个字段的结构化数据。我们还没有启用解析器,因为在解析过程之后才会启动处理方法。 diff --git a/app/core/rag/deepdoc/__init__.py b/app/core/rag/deepdoc/__init__.py deleted file mode 100644 index 6bfdd33d..00000000 --- a/app/core/rag/deepdoc/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from beartype.claw import beartype_this_package -beartype_this_package() diff --git a/app/core/rag/deepdoc/parser/__init__.py b/app/core/rag/deepdoc/parser/__init__.py deleted file mode 100644 index 4cc4bada..00000000 --- a/app/core/rag/deepdoc/parser/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from .docx_parser import RAGDocxParser as DocxParser -from .excel_parser import RAGExcelParser as ExcelParser -from .html_parser import RAGHtmlParser as HtmlParser -from .json_parser import RAGJsonParser as JsonParser -from .markdown_parser import MarkdownElementExtractor -from .markdown_parser import RAGMarkdownParser as MarkdownParser -from .pdf_parser import PlainParser -from .pdf_parser import RAGPdfParser as PdfParser -from .ppt_parser import RAGPptParser as PptParser -from .txt_parser import RAGTxtParser as TxtParser - -__all__ = [ - "PdfParser", - "PlainParser", - "DocxParser", - "ExcelParser", - "PptParser", - "HtmlParser", - "JsonParser", - "MarkdownParser", - "TxtParser", - "MarkdownElementExtractor", -] - diff --git a/app/core/rag/deepdoc/parser/docx_parser.py b/app/core/rag/deepdoc/parser/docx_parser.py deleted file mode 100644 index 7cf5f434..00000000 --- a/app/core/rag/deepdoc/parser/docx_parser.py +++ /dev/null @@ -1,123 +0,0 @@ -from docx import Document -import re -import pandas as pd -from collections import Counter -from app.core.rag.nlp import rag_tokenizer -from io import BytesIO - - -class RAGDocxParser: - - def __extract_table_content(self, tb): - df = [] - for row in tb.rows: - df.append([c.text for c in row.cells]) - return self.__compose_table_content(pd.DataFrame(df)) - - def __compose_table_content(self, df): - - def blockType(b): - pattern = [ - ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), - (r"^(20|19)[0-9]{2}年$", "Dt"), - (r"^(20|19)[0-9]{2}[年/-][0-9]{1,2}月*$", "Dt"), - ("^[0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), - (r"^第*[一二三四1-4]季度$", "Dt"), - (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"), - (r"^(20|19)[0-9]{2}[ABCDE]$", "DT"), - ("^[0-9.,+%/ -]+$", "Nu"), - (r"^[0-9A-Z/\._~-]+$", "Ca"), - (r"^[A-Z]*[a-z' -]+$", "En"), - (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), - (r"^.{1}$", "Sg") - ] - for p, n in pattern: - if re.search(p, b): - return n - tks = [t for t in rag_tokenizer.tokenize(b).split() if len(t) > 1] - if len(tks) > 3: - if len(tks) < 12: - return "Tx" - else: - return "Lx" - - if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr": - return "Nr" - - return "Ot" - - if len(df) < 2: - return [] - max_type = Counter([blockType(str(df.iloc[i, j])) for i in range( - 1, len(df)) for j in range(len(df.iloc[i, :]))]) - max_type = max(max_type.items(), key=lambda x: x[1])[0] - - colnm = len(df.iloc[0, :]) - hdrows = [0] # header is not necessarily appear in the first line - if max_type == "Nu": - for r in range(1, len(df)): - tys = Counter([blockType(str(df.iloc[r, j])) - for j in range(len(df.iloc[r, :]))]) - tys = max(tys.items(), key=lambda x: x[1])[0] - if tys != max_type: - hdrows.append(r) - - lines = [] - for i in range(1, len(df)): - if i in hdrows: - continue - hr = [r - i for r in hdrows] - hr = [r for r in hr if r < 0] - t = len(hr) - 1 - while t > 0: - if hr[t] - hr[t - 1] > 1: - hr = hr[t:] - break - t -= 1 - headers = [] - for j in range(len(df.iloc[i, :])): - t = [] - for h in hr: - x = str(df.iloc[i + h, j]).strip() - if x in t: - continue - t.append(x) - t = ",".join(t) - if t: - t += ": " - headers.append(t) - cells = [] - for j in range(len(df.iloc[i, :])): - if not str(df.iloc[i, j]): - continue - cells.append(headers[j] + str(df.iloc[i, j])) - lines.append(";".join(cells)) - - if colnm > 3: - return lines - return ["\n".join(lines)] - - def __call__(self, fnm, from_page=0, to_page=100000000): - self.doc = Document(fnm) if isinstance( - fnm, str) else Document(BytesIO(fnm)) - pn = 0 # parsed page - secs = [] # parsed contents - for p in self.doc.paragraphs: - if pn > to_page: - break - - runs_within_single_paragraph = [] # save runs within the range of pages - for run in p.runs: - if pn > to_page: - break - if from_page <= pn < to_page and p.text.strip(): - runs_within_single_paragraph.append(run.text) # append run.text first - - # wrap page break checker into a static method - if 'lastRenderedPageBreak' in run._element.xml: - pn += 1 - - secs.append(("".join(runs_within_single_paragraph), p.style.name if hasattr(p.style, 'name') else '')) # then concat run.text as part of the paragraph - - tbls = [self.__extract_table_content(tb) for tb in self.doc.tables] - return secs, tbls diff --git a/app/core/rag/deepdoc/parser/excel_parser.py b/app/core/rag/deepdoc/parser/excel_parser.py deleted file mode 100644 index b6e1e4a1..00000000 --- a/app/core/rag/deepdoc/parser/excel_parser.py +++ /dev/null @@ -1,210 +0,0 @@ -import logging -import re -import sys -from io import BytesIO - -import pandas as pd -from openpyxl import Workbook, load_workbook - -from app.core.rag.nlp import find_codec - -# copied from `/openpyxl/cell/cell.py` -ILLEGAL_CHARACTERS_RE = re.compile(r"[\000-\010]|[\013-\014]|[\016-\037]") - - -class RAGExcelParser: - @staticmethod - def _load_excel_to_workbook(file_like_object): - if isinstance(file_like_object, bytes): - file_like_object = BytesIO(file_like_object) - - # Read first 4 bytes to determine file type - file_like_object.seek(0) - file_head = file_like_object.read(4) - file_like_object.seek(0) - - if not (file_head.startswith(b"PK\x03\x04") or file_head.startswith(b"\xd0\xcf\x11\xe0")): - logging.info("Not an Excel file, converting CSV to Excel Workbook") - - try: - file_like_object.seek(0) - df = pd.read_csv(file_like_object) - return RAGExcelParser._dataframe_to_workbook(df) - - except Exception as e_csv: - raise Exception(f"Failed to parse CSV and convert to Excel Workbook: {e_csv}") - - try: - return load_workbook(file_like_object, data_only=True) - except Exception as e: - logging.info(f"openpyxl load error: {e}, try pandas instead") - try: - file_like_object.seek(0) - try: - dfs = pd.read_excel(file_like_object, sheet_name=None) - return RAGExcelParser._dataframe_to_workbook(dfs) - except Exception as ex: - logging.info(f"pandas with default engine load error: {ex}, try calamine instead") - file_like_object.seek(0) - df = pd.read_excel(file_like_object, engine="calamine") - return RAGExcelParser._dataframe_to_workbook(df) - except Exception as e_pandas: - raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}") - - @staticmethod - def _clean_dataframe(df: pd.DataFrame): - def clean_string(s): - if isinstance(s, str): - return ILLEGAL_CHARACTERS_RE.sub(" ", s) - return s - - return df.apply(lambda col: col.map(clean_string)) - - @staticmethod - def _dataframe_to_workbook(df): - # if contains multiple sheets use _dataframes_to_workbook - if isinstance(df, dict) and len(df) > 1: - return RAGExcelParser._dataframes_to_workbook(df) - - df = RAGExcelParser._clean_dataframe(df) - wb = Workbook() - ws = wb.active - ws.title = "Data" - - for col_num, column_name in enumerate(df.columns, 1): - ws.cell(row=1, column=col_num, value=column_name) - - for row_num, row in enumerate(df.values, 2): - for col_num, value in enumerate(row, 1): - ws.cell(row=row_num, column=col_num, value=value) - - return wb - - @staticmethod - def _dataframes_to_workbook(dfs: dict): - wb = Workbook() - default_sheet = wb.active - wb.remove(default_sheet) - - for sheet_name, df in dfs.items(): - df = RAGExcelParser._clean_dataframe(df) - ws = wb.create_sheet(title=sheet_name) - for col_num, column_name in enumerate(df.columns, 1): - ws.cell(row=1, column=col_num, value=column_name) - for row_num, row in enumerate(df.values, 2): - for col_num, value in enumerate(row, 1): - ws.cell(row=row_num, column=col_num, value=value) - return wb - - def html(self, fnm, chunk_rows=256): - from html import escape - - file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm - wb = RAGExcelParser._load_excel_to_workbook(file_like_object) - tb_chunks = [] - - def _fmt(v): - if v is None: - return "" - return str(v).strip() - - for sheetname in wb.sheetnames: - ws = wb[sheetname] - try: - rows = list(ws.rows) - except Exception as e: - logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}") - continue - - if not rows: - continue - - tb_rows_0 = "" - for t in list(rows[0]): - tb_rows_0 += f"{escape(_fmt(t.value))}" - tb_rows_0 += "" - - for chunk_i in range((len(rows) - 1) // chunk_rows + 1): - tb = "" - tb += f"" - tb += tb_rows_0 - for r in list(rows[1 + chunk_i * chunk_rows : min(1 + (chunk_i + 1) * chunk_rows, len(rows))]): - tb += "" - for i, c in enumerate(r): - if c.value is None: - tb += "" - else: - tb += f"" - tb += "" - tb += "
{sheetname}
{escape(_fmt(c.value))}
\n" - tb_chunks.append(tb) - - return tb_chunks - - def markdown(self, fnm): - import pandas as pd - - file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm - try: - file_like_object.seek(0) - df = pd.read_excel(file_like_object) - except Exception as e: - logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file") - file_like_object.seek(0) - df = pd.read_csv(file_like_object) - df = df.replace(r"^\s*$", "", regex=True) - return df.to_markdown(index=False) - - def __call__(self, fnm): - file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm - wb = RAGExcelParser._load_excel_to_workbook(file_like_object) - - res = [] - for sheetname in wb.sheetnames: - ws = wb[sheetname] - try: - rows = list(ws.rows) - except Exception as e: - logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}") - continue - if not rows: - continue - ti = list(rows[0]) - for r in list(rows[1:]): - fields = [] - for i, c in enumerate(r): - if not c.value: - continue - t = str(ti[i].value) if i < len(ti) else "" - t += (":" if t else "") + str(c.value) - fields.append(t) - line = "; ".join(fields) - if sheetname.lower().find("sheet") < 0: - line += " ——" + sheetname - res.append(line) - return res - - @staticmethod - def row_number(fnm, binary): - if fnm.split(".")[-1].lower().find("xls") >= 0: - wb = RAGExcelParser._load_excel_to_workbook(BytesIO(binary)) - total = 0 - - for sheetname in wb.sheetnames: - try: - ws = wb[sheetname] - total += len(list(ws.rows)) - except Exception as e: - logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}") - continue - return total - - if fnm.split(".")[-1].lower() in ["csv", "txt"]: - encoding = find_codec(binary) - txt = binary.decode(encoding, errors="ignore") - return len(txt.split("\n")) - - -if __name__ == "__main__": - psr = RAGExcelParser() - psr(sys.argv[1]) diff --git a/app/core/rag/deepdoc/parser/figure_parser.py b/app/core/rag/deepdoc/parser/figure_parser.py deleted file mode 100644 index 123ff596..00000000 --- a/app/core/rag/deepdoc/parser/figure_parser.py +++ /dev/null @@ -1,118 +0,0 @@ -from concurrent.futures import ThreadPoolExecutor, as_completed - -from PIL import Image - -from app.core.rag.common.constants import LLMType -from app.core.rag.common.connection_utils import timeout -from app.core.rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk -from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt - - -def vision_figure_parser_figure_data_wrapper(figures_data_without_positions): - return [ - ( - (figure_data[1], [figure_data[0]]), - [(0, 0, 0, 0, 0)], - ) - for figure_data in figures_data_without_positions - if isinstance(figure_data[1], Image.Image) - ] - -def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,vision_model=None,**kwargs): - if vision_model: - figures_data = vision_figure_parser_figure_data_wrapper(sections) - try: - docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs) - boosted_figures = docx_vision_parser(callback=callback) - tbls.extend(boosted_figures) - except Exception as e: - callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.") - return tbls - -def vision_figure_parser_pdf_wrapper(tbls,callback=None,vision_model=None,**kwargs): - if vision_model: - def is_figure_item(item): - return ( - isinstance(item[0][0], Image.Image) and - isinstance(item[0][1], list) - ) - figures_data = [item for item in tbls if is_figure_item(item)] - try: - docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs) - boosted_figures = docx_vision_parser(callback=callback) - tbls = [item for item in tbls if not is_figure_item(item)] - tbls.extend(boosted_figures) - except Exception as e: - callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.") - return tbls - -shared_executor = ThreadPoolExecutor(max_workers=10) - - -class VisionFigureParser: - def __init__(self, vision_model, figures_data, *args, **kwargs): - self.vision_model = vision_model - self._extract_figures_info(figures_data) - assert len(self.figures) == len(self.descriptions) - assert not self.positions or (len(self.figures) == len(self.positions)) - - def _extract_figures_info(self, figures_data): - self.figures = [] - self.descriptions = [] - self.positions = [] - - for item in figures_data: - # position - if len(item) == 2 and isinstance(item[0], tuple) and len(item[0]) == 2 and isinstance(item[1], list) and isinstance(item[1][0], tuple) and len(item[1][0]) == 5: - img_desc = item[0] - assert len(img_desc) == 2 and isinstance(img_desc[0], Image.Image) and isinstance(img_desc[1], list), "Should be (figure, [description])" - self.figures.append(img_desc[0]) - self.descriptions.append(img_desc[1]) - self.positions.append(item[1]) - else: - assert len(item) == 2 and isinstance(item[0], Image.Image) and isinstance(item[1], list), f"Unexpected form of figure data: get {len(item)=}, {item=}" - self.figures.append(item[0]) - self.descriptions.append(item[1]) - - def _assemble(self): - self.assembled = [] - self.has_positions = len(self.positions) != 0 - for i in range(len(self.figures)): - figure = self.figures[i] - desc = self.descriptions[i] - pos = self.positions[i] if self.has_positions else None - - figure_desc = (figure, desc) - - if pos is not None: - self.assembled.append((figure_desc, pos)) - else: - self.assembled.append((figure_desc,)) - - return self.assembled - - def __call__(self, **kwargs): - callback = kwargs.get("callback", lambda prog, msg: None) - - @timeout(30, 3) - def process(figure_idx, figure_binary): - description_text = picture_vision_llm_chunk( - binary=figure_binary, - vision_model=self.vision_model, - prompt=vision_llm_figure_describe_prompt(), - callback=callback, - ) - return figure_idx, description_text - - futures = [] - for idx, img_binary in enumerate(self.figures or []): - futures.append(shared_executor.submit(process, idx, img_binary)) - - for future in as_completed(futures): - figure_num, txt = future.result() - if txt: - self.descriptions[figure_num] = txt + "\n".join(self.descriptions[figure_num]) - - self._assemble() - - return self.assembled diff --git a/app/core/rag/deepdoc/parser/html_parser.py b/app/core/rag/deepdoc/parser/html_parser.py deleted file mode 100644 index 5db44e29..00000000 --- a/app/core/rag/deepdoc/parser/html_parser.py +++ /dev/null @@ -1,197 +0,0 @@ -from app.core.rag.nlp import find_codec, rag_tokenizer -import uuid -import chardet -from bs4 import BeautifulSoup, NavigableString, Tag, Comment -import html - -def get_encoding(file): - with open(file,'rb') as f: - tmp = chardet.detect(f.read()) - return tmp['encoding'] - -BLOCK_TAGS = [ - "h1", "h2", "h3", "h4", "h5", "h6", - "p", "div", "article", "section", "aside", - "ul", "ol", "li", - "table", "pre", "code", "blockquote", - "figure", "figcaption" -] -TITLE_TAGS = {"h1": "#", "h2": "##", "h3": "###", "h4": "#####", "h5": "#####", "h6": "######"} - - -class RAGHtmlParser: - def __call__(self, fnm, binary=None, chunk_token_num=512): - if binary: - encoding = find_codec(binary) - txt = binary.decode(encoding, errors="ignore") - else: - with open(fnm, "r",encoding=get_encoding(fnm)) as f: - txt = f.read() - return self.parser_txt(txt, chunk_token_num) - - @classmethod - def parser_txt(cls, txt, chunk_token_num): - if not isinstance(txt, str): - raise TypeError("txt type should be string!") - - temp_sections = [] - soup = BeautifulSoup(txt, "html5lib") - # delete - - - %s - - -""" % TableStructureRecognizer.construct_table(boxes, html=True) - return html - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--inputs', - help="Directory where to store images or PDFs, or a file path to a single image or PDF", - required=True) - parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'", - default="./layouts_outputs") - parser.add_argument( - '--threshold', - help="A threshold to filter out detections. Default: 0.5", - default=0.5) - parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"], - default="layout") - args = parser.parse_args() - main(args) diff --git a/app/core/rag/deepdoc/vision/table_structure_recognizer.py b/app/core/rag/deepdoc/vision/table_structure_recognizer.py deleted file mode 100644 index bd89d6a7..00000000 --- a/app/core/rag/deepdoc/vision/table_structure_recognizer.py +++ /dev/null @@ -1,597 +0,0 @@ -import logging -import os -import re -from collections import Counter - -import numpy as np -from huggingface_hub import snapshot_download - -from app.core.rag.common.file_utils import get_project_base_directory -from app.core.rag.nlp import rag_tokenizer - -from .recognizer import Recognizer - - -class TableStructureRecognizer(Recognizer): - labels = [ - "table", - "table column", - "table row", - "table column header", - "table projected row header", - "table spanning cell", - ] - - def __init__(self): - try: - super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "res/deepdoc")) - except Exception: - super().__init__( - self.labels, - "tsr", - snapshot_download( - repo_id="InfiniFlow/deepdoc", - local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"), - local_dir_use_symlinks=False, - ), - ) - - def __call__(self, images, thr=0.2): - table_structure_recognizer_type = os.getenv("TABLE_STRUCTURE_RECOGNIZER_TYPE", "onnx").lower() - if table_structure_recognizer_type not in ["onnx", "ascend"]: - raise RuntimeError("Unsupported table structure recognizer type.") - - if table_structure_recognizer_type == "onnx": - logging.debug("Using Onnx table structure recognizer") - tbls = super().__call__(images, thr) - else: # ascend - logging.debug("Using Ascend table structure recognizer") - tbls = self._run_ascend_tsr(images, thr) - - res = [] - # align left&right for rows, align top&bottom for columns - for tbl in tbls: - lts = [ - { - "label": b["type"], - "score": b["score"], - "x0": b["bbox"][0], - "x1": b["bbox"][2], - "top": b["bbox"][1], - "bottom": b["bbox"][-1], - } - for b in tbl - ] - if not lts: - continue - - left = [b["x0"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0] - right = [b["x1"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0] - if not left: - continue - left = np.mean(left) if len(left) > 4 else np.min(left) - right = np.mean(right) if len(right) > 4 else np.max(right) - for b in lts: - if b["label"].find("row") > 0 or b["label"].find("header") > 0: - if b["x0"] > left: - b["x0"] = left - if b["x1"] < right: - b["x1"] = right - - top = [b["top"] for b in lts if b["label"] == "table column"] - bottom = [b["bottom"] for b in lts if b["label"] == "table column"] - if not top: - res.append(lts) - continue - top = np.median(top) if len(top) > 4 else np.min(top) - bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom) - for b in lts: - if b["label"] == "table column": - if b["top"] > top: - b["top"] = top - if b["bottom"] < bottom: - b["bottom"] = bottom - - res.append(lts) - return res - - @staticmethod - def is_caption(bx): - patt = [r"[图表]+[ 0-9::]{2,}"] - if any([re.match(p, bx["text"].strip()) for p in patt]) or bx.get("layout_type", "").find("caption") >= 0: - return True - return False - - @staticmethod - def blockType(b): - patt = [ - ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"), - (r"^(20|19)[0-9]{2}年$", "Dt"), - (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"), - ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"), - (r"^第*[一二三四1-4]季度$", "Dt"), - (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"), - (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"), - ("^[0-9.,+%/ -]+$", "Nu"), - (r"^[0-9A-Z/\._~-]+$", "Ca"), - (r"^[A-Z]*[a-z' -]+$", "En"), - (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), - (r"^.{1}$", "Sg"), - ] - for p, n in patt: - if re.search(p, b["text"].strip()): - return n - tks = [t for t in rag_tokenizer.tokenize(b["text"]).split() if len(t) > 1] - if len(tks) > 3: - if len(tks) < 12: - return "Tx" - else: - return "Lx" - - if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr": - return "Nr" - - return "Ot" - - @staticmethod - def construct_table(boxes, is_english=False, html=True, **kwargs): - cap = "" - i = 0 - while i < len(boxes): - if TableStructureRecognizer.is_caption(boxes[i]): - if is_english: - cap + " " - cap += boxes[i]["text"] - boxes.pop(i) - i -= 1 - i += 1 - - if not boxes: - return [] - for b in boxes: - b["btype"] = TableStructureRecognizer.blockType(b) - max_type = Counter([b["btype"] for b in boxes]).items() - max_type = max(max_type, key=lambda x: x[1])[0] if max_type else "" - logging.debug("MAXTYPE: " + max_type) - - rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] - rowh = np.min(rowh) if rowh else 0 - boxes = Recognizer.sort_R_firstly(boxes, rowh / 2) - # for b in boxes:print(b) - boxes[0]["rn"] = 0 - rows = [[boxes[0]]] - btm = boxes[0]["bottom"] - for b in boxes[1:]: - b["rn"] = len(rows) - 1 - lst_r = rows[-1] - if lst_r[-1].get("R", "") != b.get("R", "") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")): # new row - btm = b["bottom"] - b["rn"] += 1 - rows.append([b]) - continue - btm = (btm + b["bottom"]) / 2.0 - rows[-1].append(b) - - colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] - colwm = np.min(colwm) if colwm else 0 - crosspage = len(set([b["page_number"] for b in boxes])) > 1 - if crosspage: - boxes = Recognizer.sort_X_firstly(boxes, colwm / 2) - else: - boxes = Recognizer.sort_C_firstly(boxes, colwm / 2) - boxes[0]["cn"] = 0 - cols = [[boxes[0]]] - right = boxes[0]["x1"] - for b in boxes[1:]: - b["cn"] = len(cols) - 1 - lst_c = cols[-1] - if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1]["page_number"]) or ( - b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2") - ): # new col - right = b["x1"] - b["cn"] += 1 - cols.append([b]) - continue - right = (right + b["x1"]) / 2.0 - cols[-1].append(b) - - tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] - for b in boxes: - tbl[b["rn"]][b["cn"]].append(b) - - if len(rows) >= 4: - # remove single in column - j = 0 - while j < len(tbl[0]): - e, ii = 0, 0 - for i in range(len(tbl)): - if tbl[i][j]: - e += 1 - ii = i - if e > 1: - break - if e > 1: - j += 1 - continue - f = (j > 0 and tbl[ii][j - 1] and tbl[ii][j - 1][0].get("text")) or j == 0 - ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii][j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) - if f and ff: - j += 1 - continue - bx = tbl[ii][j][0] - logging.debug("Relocate column single: " + bx["text"]) - # j column only has one value - left, right = 100000, 100000 - if j > 0 and not f: - for i in range(len(tbl)): - if tbl[i][j - 1]: - left = min(left, np.min([bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) - if j + 1 < len(tbl[0]) and not ff: - for i in range(len(tbl)): - if tbl[i][j + 1]: - right = min(right, np.min([a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) - assert left < 100000 or right < 100000 - if left < right: - for jj in range(j, len(tbl[0])): - for i in range(len(tbl)): - for a in tbl[i][jj]: - a["cn"] -= 1 - if tbl[ii][j - 1]: - tbl[ii][j - 1].extend(tbl[ii][j]) - else: - tbl[ii][j - 1] = tbl[ii][j] - for i in range(len(tbl)): - tbl[i].pop(j) - - else: - for jj in range(j + 1, len(tbl[0])): - for i in range(len(tbl)): - for a in tbl[i][jj]: - a["cn"] -= 1 - if tbl[ii][j + 1]: - tbl[ii][j + 1].extend(tbl[ii][j]) - else: - tbl[ii][j + 1] = tbl[ii][j] - for i in range(len(tbl)): - tbl[i].pop(j) - cols.pop(j) - assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (len(cols), len(tbl[0])) - - if len(cols) >= 4: - # remove single in row - i = 0 - while i < len(tbl): - e, jj = 0, 0 - for j in range(len(tbl[i])): - if tbl[i][j]: - e += 1 - jj = j - if e > 1: - break - if e > 1: - i += 1 - continue - f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1][jj][0].get("text")) or i == 0 - ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1][jj][0].get("text")) or i + 1 >= len(tbl) - if f and ff: - i += 1 - continue - - bx = tbl[i][jj][0] - logging.debug("Relocate row single: " + bx["text"]) - # i row only has one value - up, down = 100000, 100000 - if i > 0 and not f: - for j in range(len(tbl[i - 1])): - if tbl[i - 1][j]: - up = min(up, np.min([bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) - if i + 1 < len(tbl) and not ff: - for j in range(len(tbl[i + 1])): - if tbl[i + 1][j]: - down = min(down, np.min([a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) - assert up < 100000 or down < 100000 - if up < down: - for ii in range(i, len(tbl)): - for j in range(len(tbl[ii])): - for a in tbl[ii][j]: - a["rn"] -= 1 - if tbl[i - 1][jj]: - tbl[i - 1][jj].extend(tbl[i][jj]) - else: - tbl[i - 1][jj] = tbl[i][jj] - tbl.pop(i) - - else: - for ii in range(i + 1, len(tbl)): - for j in range(len(tbl[ii])): - for a in tbl[ii][j]: - a["rn"] -= 1 - if tbl[i + 1][jj]: - tbl[i + 1][jj].extend(tbl[i][jj]) - else: - tbl[i + 1][jj] = tbl[i][jj] - tbl.pop(i) - rows.pop(i) - - # which rows are headers - hdset = set([]) - for i in range(len(tbl)): - cnt, h = 0, 0 - for j, arr in enumerate(tbl[i]): - if not arr: - continue - cnt += 1 - if max_type == "Nu" and arr[0]["btype"] == "Nu": - continue - if any([a.get("H") for a in arr]) or (max_type == "Nu" and arr[0]["btype"] != "Nu"): - h += 1 - if h / cnt > 0.5: - hdset.add(i) - - if html: - return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True)) - - return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english) - - @staticmethod - def __html_table(cap, hdset, tbl): - # constrcut HTML - html = "" - if cap: - html += f"" - for i in range(len(tbl)): - row = "" - txts = [] - for j, arr in enumerate(tbl[i]): - if arr is None: - continue - if not arr: - row += "" if i not in hdset else "" - continue - txt = "" - if arr: - h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) - txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)]) - txts.append(txt) - sp = "" - if arr[0].get("colspan"): - sp = "colspan={}".format(arr[0]["colspan"]) - if arr[0].get("rowspan"): - sp += " rowspan={}".format(arr[0]["rowspan"]) - if i in hdset: - row += f"" - else: - row += f"" - - if i in hdset: - if all([t in hdset for t in txts]): - continue - for t in txts: - hdset.add(t) - - if row != "": - row += "" - else: - row = "" - html += "\n" + row - html += "\n
{cap}
" + txt + "" + txt + "
" - return html - - @staticmethod - def __desc_table(cap, hdr_rowno, tbl, is_english): - # get text of every colomn in header row to become header text - clmno = len(tbl[0]) - rowno = len(tbl) - headers = {} - hdrset = set() - lst_hdr = [] - de = "的" if not is_english else " for " - for r in sorted(list(hdr_rowno)): - headers[r] = ["" for _ in range(clmno)] - for i in range(clmno): - if not tbl[r][i]: - continue - txt = " ".join([a["text"].strip() for a in tbl[r][i]]) - headers[r][i] = txt - hdrset.add(txt) - if all([not t for t in headers[r]]): - del headers[r] - hdr_rowno.remove(r) - continue - for j in range(clmno): - if headers[r][j]: - continue - if j >= len(lst_hdr): - break - headers[r][j] = lst_hdr[j] - lst_hdr = headers[r] - for i in range(rowno): - if i not in hdr_rowno: - continue - for j in range(i + 1, rowno): - if j not in hdr_rowno: - break - for k in range(clmno): - if not headers[j - 1][k]: - continue - if headers[j][k].find(headers[j - 1][k]) >= 0: - continue - if len(headers[j][k]) > len(headers[j - 1][k]): - headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k] - else: - headers[j][k] = headers[j - 1][k] + (de if headers[j - 1][k] else "") + headers[j][k] - - logging.debug(f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") - row_txt = [] - for i in range(rowno): - if i in hdr_rowno: - continue - rtxt = [] - - def append(delimer): - nonlocal rtxt, row_txt - rtxt = delimer.join(rtxt) - if row_txt and len(row_txt[-1]) + len(rtxt) < 64: - row_txt[-1] += "\n" + rtxt - else: - row_txt.append(rtxt) - - r = 0 - if len(headers.items()): - _arr = [(i - r, r) for r, _ in headers.items() if r < i] - if _arr: - _, r = min(_arr, key=lambda x: x[0]) - - if r not in headers and clmno <= 2: - for j in range(clmno): - if not tbl[i][j]: - continue - txt = "".join([a["text"].strip() for a in tbl[i][j]]) - if txt: - rtxt.append(txt) - if rtxt: - append(":") - continue - - for j in range(clmno): - if not tbl[i][j]: - continue - txt = "".join([a["text"].strip() for a in tbl[i][j]]) - if not txt: - continue - ctt = headers[r][j] if r in headers else "" - if ctt: - ctt += ":" - ctt += txt - if ctt: - rtxt.append(ctt) - - if rtxt: - row_txt.append("; ".join(rtxt)) - - if cap: - if is_english: - from_ = " in " - else: - from_ = "来自" - row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt] - return row_txt - - @staticmethod - def __cal_spans(boxes, rows, cols, tbl, html=True): - # caculate span - clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols] - crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols] - rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows] - rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows] - for b in boxes: - if "SP" not in b: - continue - b["colspan"] = [b["cn"]] - b["rowspan"] = [b["rn"]] - # col span - for j in range(0, len(clft)): - if j == b["cn"]: - continue - if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]: - continue - if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]: - continue - b["colspan"].append(j) - # row span - for j in range(0, len(rtop)): - if j == b["rn"]: - continue - if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]: - continue - if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]: - continue - b["rowspan"].append(j) - - def join(arr): - if not arr: - return "" - return "".join([t["text"] for t in arr]) - - # rm the spaning cells - for i in range(len(tbl)): - for j, arr in enumerate(tbl[i]): - if not arr: - continue - if all(["rowspan" not in a and "colspan" not in a for a in arr]): - continue - rowspan, colspan = [], [] - for a in arr: - if isinstance(a.get("rowspan", 0), list): - rowspan.extend(a["rowspan"]) - if isinstance(a.get("colspan", 0), list): - colspan.extend(a["colspan"]) - rowspan, colspan = set(rowspan), set(colspan) - if len(rowspan) < 2 and len(colspan) < 2: - for a in arr: - if "rowspan" in a: - del a["rowspan"] - if "colspan" in a: - del a["colspan"] - continue - rowspan, colspan = sorted(rowspan), sorted(colspan) - rowspan = list(range(rowspan[0], rowspan[-1] + 1)) - colspan = list(range(colspan[0], colspan[-1] + 1)) - assert i in rowspan, rowspan - assert j in colspan, colspan - arr = [] - for r in rowspan: - for c in colspan: - arr_txt = join(arr) - if tbl[r][c] and join(tbl[r][c]) != arr_txt: - arr.extend(tbl[r][c]) - tbl[r][c] = None if html else arr - for a in arr: - if len(rowspan) > 1: - a["rowspan"] = len(rowspan) - elif "rowspan" in a: - del a["rowspan"] - if len(colspan) > 1: - a["colspan"] = len(colspan) - elif "colspan" in a: - del a["colspan"] - tbl[rowspan[0]][colspan[0]] = arr - - return tbl - - def _run_ascend_tsr(self, image_list, thr=0.2, batch_size=16): - import math - - from ais_bench.infer.interface import InferSession - - model_dir = os.path.join(get_project_base_directory(), "res/deepdoc") - model_file_path = os.path.join(model_dir, "tsr.om") - - if not os.path.exists(model_file_path): - raise ValueError(f"Model file not found: {model_file_path}") - - device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0)) - session = InferSession(device_id=device_id, model_path=model_file_path) - - images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list] - results = [] - - conf_thr = max(thr, 0.08) - - batch_loop_cnt = math.ceil(float(len(images)) / batch_size) - for bi in range(batch_loop_cnt): - s = bi * batch_size - e = min((bi + 1) * batch_size, len(images)) - batch_images = images[s:e] - - inputs_list = self.preprocess(batch_images) - for ins in inputs_list: - feeds = [] - if "image" in ins: - feeds.append(ins["image"]) - else: - feeds.append(ins[self.input_names[0]]) - output_list = session.infer(feeds=feeds, mode="static") - bb = self.postprocess(output_list, ins, conf_thr) - results.append(bb) - return results diff --git a/app/core/rag/graphrag/__init__.py b/app/core/rag/graphrag/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/graphrag/utils.py b/app/core/rag/graphrag/utils.py deleted file mode 100644 index 65beb31f..00000000 --- a/app/core/rag/graphrag/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import xxhash -from app.aioRedis import aio_redis_set, aio_redis_get - -def get_llm_cache(llmnm, txt, history, genconf): - hasher = xxhash.xxh64() - hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) - - k = hasher.hexdigest() - bin = aio_redis_get(k) - if not bin: - return None - return bin - - -def set_llm_cache(llmnm, txt, v, history, genconf): - hasher = xxhash.xxh64() - hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) - k = hasher.hexdigest() - aio_redis_set(k, v.encode("utf-8"), 24 * 3600) diff --git a/app/core/rag/llm/__init__.py b/app/core/rag/llm/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/llm/chat_model.py b/app/core/rag/llm/chat_model.py deleted file mode 100644 index b9a5f87e..00000000 --- a/app/core/rag/llm/chat_model.py +++ /dev/null @@ -1,670 +0,0 @@ -import json -import logging -import os -import random -import re -import time -from abc import ABC -from copy import deepcopy -from typing import Any, Protocol -from urllib.parse import urljoin - -import json_repair -import openai -from openai import OpenAI -from openai.lib.azure import AzureOpenAI -from strenum import StrEnum - -from app.core.rag.nlp import is_chinese, is_english -from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response - - -# Error message constants -class LLMErrorCode(StrEnum): - ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED" - ERROR_AUTHENTICATION = "AUTH_ERROR" - ERROR_INVALID_REQUEST = "INVALID_REQUEST" - ERROR_SERVER = "SERVER_ERROR" - ERROR_TIMEOUT = "TIMEOUT" - ERROR_CONNECTION = "CONNECTION_ERROR" - ERROR_MODEL = "MODEL_ERROR" - ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS" - ERROR_CONTENT_FILTER = "CONTENT_FILTERED" - ERROR_QUOTA = "QUOTA_EXCEEDED" - ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED" - ERROR_GENERIC = "GENERIC_ERROR" - - -class ReActMode(StrEnum): - FUNCTION_CALL = "function_call" - REACT = "react" - - -ERROR_PREFIX = "**ERROR**" -LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。" -LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." - - -class ToolCallSession(Protocol): - def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... - - -class Base(ABC): - def __init__(self, key, model_name, base_url, **kwargs): - timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600)) - self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) - self.model_name = model_name - # Configure retry parameters - self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) - self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) - self.max_rounds = kwargs.get("max_rounds", 5) - self.is_tools = False - self.tools = [] - self.toolcall_sessions = {} - - def _get_delay(self): - """Calculate retry delay time""" - return self.base_delay * random.uniform(10, 150) - - def _classify_error(self, error): - """Classify error based on error message content""" - error_str = str(error).lower() - - keywords_mapping = [ - (["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA), - (["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT), - (["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION), - (["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST), - (["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER), - (["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT), - (["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION), - (["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER), - (["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL), - (["max rounds"], LLMErrorCode.ERROR_MODEL), - ] - for words, code in keywords_mapping: - if re.search("({})".format("|".join(words)), error_str): - return code - - return LLMErrorCode.ERROR_GENERIC - - def _clean_conf(self, gen_conf): - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - - allowed_conf = { - "temperature", - "max_completion_tokens", - "top_p", - "stream", - "stream_options", - "stop", - "n", - "presence_penalty", - "frequency_penalty", - "functions", - "function_call", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "logprobs", - "top_logprobs", - "extra_headers" - } - - gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf} - - return gen_conf - - def _chat(self, history, gen_conf, **kwargs): - logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2)) - if self.model_name.lower().find("qwq") >= 0: - logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly") - - final_ans = "" - tol_token = 0 - for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs): - if delta.startswith("") or delta.endswith(""): - continue - final_ans += delta - tol_token = tol - - if len(final_ans.strip()) == 0: - final_ans = "**ERROR**: Empty response from reasoning model" - - return final_ans.strip(), tol_token - - if self.model_name.lower().find("qwen3") >= 0: - kwargs["extra_body"] = {"enable_thinking": False} - - response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs) - - if not response.choices or not response.choices[0].message or not response.choices[0].message.content: - return "", 0 - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - return ans, total_token_count_from_response(response) - - def _chat_streamly(self, history, gen_conf, **kwargs): - logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) - reasoning_start = False - - if kwargs.get("stop") or "stop" in gen_conf: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop")) - else: - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - - for resp in response: - if not resp.choices: - continue - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += resp.choices[0].delta.reasoning_content + "" - else: - reasoning_start = False - ans = resp.choices[0].delta.content - - tol = total_token_count_from_response(resp) - if not tol: - tol = num_tokens_from_string(resp.choices[0].delta.content) - - if resp.choices[0].finish_reason == "length": - if is_chinese(ans): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - yield ans, tol - - def _length_stop(self, ans): - if is_chinese([ans]): - return ans + LENGTH_NOTIFICATION_CN - return ans + LENGTH_NOTIFICATION_EN - - @property - def _retryable_errors(self) -> set[str]: - return { - LLMErrorCode.ERROR_RATE_LIMIT, - LLMErrorCode.ERROR_SERVER, - } - - def _should_retry(self, error_code: str) -> bool: - return error_code in self._retryable_errors - - def _exceptions(self, e, attempt) -> str | None: - logging.exception("OpenAI chat_with_tools") - # Classify the error - error_code = self._classify_error(e) - if attempt == self.max_retries: - error_code = LLMErrorCode.ERROR_MAX_RETRIES - - if self._should_retry(error_code): - delay = self._get_delay() - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") - time.sleep(delay) - return None - - return f"{ERROR_PREFIX}: {error_code} - {str(e)}" - - def _verbose_tool_use(self, name, args, res): - return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" - - def _append_history(self, hist, tool_call, tool_res): - hist.append( - { - "role": "assistant", - "tool_calls": [ - { - "index": tool_call.index, - "id": tool_call.id, - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - "type": "function", - }, - ], - } - ) - try: - if isinstance(tool_res, dict): - tool_res = json.dumps(tool_res, ensure_ascii=False) - finally: - hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)}) - return hist - - def bind_tools(self, toolcall_session, tools): - if not (toolcall_session and tools): - return - self.is_tools = True - self.toolcall_session = toolcall_session - self.tools = tools - - def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): - gen_conf = self._clean_conf(gen_conf) - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - - ans = "" - tk_count = 0 - hist = deepcopy(history) - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - history = hist - try: - for _ in range(self.max_rounds + 1): - logging.info(f"{self.tools=}") - response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf) - tk_count += total_token_count_from_response(response) - if any([not response.choices, not response.choices[0].message]): - raise Exception(f"500 response structure error. Response: {response}") - - if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls: - if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content: - ans += "" + response.choices[0].message.reasoning_content + "" - - ans += response.choices[0].message.content - if response.choices[0].finish_reason == "length": - ans = self._length_stop(ans) - - return ans, tk_count - - for tool_call in response.choices[0].message.tool_calls: - logging.info(f"Response {tool_call=}") - name = tool_call.function.name - try: - args = json_repair.loads(tool_call.function.arguments) - tool_response = self.toolcall_session.tool_call(name, args) - history = self._append_history(history, tool_call, tool_response) - ans += self._verbose_tool_use(name, args, tool_response) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - ans += self._verbose_tool_use(name, {}, str(e)) - - logging.warning(f"Exceed max rounds: {self.max_rounds}") - history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response, token_count = self._chat(history, gen_conf) - ans += response - tk_count += token_count - return ans, tk_count - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, tk_count - - assert False, "Shouldn't be here." - - def chat(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - try: - return self._chat(history, gen_conf, **kwargs) - except Exception as e: - e = self._exceptions(e, attempt) - if e: - return e, 0 - assert False, "Shouldn't be here." - - def _wrap_toolcall_message(self, stream): - final_tool_calls = {} - - for chunk in stream: - for tool_call in chunk.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - final_tool_calls[index] = tool_call - - final_tool_calls[index].function.arguments += tool_call.function.arguments - - return final_tool_calls - - def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): - gen_conf = self._clean_conf(gen_conf) - tools = self.tools - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - - total_tokens = 0 - hist = deepcopy(history) - # Implement exponential backoff retry strategy - for attempt in range(self.max_retries + 1): - history = hist - try: - for _ in range(self.max_rounds + 1): - reasoning_start = False - logging.info(f"{tools=}") - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf) - final_tool_calls = {} - answer = "" - for resp in response: - if resp.choices[0].delta.tool_calls: - for tool_call in resp.choices[0].delta.tool_calls or []: - index = tool_call.index - - if index not in final_tool_calls: - if not tool_call.function.arguments: - tool_call.function.arguments = "" - final_tool_calls[index] = tool_call - else: - final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else "" - continue - - if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): - raise Exception("500 response structure error.") - - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - - if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: - ans = "" - if not reasoning_start: - reasoning_start = True - ans = "" - ans += resp.choices[0].delta.reasoning_content + "" - yield ans - else: - reasoning_start = False - answer += resp.choices[0].delta.content - yield resp.choices[0].delta.content - - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(resp.choices[0].delta.content) - else: - total_tokens = tol - - finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else "" - if finish_reason == "length": - yield self._length_stop("") - - if answer: - yield total_tokens - return - - for tool_call in final_tool_calls.values(): - name = tool_call.function.name - try: - args = json_repair.loads(tool_call.function.arguments) - yield self._verbose_tool_use(name, args, "Begin to call...") - tool_response = self.toolcall_session.tool_call(name, args) - history = self._append_history(history, tool_call, tool_response) - yield self._verbose_tool_use(name, args, tool_response) - except Exception as e: - logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") - history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)}) - yield self._verbose_tool_use(name, {}, str(e)) - - logging.warning(f"Exceed max rounds: {self.max_rounds}") - history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) - response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) - for resp in response: - if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]): - raise Exception("500 response structure error.") - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - continue - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(resp.choices[0].delta.content) - else: - total_tokens = tol - answer += resp.choices[0].delta.content - yield resp.choices[0].delta.content - - yield total_tokens - return - - except Exception as e: - e = self._exceptions(e, attempt) - if e: - yield e - yield total_tokens - return - - assert False, "Shouldn't be here." - - def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - gen_conf = self._clean_conf(gen_conf) - ans = "" - total_tokens = 0 - try: - for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs): - yield delta_ans - total_tokens += tol - except openai.APIError as e: - yield ans + "\n**ERROR**: " + str(e) - - yield total_tokens - - def _calculate_dynamic_ctx(self, history): - """Calculate dynamic context window size""" - - def count_tokens(text): - """Calculate token count for text""" - # Simple calculation: 1 token per ASCII character - # 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) - total = 0 - for char in text: - if ord(char) < 128: # ASCII characters - total += 1 - else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) - total += 2 - return total - - # Calculate total tokens for all messages - total_tokens = 0 - for message in history: - content = message.get("content", "") - # Calculate content tokens - content_tokens = count_tokens(content) - # Add role marker token overhead - role_tokens = 4 - total_tokens += content_tokens + role_tokens - - # Apply 1.2x buffer ratio - total_tokens_with_buffer = int(total_tokens * 1.2) - - if total_tokens_with_buffer <= 8192: - ctx_size = 8192 - else: - ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 - ctx_size = ctx_multiplier * 8192 - - return ctx_size - - -class GptTurbo(Base): - _FACTORY_NAME = "OpenAI" - - def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs): - if not base_url: - base_url = "https://api.openai.com/v1" - super().__init__(key, model_name, base_url, **kwargs) - - -class XinferenceChat(Base): - _FACTORY_NAME = "Xinference" - - def __init__(self, key=None, model_name="", base_url="", **kwargs): - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) - - -class HuggingFaceChat(Base): - _FACTORY_NAME = "HuggingFace" - - def __init__(self, key=None, model_name="", base_url="", **kwargs): - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - super().__init__(key, model_name.split("___")[0], base_url, **kwargs) - - -class ModelScopeChat(Base): - _FACTORY_NAME = "ModelScope" - - def __init__(self, key=None, model_name="", base_url="", **kwargs): - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - super().__init__(key, model_name.split("___")[0], base_url, **kwargs) - - -class AzureChat(Base): - _FACTORY_NAME = "Azure-OpenAI" - - def __init__(self, key, model_name, base_url, **kwargs): - api_key = json.loads(key).get("api_key", "") - api_version = json.loads(key).get("api_version", "2024-02-01") - super().__init__(key, model_name, base_url, **kwargs) - self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version) - self.model_name = model_name - - @property - def _retryable_errors(self) -> set[str]: - return { - LLMErrorCode.ERROR_RATE_LIMIT, - LLMErrorCode.ERROR_SERVER, - LLMErrorCode.ERROR_QUOTA, - } - - -class BaiChuanChat(Base): - _FACTORY_NAME = "BaiChuan" - - def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs): - if not base_url: - base_url = "https://api.baichuan-ai.com/v1" - super().__init__(key, model_name, base_url, **kwargs) - - @staticmethod - def _format_params(params): - return { - "temperature": params.get("temperature", 0.3), - "top_p": params.get("top_p", 0.85), - } - - def _clean_conf(self, gen_conf): - return { - "temperature": gen_conf.get("temperature", 0.3), - "top_p": gen_conf.get("top_p", 0.85), - } - - def _chat(self, history, gen_conf={}, **kwargs): - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, - **gen_conf, - ) - ans = response.choices[0].message.content.strip() - if response.choices[0].finish_reason == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - return ans, total_token_count_from_response(response) - - def chat_streamly(self, system, history, gen_conf={}, **kwargs): - if system and history and history[0].get("role") != "system": - history.insert(0, {"role": "system", "content": system}) - if "max_tokens" in gen_conf: - del gen_conf["max_tokens"] - ans = "" - total_tokens = 0 - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, - stream=True, - **self._format_params(gen_conf), - ) - for resp in response: - if not resp.choices: - continue - if not resp.choices[0].delta.content: - resp.choices[0].delta.content = "" - ans = resp.choices[0].delta.content - tol = total_token_count_from_response(resp) - if not tol: - total_tokens += num_tokens_from_string(resp.choices[0].delta.content) - else: - total_tokens = tol - if resp.choices[0].finish_reason == "length": - if is_chinese([ans]): - ans += LENGTH_NOTIFICATION_CN - else: - ans += LENGTH_NOTIFICATION_EN - yield ans - - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield total_tokens - - -class LocalAIChat(Base): - _FACTORY_NAME = "LocalAI" - - def __init__(self, key, model_name, base_url=None, **kwargs): - super().__init__(key, model_name, base_url=base_url, **kwargs) - - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - self.client = OpenAI(api_key="empty", base_url=base_url) - self.model_name = model_name.split("___")[0] - - -class VolcEngineChat(Base): - _FACTORY_NAME = "VolcEngine" - - def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs): - """ - Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, - Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use - model_name is for display only - """ - base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3" - ark_api_key = json.loads(key).get("ark_api_key", "") - model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") - super().__init__(ark_api_key, model_name, base_url, **kwargs) - - -class OpenAI_APIChat(Base): - _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"] - - def __init__(self, key, model_name, base_url, **kwargs): - if not base_url: - raise ValueError("url cannot be None") - model_name = model_name.split("___")[0] - super().__init__(key, model_name, base_url, **kwargs) - - -class GPUStackChat(Base): - _FACTORY_NAME = "GPUStack" - - def __init__(self, key=None, model_name="", base_url="", **kwargs): - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - super().__init__(key, model_name, base_url, **kwargs) diff --git a/app/core/rag/llm/cv_model.py b/app/core/rag/llm/cv_model.py deleted file mode 100644 index 663272ce..00000000 --- a/app/core/rag/llm/cv_model.py +++ /dev/null @@ -1,470 +0,0 @@ -import base64 -import json -import os -import tempfile -import logging -from abc import ABC -from copy import deepcopy -from io import BytesIO -from pathlib import Path -from urllib.parse import urljoin -import requests -from openai import OpenAI -from openai.lib.azure import AzureOpenAI -from app.core.rag.nlp import is_english -from app.core.rag.prompts.generator import vision_llm_describe_prompt -from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response - - -class Base(ABC): - def __init__(self, **kwargs): - # Configure retry parameters - self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5))) - self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0))) - self.max_rounds = kwargs.get("max_rounds", 5) - self.is_tools = False - self.tools = [] - self.toolcall_sessions = {} - self.extra_body = None - - def describe(self, image): - raise NotImplementedError("Please implement encode method!") - - def describe_with_prompt(self, image, prompt=None): - raise NotImplementedError("Please implement encode method!") - - def _form_history(self, system, history, images=None): - hist = [] - if system: - hist.append({"role": "system", "content": system}) - for h in history: - if images and h["role"] == "user": - h["content"] = self._image_prompt(h["content"], images) - images = [] - hist.append(h) - return hist - - def _image_prompt(self, text, images): - if not images: - return text - - if isinstance(images, str) or "bytes" in type(images).__name__: - images = [images] - - pmpt = [{"type": "text", "text": text}] - for img in images: - pmpt.append({ - "type": "image_url", - "image_url": { - "url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}" - } - }) - return pmpt - - def chat(self, system, history, gen_conf, images=None, **kwargs): - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=self._form_history(system, history, images), - extra_body=self.extra_body, - ) - return response.choices[0].message.content.strip(), response.usage.total_tokens - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): - ans = "" - tk_count = 0 - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=self._form_history(system, history, images), - stream=True, - extra_body=self.extra_body, - ) - for resp in response: - if not resp.choices[0].delta.content: - continue - delta = resp.choices[0].delta.content - ans = delta - if resp.choices[0].finish_reason == "length": - ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" - if resp.choices[0].finish_reason == "stop": - tk_count += resp.usage.total_tokens - yield ans - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - - yield tk_count - - @staticmethod - def image2base64_rawvalue(self, image): - # Return a base64 string without data URL header - if isinstance(image, bytes): - b64 = base64.b64encode(image).decode("utf-8") - return b64 - if isinstance(image, BytesIO): - data = image.getvalue() - b64 = base64.b64encode(data).decode("utf-8") - return b64 - with BytesIO() as buffered: - try: - image.save(buffered, format="JPEG") - except Exception: - # reset buffer before saving PNG - buffered.seek(0) - buffered.truncate() - image.save(buffered, format="PNG") - data = buffered.getvalue() - b64 = base64.b64encode(data).decode("utf-8") - return b64 - - @staticmethod - def image2base64(image): - # Return a data URL with the correct MIME to avoid provider mismatches - if isinstance(image, bytes): - # Best-effort magic number sniffing - mime = "image/png" - if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8: - mime = "image/jpeg" - b64 = base64.b64encode(image).decode("utf-8") - return f"data:{mime};base64,{b64}" - if isinstance(image, BytesIO): - data = image.getvalue() - mime = "image/png" - if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8: - mime = "image/jpeg" - b64 = base64.b64encode(data).decode("utf-8") - return f"data:{mime};base64,{b64}" - with BytesIO() as buffered: - fmt = "jpeg" - try: - image.save(buffered, format="JPEG") - except Exception: - # reset buffer before saving PNG - buffered.seek(0) - buffered.truncate() - image.save(buffered, format="PNG") - fmt = "png" - data = buffered.getvalue() - b64 = base64.b64encode(data).decode("utf-8") - mime = f"image/{fmt}" - return f"data:{mime};base64,{b64}" - - def prompt(self, b64): - return [ - { - "role": "user", - "content": self._image_prompt( - "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" - if self.lang.lower() == "chinese" - else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.", - b64 - ) - } - ] - - def vision_llm_prompt(self, b64, prompt=None): - return [ - { - "role": "user", - "content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64) - } - ] - - -class GptV4(Base): - _FACTORY_NAME = "OpenAI" - - def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs): - if not base_url: - base_url = "https://api.openai.com/v1" - self.api_key = key - self.client = OpenAI(api_key=key, base_url=base_url) - self.model_name = model_name - self.lang = lang - super().__init__(**kwargs) - - def describe(self, image): - b64 = self.image2base64(image) - res = self.client.chat.completions.create( - model=self.model_name, - messages=self.prompt(b64), - extra_body=self.extra_body, - ) - return res.choices[0].message.content.strip(), total_token_count_from_response(res) - - def describe_with_prompt(self, image, prompt=None): - b64 = self.image2base64(image) - res = self.client.chat.completions.create( - model=self.model_name, - messages=self.vision_llm_prompt(b64, prompt), - extra_body=self.extra_body, - ) - return res.choices[0].message.content.strip(),total_token_count_from_response(res) - - -class AzureGptV4(GptV4): - _FACTORY_NAME = "Azure-OpenAI" - - def __init__(self, key, model_name, lang="Chinese", **kwargs): - api_key = json.loads(key).get("api_key", "") - api_version = json.loads(key).get("api_version", "2024-02-01") - self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) - self.model_name = model_name - self.lang = lang - Base.__init__(self, **kwargs) - - -class QWenCV(GptV4): - _FACTORY_NAME = "Tongyi-Qianwen" - - def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs): - if not base_url: - base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" - super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs) - - def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs): - if video_bytes: - try: - summary, summary_num_tokens = self._process_video(video_bytes, filename) - return summary, summary_num_tokens - except Exception as e: - return "**ERROR**: " + str(e), 0 - - return "**ERROR**: Method chat not supported yet.", 0 - - def _process_video(self, video_bytes, filename): - from dashscope import MultiModalConversation - - video_suffix = Path(filename).suffix or ".mp4" - with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp: - tmp.write(video_bytes) - tmp_path = tmp.name - - video_path = f"file://{tmp_path}" - messages = [ - { - "role": "user", - "content": [ - { - "video": video_path, - "fps": 2, - }, - { - "text": "Please summarize this video in proper sentences.", - }, - ], - } - ] - - def call_api(): - response = MultiModalConversation.call( - api_key=self.api_key, - model=self.model_name, - messages=messages, - ) - summary = response["output"]["choices"][0]["message"].content[0]["text"] - return summary, num_tokens_from_string(summary) - - try: - return call_api() - except Exception as e1: - import dashscope - - dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1" - try: - return call_api() - except Exception as e2: - raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}") - - -class XinferenceCV(GptV4): - _FACTORY_NAME = "Xinference" - - def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs): - base_url = urljoin(base_url, "v1") - self.client = OpenAI(api_key=key, base_url=base_url) - self.model_name = model_name - self.lang = lang - Base.__init__(self, **kwargs) - - -class GPUStackCV(GptV4): - _FACTORY_NAME = "GPUStack" - - def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs): - if not base_url: - raise ValueError("Local llm url cannot be None") - base_url = urljoin(base_url, "v1") - self.client = OpenAI(api_key=key, base_url=base_url) - self.model_name = model_name - self.lang = lang - Base.__init__(self, **kwargs) - - -class OllamaCV(Base): - _FACTORY_NAME = "Ollama" - - def __init__(self, key, model_name, lang="Chinese", **kwargs): - from ollama import Client - self.client = Client(host=kwargs["base_url"]) - self.model_name = model_name - self.lang = lang - self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1))) - Base.__init__(self, **kwargs) - - - def _clean_img(self, img): - if not isinstance(img, str): - return img - - #remove the header like "data/*;base64," - if img.startswith("data:") and ";base64," in img: - img = img.split(";base64,")[1] - return img - - def _clean_conf(self, gen_conf): - options = {} - if "temperature" in gen_conf: - options["temperature"] = gen_conf["temperature"] - if "top_p" in gen_conf: - options["top_k"] = gen_conf["top_p"] - if "presence_penalty" in gen_conf: - options["presence_penalty"] = gen_conf["presence_penalty"] - if "frequency_penalty" in gen_conf: - options["frequency_penalty"] = gen_conf["frequency_penalty"] - return options - - def _form_history(self, system, history, images=None): - hist = deepcopy(history) - if system and hist[0]["role"] == "user": - hist.insert(0, {"role": "system", "content": system}) - if not images: - return hist - temp_images = [] - for img in images: - temp_images.append(self._clean_img(img)) - for his in hist: - if his["role"] == "user": - his["images"] = temp_images - break - return hist - - def describe(self, image): - prompt = self.prompt("") - try: - response = self.client.generate( - model=self.model_name, - prompt=prompt[0]["content"], - images=[image], - ) - ans = response["response"].strip() - return ans, 128 - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def describe_with_prompt(self, image, prompt=None): - vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("") - try: - response = self.client.generate( - model=self.model_name, - prompt=vision_prompt[0]["content"], - images=[image], - ) - ans = response["response"].strip() - return ans, 128 - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat(self, system, history, gen_conf, images=None, **kwargs): - try: - response = self.client.chat( - model=self.model_name, - messages=self._form_history(system, history, images), - options=self._clean_conf(gen_conf), - keep_alive=self.keep_alive - ) - - ans = response["message"]["content"].strip() - return ans, response["eval_count"] + response.get("prompt_eval_count", 0) - except Exception as e: - return "**ERROR**: " + str(e), 0 - - def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): - ans = "" - try: - response = self.client.chat( - model=self.model_name, - messages=self._form_history(system, history, images), - stream=True, - options=self._clean_conf(gen_conf), - keep_alive=self.keep_alive - ) - for resp in response: - if resp["done"]: - yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) - ans = resp["message"]["content"] - yield ans - except Exception as e: - yield ans + "\n**ERROR**: " + str(e) - yield 0 - - -if __name__ == "__main__": - # import sys - # chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy) - - # # 准备配置vision_model信息 - # azure_config = { - # "api_key": "xxxxx", - # "api_version": "2024-02-01" - # } - # # 转换为 JSON 字符串,因为类中使用 json.loads(key) 解析 - # key = json.dumps(azure_config) - # # 初始化 AzureGptV4 - # vision_model = AzureGptV4( - # key=key, # JSON 字符串形式的配置 - # model_name="gpt-4o", - # lang="Chinese", # 默认使用中文 - # base_url="https://fosun-openai-east-us-001.openai.azure.com/" # Azure OpenAI 端点 - # ) - # try: - # # 测试图像描述功能 - # image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/aippt.cn.png" - # with open(image_path, "rb") as image_file: - # image_data = image_file.read() - # - # # 使用 describe 方法 - # description, token_count = vision_model.describe(image_data) - # # from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt - # # description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt()) - # print(f"描述: {description}") - # print(f"使用的令牌数: {token_count}") - # - # except Exception as e: - # print(f"初始化或处理过程中出错: {str(e)}") - - # 准备配置vision_model信息 - # 初始化 QWenCV - vision_model = QWenCV( - key="sk-8e9e40cd171749858ce2d3722ea75669", - model_name="qwen-vl-max", - lang="Chinese", # 默认使用中文 - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1" - ) - try: - # 测试图像描述功能 - image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png" - with open(image_path, "rb") as image_file: - image_data = image_file.read() - - # 使用 describe 方法 - description, token_count = vision_model.describe(image_data) - # from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt - # description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt()) - print(f"描述: {description}") - print(f"使用的令牌数: {token_count}") - - except Exception as e: - print(f"初始化或处理过程中出错: {str(e)}") diff --git a/app/core/rag/llm/sequence2txt_model.py b/app/core/rag/llm/sequence2txt_model.py deleted file mode 100644 index dcea9346..00000000 --- a/app/core/rag/llm/sequence2txt_model.py +++ /dev/null @@ -1,179 +0,0 @@ -import base64 -import io -import json -import os -import re -from abc import ABC - -import requests -from openai import OpenAI -from openai.lib.azure import AzureOpenAI - -from app.core.rag.common.token_utils import num_tokens_from_string - - -class Base(ABC): - def __init__(self, key, model_name, **kwargs): - """ - Abstract base class constructor. - Parameters are not stored; initialization is left to subclasses. - """ - pass - - def transcription(self, audio_path, **kwargs): - audio_file = open(audio_path, "rb") - transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file) - return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) - - def audio2base64(self, audio): - if isinstance(audio, bytes): - return base64.b64encode(audio).decode("utf-8") - if isinstance(audio, io.BytesIO): - return base64.b64encode(audio.getvalue()).decode("utf-8") - raise TypeError("The input audio file should be in binary format.") - - -class GPTSeq2txt(Base): - _FACTORY_NAME = "OpenAI" - - def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs): - if not base_url: - base_url = "https://api.openai.com/v1" - self.client = OpenAI(api_key=key, base_url=base_url) - self.model_name = model_name - - -class QWenSeq2txt(Base): - _FACTORY_NAME = "Tongyi-Qianwen" - - def __init__(self, key, model_name="qwen-audio-asr", **kwargs): - import dashscope - - dashscope.api_key = key - self.model_name = model_name - - def transcription(self, audio_path): - if "paraformer" in self.model_name or "sensevoice" in self.model_name: - return f"**ERROR**: model {self.model_name} is not suppported yet.", 0 - - from dashscope import MultiModalConversation - - audio_path = f"file://{audio_path}" - messages = [ - { - "role": "user", - "content": [{"audio": audio_path}], - } - ] - - response = None - full_content = "" - try: - response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True) - for response in response: - try: - full_content += response["output"]["choices"][0]["message"].content[0]["text"] - except Exception: - pass - return full_content, num_tokens_from_string(full_content) - except Exception as e: - return "**ERROR**: " + str(e), 0 - - -class AzureSeq2txt(Base): - _FACTORY_NAME = "Azure-OpenAI" - - def __init__(self, key, model_name, lang="Chinese", **kwargs): - self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") - self.model_name = model_name - self.lang = lang - - -class XinferenceSeq2txt(Base): - _FACTORY_NAME = "Xinference" - - def __init__(self, key, model_name="whisper-small", **kwargs): - self.base_url = kwargs.get("base_url", None) - self.model_name = model_name - self.key = key - - def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): - if isinstance(audio, str): - audio_file = open(audio, "rb") - audio_data = audio_file.read() - audio_file_name = audio.split("/")[-1] - else: - audio_data = audio - audio_file_name = "audio.wav" - - payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature} - - files = {"file": (audio_file_name, audio_data, "audio/wav")} - - try: - response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload) - response.raise_for_status() - result = response.json() - - if "text" in result: - transcription_text = result["text"].strip() - return transcription_text, num_tokens_from_string(transcription_text) - else: - return "**ERROR**: Failed to retrieve transcription.", 0 - - except requests.exceptions.RequestException as e: - return f"**ERROR**: {str(e)}", 0 - - -class GPUStackSeq2txt(Base): - _FACTORY_NAME = "GPUStack" - - def __init__(self, key, model_name, base_url): - if not base_url: - raise ValueError("url cannot be None") - if base_url.split("/")[-1] != "v1": - base_url = os.path.join(base_url, "v1") - self.base_url = base_url - self.model_name = model_name - self.key = key - - -class ZhipuSeq2txt(Base): - _FACTORY_NAME = "ZHIPU-AI" - - def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs): - if not base_url: - base_url = "https://open.bigmodel.cn/api/paas/v4" - self.base_url = base_url - self.api_key = key - self.model_name = model_name - self.gen_conf = kwargs.get("gen_conf", {}) - self.stream = kwargs.get("stream", False) - - def transcription(self, audio_path): - payload = { - "model": self.model_name, - "temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75", - "stream": self.stream, - } - - headers = {"Authorization": f"Bearer {self.api_key}"} - with open(audio_path, "rb") as audio_file: - files = {"file": audio_file} - - try: - response = requests.post( - url=f"{self.base_url}/audio/transcriptions", - data=payload, - files=files, - headers=headers, - ) - body = response.json() - if response.status_code == 200: - full_content = body["text"] - return full_content, num_tokens_from_string(full_content) - else: - error = body["error"] - return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0 - except Exception as e: - return "**ERROR**: " + str(e), 0 diff --git a/app/core/rag/models/__init__.py b/app/core/rag/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/models/chunk.py b/app/core/rag/models/chunk.py deleted file mode 100644 index 731924e1..00000000 --- a/app/core/rag/models/chunk.py +++ /dev/null @@ -1,72 +0,0 @@ -from pydantic import BaseModel, Field - - -class ChildDocumentChunk(BaseModel): - """Class for storing a piece of text and associated metadata.""" - - page_content: str - - vector: list[float] | None = None - - """Arbitrary metadata about the page content (e.g., source, relationships to other - documents, etc.). - """ - metadata: dict = Field(default_factory=dict) - - -class DocumentChunk(BaseModel): - """Class for storing a piece of text and associated metadata.""" - - page_content: str - - vector: list[float] | None = None - - """Arbitrary metadata about the page content (e.g., source, relationships to other - documents, etc.). - """ - metadata: dict = Field(default_factory=dict) - - children: list[ChildDocumentChunk] | None = None - - -class GeneralStructureChunk(BaseModel): - """ - General Structure Chunk. - """ - - general_chunks: list[str] - - -class ParentChildChunk(BaseModel): - """ - Parent Child Chunk. - """ - - parent_content: str - child_contents: list[str] - - -class ParentChildStructureChunk(BaseModel): - """ - Parent Child Structure Chunk. - """ - - parent_child_chunks: list[ParentChildChunk] - parent_mode: str = "paragraph" - - -class QAChunk(BaseModel): - """ - QA Chunk. - """ - - question: str - answer: str - - -class QAStructureChunk(BaseModel): - """ - QAStructureChunk. - """ - - qa_chunks: list[QAChunk] diff --git a/app/core/rag/nlp/__init__.py b/app/core/rag/nlp/__init__.py deleted file mode 100644 index dae97e0e..00000000 --- a/app/core/rag/nlp/__init__.py +++ /dev/null @@ -1,857 +0,0 @@ -import logging -import random -from collections import Counter - -from app.core.rag.common.token_utils import num_tokens_from_string -from . import rag_tokenizer -import re -import copy -import roman_numbers as r -from word2number import w2n -from cn2an import cn2an -from PIL import Image - -import chardet - -all_codecs = [ - 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', - 'cp037', 'cp273', 'cp424', 'cp437', - 'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857', - 'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869', - 'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125', - 'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256', - 'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr', - 'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2', - 'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1', - 'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7', - 'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13', - 'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u', - 'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman', - 'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213', - 'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251', - 'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256', - 'windows-1257', 'windows-1258', 'latin-2' -] - - -def find_codec(blob): - detected = chardet.detect(blob[:1024]) - if detected['confidence'] > 0.5: - if detected['encoding'] == "ascii": - return "utf-8" - - for c in all_codecs: - try: - blob[:1024].decode(c) - return c - except Exception: - pass - try: - blob.decode(c) - return c - except Exception: - pass - - return "utf-8" - - -QUESTION_PATTERN = [ - r"第([零一二三四五六七八九十百0-9]+)问", - r"第([零一二三四五六七八九十百0-9]+)条", - r"[\((]([零一二三四五六七八九十百]+)[\))]", - r"第([0-9]+)问", - r"第([0-9]+)条", - r"([0-9]{1,2})[\. 、]", - r"([零一二三四五六七八九十百]+)[ 、]", - r"[\((]([0-9]{1,2})[\))]", - r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", - r"QUESTION (I+V?|VI*|XI|IX|X)", - r"QUESTION ([0-9]+)", -] - - -def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list): - section, last_section = box['text'], last_box['text'] - q_reg = r'(\w|\W)*?(?:?|\?|\n|$)+' - full_reg = reg + q_reg - has_bull = re.match(full_reg, section) - index_str = None - if has_bull: - if 'x0' not in last_box: - last_box['x0'] = box['x0'] - if 'top' not in last_box: - last_box['top'] = box['top'] - if last_bull and box['x0'] - last_box['x0'] > 10: - return None, last_index - if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20: - return None, last_index - avg_bull_x0 = 0 - if bull_x0_list: - avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list) - else: - avg_bull_x0 = box['x0'] - if box['x0'] - avg_bull_x0 > 10: - return None, last_index - index_str = has_bull.group(1) - index = index_int(index_str) - if last_section[-1] == ':' or last_section[-1] == ':': - return None, last_index - if not last_index or index >= last_index: - bull_x0_list.append(box['x0']) - return has_bull, index - if section[-1] == '?' or section[-1] == '?': - bull_x0_list.append(box['x0']) - return has_bull, index - if box['layout_type'] == 'title': - bull_x0_list.append(box['x0']) - return has_bull, index - pure_section = section.lstrip(re.match(reg, section).group()).lower() - ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)' - if re.match(ask_reg, pure_section): - bull_x0_list.append(box['x0']) - return has_bull, index - return None, last_index - - -def index_int(index_str): - res = -1 - try: - res = int(index_str) - except ValueError: - try: - res = w2n.word_to_num(index_str) - except ValueError: - try: - res = cn2an(index_str) - except ValueError: - try: - res = r.number(index_str) - except ValueError: - return -1 - return res - - -def qbullets_category(sections): - global QUESTION_PATTERN - hits = [0] * len(QUESTION_PATTERN) - for i, pro in enumerate(QUESTION_PATTERN): - for sec in sections: - if re.match(pro, sec) and not not_bullet(sec): - hits[i] += 1 - break - maxium = 0 - res = -1 - for i, h in enumerate(hits): - if h <= maxium: - continue - res = i - maxium = h - return res, QUESTION_PATTERN[res] - - -BULLET_PATTERN = [[ - r"第[零一二三四五六七八九十百0-9]+(分?编|部分)", - r"第[零一二三四五六七八九十百0-9]+章", - r"第[零一二三四五六七八九十百0-9]+节", - r"第[零一二三四五六七八九十百0-9]+条", - r"[\((][零一二三四五六七八九十百]+[\))]", -], [ - r"第[0-9]+章", - r"第[0-9]+节", - r"[0-9]{,2}[\. 、]", - r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]", - r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", - r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}", -], [ - r"第[零一二三四五六七八九十百0-9]+章", - r"第[零一二三四五六七八九十百0-9]+节", - r"[零一二三四五六七八九十百]+[ 、]", - r"[\((][零一二三四五六七八九十百]+[\))]", - r"[\((][0-9]{,2}[\))]", -], [ - r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)", - r"Chapter (I+V?|VI*|XI|IX|X)", - r"Section [0-9]+", - r"Article [0-9]+" -], [ - r"^#[^#]", - r"^##[^#]", - r"^###.*", - r"^####.*", - r"^#####.*", - r"^######.*", -] -] - - -def random_choices(arr, k): - k = min(len(arr), k) - return random.choices(arr, k=k) - - -def not_bullet(line): - patt = [ - r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}" - ] - return any([re.match(r, line) for r in patt]) - - -def bullets_category(sections): - global BULLET_PATTERN - hits = [0] * len(BULLET_PATTERN) - for i, pro in enumerate(BULLET_PATTERN): - for sec in sections: - sec = sec.strip() - for p in pro: - if re.match(p, sec) and not not_bullet(sec): - hits[i] += 1 - break - maxium = 0 - res = -1 - for i, h in enumerate(hits): - if h <= maxium: - continue - res = i - maxium = h - return res - - -def is_english(texts): - if not texts: - return False - - pattern = re.compile(r"[`a-zA-Z0-9\s.,':;/\"?<>!\(\)\-]") - - if isinstance(texts, str): - texts = list(texts) - elif isinstance(texts, list): - texts = [t for t in texts if isinstance(t, str) and t.strip()] - else: - return False - - if not texts: - return False - - eng = sum(1 for t in texts if pattern.fullmatch(t.strip())) - return (eng / len(texts)) > 0.8 - - -def is_chinese(text): - if not text: - return False - chinese = 0 - for ch in text: - if '\u4e00' <= ch <= '\u9fff': - chinese += 1 - if chinese / len(text) > 0.2: - return True - return False - - -def tokenize(d, t, eng): - d["content_with_weight"] = t - t = re.sub(r"]{0,12})?>", " ", t) - d["content_ltks"] = rag_tokenizer.tokenize(t) - d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) - - -def tokenize_chunks(chunks, doc, eng, pdf_parser=None): - res = [] - # wrap up as es documents - for ii, ck in enumerate(chunks): - if len(ck.strip()) == 0: - continue - logging.debug("-- {}".format(ck)) - d = copy.deepcopy(doc) - if pdf_parser: - try: - d["image"], poss = pdf_parser.crop(ck, need_position=True) - add_positions(d, poss) - ck = pdf_parser.remove_tag(ck) - except NotImplementedError: - pass - else: - add_positions(d, [[ii]*5]) - tokenize(d, ck, eng) - res.append(d) - return res - - -def tokenize_chunks_with_images(chunks, doc, eng, images): - res = [] - # wrap up as es documents - for ii, (ck, image) in enumerate(zip(chunks, images)): - if len(ck.strip()) == 0: - continue - logging.debug("-- {}".format(ck)) - d = copy.deepcopy(doc) - d["image"] = image - add_positions(d, [[ii]*5]) - tokenize(d, ck, eng) - res.append(d) - return res - - -def tokenize_table(tbls, doc, eng, batch_size=10): - res = [] - # add tables - for (img, rows), poss in tbls: - if not rows: - continue - if isinstance(rows, str): - d = copy.deepcopy(doc) - tokenize(d, rows, eng) - d["content_with_weight"] = rows - if img: - d["image"] = img - d["doc_type_kwd"] = "image" - if poss: - add_positions(d, poss) - res.append(d) - continue - de = "; " if eng else "; " - for i in range(0, len(rows), batch_size): - d = copy.deepcopy(doc) - r = de.join(rows[i:i + batch_size]) - tokenize(d, r, eng) - if img: - d["image"] = img - d["doc_type_kwd"] = "image" - add_positions(d, poss) - res.append(d) - return res - - -def add_positions(d, poss): - if not poss: - return - page_num_int = [] - position_int = [] - top_int = [] - for pn, left, right, top, bottom in poss: - page_num_int.append(int(pn + 1)) - top_int.append(int(top)) - position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom))) - d["page_num_int"] = page_num_int - d["position_int"] = position_int - d["top_int"] = top_int - - -def remove_contents_table(sections, eng=False): - i = 0 - while i < len(sections): - def get(i): - nonlocal sections - return (sections[i] if isinstance(sections[i], - type("")) else sections[i][0]).strip() - - if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$", - re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)): - i += 1 - continue - sections.pop(i) - if i >= len(sections): - break - prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2]) - while not prefix: - sections.pop(i) - if i >= len(sections): - break - prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2]) - sections.pop(i) - if i >= len(sections) or not prefix: - break - for j in range(i, min(i + 128, len(sections))): - if not re.match(prefix, get(j)): - continue - for _ in range(i, j): - sections.pop(i) - break - - -def make_colon_as_title(sections): - if not sections: - return [] - if isinstance(sections[0], type("")): - return sections - i = 0 - while i < len(sections): - txt, layout = sections[i] - i += 1 - txt = txt.split("@")[0].strip() - if not txt: - continue - if txt[-1] not in "::": - continue - txt = txt[::-1] - arr = re.split(r"([。?!!?;;]| \.)", txt) - if len(arr) < 2 or len(arr[1]) < 32: - continue - sections.insert(i - 1, (arr[0][::-1], "title")) - i += 1 - - -def title_frequency(bull, sections): - bullets_size = len(BULLET_PATTERN[bull]) - levels = [bullets_size + 1 for _ in range(len(sections))] - if not sections or bull < 0: - return bullets_size + 1, levels - - for i, (txt, layout) in enumerate(sections): - for j, p in enumerate(BULLET_PATTERN[bull]): - if re.match(p, txt.strip()) and not not_bullet(txt): - levels[i] = j - break - else: - if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]): - levels[i] = bullets_size - most_level = bullets_size + 1 - for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1): - if level <= bullets_size: - most_level = level - break - return most_level, levels - - -def not_title(txt): - if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt): - return False - if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32): - return True - return re.search(r"[,;,。;!!]", txt) - -def tree_merge(bull, sections, depth): - - if not sections or bull < 0: - return sections - if isinstance(sections[0], type("")): - sections = [(s, "") for s in sections] - - # filter out position information in pdf sections - sections = [(t, o) for t, o in sections if - t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] - - def get_level(bull, section): - text, layout = section - text = re.sub(r"\u3000", " ", text).strip() - - for i, title in enumerate(BULLET_PATTERN[bull]): - if re.match(title, text.strip()): - return i+1, text - else: - if re.search(r"(title|head)", layout) and not not_title(text): - return len(BULLET_PATTERN[bull])+1, text - else: - return len(BULLET_PATTERN[bull])+2, text - level_set = set() - lines = [] - for section in sections: - level, text = get_level(bull, section) - if not text.strip("\n"): - continue - - lines.append((level, text)) - level_set.add(level) - - sorted_levels = sorted(list(level_set)) - - if depth <= len(sorted_levels): - target_level = sorted_levels[depth - 1] - else: - target_level = sorted_levels[-1] - - if target_level == len(BULLET_PATTERN[bull]) + 2: - target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0] - - root = Node(level=0, depth=target_level, texts=[]) - root.build_tree(lines) - - return [("\n").join(element) for element in root.get_tree() if element] - -def hierarchical_merge(bull, sections, depth): - - if not sections or bull < 0: - return [] - if isinstance(sections[0], type("")): - sections = [(s, "") for s in sections] - sections = [(t, o) for t, o in sections if - t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())] - bullets_size = len(BULLET_PATTERN[bull]) - levels = [[] for _ in range(bullets_size + 2)] - - for i, (txt, layout) in enumerate(sections): - for j, p in enumerate(BULLET_PATTERN[bull]): - if re.match(p, txt.strip()): - levels[j].append(i) - break - else: - if re.search(r"(title|head)", layout) and not not_title(txt): - levels[bullets_size].append(i) - else: - levels[bullets_size + 1].append(i) - sections = [t for t, _ in sections] - - # for s in sections: print("--", s) - - def binary_search(arr, target): - if not arr: - return -1 - if target > arr[-1]: - return len(arr) - 1 - if target < arr[0]: - return -1 - s, e = 0, len(arr) - while e - s > 1: - i = (e + s) // 2 - if target > arr[i]: - s = i - continue - elif target < arr[i]: - e = i - continue - else: - assert False - return s - - cks = [] - readed = [False] * len(sections) - levels = levels[::-1] - for i, arr in enumerate(levels[:depth]): - for j in arr: - if readed[j]: - continue - readed[j] = True - cks.append([j]) - if i + 1 == len(levels) - 1: - continue - for ii in range(i + 1, len(levels)): - jj = binary_search(levels[ii], j) - if jj < 0: - continue - if levels[ii][jj] > cks[-1][-1]: - cks[-1].pop(-1) - cks[-1].append(levels[ii][jj]) - for ii in cks[-1]: - readed[ii] = True - - if not cks: - return cks - - for i in range(len(cks)): - cks[i] = [sections[j] for j in cks[i][::-1]] - logging.debug("\n* ".join(cks[i])) - - res = [[]] - num = [0] - for ck in cks: - if len(ck) == 1: - n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0])) - if n + num[-1] < 218: - res[-1].append(ck[0]) - num[-1] += n - continue - res.append(ck) - num.append(n) - continue - res.append(ck) - num.append(218) - - return res - - -def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0): - from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser - if not sections: - return [] - if isinstance(sections, str): - sections = [sections] - if isinstance(sections[0], str): - sections = [(s, "") for s in sections] - cks = [""] - tk_nums = [0] - - def add_chunk(t, pos): - nonlocal cks, tk_nums, delimiter - tnum = num_tokens_from_string(t) - if not pos: - pos = "" - if tnum < 8: - pos = "" - # Ensure that the length of the merged chunk does not exceed chunk_token_num - if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.: - if cks: - overlapped = RAGPdfParser.remove_tag(cks[-1]) - t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t - if t.find(pos) < 0: - t += pos - cks.append(t) - tk_nums.append(tnum) - else: - if cks[-1].find(pos) < 0: - t += pos - cks[-1] += t - tk_nums[-1] += tnum - - dels = get_delimiters(delimiter) - for sec, pos in sections: - if num_tokens_from_string(sec) < chunk_token_num: - add_chunk("\n"+sec, pos) - continue - split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL) - for sub_sec in split_sec: - if re.match(f"^{dels}$", sub_sec): - continue - add_chunk("\n"+sub_sec, pos) - - return cks - - -def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0): - from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser - if not texts or len(texts) != len(images): - return [], [] - cks = [""] - result_images = [None] - tk_nums = [0] - - def add_chunk(t, image, pos=""): - nonlocal cks, result_images, tk_nums, delimiter - tnum = num_tokens_from_string(t) - if not pos: - pos = "" - if tnum < 8: - pos = "" - # Ensure that the length of the merged chunk does not exceed chunk_token_num - if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.: - if cks: - overlapped = RAGPdfParser.remove_tag(cks[-1]) - t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t - if t.find(pos) < 0: - t += pos - cks.append(t) - result_images.append(image) - tk_nums.append(tnum) - else: - if cks[-1].find(pos) < 0: - t += pos - cks[-1] += t - if result_images[-1] is None: - result_images[-1] = image - else: - result_images[-1] = concat_img(result_images[-1], image) - tk_nums[-1] += tnum - - dels = get_delimiters(delimiter) - for text, image in zip(texts, images): - # if text is tuple, unpack it - if isinstance(text, tuple): - text_str = text[0] - text_pos = text[1] if len(text) > 1 else "" - split_sec = re.split(r"(%s)" % dels, text_str) - for sub_sec in split_sec: - if re.match(f"^{dels}$", sub_sec): - continue - add_chunk("\n"+sub_sec, image, text_pos) - else: - split_sec = re.split(r"(%s)" % dels, text) - for sub_sec in split_sec: - if re.match(f"^{dels}$", sub_sec): - continue - add_chunk("\n"+sub_sec, image) - - return cks, result_images - -def docx_question_level(p, bull=-1): - txt = re.sub(r"\u3000", " ", p.text).strip() - if p.style.name.startswith('Heading'): - return int(p.style.name.split(' ')[-1]), txt - else: - if bull < 0: - return 0, txt - for j, title in enumerate(BULLET_PATTERN[bull]): - if re.match(title, txt): - return j + 1, txt - return len(BULLET_PATTERN[bull])+1, txt - - -def concat_img(img1, img2): - if img1 and not img2: - return img1 - if not img1 and img2: - return img2 - if not img1 and not img2: - return None - - if img1 is img2: - return img1 - - if isinstance(img1, Image.Image) and isinstance(img2, Image.Image): - pixel_data1 = img1.tobytes() - pixel_data2 = img2.tobytes() - if pixel_data1 == pixel_data2: - return img1 - - width1, height1 = img1.size - width2, height2 = img2.size - - new_width = max(width1, width2) - new_height = height1 + height2 - new_image = Image.new('RGB', (new_width, new_height)) - - new_image.paste(img1, (0, 0)) - new_image.paste(img2, (0, height1)) - return new_image - - -def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): - if not sections: - return [], [] - - cks = [""] - images = [None] - tk_nums = [0] - - def add_chunk(t, image, pos=""): - nonlocal cks, tk_nums, delimiter - tnum = num_tokens_from_string(t) - if tnum < 8: - pos = "" - if cks[-1] == "" or tk_nums[-1] > chunk_token_num: - if t.find(pos) < 0: - t += pos - cks.append(t) - images.append(image) - tk_nums.append(tnum) - else: - if cks[-1].find(pos) < 0: - t += pos - cks[-1] += t - images[-1] = concat_img(images[-1], image) - tk_nums[-1] += tnum - - dels = get_delimiters(delimiter) - line = "" - for sec, image in sections: - if not image: - line += sec + "\n" - continue - split_sec = re.split(r"(%s)" % dels, line + sec) - for sub_sec in split_sec: - if re.match(f"^{dels}$", sub_sec): - continue - add_chunk("\n"+sub_sec, image,"") - line = "" - - if line: - split_sec = re.split(r"(%s)" % dels, line) - for sub_sec in split_sec: - if re.match(f"^{dels}$", sub_sec): - continue - add_chunk("\n"+sub_sec, image,"") - - return cks, images - - -def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]: - pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) - return re.findall(pattern, text, flags=re.DOTALL) - - -def get_delimiters(delimiters: str): - dels = [] - s = 0 - for m in re.finditer(r"`([^`]+)`", delimiters, re.I): - f, t = m.span() - dels.append(m.group(1)) - dels.extend(list(delimiters[s: f])) - s = t - if s < len(delimiters): - dels.extend(list(delimiters[s:])) - - dels.sort(key=lambda x: -len(x)) - dels = [re.escape(d) for d in dels if d] - dels = [d for d in dels if d] - dels_pattern = "|".join(dels) - - return dels_pattern - -class Node: - def __init__(self, level, depth=-1, texts=None): - self.level = level - self.depth = depth - self.texts = texts or [] - self.children = [] - - def add_child(self, child_node): - self.children.append(child_node) - - def get_children(self): - return self.children - - def get_level(self): - return self.level - - def get_texts(self): - return self.texts - - def set_texts(self, texts): - self.texts = texts - - def add_text(self, text): - self.texts.append(text) - - def clear_text(self): - self.texts = [] - - def __repr__(self): - return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})" - - def build_tree(self, lines): - stack = [self] - for level, text in lines: - if self.depth != -1 and level > self.depth: - # Beyond target depth: merge content into the current leaf instead of creating deeper nodes - stack[-1].add_text(text) - continue - - # Move up until we find the proper parent whose level is strictly smaller than current - while len(stack) > 1 and level <= stack[-1].get_level(): - stack.pop() - - node = Node(level=level, texts=[text]) - # Attach as child of current parent and descend - stack[-1].add_child(node) - stack.append(node) - - return self - - def get_tree(self): - tree_list = [] - self._dfs(self, tree_list, []) - return tree_list - - def _dfs(self, node, tree_list, titles): - level = node.get_level() - texts = node.get_texts() - child = node.get_children() - - if level == 0 and texts: - tree_list.append("\n".join(titles+texts)) - - # Titles within configured depth are accumulated into the current path - if 1 <= level <= self.depth: - path_titles = titles + texts - else: - path_titles = titles - - # Body outside the depth limit becomes its own chunk under the current title path - if level > self.depth and texts: - tree_list.append("\n".join(path_titles + texts)) - - # A leaf title within depth emits its title path as a chunk (header-only section) - elif not child and (1 <= level <= self.depth): - tree_list.append("\n".join(path_titles)) - - # Recurse into children with the updated title path - for c in child: - self._dfs(c, tree_list, path_titles) \ No newline at end of file diff --git a/app/core/rag/nlp/query.py b/app/core/rag/nlp/query.py deleted file mode 100644 index 0922b97e..00000000 --- a/app/core/rag/nlp/query.py +++ /dev/null @@ -1,261 +0,0 @@ -import logging -import json -import re -from collections import defaultdict - -from app.core.rag.utils.doc_store_conn import MatchTextExpr -from . import rag_tokenizer, term_weight, synonym - - -class FulltextQueryer: - def __init__(self): - self.tw = term_weight.Dealer() - self.syn = synonym.Dealer() - self.query_fields = [ - "title_tks^10", - "title_sm_tks^5", - "important_kwd^30", - "important_tks^20", - "question_tks^20", - "content_ltks^2", - "content_sm_ltks", - ] - - @staticmethod - def subSpecialChar(line): - return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip() - - @staticmethod - def isChinese(line): - arr = re.split(r"[ \t]+", line) - if len(arr) <= 3: - return True - e = 0 - for t in arr: - if not re.match(r"[a-zA-Z]+$", t): - e += 1 - return e * 1.0 / len(arr) >= 0.7 - - @staticmethod - def rmWWW(txt): - patts = [ - ( - r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*", - "", - ), - (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), - ( - r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", - " ") - ] - otxt = txt - for r, p in patts: - txt = re.sub(r, p, txt, flags=re.IGNORECASE) - if not txt: - txt = otxt - return txt - - @staticmethod - def add_space_between_eng_zh(txt): - # (ENG/ENG+NUM) + ZH - txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt) - # ENG + ZH - txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt) - # ZH + (ENG/ENG+NUM) - txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt) - txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt) - return txt - - def question(self, txt, tbl="qa", min_match: float = 0.6): - txt = FulltextQueryer.add_space_between_eng_zh(txt) - txt = re.sub( - r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+", - " ", - rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())), - ).strip() - otxt = txt - txt = FulltextQueryer.rmWWW(txt) - - if not self.isChinese(txt): - txt = FulltextQueryer.rmWWW(txt) - tks = rag_tokenizer.tokenize(txt).split() - keywords = [t for t in tks if t] - tks_w = self.tw.weights(tks, preprocess=False) - tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w] - tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk] - tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] - tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()] - syns = [] - for tk, w in tks_w[:256]: - syn = self.syn.lookup(tk) - syn = rag_tokenizer.tokenize(" ".join(syn)).split() - keywords.extend(syn) - syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] - syns.append(" ".join(syn)) - - q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if - tk and not re.match(r"[.^+\(\)-]", tk)] - for i in range(1, len(tks_w)): - left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip() - if not left or not right: - continue - q.append( - '"%s %s"^%.4f' - % ( - tks_w[i - 1][0], - tks_w[i][0], - max(tks_w[i - 1][1], tks_w[i][1]) * 2, - ) - ) - if not q: - q.append(txt) - query = " ".join(q) - return MatchTextExpr( - self.query_fields, query, 100 - ), keywords - - def need_fine_grained_tokenize(tk): - if len(tk) < 3: - return False - if re.match(r"[0-9a-z\.\+#_\*-]+$", tk): - return False - return True - - txt = FulltextQueryer.rmWWW(txt) - qs, keywords = [], [] - for tt in self.tw.split(txt)[:256]: # .split(): - if not tt: - continue - keywords.append(tt) - twts = self.tw.weights([tt]) - syns = self.syn.lookup(tt) - if syns and len(keywords) < 32: - keywords.extend(syns) - logging.debug(json.dumps(twts, ensure_ascii=False)) - tms = [] - for tk, w in sorted(twts, key=lambda x: x[1] * -1): - sm = ( - rag_tokenizer.fine_grained_tokenize(tk).split() - if need_fine_grained_tokenize(tk) - else [] - ) - sm = [ - re.sub( - r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", - "", - m, - ) - for m in sm - ] - sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1] - sm = [m for m in sm if len(m) > 1] - - if len(keywords) < 32: - keywords.append(re.sub(r"[ \\\"']+", "", tk)) - keywords.extend(sm) - - tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] - if len(keywords) < 32: - keywords.extend([s for s in tk_syns if s]) - tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] - tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] - - if len(keywords) >= 32: - break - - tk = FulltextQueryer.subSpecialChar(tk) - if tk.find(" ") > 0: - tk = '"%s"' % tk - if tk_syns: - tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) - if sm: - tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm)) - if tk.strip(): - tms.append((tk, w)) - - tms = " ".join([f"({t})^{w}" for t, w in tms]) - - if len(twts) > 1: - tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt) - - syns = " OR ".join( - [ - '"%s"' - % rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s)) - for s in syns - ] - ) - if syns and tms: - tms = f"({tms})^5 OR ({syns})^0.7" - - qs.append(tms) - - if qs: - query = " OR ".join([f"({t})" for t in qs if t]) - if not query: - query = otxt - return MatchTextExpr( - self.query_fields, query, 100, {"minimum_should_match": min_match} - ), keywords - return None, keywords - - def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): - from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity - import numpy as np - - sims = CosineSimilarity([avec], bvecs) - tksim = self.token_similarity(atks, btkss) - if np.sum(sims[0]) == 0: - return np.array(tksim), tksim, sims[0] - return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] - - def token_similarity(self, atks, btkss): - def toDict(tks): - if isinstance(tks, str): - tks = tks.split() - d = defaultdict(int) - wts = self.tw.weights(tks, preprocess=False) - for i, (t, c) in enumerate(wts): - d[t] += c - return d - - atks = toDict(atks) - btkss = [toDict(tks) for tks in btkss] - return [self.similarity(atks, btks) for btks in btkss] - - def similarity(self, qtwt, dtwt): - if isinstance(dtwt, type("")): - dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)} - if isinstance(qtwt, type("")): - qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)} - s = 1e-9 - for k, v in qtwt.items(): - if k in dtwt: - s += v #* dtwt[k] - q = 1e-9 - for k, v in qtwt.items(): - q += v #* v - return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 ))) - - def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30): - if isinstance(content_tks, str): - content_tks = [c.strip() for c in content_tks.strip() if c.strip()] - tks_w = self.tw.weights(content_tks, preprocess=False) - - keywords = [f'"{k.strip()}"' for k in keywords] - for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]: - tk_syns = self.syn.lookup(tk) - tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] - tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] - tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] - tk = FulltextQueryer.subSpecialChar(tk) - if tk.find(" ") > 0: - tk = '"%s"' % tk - if tk_syns: - tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns) - if tk: - keywords.append(f"{tk}^{w}") - - return MatchTextExpr(self.query_fields, " ".join(keywords), 100, - {"minimum_should_match": min(3, len(keywords) // 10)}) diff --git a/app/core/rag/nlp/rag_tokenizer.py b/app/core/rag/nlp/rag_tokenizer.py deleted file mode 100644 index 55dc7d95..00000000 --- a/app/core/rag/nlp/rag_tokenizer.py +++ /dev/null @@ -1,499 +0,0 @@ -import logging -import copy -import datrie -import math -import os -import re -import string -import sys -from hanziconv import HanziConv -from nltk import word_tokenize -from nltk.stem import PorterStemmer, WordNetLemmatizer -from app.core.rag.common.file_utils import get_project_base_directory - - -class RagTokenizer: - def key_(self, line): - return str(line.lower().encode("utf-8"))[2:-1] - - def rkey_(self, line): - return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1] - - def loadDict_(self, fnm): - logging.info(f"[HUQIE]:Build trie from {fnm}") - try: - of = open(fnm, "r", encoding='utf-8') - while True: - line = of.readline() - if not line: - break - line = re.sub(r"[\r\n]+", "", line) - line = re.split(r"[ \t]", line) - k = self.key_(line[0]) - F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5) - if k not in self.trie_ or self.trie_[k][0] < F: - self.trie_[self.key_(line[0])] = (F, line[2]) - self.trie_[self.rkey_(line[0])] = 1 - - trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie" - logging.info(f"[HUQIE]:Build trie cache to {trie_file_name}") - self.trie_.save(trie_file_name) - of.close() - except Exception: - logging.exception(f"[HUQIE]:Build trie {fnm} failed") - - def __init__(self, debug=False): - self.DEBUG = debug - self.DENOMINATOR = 1000000 - - self.stemmer = PorterStemmer() - self.lemmatizer = WordNetLemmatizer() - - self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)" - - trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie" - # check if trie file existence - if os.path.exists(trie_file_name): - try: - # load trie from file - self.trie_ = datrie.Trie.load(trie_file_name) - return - except Exception: - # fail to load trie from file, build default trie - logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file") - self.trie_ = datrie.Trie(string.printable) - else: - # file not exist, build default trie - logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file") - self.trie_ = datrie.Trie(string.printable) - - # load data from dict file and save to trie file - self.loadDict_(os.path.join(get_project_base_directory(), "app/core/rag/res", "huqie") + ".txt") - - def loadUserDict(self, fnm): - try: - self.trie_ = datrie.Trie.load(fnm + ".trie") - return - except Exception: - self.trie_ = datrie.Trie(string.printable) - self.loadDict_(fnm) - - def addUserDict(self, fnm): - self.loadDict_(fnm) - - def _strQ2B(self, ustring): - """Convert full-width characters to half-width characters""" - rstring = "" - for uchar in ustring: - inside_code = ord(uchar) - if inside_code == 0x3000: - inside_code = 0x0020 - else: - inside_code -= 0xfee0 - if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character. - rstring += uchar - else: - rstring += chr(inside_code) - return rstring - - def _tradi2simp(self, line): - return HanziConv.toSimplified(line) - - def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None): - if _memo is None: - _memo = {} - MAX_DEPTH = 10 - if _depth > MAX_DEPTH: - if s < len(chars): - copy_pretks = copy.deepcopy(preTks) - remaining = "".join(chars[s:]) - copy_pretks.append((remaining, (-12, ''))) - tkslist.append(copy_pretks) - return s - - state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None) - if state_key in _memo: - return _memo[state_key] - - res = s - if s >= len(chars): - tkslist.append(preTks) - _memo[state_key] = s - return s - if s < len(chars) - 4: - is_repetitive = True - char_to_check = chars[s] - for i in range(1, 5): - if s + i >= len(chars) or chars[s + i] != char_to_check: - is_repetitive = False - break - if is_repetitive: - end = s - while end < len(chars) and chars[end] == char_to_check: - end += 1 - mid = s + min(10, end - s) - t = "".join(chars[s:mid]) - k = self.key_(t) - copy_pretks = copy.deepcopy(preTks) - if k in self.trie_: - copy_pretks.append((t, self.trie_[k])) - else: - copy_pretks.append((t, (-12, ''))) - next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo) - res = max(res, next_res) - _memo[state_key] = res - return res - - S = s + 1 - if s + 2 <= len(chars): - t1 = "".join(chars[s:s + 1]) - t2 = "".join(chars[s:s + 2]) - if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)): - S = s + 2 - if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1: - t1 = preTks[-1][0] + "".join(chars[s:s + 1]) - if self.trie_.has_keys_with_prefix(self.key_(t1)): - S = s + 2 - - for e in range(S, len(chars) + 1): - t = "".join(chars[s:e]) - k = self.key_(t) - if e > s + 1 and not self.trie_.has_keys_with_prefix(k): - break - if k in self.trie_: - pretks = copy.deepcopy(preTks) - pretks.append((t, self.trie_[k])) - res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo)) - - if res > s: - _memo[state_key] = res - return res - - t = "".join(chars[s:s + 1]) - k = self.key_(t) - copy_pretks = copy.deepcopy(preTks) - if k in self.trie_: - copy_pretks.append((t, self.trie_[k])) - else: - copy_pretks.append((t, (-12, ''))) - result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo) - _memo[state_key] = result - return result - - def freq(self, tk): - k = self.key_(tk) - if k not in self.trie_: - return 0 - return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5) - - def tag(self, tk): - k = self.key_(tk) - if k not in self.trie_: - return "" - return self.trie_[k][1] - - def score_(self, tfts): - B = 30 - F, L, tks = 0, 0, [] - for tk, (freq, tag) in tfts: - F += freq - L += 0 if len(tk) < 2 else 1 - tks.append(tk) - #F /= len(tks) - L /= len(tks) - logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F)) - return tks, B / len(tks) + L + F - - def sortTks_(self, tkslist): - res = [] - for tfts in tkslist: - tks, s = self.score_(tfts) - res.append((tks, s)) - return sorted(res, key=lambda x: x[1], reverse=True) - - def merge_(self, tks): - # if split chars is part of token - res = [] - tks = re.sub(r"[ ]+", " ", tks).split() - s = 0 - while True: - if s >= len(tks): - break - E = s + 1 - for e in range(s + 2, min(len(tks) + 2, s + 6)): - tk = "".join(tks[s:e]) - if re.search(self.SPLIT_CHAR, tk) and self.freq(tk): - E = e - res.append("".join(tks[s:E])) - s = E - - return " ".join(res) - - def maxForward_(self, line): - res = [] - s = 0 - while s < len(line): - e = s + 1 - t = line[s:e] - while e < len(line) and self.trie_.has_keys_with_prefix( - self.key_(t)): - e += 1 - t = line[s:e] - - while e - 1 > s and self.key_(t) not in self.trie_: - e -= 1 - t = line[s:e] - - if self.key_(t) in self.trie_: - res.append((t, self.trie_[self.key_(t)])) - else: - res.append((t, (0, ''))) - - s = e - - return self.score_(res) - - def maxBackward_(self, line): - res = [] - s = len(line) - 1 - while s >= 0: - e = s + 1 - t = line[s:e] - while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)): - s -= 1 - t = line[s:e] - - while s + 1 < e and self.key_(t) not in self.trie_: - s += 1 - t = line[s:e] - - if self.key_(t) in self.trie_: - res.append((t, self.trie_[self.key_(t)])) - else: - res.append((t, (0, ''))) - - s -= 1 - - return self.score_(res[::-1]) - - def english_normalize_(self, tks): - return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks] - - def _split_by_lang(self, line): - txt_lang_pairs = [] - arr = re.split(self.SPLIT_CHAR, line) - for a in arr: - if not a: - continue - s = 0 - e = s + 1 - zh = is_chinese(a[s]) - while e < len(a): - _zh = is_chinese(a[e]) - if _zh == zh: - e += 1 - continue - txt_lang_pairs.append((a[s: e], zh)) - s = e - e = s + 1 - zh = _zh - if s >= len(a): - continue - txt_lang_pairs.append((a[s: e], zh)) - return txt_lang_pairs - - def tokenize(self, line): - line = re.sub(r"\W+", " ", line) - line = self._strQ2B(line).lower() - line = self._tradi2simp(line) - - arr = self._split_by_lang(line) - res = [] - for L,lang in arr: - if not lang: - res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)]) - continue - if len(L) < 2 or re.match( - r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L): - res.append(L) - continue - - # use maxforward for the first time - tks, s = self.maxForward_(L) - tks1, s1 = self.maxBackward_(L) - if self.DEBUG: - logging.debug("[FW] {} {}".format(tks, s)) - logging.debug("[BW] {} {}".format(tks1, s1)) - - i, j, _i, _j = 0, 0, 0, 0 - same = 0 - while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: - same += 1 - if same > 0: - res.append(" ".join(tks[j: j + same])) - _i = i + same - _j = j + same - j = _j + 1 - i = _i + 1 - - while i < len(tks1) and j < len(tks): - tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j]) - if tk1 != tk: - if len(tk1) > len(tk): - j += 1 - else: - i += 1 - continue - - if tks1[i] != tks[j]: - i += 1 - j += 1 - continue - # backward tokens from_i to i are different from forward tokens from _j to j. - tkslist = [] - self.dfs_("".join(tks[_j:j]), 0, [], tkslist) - res.append(" ".join(self.sortTks_(tkslist)[0][0])) - - same = 1 - while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]: - same += 1 - res.append(" ".join(tks[j: j + same])) - _i = i + same - _j = j + same - j = _j + 1 - i = _i + 1 - - if _i < len(tks1): - assert _j < len(tks) - assert "".join(tks1[_i:]) == "".join(tks[_j:]) - tkslist = [] - self.dfs_("".join(tks[_j:]), 0, [], tkslist) - res.append(" ".join(self.sortTks_(tkslist)[0][0])) - - res = " ".join(res) - logging.debug("[TKS] {}".format(self.merge_(res))) - return self.merge_(res) - - def fine_grained_tokenize(self, tks): - tks = tks.split() - zh_num = len([1 for c in tks if c and is_chinese(c[0])]) - if zh_num < len(tks) * 0.2: - res = [] - for tk in tks: - res.extend(tk.split("/")) - return " ".join(res) - - res = [] - for tk in tks: - if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk): - res.append(tk) - continue - tkslist = [] - if len(tk) > 10: - tkslist.append(tk) - else: - self.dfs_(tk, 0, [], tkslist) - if len(tkslist) < 2: - res.append(tk) - continue - stk = self.sortTks_(tkslist)[1][0] - if len(stk) == len(tk): - stk = tk - else: - if re.match(r"[a-z\.-]+$", tk): - for t in stk: - if len(t) < 3: - stk = tk - break - else: - stk = " ".join(stk) - else: - stk = " ".join(stk) - - res.append(stk) - - return " ".join(self.english_normalize_(res)) - - -def is_chinese(s): - if s >= u'\u4e00' and s <= u'\u9fa5': - return True - else: - return False - - -def is_number(s): - if s >= u'\u0030' and s <= u'\u0039': - return True - else: - return False - - -def is_alphabet(s): - if (s >= u'\u0041' and s <= u'\u005a') or ( - s >= u'\u0061' and s <= u'\u007a'): - return True - else: - return False - - -def naiveQie(txt): - tks = [] - for t in txt.split(): - if tks and re.match(r".*[a-zA-Z]$", tks[-1] - ) and re.match(r".*[a-zA-Z]$", t): - tks.append(" ") - tks.append(t) - return tks - - -tokenizer = RagTokenizer() -tokenize = tokenizer.tokenize -fine_grained_tokenize = tokenizer.fine_grained_tokenize -tag = tokenizer.tag -freq = tokenizer.freq -loadUserDict = tokenizer.loadUserDict -addUserDict = tokenizer.addUserDict -tradi2simp = tokenizer._tradi2simp -strQ2B = tokenizer._strQ2B - -if __name__ == '__main__': - tknzr = RagTokenizer(debug=True) - # huqie.addUserDict("/tmp/tmp.new.tks.dict") - tks = tknzr.tokenize( - "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize( - "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize( - "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize( - "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize("虽然我不怎么玩") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize( - "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ") - logging.info(tknzr.fine_grained_tokenize(tks)) - tks = tknzr.tokenize( - "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-") - logging.info(tknzr.fine_grained_tokenize(tks)) - if len(sys.argv) < 2: - sys.exit() - tknzr.DEBUG = False - tknzr.loadUserDict(sys.argv[1]) - of = open(sys.argv[2], "r") - while True: - line = of.readline() - if not line: - break - logging.info(tknzr.tokenize(line)) - of.close() diff --git a/app/core/rag/nlp/search.py b/app/core/rag/nlp/search.py deleted file mode 100644 index 006dc5b1..00000000 --- a/app/core/rag/nlp/search.py +++ /dev/null @@ -1,192 +0,0 @@ -import uuid -from typing import Dict, List, Any -from sqlalchemy.orm import Session - -from langchain_core.documents import Document -from app.db import get_db -from app.core.models.base import RedBearModelConfig -from app.core.models import RedBearLLM, RedBearRerank -from app.models.models_model import ModelApiKey -from app.models import knowledge_model -from app.core.rag.models.chunk import DocumentChunk -from app.repositories import knowledge_repository, knowledgeshare_repository -from app.services.model_service import ModelConfigService -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory - - -def knowledge_retrieval( - query: str, - config: Dict[str, Any], - user_ids: List[str] = None, -) -> list[DocumentChunk]: - """ - Knowledge retrieval with multiple knowledge bases and reranking - - Args: - query: Search query string - config: Configuration dictionary containing: - - knowledge_bases: List of knowledge base configs with: - - kb_id: Knowledge base ID - - similarity_threshold: float - - vector_similarity_weight: float - - top_k: int - - retrieve_type: "participle" or "semantic" or "hybrid" - - merge_strategy: "weight" or other strategies - - reranker_id: UUID of the reranker to use - - reranker_top_k: int - - Returns: - Rearranged document block list (in descending order of relevance) - """ - db = next(get_db()) # Manually call the generator - try: - # parse configuration - knowledge_bases = config.get("knowledge_bases", []) - merge_strategy = config.get("merge_strategy", "weight") - reranker_id = config.get("reranker_id") - reranker_top_k = config.get("reranker_top_k", 1024) - - file_names_filter=[] - if user_ids: - file_names_filter.extend([f"{user_id}.txt" for user_id in user_ids]) - - if not knowledge_bases: - return [] - - all_results = [] - # Search each knowledge base - for kb_config in knowledge_bases: - kb_id = kb_config["kb_id"] - try: - # Check whether the knowledge base exists and is available - db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id) - if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1: - # Process shared knowledge base - if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share: - knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, - knowledgeshare_id=db_knowledge.id) - if knowledgeshare: - db_knowledge = knowledge_repository.get_knowledge_by_id(db, - knowledge_id=knowledgeshare.source_kb_id) - if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1): - continue - else: - continue - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # Retrieve according to the configured retrieval type - match kb_config["retrieve_type"]: - case "participle": - rs = vector_service.search_by_full_text( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["similarity_threshold"], - file_names_filter=file_names_filter - ) - case "semantic": - rs = vector_service.search_by_vector( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["vector_similarity_weight"], - file_names_filter=file_names_filter - ) - case _: # hybrid - rs1 = vector_service.search_by_vector( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["vector_similarity_weight"], - file_names_filter=file_names_filter - ) - rs2 = vector_service.search_by_full_text( - query=query, - top_k=kb_config["top_k"], - score_threshold=kb_config["similarity_threshold"], - file_names_filter=file_names_filter - ) - - # Deduplication of merge results - seen_ids = set() - unique_rs = [] - for doc in rs1 + rs2: - if doc.metadata["doc_id"] not in seen_ids: - seen_ids.add(doc.metadata["doc_id"]) - unique_rs.append(doc) - rs = unique_rs - - all_results.extend(rs) - except Exception as e: - # Failure of retrieval in a single knowledge base does not affect other knowledge bases - print(f"retrieval knowledge({kb_id}) failed: {str(e)}") - continue - - # Use the specified reranker for re-ranking - if reranker_id: - return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) - return all_results - - except Exception as e: - print(f"retrieval knowledge failed: {str(e)}") - finally: - db.close() - - -def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: - """ - Reorder the list of document blocks and return the top_k results most relevant to the query - Args: - reranker_id: reranker model id - query: query string - docs: List of document blocks to be rearranged - top_k: Number of top-level documents returned - - Returns: - Rearranged document block list (in descending order of relevance) - - Raises: - ValueError: If the input document list is empty or top_k is invalid - """ - # 参数校验 - if not reranker_id: - raise ValueError("reranker_id be empty") - if not docs: - raise ValueError("retrieval chunks be empty") - if top_k <= 0: - raise ValueError("top_k must be a positive integer") - try: - # initialize reranker - config = ModelConfigService.get_model_by_id(db=db, model_id=reranker_id) - apiConfig: ModelApiKey = config.api_keys[0] - reranker = RedBearRerank(RedBearModelConfig( - model_name=apiConfig.model_name, - provider=apiConfig.provider, - api_key=apiConfig.api_key, - base_url=apiConfig.api_base - )) - # Convert to LangChain Document object - documents = [ - Document( - page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute - metadata=doc.metadata or {} # Deal with possible None metadata - ) - for doc in docs - ] - - # Perform reordering (compress_documents will automatically handle relevance scores and indexing) - reranked_docs = list(reranker.compress_documents(documents, query)) - print(reranked_docs) - - # Sort in descending order based on relevance score - reranked_docs.sort( - key=lambda x: x.metadata.get("relevance_score", 0), - reverse=True - ) - # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"] - result = [] - for item in reranked_docs[:top_k]: - for doc in docs: - if doc.page_content == item.page_content: - doc.metadata["score"] = item.metadata["relevance_score"] - result.append(doc) - return result - except Exception as e: - raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e diff --git a/app/core/rag/nlp/surname.py b/app/core/rag/nlp/surname.py deleted file mode 100644 index 22a8657a..00000000 --- a/app/core/rag/nlp/surname.py +++ /dev/null @@ -1,126 +0,0 @@ -m = set(["赵","钱","孙","李", -"周","吴","郑","王", -"冯","陈","褚","卫", -"蒋","沈","韩","杨", -"朱","秦","尤","许", -"何","吕","施","张", -"孔","曹","严","华", -"金","魏","陶","姜", -"戚","谢","邹","喻", -"柏","水","窦","章", -"云","苏","潘","葛", -"奚","范","彭","郎", -"鲁","韦","昌","马", -"苗","凤","花","方", -"俞","任","袁","柳", -"酆","鲍","史","唐", -"费","廉","岑","薛", -"雷","贺","倪","汤", -"滕","殷","罗","毕", -"郝","邬","安","常", -"乐","于","时","傅", -"皮","卞","齐","康", -"伍","余","元","卜", -"顾","孟","平","黄", -"和","穆","萧","尹", -"姚","邵","湛","汪", -"祁","毛","禹","狄", -"米","贝","明","臧", -"计","伏","成","戴", -"谈","宋","茅","庞", -"熊","纪","舒","屈", -"项","祝","董","梁", -"杜","阮","蓝","闵", -"席","季","麻","强", -"贾","路","娄","危", -"江","童","颜","郭", -"梅","盛","林","刁", -"钟","徐","邱","骆", -"高","夏","蔡","田", -"樊","胡","凌","霍", -"虞","万","支","柯", -"昝","管","卢","莫", -"经","房","裘","缪", -"干","解","应","宗", -"丁","宣","贲","邓", -"郁","单","杭","洪", -"包","诸","左","石", -"崔","吉","钮","龚", -"程","嵇","邢","滑", -"裴","陆","荣","翁", -"荀","羊","於","惠", -"甄","曲","家","封", -"芮","羿","储","靳", -"汲","邴","糜","松", -"井","段","富","巫", -"乌","焦","巴","弓", -"牧","隗","山","谷", -"车","侯","宓","蓬", -"全","郗","班","仰", -"秋","仲","伊","宫", -"宁","仇","栾","暴", -"甘","钭","厉","戎", -"祖","武","符","刘", -"景","詹","束","龙", -"叶","幸","司","韶", -"郜","黎","蓟","薄", -"印","宿","白","怀", -"蒲","邰","从","鄂", -"索","咸","籍","赖", -"卓","蔺","屠","蒙", -"池","乔","阴","鬱", -"胥","能","苍","双", -"闻","莘","党","翟", -"谭","贡","劳","逄", -"姬","申","扶","堵", -"冉","宰","郦","雍", -"郤","璩","桑","桂", -"濮","牛","寿","通", -"边","扈","燕","冀", -"郏","浦","尚","农", -"温","别","庄","晏", -"柴","瞿","阎","充", -"慕","连","茹","习", -"宦","艾","鱼","容", -"向","古","易","慎", -"戈","廖","庾","终", -"暨","居","衡","步", -"都","耿","满","弘", -"匡","国","文","寇", -"广","禄","阙","东", -"欧","殳","沃","利", -"蔚","越","夔","隆", -"师","巩","厍","聂", -"晁","勾","敖","融", -"冷","訾","辛","阚", -"那","简","饶","空", -"曾","母","沙","乜", -"养","鞠","须","丰", -"巢","关","蒯","相", -"查","后","荆","红", -"游","竺","权","逯", -"盖","益","桓","公", -"兰","原","乞","西","阿","肖","丑","位","曽","巨","德","代","圆","尉","仵","纳","仝","脱","丘","但","展","迪","付","覃","晗","特","隋","苑","奥","漆","谌","郄","练","扎","邝","渠","信","门","陳","化","原","密","泮","鹿","赫", -"万俟","司马","上官","欧阳", -"夏侯","诸葛","闻人","东方", -"赫连","皇甫","尉迟","公羊", -"澹台","公冶","宗政","濮阳", -"淳于","单于","太叔","申屠", -"公孙","仲孙","轩辕","令狐", -"钟离","宇文","长孙","慕容", -"鲜于","闾丘","司徒","司空", -"亓官","司寇","仉督","子车", -"颛孙","端木","巫马","公西", -"漆雕","乐正","壤驷","公良", -"拓跋","夹谷","宰父","榖梁", -"晋","楚","闫","法","汝","鄢","涂","钦", -"段干","百里","东郭","南门", -"呼延","归","海","羊舌","微","生", -"岳","帅","缑","亢","况","后","有","琴", -"梁丘","左丘","东门","西门", -"商","牟","佘","佴","伯","赏","南宫", -"墨","哈","谯","笪","年","爱","阳","佟", -"第五","言","福"]) - -def isit(n):return n.strip() in m - diff --git a/app/core/rag/nlp/synonym.py b/app/core/rag/nlp/synonym.py deleted file mode 100644 index 8be688a8..00000000 --- a/app/core/rag/nlp/synonym.py +++ /dev/null @@ -1,85 +0,0 @@ -import logging -import json -import os -import time -import re -from nltk.corpus import wordnet -from app.core.rag.common.file_utils import get_project_base_directory - - -class Dealer: - def __init__(self, redis=None): - - self.lookup_num = 100000000 - self.load_tm = time.time() - 1000000 - self.dictionary = None - path = os.path.join(get_project_base_directory(), "app/core/rag/res", "synonym.json") - try: - self.dictionary = json.load(open(path, 'r')) - self.dictionary = { (k.lower() if isinstance(k, str) else k): v for k, v in self.dictionary.items() } - except Exception: - logging.warning("Missing synonym.json") - self.dictionary = {} - - if not redis: - logging.warning( - "Realtime synonym is disabled, since no redis connection.") - if not len(self.dictionary.keys()): - logging.warning("Fail to load synonym") - - self.redis = redis - self.load() - - def load(self): - if not self.redis: - return - - if self.lookup_num < 100: - return - tm = time.time() - if tm - self.load_tm < 3600: - return - - self.load_tm = time.time() - self.lookup_num = 0 - d = self.redis.get("kevin_synonyms") - if not d: - return - try: - d = json.loads(d) - self.dictionary = d - except Exception as e: - logging.error("Fail to load synonym!" + str(e)) - - - def lookup(self, tk, topn=8): - if not tk or not isinstance(tk, str): - return [] - - # 1) Check the custom dictionary first (both keys and tk are already lowercase) - self.lookup_num += 1 - self.load() - key = re.sub(r"[ \t]+", " ", tk.strip()) - res = self.dictionary.get(key, []) - if isinstance(res, str): - res = [res] - if res: # Found in dictionary → return directly - return res[:topn] - - # 2) If not found and tk is purely alphabetical → fallback to WordNet - if re.fullmatch(r"[a-z]+", tk): - wn_set = { - re.sub("_", " ", syn.name().split(".")[0]) - for syn in wordnet.synsets(tk) - } - wn_set.discard(tk) # Remove the original token itself - wn_res = [t for t in wn_set if t] - return wn_res[:topn] - - # 3) Nothing found in either source - return [] - - -if __name__ == '__main__': - dl = Dealer() - print(dl.dictionary) diff --git a/app/core/rag/nlp/term_weight.py b/app/core/rag/nlp/term_weight.py deleted file mode 100644 index 14541ba1..00000000 --- a/app/core/rag/nlp/term_weight.py +++ /dev/null @@ -1,228 +0,0 @@ -import logging -import math -import json -import re -import os -import numpy as np -from . import rag_tokenizer -from app.core.rag.common.file_utils import get_project_base_directory - - -class Dealer: - def __init__(self): - self.stop_words = set(["请问", - "您", - "你", - "我", - "他", - "是", - "的", - "就", - "有", - "于", - "及", - "即", - "在", - "为", - "最", - "有", - "从", - "以", - "了", - "将", - "与", - "吗", - "吧", - "中", - "#", - "什么", - "怎么", - "哪个", - "哪些", - "啥", - "相关"]) - - def load_dict(fnm): - res = {} - f = open(fnm, "r") - while True: - line = f.readline() - if not line: - break - arr = line.replace("\n", "").split("\t") - if len(arr) < 2: - res[arr[0]] = 0 - else: - res[arr[0]] = int(arr[1]) - - c = 0 - for _, v in res.items(): - c += v - if c == 0: - return set(res.keys()) - return res - - fnm = os.path.join(get_project_base_directory(), "app/core/rag/res") - self.ne, self.df = {}, {} - try: - self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r")) - except Exception: - logging.warning("Load ner.json FAIL!") - try: - self.df = load_dict(os.path.join(fnm, "term.freq")) - except Exception: - logging.warning("Load term.freq FAIL!") - - def pretoken(self, txt, num=False, stpwd=True): - patt = [ - r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]" - ] - rewt = [ - ] - for p, r in rewt: - txt = re.sub(p, r, txt) - - res = [] - for t in rag_tokenizer.tokenize(txt).split(): - tk = t - if (stpwd and tk in self.stop_words) or ( - re.match(r"[0-9]$", tk) and not num): - continue - for p in patt: - if re.match(p, t): - tk = "#" - break - #tk = re.sub(r"([\+\\-])", r"\\\1", tk) - if tk != "#" and tk: - res.append(tk) - return res - - def tokenMerge(self, tks): - def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t) - - res, i = [], 0 - while i < len(tks): - j = i - if i == 0 and oneTerm(tks[i]) and len( - tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位 - res.append(" ".join(tks[0:2])) - i = 2 - continue - - while j < len( - tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]): - j += 1 - if j - i > 1: - if j - i < 5: - res.append(" ".join(tks[i:j])) - i = j - else: - res.append(" ".join(tks[i:i + 2])) - i = i + 2 - else: - if len(tks[i]) > 0: - res.append(tks[i]) - i += 1 - return [t for t in res if t] - - def ner(self, t): - if not self.ne: - return "" - res = self.ne.get(t, "") - if res: - return res - - def split(self, txt): - tks = [] - for t in re.sub(r"[ \t]+", " ", txt).split(): - if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \ - re.match(r".*[a-zA-Z]$", t) and tks and \ - self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func": - tks[-1] = tks[-1] + " " + t - else: - tks.append(t) - return tks - - def weights(self, tks, preprocess=True): - num_pattern = re.compile(r"[0-9,.]{2,}$") - short_letter_pattern = re.compile(r"[a-z]{1,2}$") - num_space_pattern = re.compile(r"[0-9. -]{2,}$") - letter_pattern = re.compile(r"[a-z. -]+$") - - def ner(t): - if num_pattern.match(t): - return 2 - if short_letter_pattern.match(t): - return 0.01 - if not self.ne or t not in self.ne: - return 1 - m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3, - "firstnm": 1} - return m[self.ne[t]] - - def postag(t): - t = rag_tokenizer.tag(t) - if t in set(["r", "c", "d"]): - return 0.3 - if t in set(["ns", "nt"]): - return 3 - if t in set(["n"]): - return 2 - if re.match(r"[0-9-]+", t): - return 2 - return 1 - - def freq(t): - if num_space_pattern.match(t): - return 3 - s = rag_tokenizer.freq(t) - if not s and letter_pattern.match(t): - return 300 - if not s: - s = 0 - - if not s and len(t) >= 4: - s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1] - if len(s) > 1: - s = np.min([freq(tt) for tt in s]) / 6. - else: - s = 0 - - return max(s, 10) - - def df(t): - if num_space_pattern.match(t): - return 5 - if t in self.df: - return self.df[t] + 3 - elif letter_pattern.match(t): - return 300 - elif len(t) >= 4: - s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1] - if len(s) > 1: - return max(3, np.min([df(tt) for tt in s]) / 6.) - - return 3 - - def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5))) - - tw = [] - if not preprocess: - idf1 = np.array([idf(freq(t), 10000000) for t in tks]) - idf2 = np.array([idf(df(t), 1000000000) for t in tks]) - wts = (0.3 * idf1 + 0.7 * idf2) * \ - np.array([ner(t) * postag(t) for t in tks]) - wts = [s for s in wts] - tw = list(zip(tks, wts)) - else: - for tk in tks: - tt = self.tokenMerge(self.pretoken(tk, True)) - idf1 = np.array([idf(freq(t), 10000000) for t in tt]) - idf2 = np.array([idf(df(t), 1000000000) for t in tt]) - wts = (0.3 * idf1 + 0.7 * idf2) * \ - np.array([ner(t) * postag(t) for t in tt]) - wts = [s for s in wts] - tw.extend(zip(tt, wts)) - - S = np.sum([s for _, s in tw]) - return [(t, s / S) for t, s in tw] diff --git a/app/core/rag/prompts/__init__.py b/app/core/rag/prompts/__init__.py deleted file mode 100644 index b8b924b9..00000000 --- a/app/core/rag/prompts/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import generator - -__all__ = [name for name in dir(generator) - if not name.startswith('_')] - -globals().update({name: getattr(generator, name) for name in __all__}) \ No newline at end of file diff --git a/app/core/rag/prompts/analyze_task_system.md b/app/core/rag/prompts/analyze_task_system.md deleted file mode 100644 index 148e4113..00000000 --- a/app/core/rag/prompts/analyze_task_system.md +++ /dev/null @@ -1,48 +0,0 @@ -You are an intelligent task analyzer that adapts analysis depth to task complexity. - -**Analysis Framework** - -**Step 1: Task Transmission Assessment** -**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions. - -**Evaluate if task transmission information is needed:** -- **Is this an initial step?** If yes, skip this section -- **Are there upstream agents/steps?** If no, provide minimal transmission -- **Is there critical state/context to preserve?** If yes, include full transmission - -### If Task Transmission is Needed: -- **Current State Summary**: [1-2 sentences on where we are] -- **Key Data/Results**: [Critical findings that must carry forward] -- **Context Dependencies**: [Essential context for next agent/step] -- **Unresolved Items**: [Issues requiring continuation] -- **Status for User**: [Clear status update in user terms] -- **Technical State**: [System state for technical handoffs] - -**Step 2: Complexity Classification** -Classify as LOW / MEDIUM / HIGH: -- **LOW**: Single-step tasks, direct queries, small talk -- **MEDIUM**: Multi-step tasks within one domain -- **HIGH**: Multi-domain coordination or complex reasoning - -**Step 3: Adaptive Analysis** -Scale depth to match complexity. Always stop once success criteria are met. - -**For LOW (max 50 words for analysis only):** -- Detect small talk; if true, output exactly: `Small talk — no further analysis needed` -- One-sentence objective -- Direct execution approach (1–2 steps) - -**For MEDIUM (80–150 words for analysis only):** -- Objective; Intent & Scope -- 3–5 step minimal Plan (may mark parallel steps) -- **Uncertainty & Probes** (at least one probe with a clear stop condition) -- Success Criteria + basic Failure detection & fallback -- **Source Plan** (how evidence will be obtained/verified) - -**For HIGH (150–250 words for analysis only):** -- Comprehensive objective analysis; Intent & Scope -- 5–8 step Plan with dependencies/parallelism -- **Uncertainty & Probes** (key unknowns → probe → stop condition) -- Measurable Success Criteria; Failure detectors & fallbacks -- **Source Plan** (evidence acquisition & validation) -- **Reflection Hooks** (escalation/de-escalation triggers) diff --git a/app/core/rag/prompts/analyze_task_user.md b/app/core/rag/prompts/analyze_task_user.md deleted file mode 100644 index 81dc9f2b..00000000 --- a/app/core/rag/prompts/analyze_task_user.md +++ /dev/null @@ -1,9 +0,0 @@ -**Input Variables** -- **{{ task }}** — the task/request to analyze -- **{{ context }}** — background, history, situational context -- **{{ agent_prompt }}** — special instructions/role hints -- **{{ tools_desc }}** — available sub-agents and capabilities - -**Final Output Rule** -Return the Task Transmission section (if needed) followed by the concrete analysis and planning steps according to LOW / MEDIUM / HIGH complexity. -Do not restate the framework, definitions, or rules. Output only the final structured result. diff --git a/app/core/rag/prompts/ask_summary.md b/app/core/rag/prompts/ask_summary.md deleted file mode 100644 index 2074e9c3..00000000 --- a/app/core/rag/prompts/ask_summary.md +++ /dev/null @@ -1,14 +0,0 @@ -Role: You're a smart assistant. Your name is Miss R. -Task: Summarize the information from knowledge bases and answer user's question. -Requirements and restriction: - - DO NOT make things up, especially for numbers. - - If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided. - - Answer with markdown format text. - - Answer in language of user's question. - - DO NOT make things up, especially for numbers. - -### Information from knowledge bases - -{{ knowledge }} - -The above is information from knowledge bases. diff --git a/app/core/rag/prompts/assign_toc_levels.md b/app/core/rag/prompts/assign_toc_levels.md deleted file mode 100644 index d35dee77..00000000 --- a/app/core/rag/prompts/assign_toc_levels.md +++ /dev/null @@ -1,53 +0,0 @@ -You are given a JSON array of TOC(tabel of content) items. Each item has at least {"title": string} and may include an existing title hierarchical level. - -Task -- For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc. -- Multiple items may share the same depth (e.g., many 1s, many 2s). -- Do not use dotted numbering (no 1.1/1.2). Use a single digit string per item indicating its depth only. -- Preserve the original item order exactly. Do not insert, delete, or reorder. -- Decide levels yourself to keep a coherent hierarchy. Keep peers at the same depth. - -Output -- Return a valid JSON array only (no extra text). -- Each element must be {"level": "1|2|3", "title": }. -- title must be the original title string. - -Examples - -Example A (chapters with sections) -Input: -["Chapter 1 Methods", "Section 1 Definition", "Section 2 Process", "Chapter 2 Experiment"] - -Output: -[ - {"level":"1","title":"Chapter 1 Methods"}, - {"level":"2","title":"Section 1 Definition"}, - {"level":"2","title":"Section 2 Process"}, - {"level":"1","title":"Chapter 2 Experiment"} -] - -Example B (parts with chapters) -Input: -["Part I Theory", "Chapter 1 Basics", "Chapter 2 Methods", "Part II Applications", "Chapter 3 Case Studies"] - -Output: -[ - {"level":"1","title":"Part I Theory"}, - {"level":"2","title":"Chapter 1 Basics"}, - {"level":"2","title":"Chapter 2 Methods"}, - {"level":"1","title":"Part II Applications"}, - {"level":"2","title":"Chapter 3 Case Studies"} -] - -Example C (plain headings) -Input: -["Introduction", "Background and Motivation", "Related Work", "Methodology", "Evaluation"] - -Output: -[ - {"level":"1","title":"Introduction"}, - {"level":"2","title":"Background and Motivation"}, - {"level":"2","title":"Related Work"}, - {"level":"1","title":"Methodology"}, - {"level":"1","title":"Evaluation"} -] \ No newline at end of file diff --git a/app/core/rag/prompts/citation_plus.md b/app/core/rag/prompts/citation_plus.md deleted file mode 100644 index 77bba4e2..00000000 --- a/app/core/rag/prompts/citation_plus.md +++ /dev/null @@ -1,13 +0,0 @@ -You are an agent for adding correct citations to the given text by user. -You are given a piece of text within [ID:] tags, which was generated based on the provided sources. -However, the sources are not cited in the [ID:]. -Your task is to enhance user trust by generating correct, appropriate citations for this report. - -{{ example }} - - - -{{ sources }} - - - diff --git a/app/core/rag/prompts/citation_prompt.md b/app/core/rag/prompts/citation_prompt.md deleted file mode 100644 index 55c89c45..00000000 --- a/app/core/rag/prompts/citation_prompt.md +++ /dev/null @@ -1,109 +0,0 @@ -Based on the provided document or chat history, add citations to the input text using the format specified later. - -# Citation Requirements: - -## Technical Rules: -- Use format: [ID:i] or [ID:i] [ID:j] for multiple sources -- Place citations at the end of sentences, before punctuation -- Maximum 4 citations per sentence -- DO NOT cite content not from -- DO NOT modify whitespace or original text -- STRICTLY prohibit non-standard formatting (~~, etc.) - -## What MUST Be Cited: -1. **Quantitative data**: Numbers, percentages, statistics, measurements -2. **Temporal claims**: Dates, timeframes, sequences of events -3. **Causal relationships**: Claims about cause and effect -4. **Comparative statements**: Rankings, comparisons, superlatives -5. **Technical definitions**: Specialized terms, concepts, methodologies -6. **Direct attributions**: What someone said, did, or believes -7. **Predictions/forecasts**: Future projections, trend analyses -8. **Controversial claims**: Disputed facts, minority opinions - -## What Should NOT Be Cited: -- Common knowledge (e.g., "The sun rises in the east") -- Transitional phrases -- General introductions -- Your own analysis or synthesis (unless directly from source) - -# Comprehensive Examples: - -## Example 1: Data and Statistics - -ID: 45 -└── Content: The global smartphone market grew by 7.8% in Q3 2024, with Samsung holding 20.6% market share and Apple at 15.8%. - -ID: 46 -└── Content: 5G adoption reached 1.5 billion users worldwide by October 2024. - - -USER: How is the smartphone market performing? - -ASSISTANT: -The smartphone industry is showing strong recovery. The global smartphone market grew by 7.8% in Q3 2024 [ID:45]. Samsung leads with 20.6% market share, followed by Apple at 15.8% [ID:45]. This growth is partly driven by 5G technology, which reached 1.5 billion users worldwide by October 2024 [ID:46]. - -## Example 2: Causal Relationships and Predictions - -ID: 78 -└── Content: Rising interest rates have led to a 15% decline in housing starts. Economists predict this trend will continue through 2025. - -ID: 79 -└── Content: The Federal Reserve's aggressive monetary policy is the primary driver behind the real estate slowdown. - - -USER: What's happening in the real estate market? - -ASSISTANT: -The real estate market is experiencing significant challenges. Rising interest rates have led to a 15% decline in housing starts [ID:78]. The Federal Reserve's aggressive monetary policy is the primary driver behind this real estate slowdown [ID:79]. Economists predict this trend will continue through 2025 [ID:78]. These factors suggest a prolonged period of market adjustment ahead. - -## Example 3: Technical Concepts and Comparisons - -ID: 101 -└── Content: Quantum computing uses quantum bits (qubits) that can exist in superposition, unlike classical bits which are either 0 or 1. - -ID: 102 -└── Content: Current quantum computers can solve certain problems 158 million times faster than classical supercomputers. - - -USER: Explain quantum computing advantages. - -ASSISTANT: -Quantum computing represents a fundamental shift in computational capability. Unlike classical computers, quantum computing uses quantum bits (qubits) that can exist in superposition, rather than being limited to just 0 or 1 [ID:101]. This unique property enables extraordinary performance gains. Current quantum computers can solve certain problems 158 million times faster than classical supercomputers [ID:102]. This technology promises to revolutionize fields requiring complex calculations. - -## Example 4: Mixed Citation Needs - -ID: 201 -└── Content: CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026." - -ID: 202 -└── Content: Industry analysts remain skeptical, with Morgan Stanley rating the stock as "underperform." - -ID: 203 -└── Content: The company's AI division currently represents only 5% of total revenue. - - -USER: What's the company's new strategy? - -ASSISTANT: -The company is undergoing a major strategic transformation. CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026" [ID:201]. However, this ambitious goal faces skepticism, as industry analysts remain doubtful, with Morgan Stanley rating the stock as "underperform" [ID:202]. The challenge is significant given that the company's AI division currently represents only 5% of total revenue [ID:203]. Success will depend on execution of this bold vision. - -## Example 5: When NOT to Cite - -ID: 301 -└── Content: Paris is the capital of France. The Eiffel Tower was completed in 1889. - - -USER: Tell me about Paris. - -ASSISTANT: -Paris is the capital of France. It's known for its rich history, culture, and architecture. The Eiffel Tower was completed in 1889 [ID:301]. The city attracts millions of tourists annually. Paris remains one of the world's most visited destinations. -(Note: Only the specific date needs citation, not common knowledge about Paris) - ---- Examples END --- - -REMEMBER: -- Cite FACTS, not opinions or transitions -- Each citation supports the ENTIRE sentence -- When in doubt, ask: "Would a fact-checker need to verify this?" -- Place citations at sentence end, before punctuation -- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]... diff --git a/app/core/rag/prompts/content_tagging_prompt.md b/app/core/rag/prompts/content_tagging_prompt.md deleted file mode 100644 index 75d6f158..00000000 --- a/app/core/rag/prompts/content_tagging_prompt.md +++ /dev/null @@ -1,32 +0,0 @@ -## Role -You are a text analyzer. - -## Task -Add tags (labels) to a given piece of text content based on the examples and the entire tag set. - -## Steps -- Review the tag/label set. -- Review examples which all consist of both text content and assigned tags with relevance score in JSON format. -- Summarize the text content, and tag it with the top {{ topn }} most relevant tags from the set of tags/labels and the corresponding relevance score. - -## Requirements -- The tags MUST be from the tag set. -- The output MUST be in JSON format only, the key is tag and the value is its relevance score. -- The relevance score must range from 1 to 10. -- Output keywords ONLY. - -# TAG SET -{{ all_tags | join(', ') }} - -{% for ex in examples %} -# Examples {{ loop.index0 }} -### Text Content -{{ ex.content }} - -Output: -{{ ex.tags_json }} - -{% endfor %} -# Real Data -### Text Content -{{ content }} diff --git a/app/core/rag/prompts/cross_languages_sys_prompt.md b/app/core/rag/prompts/cross_languages_sys_prompt.md deleted file mode 100644 index 9761944d..00000000 --- a/app/core/rag/prompts/cross_languages_sys_prompt.md +++ /dev/null @@ -1,35 +0,0 @@ -## Role -A streamlined multilingual translator. - -## Behavior Rules -1. Accept batch translation requests in the following format: - **Input:** `[text]` - **Target Languages:** comma-separated list - -2. Maintain: - - Original formatting (tables, lists, spacing) - - Technical terminology accuracy - - Cultural context appropriateness - -3. Output translations in the following format: - -[Translation in language1] -### -[Translation in language2] - ---- - -## Example - -**Input:** -Hello World! Let's discuss AI safety. -=== -Chinese, French, Japanese - -**Output:** -你好世界!让我们讨论人工智能安全问题。 -### -Bonjour le monde ! Parlons de la sécurité de l'IA. -### -こんにちは世界!AIの安全性について話し合いましょう。 - diff --git a/app/core/rag/prompts/cross_languages_user_prompt.md b/app/core/rag/prompts/cross_languages_user_prompt.md deleted file mode 100644 index f729ef56..00000000 --- a/app/core/rag/prompts/cross_languages_user_prompt.md +++ /dev/null @@ -1,7 +0,0 @@ -**Input:** -{{ query }} -=== -{{ languages | join(', ') }} - -**Output:** - diff --git a/app/core/rag/prompts/full_question_prompt.md b/app/core/rag/prompts/full_question_prompt.md deleted file mode 100644 index d7276a3e..00000000 --- a/app/core/rag/prompts/full_question_prompt.md +++ /dev/null @@ -1,62 +0,0 @@ -## Role -A helpful assistant. - -## Task & Steps -1. Generate a full user question that would follow the conversation. -2. If the user's question involves relative dates, convert them into absolute dates based on today ({{ today }}). - - "yesterday" = {{ yesterday }}, "tomorrow" = {{ tomorrow }} - -## Requirements & Restrictions -- If the user's latest question is already complete, don't do anything — just return the original question. -- DON'T generate anything except a refined question. -{% if language %} -- Text generated MUST be in {{ language }}. -{% else %} -- Text generated MUST be in the same language as the original user's question. -{% endif %} - ---- - -## Examples - -### Example 1 -**Conversation:** - -USER: What is the name of Donald Trump's father? -ASSISTANT: Fred Trump. -USER: And his mother? - -**Output:** What's the name of Donald Trump's mother? - ---- - -### Example 2 -**Conversation:** - -USER: What is the name of Donald Trump's father? -ASSISTANT: Fred Trump. -USER: And his mother? -ASSISTANT: Mary Trump. -USER: What's her full name? - -**Output:** What's the full name of Donald Trump's mother Mary Trump? - ---- - -### Example 3 -**Conversation:** - -USER: What's the weather today in London? -ASSISTANT: Cloudy. -USER: What's about tomorrow in Rochester? - -**Output:** What's the weather in Rochester on {{ tomorrow }}? - ---- - -## Real Data - -**Conversation:** - -{{ conversation }} - diff --git a/app/core/rag/prompts/generator.py b/app/core/rag/prompts/generator.py deleted file mode 100644 index 67891c1a..00000000 --- a/app/core/rag/prompts/generator.py +++ /dev/null @@ -1,728 +0,0 @@ -import datetime -import json -import logging -import re -from copy import deepcopy -from typing import Tuple -import jinja2 -import json_repair -import trio -from app.core.rag.common.misc_utils import hash_str2int -from app.core.rag.nlp import rag_tokenizer -from .template import load_prompt -from app.core.rag.common.constants import TAG_FLD -from app.core.rag.common.token_utils import encoder, num_tokens_from_string - - -STOP_TOKEN="<|STOP|>" -COMPLETE_TASK="complete_task" -INPUT_UTILIZATION = 0.5 - -def get_value(d, k1, k2): - return d.get(k1, d.get(k2)) - - -def chunks_format(reference): - - return [ - { - "id": get_value(chunk, "chunk_id", "id"), - "content": get_value(chunk, "content", "content_with_weight"), - "document_id": get_value(chunk, "doc_id", "document_id"), - "document_name": get_value(chunk, "docnm_kwd", "document_name"), - "dataset_id": get_value(chunk, "kb_id", "dataset_id"), - "image_id": get_value(chunk, "image_id", "img_id"), - "positions": get_value(chunk, "positions", "position_int"), - "url": chunk.get("url"), - "similarity": chunk.get("similarity"), - "vector_similarity": chunk.get("vector_similarity"), - "term_similarity": chunk.get("term_similarity"), - "doc_type": chunk.get("doc_type_kwd"), - } - for chunk in reference.get("chunks", []) - ] - - -def message_fit_in(msg, max_length=4000): - def count(): - nonlocal msg - tks_cnts = [] - for m in msg: - tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])}) - total = 0 - for m in tks_cnts: - total += m["count"] - return total - - c = count() - if c < max_length: - return c, msg - - msg_ = [m for m in msg if m["role"] == "system"] - if len(msg) > 1: - msg_.append(msg[-1]) - msg = msg_ - c = count() - if c < max_length: - return c, msg - - ll = num_tokens_from_string(msg_[0]["content"]) - ll2 = num_tokens_from_string(msg_[-1]["content"]) - if ll / (ll + ll2) > 0.8: - m = msg_[0]["content"] - m = encoder.decode(encoder.encode(m)[: max_length - ll2]) - msg[0]["content"] = m - return max_length, msg - - m = msg_[-1]["content"] - m = encoder.decode(encoder.encode(m)[: max_length - ll2]) - msg[-1]["content"] = m - return max_length, msg - - -CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt") -CITATION_PLUS_TEMPLATE = load_prompt("citation_plus") -CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt") -CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt") -CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt") -FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt") -KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt") -QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt") -VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt") -VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt") -STRUCTURED_OUTPUT_PROMPT = load_prompt("structured_output_prompt") - -ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system") -ANALYZE_TASK_USER = load_prompt("analyze_task_user") -NEXT_STEP = load_prompt("next_step") -REFLECT = load_prompt("reflect") -SUMMARY4MEMORY = load_prompt("summary4memory") -RANK_MEMORY = load_prompt("rank_memory") -META_FILTER = load_prompt("meta_filter") -ASK_SUMMARY = load_prompt("ask_summary") - -PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True) - - -def citation_prompt(user_defined_prompts: dict={}) -> str: - template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE)) - return template.render() - - -def citation_plus(sources: str) -> str: - template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE) - return template.render(example=citation_prompt(), sources=sources) - - -def keyword_extraction(chat_mdl, content, topn=3): - template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE) - rendered_prompt = template.render(content=content, topn=topn) - - msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] - _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) - if isinstance(kwd, tuple): - kwd = kwd[0] - kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) - if kwd.find("**ERROR**") >= 0: - return "" - return kwd - - -def question_proposal(chat_mdl, content, topn=3): - template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) - rendered_prompt = template.render(content=content, topn=topn) - - msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] - _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) - if isinstance(kwd, tuple): - kwd = kwd[0] - kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) - if kwd.find("**ERROR**") >= 0: - return "" - return kwd - - -def full_question(messages=[], language=None, chat_mdl=None): - conv = [] - for m in messages: - if m["role"] not in ["user", "assistant"]: - continue - conv.append("{}: {}".format(m["role"].upper(), m["content"])) - conversation = "\n".join(conv) - today = datetime.date.today().isoformat() - yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat() - tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat() - - template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE) - rendered_prompt = template.render( - today=today, - yesterday=yesterday, - tomorrow=tomorrow, - conversation=conversation, - language=language, - ) - - ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}]) - ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) - return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] - - -def cross_languages(query, languages=[], chat_mdl=None): - rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() - rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) - - ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) - ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) - if ans.find("**ERROR**") >= 0: - return query - return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()]) - - -def content_tagging(chat_mdl, content, all_tags, examples, topn=3): - template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE) - - for ex in examples: - ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False) - - rendered_prompt = template.render( - topn=topn, - all_tags=all_tags, - examples=examples, - content=content, - ) - - msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] - _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5}) - if isinstance(kwd, tuple): - kwd = kwd[0] - kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) - if kwd.find("**ERROR**") >= 0: - raise Exception(kwd) - - try: - obj = json_repair.loads(kwd) - except json_repair.JSONDecodeError: - try: - result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip() - result = "{" + result.split("{")[1].split("}")[0] + "}" - obj = json_repair.loads(result) - except Exception as e: - logging.exception(f"JSON parsing error: {result} -> {e}") - raise e - res = {} - for k, v in obj.items(): - try: - if int(v) > 0: - res[str(k)] = int(v) - except Exception: - pass - return res - - -def vision_llm_describe_prompt(page=None) -> str: - template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT) - - return template.render(page=page) - - -def vision_llm_figure_describe_prompt() -> str: - template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT) - return template.render() - - -def tool_schema(tools_description: list[dict], complete_task=False): - if not tools_description: - return "" - desc = {} - if complete_task: - desc[COMPLETE_TASK] = { - "type": "function", - "function": { - "name": COMPLETE_TASK, - "description": "When you have the final answer and are ready to complete the task, call this function with your answer", - "parameters": { - "type": "object", - "properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}}, - "required": ["answer"] - } - } - } - for tool in tools_description: - desc[tool["function"]["name"]] = tool - - return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())]) - - -def form_history(history, limit=-6): - context = "" - for h in history[limit:]: - if h["role"] == "system": - continue - role = "USER" - if h["role"].upper()!= role: - role = "AGENT" - context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}" - return context - - -def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}): - tools_desc = tool_schema(tools_description) - context = "" - - if user_defined_prompts.get("task_analysis"): - template = PROMPT_JINJA_ENV.from_string(user_defined_prompts["task_analysis"]) - else: - template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER) - context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) - kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}]) - if isinstance(kwd, tuple): - kwd = kwd[0] - kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) - if kwd.find("**ERROR**") >= 0: - return "" - return kwd - - -def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): - if not tools_description: - return "" - desc = tool_schema(tools_description) - template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP)) - user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`." - hist = deepcopy(history) - if hist[-1]["role"] == "user": - hist[-1]["content"] += user_prompt - else: - hist.append({"role": "user", "content": user_prompt}) - json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")), - hist[1:], stop=["<|stop|>"]) - tk_cnt = num_tokens_from_string(json_str) - json_str = re.sub(r"^.*", "", json_str, flags=re.DOTALL) - return json_str, tk_cnt - - -def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}): - tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res] - goal = history[1]["content"] - template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT)) - user_prompt = template.render(goal=goal, tool_calls=tool_calls) - hist = deepcopy(history) - if hist[-1]["role"] == "user": - hist[-1]["content"] += user_prompt - else: - hist.append({"role": "user", "content": user_prompt}) - _, msg = message_fit_in(hist, chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:]) - ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) - return """ -**Observation** -{} - -**Reflection** -{} - """.format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans) - - -def form_message(system_prompt, user_prompt): - return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}] - - -def structured_output_prompt(schema=None) -> str: - template = PROMPT_JINJA_ENV.from_string(STRUCTURED_OUTPUT_PROMPT) - return template.render(schema=schema) - - -def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: - template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY) - system_prompt = template.render(name=name, - params=json.dumps(params, ensure_ascii=False, indent=2), - result=result) - user_prompt = "→ Summary: " - _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:]) - return re.sub(r"^.*", "", ans, flags=re.DOTALL) - - -def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}): - template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY) - system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) - user_prompt = " → rank: " - _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>") - return re.sub(r"^.*", "", ans, flags=re.DOTALL) - - -def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list: - sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render( - current_date=datetime.datetime.today().strftime('%Y-%m-%d'), - metadata_keys=json.dumps(meta_data), - user_question=query - ) - user_prompt = "Generate filters:" - ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}]) - ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) - try: - ans = json_repair.loads(ans) - assert isinstance(ans, list), ans - return ans - except Exception: - logging.exception(f"Loading json failure: {ans}") - return [] - - -def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): - from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache - cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf) - if cached: - return json_repair.loads(cached) - _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) - ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) - try: - res = json_repair.loads(ans) - set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf) - return res - except Exception: - logging.exception(f"Loading json failure: {ans}") - - -TOC_DETECTION = load_prompt("toc_detection") -def detect_table_of_contents(page_1024:list[str], chat_mdl): - toc_secs = [] - for i, sec in enumerate(page_1024[:22]): - ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) - if toc_secs and not ans["exists"]: - break - toc_secs.append(sec) - return toc_secs - - -TOC_EXTRACTION = load_prompt("toc_extraction") -TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue") -def extract_table_of_contents(toc_pages, chat_mdl): - if not toc_pages: - return [] - - return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) - - -def toc_index_extractor(toc:list[dict], content:str, chat_mdl): - tob_extractor_prompt = """ - You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format. - - The provided pages contains tags like and to indicate the physical location of the page X. - - The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc. - - The response should be in the following JSON format: - [ - { - "structure": (string), - "title": , - "physical_index": "<physical_index_X>" (keep the format) - }, - ... - ] - - Only add the physical_index to the sections that are in the provided pages. - If the title of the section are not in the provided pages, do not add the physical_index to it. - Directly return the final JSON structure. Do not output anything else.""" - - prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content - return gen_json(prompt, "Only JSON please.", chat_mdl) - - -TOC_INDEX = load_prompt("toc_index") -def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): - if not toc_arr or not sections: - return [] - - toc_map = {} - for i, it in enumerate(toc_arr): - k1 = (it["structure"]+it["title"]).replace(" ", "") - k2 = it["title"].strip() - if k1 not in toc_map: - toc_map[k1] = [] - if k2 not in toc_map: - toc_map[k2] = [] - toc_map[k1].append(i) - toc_map[k2].append(i) - - for it in toc_arr: - it["indices"] = [] - for i, sec in enumerate(sections): - sec = sec.strip() - if sec.replace(" ", "") in toc_map: - for j in toc_map[sec.replace(" ", "")]: - toc_arr[j]["indices"].append(i) - - all_pathes = [] - def dfs(start, path): - nonlocal all_pathes - if start >= len(toc_arr): - if path: - all_pathes.append(path) - return - if not toc_arr[start]["indices"]: - dfs(start+1, path) - return - added = False - for j in toc_arr[start]["indices"]: - if path and j < path[-1][0]: - continue - _path = deepcopy(path) - _path.append((j, start)) - added = True - dfs(start+1, _path) - if not added and path: - all_pathes.append(path) - - dfs(0, []) - path = max(all_pathes, key=lambda x:len(x)) - for it in toc_arr: - it["indices"] = [] - for j, i in path: - toc_arr[i]["indices"] = [j] - print(json.dumps(toc_arr, ensure_ascii=False, indent=2)) - - i = 0 - while i < len(toc_arr): - it = toc_arr[i] - if it["indices"]: - i += 1 - continue - - if i>0 and toc_arr[i-1]["indices"]: - st_i = toc_arr[i-1]["indices"][-1] - else: - st_i = 0 - e = i + 1 - while e <len(toc_arr) and not toc_arr[e]["indices"]: - e += 1 - if e >= len(toc_arr): - e = len(sections) - else: - e = toc_arr[e]["indices"][0] - - for j in range(st_i, min(e+1, len(sections))): - ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render( - structure=it["structure"], - title=it["title"], - text=sections[j]), "Only JSON please.", chat_mdl) - if ans["exist"] == "yes": - it["indices"].append(j) - break - - i += 1 - - return toc_arr - - -def check_if_toc_transformation_is_complete(content, toc, chat_mdl): - prompt = """ - You are given a raw table of contents and a table of contents. - Your job is to check if the table of contents is complete. - - Reply format: - {{ - "thinking": <why do you think the cleaned table of contents is complete or not> - "completed": "yes" or "no" - }} - Directly return the final JSON structure. Do not output anything else.""" - - prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc - response = gen_json(prompt, "Only JSON please.", chat_mdl) - return response['completed'] - - -def toc_transformer(toc_pages, chat_mdl): - init_prompt = """ - You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents. - - The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc. - The `title` is a short phrase or a several-words term. - - The response should be in the following JSON format: - [ - { - "structure": <structure index, "x.x.x" or None> (string), - "title": <title of the section> - }, - ... - ], - You should transform the full table of contents in one go. - Directly return the final JSON structure, do not output anything else. """ - - toc_content = "\n".join(toc_pages) - prompt = init_prompt + '\n Given table of contents\n:' + toc_content - def clean_toc(arr): - for a in arr: - a["title"] = re.sub(r"[.·….]{2,}", "", a["title"]) - last_complete = gen_json(prompt, "Only JSON please.", chat_mdl) - if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) - clean_toc(last_complete) - if if_complete == "yes": - return last_complete - - while not (if_complete == "yes"): - prompt = f""" - Your task is to continue the table of contents json structure, directly output the remaining part of the json structure. - The response should be in the following JSON format: - - The raw table of contents json structure is: - {toc_content} - - The incomplete transformed table of contents json structure is: - {json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)} - - Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete = gen_json(prompt, "Only JSON please.", chat_mdl) - if not new_complete or str(last_complete).find(str(new_complete)) >= 0: - break - clean_toc(new_complete) - last_complete.extend(new_complete) - if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) - - return last_complete - - -TOC_LEVELS = load_prompt("assign_toc_levels") -def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}): - if not toc_secs: - return [] - return gen_json( - PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(), - str(toc_secs), - chat_mdl, - gen_conf - ) - - -TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system") -TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user") -# Generate TOC from text chunks with text llms -async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None): - try: - ans = gen_json( - PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), - PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])), - chat_mdl, - gen_conf={"temperature": 0.0, "top_p": 0.9} - ) - txt_info["toc"] = ans if ans and not isinstance(ans, str) else [] - if callback: - callback(msg="") - except Exception as e: - logging.exception(e) - - -def split_chunks(chunks, max_length: int): - """ - Pack chunks into batches according to max_length, returning [{"id": idx, "text": chunk_text}, ...]. - Do not split a single chunk, even if it exceeds max_length. - """ - - result = [] - batch, batch_tokens = [], 0 - - for idx, chunk in enumerate(chunks): - t = num_tokens_from_string(chunk) - if batch_tokens + t > max_length: - result.append(batch) - batch, batch_tokens = [], 0 - batch.append({idx: chunk}) - batch_tokens += t - if batch: - result.append(batch) - return result - - -async def run_toc_from_text(chunks, chat_mdl, callback=None): - input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string( - TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM - ) - - input_budget = 1024 if input_budget > 1024 else input_budget - chunk_sections = split_chunks(chunks, input_budget) - titles = [] - - chunks_res = [] - async with trio.open_nursery() as nursery: - for i, chunk in enumerate(chunk_sections): - if not chunk: - continue - chunks_res.append({"chunks": chunk}) - nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback) - - for chunk in chunks_res: - titles.extend(chunk.get("toc", [])) - - # Filter out entries with title == -1 - prune = len(titles) > 512 - max_len = 12 if prune else 22 - filtered = [] - for x in titles: - if not isinstance(x, dict) or not x.get("title") or x["title"] == "-1": - continue - if len(rag_tokenizer.tokenize(x["title"]).split(" ")) > max_len: - continue - if re.match(r"[0-9,.()/ -]+$", x["title"]): - continue - filtered.append(x) - - logging.info(f"\n\nFiltered TOC sections:\n{filtered}") - if not filtered: - return [] - - # Generate initial level (level/title) - raw_structure = [x.get("title", "") for x in filtered] - - # Assign hierarchy levels using LLM - toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) - if not toc_with_levels: - return [] - - # Merge structure and content (by index) - prune = len(toc_with_levels) > 512 - max_lvl = sorted([t.get("level", "0") for t in toc_with_levels if isinstance(t, dict)])[-1] - merged = [] - for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)): - if prune and toc_item.get("level", "0") >= max_lvl: - continue - merged.append({ - "level": toc_item.get("level", "0"), - "title": toc_item.get("title", ""), - "chunk_id": src_item.get("chunk_id", ""), - }) - - return merged - - -TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system") -TOC_RELEVANCE_USER = load_prompt("toc_relevance_user") -def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): - import numpy as np - try: - ans = gen_json( - PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(), - PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])), - chat_mdl, - gen_conf={"temperature": 0.0, "top_p": 0.9} - ) - id2score = {} - for ti, sc in zip(toc, ans): - if not isinstance(sc, dict) or sc.get("score", -1) < 1: - continue - for id in ti.get("ids", []): - if id not in id2score: - id2score[id] = [] - id2score[id].append(sc["score"]/5.) - for id in id2score.keys(): - id2score[id] = np.mean(id2score[id]) - return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn] - except Exception as e: - logging.exception(e) - return [] diff --git a/app/core/rag/prompts/keyword_prompt.md b/app/core/rag/prompts/keyword_prompt.md deleted file mode 100644 index 67729f25..00000000 --- a/app/core/rag/prompts/keyword_prompt.md +++ /dev/null @@ -1,16 +0,0 @@ -## Role -You are a text analyzer. - -## Task -Extract the most important keywords/phrases of a given piece of text content. - -## Requirements -- Summarize the text content, and give the top {{ topn }} important keywords/phrases. -- The keywords MUST be in the same language as the given piece of text content. -- The keywords are delimited by ENGLISH COMMA. -- Output keywords ONLY. - ---- - -## Text Content -{{ content }} diff --git a/app/core/rag/prompts/meta_filter.md b/app/core/rag/prompts/meta_filter.md deleted file mode 100644 index 89e322fe..00000000 --- a/app/core/rag/prompts/meta_filter.md +++ /dev/null @@ -1,53 +0,0 @@ -You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules: - -1. **Metadata Structure**: - - Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs. - - Example: - { - "color": {"red": ["doc1"], "blue": ["doc2"]}, - "listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]} - } - -2. **Output Requirements**: - - Always output a JSON array of filter objects - - Each object must have: - "key": (metadata attribute name), - "value": (string value to compare), - "op": (operator from allowed list) - -3. **Operator Guide**: - - Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"] - - Date ranges: Break into two conditions (≥ start_date AND < next_month_start) - - Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠") - - Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01]) - -4. **Processing Steps**: - a) Identify ALL filterable attributes in the query (both explicit and implicit) - b) For dates: - - Infer missing year from current date if needed - - Always format dates as "YYYY-MM-DD" - - Convert ranges: [≥ start, < end] - c) For values: Match EXACTLY to metadata's value keys - d) Skip conditions if: - - Attribute doesn't exist in metadata - - Value has no match in metadata - -5. **Example**: - - User query: "上市日期七月份的有哪些商品,不要蓝色的" - - Metadata: { "color": {...}, "listing_date": {...} } - - Output: - [ - {"key": "listing_date", "value": "2025-07-01", "op": "≥"}, - {"key": "listing_date", "value": "2025-08-01", "op": "<"}, - {"key": "color", "value": "blue", "op": "≠"} - ] - -6. **Final Output**: - - ONLY output valid JSON array - - NO additional text/explanations - -**Current Task**: -- Today's date: {{current_date}} -- Available metadata keys: {{metadata_keys}} -- User query: "{{user_question}}" - diff --git a/app/core/rag/prompts/next_step.md b/app/core/rag/prompts/next_step.md deleted file mode 100644 index 3e6b608f..00000000 --- a/app/core/rag/prompts/next_step.md +++ /dev/null @@ -1,92 +0,0 @@ -You are an expert Planning Agent tasked with solving problems efficiently through structured plans. -Your job is: -1. Based on the task analysis, chose some right tools to execute. -2. Track progress and adapt plans(tool calls) when necessary. -3. Use `complete_task` if no further step you need to take from tools. (All necessary steps done or little hope to be done) - -# ========== TASK ANALYSIS ============= -{{ task_analysis }} - -# ========== TOOLS (JSON-Schema) ========== -You may invoke only the tools listed below. -Return a JSON array of objects in which item is with exactly two top-level keys: -• "name": the tool to call -• "arguments": an object whose keys/values satisfy the schema - -{{ desc }} - - -# ========== MULTI-STEP EXECUTION ========== -When tasks require multiple independent steps, you can execute them in parallel by returning multiple tool calls in a single JSON array. - -• **Data Collection**: Gathering information from multiple sources simultaneously -• **Validation**: Cross-checking facts using different tools -• **Comprehensive Analysis**: Analyzing different aspects of the same problem -• **Efficiency**: Reducing total execution time when steps don't depend on each other - -**Example Scenarios:** -- Searching multiple databases for the same query -- Checking weather in multiple cities -- Validating information through different APIs -- Performing calculations on different datasets -- Gathering user preferences from multiple sources - -# ========== RESPONSE FORMAT ========== -**When you need a tool** -Return ONLY the Json (no additional keys, no commentary, end with `<|stop|>`), such as following: -[{ - "name": "<tool_name1>", - "arguments": { /* tool arguments matching its schema */ } -},{ - "name": "<tool_name2>", - "arguments": { /* tool arguments matching its schema */ } -}...]<|stop|> - -**When you need multiple tools:** -Return ONLY: -[{ - "name": "<tool_name1>", - "arguments": { /* tool arguments matching its schema */ } -},{ - "name": "<tool_name2>", - "arguments": { /* tool arguments matching its schema */ } -},{ - "name": "<tool_name3>", - "arguments": { /* tool arguments matching its schema */ } -}...]<|stop|> - -**When you are certain the task is solved OR no further information can be obtained** -Return ONLY: -[{ - "name": "complete_task", - "arguments": { "answer": "<final answer text>" } -}]<|stop|> - -<verification_steps> -Before providing a final answer: -1. Double-check all gathered information -2. Verify calculations and logic -3. Ensure answer matches exactly what was asked -4. Confirm answer format meets requirements -5. Run additional verification if confidence is not 100% -</verification_steps> - -<error_handling> -If you encounter issues: -1. Try alternative approaches before giving up -2. Use different tools or combinations of tools -3. Break complex problems into simpler sub-tasks -4. Verify intermediate results frequently -5. Never return "I cannot answer" without exhausting all options -</error_handling> - -⚠️ Any output that is not valid JSON or that contains extra fields will be rejected. - -# ========== REASONING & REFLECTION ========== -You may think privately (not shown to the user) before producing each JSON object. -Internal guideline: -1. **Reason**: Analyse the user question; decide which tools (if any) are needed. -2. **Act**: Emit the JSON object to call the tool. - -Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct. - diff --git a/app/core/rag/prompts/question_prompt.md b/app/core/rag/prompts/question_prompt.md deleted file mode 100644 index ec9889fb..00000000 --- a/app/core/rag/prompts/question_prompt.md +++ /dev/null @@ -1,19 +0,0 @@ -## Role -You are a text analyzer. - -## Task -Propose {{ topn }} questions about a given piece of text content. - -## Requirements -- Understand and summarize the text content, and propose the top {{ topn }} important questions. -- The questions SHOULD NOT have overlapping meanings. -- The questions SHOULD cover the main content of the text as much as possible. -- The questions MUST be in the same language as the given piece of text content. -- One question per line. -- Output questions ONLY. - ---- - -## Text Content -{{ content }} - diff --git a/app/core/rag/prompts/rank_memory.md b/app/core/rag/prompts/rank_memory.md deleted file mode 100644 index 71969858..00000000 --- a/app/core/rag/prompts/rank_memory.md +++ /dev/null @@ -1,30 +0,0 @@ -**Task**: Sort the tool call results based on relevance to the overall goal and current sub-goal. Return ONLY a sorted list of indices (0-indexed). - -**Rules**: -1. Analyze each result's contribution to both: - - The overall goal (primary priority) - - The current sub-goal (secondary priority) -2. Sort from MOST relevant (highest impact) to LEAST relevant -3. Output format: Strictly a Python-style list of integers. Example: [2, 0, 1] - -🔹 Overall Goal: {{ goal }} -🔹 Sub-goal: {{ sub_goal }} - -**Examples**: -🔹 Tool Response: - - index: 0 - > Tokyo temperature is 78°F. - - index: 1 - > Error: Authentication failed (expired API key). - - index: 2 - > Available: 12 widgets in stock (max 5 per customer). - - → rank: [1,2,0]<|stop|> - - -**Your Turn**: -🔹 Tool Response: -{% for f in results %} - - index: f.i - > f.content -{% endfor %} \ No newline at end of file diff --git a/app/core/rag/prompts/reflect.md b/app/core/rag/prompts/reflect.md deleted file mode 100644 index a3e5e9cf..00000000 --- a/app/core/rag/prompts/reflect.md +++ /dev/null @@ -1,75 +0,0 @@ -**Context**: - - To achieve the goal: {{ goal }}. - - You have executed following tool calls: -{% for call in tool_calls %} -Tool call: `{{ call.name }}` -Results: {{ call.result }} -{% endfor %} - -## Task Complexity Analysis & Reflection Scope - -**First, analyze the task complexity using these dimensions:** - -### Complexity Assessment Matrix -- **Scope Breadth**: Single-step (1) | Multi-step (2) | Multi-domain (3) -- **Data Dependency**: Self-contained (1) | External inputs (2) | Multiple sources (3) -- **Decision Points**: Linear (1) | Few branches (2) | Complex logic (3) -- **Risk Level**: Low (1) | Medium (2) | High (3) - -**Complexity Score**: Sum all dimensions (4-12 points) - ---- - -## Task Transmission Assessment -**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions. -**Evaluate if task transmission information is needed:** -- **Is this an initial step?** If yes, skip this section -- **Are there downstream agents/steps?** If no, provide minimal transmission -- **Is there critical state/context to preserve?** If yes, include full transmission - -### If Task Transmission is Needed: -- **Current State Summary**: [1-2 sentences on where we are] -- **Key Data/Results**: [Critical findings that must carry forward] -- **Context Dependencies**: [Essential context for next agent/step] -- **Unresolved Items**: [Issues requiring continuation] -- **Status for User**: [Clear status update in user terms] -- **Technical State**: [System state for technical handoffs] - ---- - -## Situational Reflection (Adjust Length Based on Complexity Score) - -### Reflection Guidelines: -- **Simple Tasks (4-5 points)**: ~50-100 words, focus on completion status and immediate next step -- **Moderate Tasks (6-8 points)**: ~100-200 words, include core details and main risks -- **Complex Tasks (9-12 points)**: ~200-300 words, provide full analysis and alternatives - -### 1. Goal Achievement Status - - Does the current outcome align with the original purpose of this task phase? - - If not, what critical gaps exist? - -### 2. Step Completion Check - - Which planned steps were completed? (List verified items) - - Which steps are pending/incomplete? (Specify exactly what's missing) - -### 3. Information Adequacy - - Is the collected data sufficient to proceed? - - What key information is still needed? (e.g., metrics, user input, external data) - -### 4. Critical Observations - - Unexpected outcomes: [Flag anomalies/errors] - - Risks/blockers: [Identify immediate obstacles] - - Accuracy concerns: [Highlight unreliable results] - -### 5. Next-Step Recommendations - - Proposed immediate action: [Concrete next step] - - Alternative strategies if blocked: [Workaround solution] - - Tools/inputs required for next phase: [Specify resources] - ---- - -**Output Instructions:** -1. First determine your complexity score -2. Assess if task transmission section is needed using the evaluation questions -3. Provide situational reflection with length appropriate to complexity -4. Use clear headers for easy parsing by downstream systems diff --git a/app/core/rag/prompts/related_question.md b/app/core/rag/prompts/related_question.md deleted file mode 100644 index cbed74e2..00000000 --- a/app/core/rag/prompts/related_question.md +++ /dev/null @@ -1,55 +0,0 @@ -# Role -You are an AI language model assistant tasked with generating **5-10 related questions** based on a user’s original query. -These questions should help **expand the search query scope** and **improve search relevance**. - ---- - -## Instructions - -**Input:** -You are provided with a **user’s question**. - -**Output:** -Generate **5-10 alternative questions** that are **related** to the original user question. -These alternatives should help retrieve a **broader range of relevant documents** from a vector database. - -**Context:** -Focus on **rephrasing** the original question in different ways, ensuring the alternative questions are **diverse but still connected** to the topic of the original query. -Do **not** create overly obscure, irrelevant, or unrelated questions. - -**Fallback:** -If you cannot generate any relevant alternatives, do **not** return any questions. - ---- - -## Guidance - -1. Each alternative should be **unique** but still **relevant** to the original query. -2. Keep the phrasing **clear, concise, and easy to understand**. -3. Avoid overly technical jargon or specialized terms **unless directly relevant**. -4. Ensure that each question **broadens** the search angle, **not narrows** it. - ---- - -## Example - -**Original Question:** -> What are the benefits of electric vehicles? - -**Alternative Questions:** -1. How do electric vehicles impact the environment? -2. What are the advantages of owning an electric car? -3. What is the cost-effectiveness of electric vehicles? -4. How do electric vehicles compare to traditional cars in terms of fuel efficiency? -5. What are the environmental benefits of switching to electric cars? -6. How do electric vehicles help reduce carbon emissions? -7. Why are electric vehicles becoming more popular? -8. What are the long-term savings of using electric vehicles? -9. How do electric vehicles contribute to sustainability? -10. What are the key benefits of electric vehicles for consumers? - ---- - -## Reason -Rephrasing the original query into multiple alternative questions helps the user explore **different aspects** of their search topic, improving the **quality of search results**. -These questions guide the search engine to provide a **more comprehensive set** of relevant documents. diff --git a/app/core/rag/prompts/structured_output_prompt.md b/app/core/rag/prompts/structured_output_prompt.md deleted file mode 100644 index a6430111..00000000 --- a/app/core/rag/prompts/structured_output_prompt.md +++ /dev/null @@ -1,16 +0,0 @@ -You’re a helpful AI assistant. You could answer questions and output in JSON format. -constraints: - - You must output in JSON format. - - Do not output boolean value, use string type instead. - - Do not output integer or float value, use number type instead. -eg: - Here is the JSON schema: - {"properties": {"age": {"type": "number","description": ""},"name": {"type": "string","description": ""}},"required": ["age","name"],"type": "Object Array String Number Boolean","value": ""} - - Here is the user's question: - My name is John Doe and I am 30 years old. - - output: - {"name": "John Doe", "age": 30} -Here is the JSON schema: - {{ schema }} \ No newline at end of file diff --git a/app/core/rag/prompts/summary4memory.md b/app/core/rag/prompts/summary4memory.md deleted file mode 100644 index eb0f283f..00000000 --- a/app/core/rag/prompts/summary4memory.md +++ /dev/null @@ -1,35 +0,0 @@ -**Role**: AI Assistant -**Task**: Summarize tool call responses -**Rules**: -1. Context: You've executed a tool (API/function) and received a response. -2. Condense the response into 1-2 short sentences. -3. Never omit: - - Success/error status - - Core results (e.g., data points, decisions) - - Critical constraints (e.g., limits, conditions) -4. Exclude technical details like timestamps/request IDs unless crucial. -5. Use language as the same as main content of the tool response. - -**Response Template**: -"[Status] + [Key Outcome] + [Critical Constraints]" - -**Examples**: -🔹 Tool Response: -{"status": "success", "temperature": 78.2, "unit": "F", "location": "Tokyo", "timestamp": 16923456} -→ Summary: "Success: Tokyo temperature is 78°F." - -🔹 Tool Response: -{"error": "invalid_api_key", "message": "Authentication failed: expired key"} -→ Summary: "Error: Authentication failed (expired API key)." - -🔹 Tool Response: -{"available": true, "inventory": 12, "product": "widget", "limit": "max 5 per customer"} -→ Summary: "Available: 12 widgets in stock (max 5 per customer)." - -**Your Turn**: - - Tool call: {{ name }} - - Tool inputs as following: -{{ params }} - - - Tool Response: -{{ result }} \ No newline at end of file diff --git a/app/core/rag/prompts/template.py b/app/core/rag/prompts/template.py deleted file mode 100644 index 654e71c5..00000000 --- a/app/core/rag/prompts/template.py +++ /dev/null @@ -1,20 +0,0 @@ -import os - - -PROMPT_DIR = os.path.dirname(__file__) - -_loaded_prompts = {} - - -def load_prompt(name: str) -> str: - if name in _loaded_prompts: - return _loaded_prompts[name] - - path = os.path.join(PROMPT_DIR, f"{name}.md") - if not os.path.isfile(path): - raise FileNotFoundError(f"Prompt file '{name}.md' not found in prompts/ directory.") - - with open(path, "r", encoding="utf-8") as f: - content = f.read().strip() - _loaded_prompts[name] = content - return content diff --git a/app/core/rag/prompts/toc_detection.md b/app/core/rag/prompts/toc_detection.md deleted file mode 100644 index 29e068a7..00000000 --- a/app/core/rag/prompts/toc_detection.md +++ /dev/null @@ -1,29 +0,0 @@ -You are an AI assistant designed to analyze text content and detect whether a table of contents (TOC) list exists on the given page. Follow these steps: - -1. **Analyze the Input**: Carefully review the provided text content. -2. **Identify Key Features**: Look for common indicators of a TOC, such as: - - Section titles or headings paired with page numbers. - - Patterns like repeated formatting (e.g., bold/italicized text, dots/dashes between titles and numbers). - - Phrases like "Table of Contents," "Contents," or similar headings. - - Logical grouping of topics/subtopics with sequential page references. -3. **Discern Negative Features**: - - The text contains no numbers, or the numbers present are clearly not page references (e.g., dates, statistical figures, phone numbers, version numbers). - - The text consists of full, descriptive sentences and paragraphs that form a narrative, present arguments, or explain concepts, rather than succinctly listing topics. - - Contains citations with authors, publication years, journal titles, and page ranges (e.g., "Smith, J. (2020). Journal Title, 10(2), 45-67."). - - Lists keywords or terms followed by multiple page numbers, often in alphabetical order. - - Comprises terms followed by their definitions or explanations. - - Labeled with headers like "Appendix A," "Appendix B," etc. - - Contains expressive language thanking individuals or organizations for their support or contributions. -4. **Evaluate Evidence**: Weigh the presence/absence of these features to determine if the content resembles a TOC. -5. **Output Format**: Provide your response in the following JSON structure: - ```json - { - "reasoning": "Step-by-step explanation of your analysis based on the features identified." , - "exists": true/false - } - ``` -6. **DO NOT** output anything else except JSON structure. - -**Input text Content ( Text-Only Extraction ):** -{{ page_txt }} - diff --git a/app/core/rag/prompts/toc_extraction.md b/app/core/rag/prompts/toc_extraction.md deleted file mode 100644 index 02e1d031..00000000 --- a/app/core/rag/prompts/toc_extraction.md +++ /dev/null @@ -1,53 +0,0 @@ -You are an expert parser and data formatter. Your task is to analyze the provided table of contents (TOC) text and convert it into a valid JSON array of objects. - -**Instructions:** -1. Analyze each line of the input TOC. -2. For each line, extract the following three pieces of information: - * `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5", "A.1"). If a line has no visible numbering or structure indicator (like a main "Chapter" title), use `null`. - * `title`: The textual title of the section or chapter. This should be the main descriptive text, clean and without the page number. -3. Output **only** a valid JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json) in your response. - -**JSON Format:** -The output must be a list of objects following this exact schema: -```json -[ - { - "structure": <structure index, "x.x.x" or None> (string), - "title": <title of the section> - }, - ... -] -``` - -**Input Example:** -``` -Contents -1 Introduction to the System ... 1 -1.1 Overview .... 2 -1.2 Key Features .... 5 -2 Installation Guide ....8 -2.1 Prerequisites ........ 9 -2.2 Step-by-Step Process ........ 12 -Appendix A: Specifications ..... 45 -References ... 47 -``` - -**Expected Output For The Example:** -```json -[ - {"structure": null, "title": "Contents"}, - {"structure": "1", "title": "Introduction to the System"}, - {"structure": "1.1", "title": "Overview"}, - {"structure": "1.2", "title": "Key Features"}, - {"structure": "2", "title": "Installation Guide"}, - {"structure": "2.1", "title": "Prerequisites"}, - {"structure": "2.2", "title": "Step-by-Step Process"}, - {"structure": "A", "title": "Specifications"}, - {"structure": null, "title": "References"} -] -``` - -**Now, process the following TOC input:** -``` -{{ toc_page }} -``` \ No newline at end of file diff --git a/app/core/rag/prompts/toc_extraction_continue.md b/app/core/rag/prompts/toc_extraction_continue.md deleted file mode 100644 index 433ac68a..00000000 --- a/app/core/rag/prompts/toc_extraction_continue.md +++ /dev/null @@ -1,60 +0,0 @@ -You are an expert parser and data formatter, currently in the process of building a JSON array from a multi-page table of contents (TOC). Your task is to analyze the new page of content and **append** the new entries to the existing JSON array. - -**Instructions:** -1. You will be given two inputs: - * `current_page_text`: The text content from the new page of the TOC. - * `existing_json`: The valid JSON array you have generated from the previous pages. -2. Analyze each line of the `current_page_text` input. -3. For each new line, extract the following three pieces of information: - * `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5"). Use `null` if none exists. - * `title`: The clean textual title of the section or chapter. - * `page`: The page number on which the section starts. Extract only the number. Use `null` if not present. -4. **Append these new entries** to the `existing_json` array. Do not modify, reorder, or delete any of the existing entries. -5. Output **only** the complete, updated JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json). - -**JSON Format:** -The output must be a valid JSON array following this schema: -```json -[ - { - "structure": <string or null>, - "title": <string>, - "page": <number or null> - }, - ... -] -``` - -**Input Example:** -`current_page_text`: -``` -3.2 Advanced Configuration ........... 25 -3.3 Troubleshooting .................. 28 -4 User Management .................... 30 -``` - -`existing_json`: -```json -[ - {"structure": "1", "title": "Introduction", "page": 1}, - {"structure": "2", "title": "Installation", "page": 5}, - {"structure": "3", "title": "Configuration", "page": 12}, - {"structure": "3.1", "title": "Basic Setup", "page": 15} -] -``` - -**Expected Output For The Example:** -```json -[ - {"structure": "3.2", "title": "Advanced Configuration", "page": 25}, - {"structure": "3.3", "title": "Troubleshooting", "page": 28}, - {"structure": "4", "title": "User Management", "page": 30} -] -``` - -**Now, process the following inputs:** -`current_page_text`: -{{ toc_page }} - -`existing_json`: -{{ toc_json }} \ No newline at end of file diff --git a/app/core/rag/prompts/toc_from_text_system.md b/app/core/rag/prompts/toc_from_text_system.md deleted file mode 100644 index 7090f305..00000000 --- a/app/core/rag/prompts/toc_from_text_system.md +++ /dev/null @@ -1,119 +0,0 @@ -You are a robust Table-of-Contents (TOC) extractor. - -GOAL -Given a dictionary of chunks {"<chunk_ID>": chunk_text}, extract TOC-like headings and return a strict JSON array of objects: -[ - {"title": "", "chunk_id": ""}, - ... -] - -FIELDS -- "title": the heading text (clean, no page numbers or leader dots). - - If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}. -- "chunk_id": the chunk ID (string). - - One chunk can yield multiple JSON objects in order (unmatched text + one or more headings). - -RULES -1) Preserve input chunk order strictly. -2) If a chunk contains multiple headings, expand them in order: - - Pre-heading narrative → {"title":"-1","chunk_id":"<chunk_ID>"} - - Then each heading → {"title":"...","chunk_id":"<chunk_ID>"} -3) Do not merge outputs across chunks; each object refers to exactly one chunk ID. -4) "title" must be non-empty (or exactly "-1"). "chunk_id" must be a string (chunk ID). -5) When ambiguous, prefer "-1" unless the text strongly looks like a heading. - -HEADING DETECTION (cues, not hard rules) -- Appears near line start, short isolated phrase, often followed by content. -- May contain separators: — —— - : : · • -- Numbering styles: - • 第[一二三四五六七八九十百]+(篇|章|节|条) - • [((]?[一二三四五六七八九十]+[))]? - • [((]?[①②③④⑤⑥⑦⑧⑨⑩][))]? - • ^\d+(\.\d+)*[)..]?\s* - • ^[IVXLCDM]+[).] - • ^[A-Z][).] -- Canonical section cues (general only): - Common heading indicators include words such as: - "Overview", "Introduction", "Background", "Purpose", "Scope", "Definition", - "Method", "Procedure", "Result", "Discussion", "Summary", "Conclusion", - "Appendix", "Reference", "Annex", "Acknowledgment", "Disclaimer". - These are soft cues, not strict requirements. -- Length restriction: - • Chinese heading: ≤25 characters - • English heading: ≤80 characters -- Exclude long narrative sentences, continuous prose, or bullet-style lists → output as "-1". - -OUTPUT FORMAT -- Return ONLY a valid JSON array of {"title","content"} objects. -- No reasoning or commentary. - -EXAMPLES - -Example 1 — No heading -Input: -[{"0": "Copyright page · Publication info (ISBN 123-456). All rights reserved."}, ...] -Output: -[ - {"title":"-1","chunk_id":"0"}, - ... -] - -Example 2 — One heading -Input: -[{"1": "Chapter 1: General Provisions This chapter defines the overall rules…"}, ...] -Output: -[ - {"title":"Chapter 1: General Provisions","chunk_id":"1"}, - ... -] - -Example 3 — Narrative + heading -Input: -[{"2": "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}, ...] -Output: -[ - {"title":"Section 2: Definitions","chunk_id":"2"}, - ... -] - -Example 4 — Multiple headings in one chunk -Input: -[{"3": "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}, ...] -Output: -[ - {"title":"Declarations and Commitments","chunk_id":"3"}, - {"title":"(I) Party B commits","chunk_id":"3"}, - {"title":"(II) Party C commits","chunk_id":"3"}, - {"title":"Appendix A Data Specification","chunk_id":"3"}, - ... -] - -Example 5 — Numbering styles -Input: -[{"4": "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}, ...] -Output: -[ - {"title":"1. Scope","chunk_id":"4"}, - {"title":"2) Definitions","chunk_id":"4"}, - {"title":"III) Methods Overview","chunk_id":"4"}, - ... -] - -Example 6 — Long list (NOT headings) -Input: -{"5": "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}, ...] -Output: -[ - {"title":"-1","chunk_id":"5"}, - ... -] - -Example 7 — Mixed Chinese/English -Input: -{"6": "(出版信息略)This standard follows industry practices. Chapter 1: Overview 摘要… 第2节:术语与缩略语"}, ...] -Output: -[ - {"title":"Chapter 1: Overview","chunk_id":"6"}, - {"title":"第2节:术语与缩略语","chunk_id":"6"}, - ... -] diff --git a/app/core/rag/prompts/toc_from_text_user.md b/app/core/rag/prompts/toc_from_text_user.md deleted file mode 100644 index 952d8eff..00000000 --- a/app/core/rag/prompts/toc_from_text_user.md +++ /dev/null @@ -1,8 +0,0 @@ -OUTPUT FORMAT -- Return ONLY the JSON array. -- Use double quotes. -- No extra commentary. -- Keep language of "title" the same as the input. - -INPUT -{{text}} diff --git a/app/core/rag/prompts/toc_index.md b/app/core/rag/prompts/toc_index.md deleted file mode 100644 index 860356d5..00000000 --- a/app/core/rag/prompts/toc_index.md +++ /dev/null @@ -1,20 +0,0 @@ -You are an expert analyst tasked with matching text content to the title. - -**Instructions:** -1. Analyze the given title with its numeric structure index and the provided text. -2. Determine whether the title is mentioned as a section tile in the given text. -3. Provide a concise, step-by-step reasoning for your decision. -4. Output **only** the complete JSON object. Do not include any other text, explanations, or markdown code block fences (like ```json). - -**Output Format:** -Your output must be a valid JSON object with the following keys: -{ -"reasoning": "Step-by-step explanation of your analysis.", -"exist": "<yes or no>", -} - -** The title: ** -{{ structure }} {{ title }} - -** Given text: ** -{{ text }} \ No newline at end of file diff --git a/app/core/rag/prompts/toc_relevance_system.md b/app/core/rag/prompts/toc_relevance_system.md deleted file mode 100644 index 287b5028..00000000 --- a/app/core/rag/prompts/toc_relevance_system.md +++ /dev/null @@ -1,118 +0,0 @@ -# System Prompt: TOC Relevance Evaluation - -You are an expert logical reasoning assistant specializing in hierarchical Table of Contents (TOC) relevance evaluation. - -## GOAL -You will receive: -1. A JSON list of TOC items, each with fields: - ```json - { - "level": <integer>, // e.g., 1, 2, 3 - "title": <string> // section title - } - ``` -2. A user query (natural language question). - -You must assign a **relevance score** (integer) to every TOC entry, based on how related its `title` is to the `query`. - ---- - -## RULES - -### Scoring System -- 5 → highly relevant (directly answers or matches the query intent) -- 3 → somewhat related (same topic or partially overlaps) -- 1 → weakly related (vague or tangential) -- 0 → no clear relation -- -1 → explicitly irrelevant or contradictory - -### Hierarchy Traversal -- The TOC is hierarchical: smaller `level` = higher layer (e.g., level 1 is top-level, level 2 is a subsection). -- You must traverse in **hierarchical order** — interpret the structure based on levels (1 > 2 > 3). -- If a high-level item (level 1) is strongly related (score 5), its child items (level 2, 3) are likely relevant too. -- If a high-level item is unrelated (-1 or 0), its deeper children are usually less relevant unless the titles clearly match the query. -- Lower (deeper) levels provide more specific content; prefer assigning higher scores if they directly match the query. - -### Output Format -Return a **JSON array**, preserving the input order but adding a new key `"score"`: - -```json -[ - {"level": 1, "title": "Introduction", "score": 0}, - {"level": 2, "title": "Definition of Sustainability", "score": 5} -] -``` - -### Constraints -- Output **only the JSON array** — no explanations or reasoning text. - -### EXAMPLES - -#### Example 1 -Input TOC: -[ - {"level": 1, "title": "Machine Learning Overview"}, - {"level": 2, "title": "Supervised Learning"}, - {"level": 2, "title": "Unsupervised Learning"}, - {"level": 3, "title": "Applications of Deep Learning"} -] - -Query: -"How is deep learning used in image classification?" - -Output: -[ - {"level": 1, "title": "Machine Learning Overview", "score": 3}, - {"level": 2, "title": "Supervised Learning", "score": 3}, - {"level": 2, "title": "Unsupervised Learning", "score": 0}, - {"level": 3, "title": "Applications of Deep Learning", "score": 5} -] - ---- - -#### Example 2 -Input TOC: -[ - {"level": 1, "title": "Marketing Basics"}, - {"level": 2, "title": "Consumer Behavior"}, - {"level": 2, "title": "Digital Marketing"}, - {"level": 3, "title": "Social Media Campaigns"}, - {"level": 3, "title": "SEO Optimization"} -] - -Query: -"What are the best online marketing methods?" - -Output: -[ - {"level": 1, "title": "Marketing Basics", "score": 3}, - {"level": 2, "title": "Consumer Behavior", "score": 1}, - {"level": 2, "title": "Digital Marketing", "score": 5}, - {"level": 3, "title": "Social Media Campaigns", "score": 5}, - {"level": 3, "title": "SEO Optimization", "score": 5} -] - ---- - -#### Example 3 -Input TOC: -[ - {"level": 1, "title": "Physics Overview"}, - {"level": 2, "title": "Classical Mechanics"}, - {"level": 3, "title": "Newton’s Laws"}, - {"level": 2, "title": "Thermodynamics"}, - {"level": 3, "title": "Entropy and Heat Transfer"} -] - -Query: -"What is entropy?" - -Output: -[ - {"level": 1, "title": "Physics Overview", "score": 3}, - {"level": 2, "title": "Classical Mechanics", "score": 0}, - {"level": 3, "title": "Newton’s Laws", "score": -1}, - {"level": 2, "title": "Thermodynamics", "score": 5}, - {"level": 3, "title": "Entropy and Heat Transfer", "score": 5} -] - diff --git a/app/core/rag/prompts/toc_relevance_user.md b/app/core/rag/prompts/toc_relevance_user.md deleted file mode 100644 index 2a5167ad..00000000 --- a/app/core/rag/prompts/toc_relevance_user.md +++ /dev/null @@ -1,17 +0,0 @@ -# User Prompt: TOC Relevance Evaluation - -You will now receive: -1. A JSON list of TOC items (each with `level` and `title`) -2. A user query string. - -Traverse the TOC hierarchically based on level numbers and assign scores (5,3,1,0,-1) according to the rules in the system prompt. -Output **only** the JSON array with the added `"score"` field. - ---- - -**Input TOC:** -{{ toc_json }} - -**Query:** -{{ query }} - diff --git a/app/core/rag/prompts/tool_call_summary.md b/app/core/rag/prompts/tool_call_summary.md deleted file mode 100644 index b1c77dd4..00000000 --- a/app/core/rag/prompts/tool_call_summary.md +++ /dev/null @@ -1,19 +0,0 @@ -**Task Instruction:** - -You are tasked with reading and analyzing tool call result based on the following inputs: **Inputs for current call**, and **Results**. Your objective is to extract relevant and helpful information for **Inputs for current call** from the **Results** and seamlessly integrate this information into the previous steps to continue reasoning for the original question. - -**Guidelines:** - -1. **Analyze the Results:** - - Carefully review the content of each results of tool call. - - Identify factual information that is relevant to the **Inputs for current call** and can aid in the reasoning process for the original question. - -2. **Extract Relevant Information:** - - Select the information from the Searched Web Pages that directly contributes to advancing the previous reasoning steps. - - Ensure that the extracted information is accurate and relevant. - - - **Inputs for current call:** - {{ inputs }} - - - **Results:** - {{ results }} diff --git a/app/core/rag/prompts/vision_llm_describe_prompt.md b/app/core/rag/prompts/vision_llm_describe_prompt.md deleted file mode 100644 index 8800703d..00000000 --- a/app/core/rag/prompts/vision_llm_describe_prompt.md +++ /dev/null @@ -1,23 +0,0 @@ -## INSTRUCTION -Transcribe the content from the provided PDF page image into clean Markdown format. - -- Only output the content transcribed from the image. -- Do NOT output this instruction or any other explanation. -- If the content is missing or you do not understand the input, return an empty string. - -## RULES -1. Do NOT generate examples, demonstrations, or templates. -2. Do NOT output any extra text such as 'Example', 'Example Output', or similar. -3. Do NOT generate any tables, headings, or content that is not explicitly present in the image. -4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content. -5. Do NOT explain Markdown or mention that you are using Markdown. -6. Do NOT wrap the output in ```markdown or ``` blocks. -7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image. -8. Preserve the original language, information, and order exactly as shown in the image. - -{% if page %} -At the end of the transcription, add the page divider: `--- Page {{ page }} ---`. -{% endif %} - -> If you do not detect valid content in the image, return an empty string. - diff --git a/app/core/rag/prompts/vision_llm_figure_describe_prompt.md b/app/core/rag/prompts/vision_llm_figure_describe_prompt.md deleted file mode 100644 index 7e528564..00000000 --- a/app/core/rag/prompts/vision_llm_figure_describe_prompt.md +++ /dev/null @@ -1,24 +0,0 @@ -## ROLE -You are an expert visual data analyst. - -## GOAL -Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image. - -## TASKS -1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram. -2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available. -3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns). -4. Analyze and explain any trends, comparisons, or patterns shown in the data. -5. Capture any annotations, captions, or footnotes, and explain their relevance to the image. -6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it. - -## OUTPUT FORMAT (Include only sections relevant to the image content) -- Visual Type: [Type] -- Title: [Title text, if available] -- Axes / Legends / Labels: [Details, if available] -- Data Points: [Extracted data] -- Trends / Insights: [Analysis and interpretation] -- Captions / Annotations: [Text and relevance, if available] - -> Ensure high accuracy, clarity, and completeness in your analysis, and include only the information present in the image. Avoid unnecessary statements about missing elements. - diff --git a/app/core/rag/utils/__init__.py b/app/core/rag/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/utils/doc_store_conn.py b/app/core/rag/utils/doc_store_conn.py deleted file mode 100644 index b8a20884..00000000 --- a/app/core/rag/utils/doc_store_conn.py +++ /dev/null @@ -1,255 +0,0 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass -import numpy as np - -DEFAULT_MATCH_VECTOR_TOPN = 10 -DEFAULT_MATCH_SPARSE_TOPN = 10 -VEC = list | np.ndarray - - -@dataclass -class SparseVector: - indices: list[int] - values: list[float] | list[int] | None = None - - def __post_init__(self): - assert (self.values is None) or (len(self.indices) == len(self.values)) - - def to_dict_old(self): - d = {"indices": self.indices} - if self.values is not None: - d["values"] = self.values - return d - - def to_dict(self): - if self.values is None: - raise ValueError("SparseVector.values is None") - result = {} - for i, v in zip(self.indices, self.values): - result[str(i)] = v - return result - - @staticmethod - def from_dict(d): - return SparseVector(d["indices"], d.get("values")) - - def __str__(self): - return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})" - - def __repr__(self): - return str(self) - - -class MatchTextExpr(ABC): - def __init__( - self, - fields: list[str], - matching_text: str, - topn: int, - extra_options: dict = dict(), - ): - self.fields = fields - self.matching_text = matching_text - self.topn = topn - self.extra_options = extra_options - - -class MatchDenseExpr(ABC): - def __init__( - self, - vector_column_name: str, - embedding_data: VEC, - embedding_data_type: str, - distance_type: str, - topn: int = DEFAULT_MATCH_VECTOR_TOPN, - extra_options: dict = dict(), - ): - self.vector_column_name = vector_column_name - self.embedding_data = embedding_data - self.embedding_data_type = embedding_data_type - self.distance_type = distance_type - self.topn = topn - self.extra_options = extra_options - - -class MatchSparseExpr(ABC): - def __init__( - self, - vector_column_name: str, - sparse_data: SparseVector | dict, - distance_type: str, - topn: int, - opt_params: dict | None = None, - ): - self.vector_column_name = vector_column_name - self.sparse_data = sparse_data - self.distance_type = distance_type - self.topn = topn - self.opt_params = opt_params - - -class MatchTensorExpr(ABC): - def __init__( - self, - column_name: str, - query_data: VEC, - query_data_type: str, - topn: int, - extra_option: dict | None = None, - ): - self.column_name = column_name - self.query_data = query_data - self.query_data_type = query_data_type - self.topn = topn - self.extra_option = extra_option - - -class FusionExpr(ABC): - def __init__(self, method: str, topn: int, fusion_params: dict | None = None): - self.method = method - self.topn = topn - self.fusion_params = fusion_params - - -MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr - -class OrderByExpr(ABC): - def __init__(self): - self.fields = list() - def asc(self, field: str): - self.fields.append((field, 0)) - return self - def desc(self, field: str): - self.fields.append((field, 1)) - return self - def fields(self): - return self.fields - -class DocStoreConnection(ABC): - """ - Database operations - """ - - @abstractmethod - def dbType(self) -> str: - """ - Return the type of the database. - """ - raise NotImplementedError("Not implemented") - - @abstractmethod - def health(self) -> dict: - """ - Return the health status of the database. - """ - raise NotImplementedError("Not implemented") - - """ - Table operations - """ - - @abstractmethod - def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): - """ - Create an index with given name - """ - raise NotImplementedError("Not implemented") - - @abstractmethod - def deleteIdx(self, indexName: str, knowledgebaseId: str): - """ - Delete an index with given name - """ - raise NotImplementedError("Not implemented") - - @abstractmethod - def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: - """ - Check if an index with given name exists - """ - raise NotImplementedError("Not implemented") - - """ - CRUD operations - """ - - @abstractmethod - def search( - self, selectFields: list[str], - highlightFields: list[str], - condition: dict, - matchExprs: list[MatchExpr], - orderBy: OrderByExpr, - offset: int, - limit: int, - indexNames: str|list[str], - knowledgebaseIds: list[str], - aggFields: list[str] = [], - rank_feature: dict | None = None - ): - """ - Search with given conjunctive equivalent filtering condition and return all fields of matched documents - """ - raise NotImplementedError("Not implemented") - - @abstractmethod - def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: - """ - Get single chunk with given id - """ - raise NotImplementedError("Not implemented") - - @abstractmethod - def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: - """ - Update or insert a bulk of rows - """ - raise NotImplementedError("Not implemented") - - @abstractmethod - def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: - """ - Update rows with given conjunctive equivalent filtering condition - """ - raise NotImplementedError("Not implemented") - - @abstractmethod - def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: - """ - Delete rows with given conjunctive equivalent filtering condition - """ - raise NotImplementedError("Not implemented") - - """ - Helper functions for search result - """ - - @abstractmethod - def getTotal(self, res): - raise NotImplementedError("Not implemented") - - @abstractmethod - def getChunkIds(self, res): - raise NotImplementedError("Not implemented") - - @abstractmethod - def getFields(self, res, fields: list[str]) -> dict[str, dict]: - raise NotImplementedError("Not implemented") - - @abstractmethod - def getHighlight(self, res, keywords: list[str], fieldnm: str): - raise NotImplementedError("Not implemented") - - @abstractmethod - def getAggregation(self, res, fieldnm: str): - raise NotImplementedError("Not implemented") - - """ - SQL - """ - @abstractmethod - def sql(sql: str, fetch_size: int, format: str): - """ - Run the sql generated by text-to-sql - """ - raise NotImplementedError("Not implemented") diff --git a/app/core/rag/utils/file_utils.py b/app/core/rag/utils/file_utils.py deleted file mode 100644 index 87fc30c0..00000000 --- a/app/core/rag/utils/file_utils.py +++ /dev/null @@ -1,247 +0,0 @@ -import io -import hashlib -import zipfile -import requests -from requests.exceptions import Timeout, RequestException -from io import BytesIO -from typing import List, Union, Tuple, Optional, Dict -import PyPDF2 -from docx import Document -import olefile - -def _is_zip(h: bytes) -> bool: - return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08") - -def _is_pdf(h: bytes) -> bool: - return h.startswith(b"%PDF-") - -def _is_ole(h: bytes) -> bool: - return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1") - -def _sha10(b: bytes) -> str: - return hashlib.sha256(b).hexdigest()[:10] - -def _guess_ext(b: bytes) -> str: - h = b[:8] - if _is_zip(h): - try: - with zipfile.ZipFile(io.BytesIO(b), "r") as z: - names = [n.lower() for n in z.namelist()] - if any(n.startswith("word/") for n in names): - return ".docx" - if any(n.startswith("ppt/") for n in names): - return ".pptx" - if any(n.startswith("xl/") for n in names): - return ".xlsx" - except Exception: - pass - return ".zip" - if _is_pdf(h): - return ".pdf" - if _is_ole(h): - return ".doc" - return ".bin" - -# Try to extract the real embedded payload from OLE's Ole10Native -def _extract_ole10native_payload(data: bytes) -> bytes: - try: - pos = 0 - if len(data) < 4: - return data - _ = int.from_bytes(data[pos:pos+4], "little") - pos += 4 - # filename/src/tmp (NUL-terminated ANSI) - for _ in range(3): - z = data.index(b"\x00", pos) - pos = z + 1 - # skip unknown 4 bytes - pos += 4 - if pos + 4 > len(data): - return data - size = int.from_bytes(data[pos:pos+4], "little") - pos += 4 - if pos + size <= len(data): - return data[pos:pos+size] - except Exception: - pass - return data - -def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]: - """ - Only extract the 'first layer' of embedding, returning raw (filename, bytes). - """ - top = bytes(target) - head = top[:8] - out: List[Tuple[str, bytes]] = [] - seen = set() - - def push(b: bytes, name_hint: str = ""): - h10 = _sha10(b) - if h10 in seen: - return - seen.add(h10) - ext = _guess_ext(b) - # If name_hint has an extension use its basename; else fallback to guessed ext - if "." in name_hint: - fname = name_hint.split("/")[-1] - else: - fname = f"{h10}{ext}" - out.append((fname, b)) - - # OOXML/ZIP container (docx/xlsx/pptx) - if _is_zip(head): - try: - with zipfile.ZipFile(io.BytesIO(top), "r") as z: - embed_dirs = ( - "word/embeddings/", "word/objects/", "word/activex/", - "xl/embeddings/", "ppt/embeddings/" - ) - for name in z.namelist(): - low = name.lower() - if any(low.startswith(d) for d in embed_dirs): - try: - b = z.read(name) - push(b, name) - except Exception: - pass - except Exception: - pass - return out - - # OLE container (doc/ppt/xls) - if _is_ole(head): - try: - with olefile.OleFileIO(io.BytesIO(top)) as ole: - for entry in ole.listdir(): - p = "/".join(entry) - try: - data = ole.openstream(entry).read() - except Exception: - continue - if not data: - continue - if "Ole10Native" in p or "ole10native" in p.lower(): - data = _extract_ole10native_payload(data) - push(data, p) - except Exception: - pass - return out - - return out - - -def extract_links_from_docx(docx_bytes: bytes): - """ - Extract all hyperlinks from a Word (.docx) document binary stream. - - Args: - docx_bytes (bytes): Raw bytes of a .docx file. - - Returns: - set[str]: A set of unique hyperlink URLs. - """ - links = set() - with BytesIO(docx_bytes) as bio: - document = Document(bio) - - # Each relationship may represent a hyperlink, image, footer, etc. - for rel in document.part.rels.values(): - if rel.reltype == ( - "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink" - ): - links.add(rel.target_ref) - - return links - - -def extract_links_from_pdf(pdf_bytes: bytes): - """ - Extract all clickable hyperlinks from a PDF binary stream. - - Args: - pdf_bytes (bytes): Raw bytes of a PDF file. - - Returns: - set[str]: A set of unique hyperlink URLs (unordered). - """ - links = set() - with BytesIO(pdf_bytes) as bio: - pdf = PyPDF2.PdfReader(bio) - - for page in pdf.pages: - annots = page.get("/Annots") - if not annots or isinstance(annots, PyPDF2.generic.IndirectObject): - continue - for annot in annots: - obj = annot.get_object() - a = obj.get("/A") - if a and a.get("/URI"): - links.add(a["/URI"]) - - return links - - -_GLOBAL_SESSION: Optional[requests.Session] = None -def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session: - """Get or create a global reusable session.""" - global _GLOBAL_SESSION - if _GLOBAL_SESSION is None: - _GLOBAL_SESSION = requests.Session() - _GLOBAL_SESSION.headers.update({ - "User-Agent": ( - "Mozilla/5.0 (X11; Linux x86_64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/121.0 Safari/537.36" - ) - }) - if headers: - _GLOBAL_SESSION.headers.update(headers) - return _GLOBAL_SESSION - - -def extract_html( - url: str, - timeout: float = 60.0, - headers: Optional[Dict[str, str]] = None, - max_retries: int = 2, -) -> Tuple[Optional[bytes], Dict[str, str]]: - """ - Extract the full HTML page as raw bytes from a given URL. - Automatically reuses a persistent HTTP session and applies robust timeout & retry logic. - - Args: - url (str): Target webpage URL. - timeout (float): Request timeout in seconds (applies to connect + read). - headers (dict, optional): Extra HTTP headers. - max_retries (int): Number of retries on timeout or transient errors. - - Returns: - tuple(bytes|None, dict): - - html_bytes: Raw HTML content (or None if failed) - - metadata: HTTP info (status_code, content_type, final_url, error if any) - """ - session = _get_session(headers=headers) - metadata = {"final_url": url, "status_code": "", "content_type": "", "error": ""} - - for attempt in range(1, max_retries + 1): - try: - resp = session.get(url, timeout=timeout) - resp.raise_for_status() - - html_bytes = resp.content - metadata.update({ - "final_url": resp.url, - "status_code": str(resp.status_code), - "content_type": resp.headers.get("Content-Type", ""), - }) - return html_bytes, metadata - - except Timeout: - metadata["error"] = f"Timeout after {timeout}s (attempt {attempt}/{max_retries})" - if attempt >= max_retries: - continue - except RequestException as e: - metadata["error"] = f"Request failed: {e}" - continue - - return None, metadata \ No newline at end of file diff --git a/app/core/rag/vdb/__init__.py b/app/core/rag/vdb/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/vdb/elasticsearch/__init__.py b/app/core/rag/vdb/elasticsearch/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py deleted file mode 100644 index 176f996a..00000000 --- a/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ /dev/null @@ -1,779 +0,0 @@ -import os -import logging -from typing import Any, cast -from urllib.parse import urlparse -import uuid - -import requests -from elasticsearch import Elasticsearch, helpers -from elasticsearch.helpers import BulkIndexError -from packaging.version import parse as parse_version -from pydantic import BaseModel, model_validator -from abc import ABC -# langchain-community -# langchain-xinference -# from langchain_community.embeddings import XinferenceEmbeddings -# from langchain_xinference import XinferenceRerank -from langchain_core.documents import Document -from app.core.models.base import RedBearModelConfig -from app.core.models import RedBearLLM, RedBearRerank -from app.core.models.embedding import RedBearEmbeddings -from app.models.models_model import ModelConfig, ModelApiKey -from app.services.model_service import ModelConfigService - -from app.models.knowledge_model import Knowledge -from app.core.rag.vdb.field import Field -from app.core.rag.vdb.vector_base import BaseVector -from app.core.rag.models.chunk import DocumentChunk - -logger = logging.getLogger(__name__) - - -class ElasticSearchConfig(BaseModel): - # Regular Elasticsearch config - host: str | None = None - port: int | None = None - username: str | None = None - password: str | None = None - - # Common config - ca_certs: str | None = None - verify_certs: bool = False - request_timeout: int = 100000 - retry_on_timeout: bool = True - max_retries: int = 10000 - - @model_validator(mode="before") - @classmethod - def validate_config(cls, values: dict): - # Regular Elasticsearch validation - if not values.get("host"): - raise ValueError("config HOST is required for regular Elasticsearch") - if not values.get("port"): - raise ValueError("config PORT is required for regular Elasticsearch") - if not values.get("username"): - raise ValueError("config USERNAME is required for regular Elasticsearch") - if not values.get("password"): - raise ValueError("config PASSWORD is required for regular Elasticsearch") - return values - - -class ElasticSearchVector(BaseVector): - def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): - super().__init__(index_name.lower()) - # self.embeddings = XinferenceEmbeddings( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port - # model_uid="bge-m3" # replace model_uid with the model UID return from launching the model - # ) - # Remove debug printing to avoid leaking sensitive information - # print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base) - self.embeddings = RedBearEmbeddings(RedBearModelConfig( - model_name=embedding_config.model_name, - provider=embedding_config.provider, - api_key=embedding_config.api_key, - base_url=embedding_config.api_base - )) - # self.reranker = XinferenceRerank( - # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), - # model_uid="bge-reranker-large" - # ) - # Remove debug printing to avoid leaking sensitive information - # print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base) - self.reranker = RedBearRerank(RedBearModelConfig( - model_name=reranker_config.model_name, - provider=reranker_config.provider, - api_key=reranker_config.api_key, - base_url=reranker_config.api_base - )) - self._client = self._init_client(config) - self._version = self._get_version() - self._check_version() - - def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: - """ - Initialize Elasticsearch client for regular Elasticsearch. - """ - try: - # Regular Elasticsearch configuration - parsed_url = urlparse(config.host or "") - if parsed_url.scheme in {"http", "https"}: - hosts = f"{config.host}:{config.port}" - use_https = parsed_url.scheme == "https" - else: - hosts = f"https://{config.host}:{config.port}" - use_https = False - - client_config = { - "hosts": [hosts], - "basic_auth": (config.username, config.password), - "request_timeout": config.request_timeout, - "retry_on_timeout": config.retry_on_timeout, - "max_retries": config.max_retries, - } - - # Only add SSL settings if using HTTPS - if use_https: - client_config["verify_certs"] = config.verify_certs - if config.ca_certs: - client_config["ca_certs"] = config.ca_certs - - client = Elasticsearch(**client_config) - - # Test connection - if not client.ping(): - raise ConnectionError("Failed to connect to Elasticsearch") - - except requests.ConnectionError as e: - raise ConnectionError(f"Vector database connection error: {str(e)}") - except Exception as e: - raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") - - return client - - def _get_version(self) -> str: - info = self._client.info() - return cast(str, info["version"]["number"]) - - def _check_version(self): - if parse_version(self._version) < parse_version("8.0.0"): - raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") - - def get_type(self) -> str: - return "elasticsearch" - - def add_chunks(self, chunks: list[DocumentChunk], **kwargs): - # 实现 Elasticsearch 保存向量 - texts = [chunk.page_content for chunk in chunks] - embeddings = self.embeddings.embed_documents(list(texts)) - self.create(chunks, embeddings, **kwargs) - - def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): - metadatas = [chunk.metadata if chunk.metadata is not None else {} for chunk in chunks] - if not self._client.indices.exists(index=self._collection_name): - self.create_collection(embeddings, metadatas) - self.add_texts(chunks, embeddings, **kwargs) - - def add_texts(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): - uuids = self._get_uuids(chunks) - actions = [] - for i, chunk in enumerate(chunks): - action = { - "_index": self._collection_name, - "_source": { - Field.CONTENT_KEY.value: chunk.page_content, - Field.METADATA_KEY.value: chunk.metadata or {}, - Field.VECTOR.value: embeddings[i] or None - } - } - actions.append(action) - # using bulk mode - result = helpers.bulk(self._client, actions) - logger.info(f"add_texts result:{result}") - return uuids - - def text_exists(self, id: str) -> bool: - if not self._client.indices.exists(index=self._collection_name): - return False - result = self._client.search( - index=self._collection_name, - from_=0, - size=5, - query={ - "bool": { - "must": { - "match": { - Field.DOC_ID.value: id - } - } - } - }, - ) - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - count = result["hits"]["total"]["value"] - if count == 0: - return False - - return True - - def delete_by_ids(self, ids: list[str]): - if not ids: - return - if not self._client.indices.exists(index=self._collection_name): - logger.warning(f"Index {self._collection_name} does not exist") - return - - # Obtaining All Actual ES _id,not metadata.doc_id - actual_ids = [] - - for doc_id in ids: - es_ids = self.get_ids_by_metadata_field('doc_id', doc_id) - if es_ids: - actual_ids.extend(es_ids) - else: - logger.warning(f"Document with metadata doc_id {doc_id} not found for deletion") - - if actual_ids: - actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids] - try: - helpers.bulk(self._client, actions) - except BulkIndexError as e: - for error in e.errors: - delete_error = error.get('delete', {}) - status = delete_error.get('status') - doc_id = delete_error.get('_id') - - if status == 404: - logger.warning(f"Document not found for deletion: {doc_id}") - else: - logger.error(f"Error deleting document: {error}") - - def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} - response = self._client.search(index=self._collection_name, body=query, size=10000) - if response['hits']['hits']: - return [hit['_id'] for hit in response['hits']['hits']] - else: - return None - - def delete_by_metadata_field(self, key: str, value: str): - if not self._client.indices.exists(index=self._collection_name): - return False - actual_ids = self.get_ids_by_metadata_field(key, value) - - if actual_ids: - actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids] - try: - helpers.bulk(self._client, actions) - except BulkIndexError as e: - for error in e.errors: - delete_error = error.get('delete', {}) - status = delete_error.get('status') - doc_id = delete_error.get('_id') - - if status == 404: - logger.warning(f"Document not found for deletion: {doc_id}") - else: - logger.error(f"Error deleting document: {error}") - - def delete(self): - if self._client.indices.exists(index=self._collection_name): - self._client.indices.delete(index=self._collection_name, ignore=[400, 404]) - - def search_by_segment(self, document_id: str | None = None, query: str | None = None, pagesize: int = 10, page: int = 1, asc: bool = True, **kwargs) -> tuple[int, list[DocumentChunk]]: # 返回 (total, results): - """ - Search documents by segment (pagination) with optional keyword query. - - Args: - document_id: If provided, filter results where `metadata.document_id` matches this value. - query: Optional keywords used to match chunk content. - pagesize: Number of documents per page. - page: 1-based page number. - **kwargs: Additional search parameters (e.g., indices). - - Returns: - List of DocumentChunk objects that match the query. - """ - indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3" - - # Calculate the start position for the current page - from_ = pagesize * (page-1) - - # Construct the query with optional keyword matching - query_str = { - "query": { - "bool": { - "must": [] - } - }, - "sort": [ - {Field.SORT_ID.value: "asc" if asc else "desc"} # Sort by the specified metadata field - ] - } - - if document_id: - query_str["query"]["bool"]["must"].append({ - "term": { - Field.DOCUMENT_ID.value: document_id # exact match document_id - } - }) - - if query: - query_str["query"]["bool"]["must"].append({ - "match": { - Field.CONTENT_KEY.value: { - "query": query, - "analyzer": "ik_max_word" # Use the same analyzer as in create_collection - } - } - }) - - # For simplicity, we use from/size here which has a limit (usually up to 10,000). - result = self._client.search( - index=indices, - from_=from_, # Only use from_ for the first page (simplified) - size=pagesize, - body=query_str, - ) - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - total = result["hits"]["total"]["value"] # Get total count - - docs_and_scores = [] - for res in result["hits"]["hits"]: - source = res["_source"] - page_content = source.get(Field.CONTENT_KEY.value) - # vector = source.get(Field.VECTOR.value) - vector = None - metadata = source.get(Field.METADATA_KEY.value, {}) - score = res["_score"] - docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score)) - - docs = [] - for doc, score in docs_and_scores: - if doc.metadata is not None: - doc.metadata["score"] = score - docs.append(doc) - - return total, docs - - def get_by_segment(self, doc_id: str, **kwargs) -> tuple[int, list[DocumentChunk]]: # 返回 (total, results): - """ - Search documents by segment with optional keyword query. - - Args: - doc_id: If provided, filter results where `metadata.doc_id` matches this value. - **kwargs: Additional search parameters (e.g., indices). - - Returns: - List of DocumentChunk objects that match the query. - """ - indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}} - result = self._client.search( - index=indices, - from_=0, # Only use from_ for the first page (simplified) - size=1, - body=query_str, - ) - # print(result) - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - total = result["hits"]["total"]["value"] # Get total count - - docs_and_scores = [] - for res in result["hits"]["hits"]: - source = res["_source"] - page_content = source.get(Field.CONTENT_KEY.value) - vector = source.get(Field.VECTOR.value) - metadata = source.get(Field.METADATA_KEY.value, {}) - score = res["_score"] - docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score)) - - docs = [] - for doc, score in docs_and_scores: - if doc.metadata is not None: - doc.metadata["score"] = score - docs.append(doc) - - return total, docs - - def update_by_segment(self, chunk: DocumentChunk, **kwargs) -> str: - """ - update documents by segment. - - Args: - doc_id: If provided, filter results where `metadata.doc_id` matches this value. - chunk: updated segment - **kwargs: Additional search parameters (e.g., indices). - - Returns: - updated count. - """ - indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - chunk.vector = self.embeddings.embed_query(chunk.page_content) - - body = { - "script": { - "source": """ - ctx._source.page_content = params.new_content; - ctx._source.vector = params.new_vector; - """, - "params": { - "new_content": chunk.page_content, - "new_vector": chunk.vector - } - }, - "query": { - "term": { - Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id - } - } - } - result = self._client.update_by_query( - index=indices, - body=body, - ) - # Remove debug printing and use logging instead - # print(result) - # print(f"Update successful, number of affected documents: {result['updated']}") - return result['updated'] - - def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str: - """ - Update the metadata.status field of all documents with the specified document_id - Args: - document_id: Document ID to be updated - status: The new state value to be set (0 或 1) - """ - indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - body = { - "script": { - "source": "ctx._source.metadata.status = params.new_status", - "params": { - "new_status": status - } - }, - "query": { - "term": { - Field.DOCUMENT_ID.value: document_id # exact match document_id - } - } - } - result = self._client.update_by_query( - index=indices, - body=body, - ) - # Remove debug printing and use logging instead - # print(result) - # print(f"Update successful, number of affected documents: {result['updated']}") - return result['updated'] - - def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: - """Search the nearest neighbors to a vector.""" - query_vector = self.embeddings.embed_query(query) - top_k = kwargs.get("top_k", 1024) - score_threshold = float(kwargs.get("score_threshold") or 0.3) - indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3" - file_names_filter = kwargs.get("file_names_filter") # ["doc1", "doc2", "doc3"] - - query_str: dict[str, Any] = { - "bool": { - "must": { - "script_score": { - "query": { - "match_all": {} - }, - "script": { - "source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0", - # The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1] - "params": {"query_vector": query_vector} - } - } - }, - "filter": { # Add the filter condition of status=1 - "term": { - "metadata.status": 1 - } - } - } - } - # If file_names_filter is passed in, merge the filtering conditions - if file_names_filter: - query_str = { - "bool": { - "must": { - "script_score": { - "query": { - "match_all": {} - }, - "script": { - "source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0", - # The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1] - "params": {"query_vector": query_vector} - } - } - }, - "filter": [ - { - "term": { - "metadata.status": 1 - } - }, - { - "terms": { - "metadata.file_name": file_names_filter # Additional file_name filtering - } - } - ], - } - } - - result = self._client.search( - index=indices, - from_=0, - size=top_k, - query=query_str - ) - # logger.info(result) - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - docs_and_scores = [] - for res in result["hits"]["hits"]: - source = res["_source"] - page_content = source.get(Field.CONTENT_KEY.value) - metadata = source.get(Field.METADATA_KEY.value, {}) - score = res["_score"] - score = score / 2 # Normalized [0-1] - docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score)) - - docs = [] - for doc, score in docs_and_scores: - # check score threshold - if score > score_threshold: - if doc.metadata is not None: - doc.metadata["score"] = score - docs.append(doc) - - return docs - - def search_by_full_text(self, query: str, **kwargs: Any) -> list[DocumentChunk]: - """Return docs using BM25F. - - Args: - query: Text to look up documents similar to. - k: Number of Documents to return. Defaults to 4. - - Returns: - List of Documents most similar to the query. - """ - top_k = kwargs.get("top_k", 1024) - score_threshold = float(kwargs.get("score_threshold") or 0.2) - indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3" - file_names_filter = kwargs.get("file_names_filter") # ["doc1", "doc2", "doc3"] - - # Basic Query(BM25) - query_str: dict[str, Any] = { - "bool": { - "must": { - "match": { - Field.CONTENT_KEY.value: { - "query": query, - "analyzer": "ik_max_word" # tokenizer - } - } - }, - "filter": { # Add the filter condition of status=1 - "term": { - "metadata.status": 1 - } - } - } - } - - # If file_names_filter is passed in, merge the filtering conditions - if file_names_filter: - query_str = { - "bool": { - "must": { - "match": { - Field.CONTENT_KEY.value: { - "query": query, - "analyzer": "ik_max_word" # tokenizer - } - } - }, - "filter": [ - { - "term": { - "metadata.status": 1 - } - }, - { - "terms": { - "metadata.file_name": file_names_filter # Additional file_name filtering - } - } - ], - } - } - - result = self._client.search( - index=indices, - from_=0, - size=top_k, - query=query_str, - ) - # logger.info(result) - - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") - - docs_and_scores = [] - max_score = result["hits"]["max_score"] or 1.0 # Get the maximum score. If it is None, use 1.0 - for res in result["hits"]["hits"]: - source = res["_source"] - page_content = source.get(Field.CONTENT_KEY.value) - metadata = source.get(Field.METADATA_KEY.value, {}) - # Normalize the score to the [0,1] interval - normalized_score = res["_score"] / max_score - docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score)) - - docs = [] - for doc, score in docs_and_scores: - # check score threshold - if score > score_threshold: - if doc.metadata is not None: - doc.metadata["score"] = score - docs.append(doc) - - return docs - - def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: - """ - Reorder the list of document blocks and return the top_k results most relevant to the query - Args: - query: query string - docs: List of document chunk to be rearranged - top_k: The number of top-level documents returned - - Returns: - Rearranged document chunk list (sorted in descending order of relevance) - - Raises: - ValueError: If the input document list is empty or top_k is invalid - """ - # parameter validation - if not docs: - raise ValueError("retrieval chunks be empty") - if top_k <= 0: - raise ValueError("top_k must be a positive integer") - try: - # Convert to LangChain Document object - documents = [ - Document( - page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute - metadata=doc.metadata or {} # Deal with possible None metadata - ) - for doc in docs - ] - - # Perform reordering (compress_documents will automatically handle relevance scores and indexing) - reranked_docs = list(self.reranker.compress_documents(documents, query)) - print(reranked_docs) - - # Sort in descending order based on relevance score - reranked_docs.sort( - key=lambda x: x.metadata.get("relevance_score", 0), - reverse=True - ) - # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"] - result = [] - for item in reranked_docs[:top_k]: - for doc in docs: - if doc.page_content == item.page_content: - doc.metadata["score"] = item.metadata["relevance_score"] - result.append(doc) - return result - except Exception as e: - raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e - - def create_collection( - self, - embeddings: list[list[float]], - metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, - ): - if not self._client.indices.exists(index=self._collection_name): - index_mapping = { - "mappings": { - "properties": { - Field.CONTENT_KEY.value: { - "type": "text", - "analyzer": "ik_max_word" # tokenizer - }, - Field.METADATA_KEY.value: { - "type": "object", - "properties": { - "doc_id": { - "type": "keyword" # Map doc_id to keyword type - }, - "file_id": { - "type": "keyword" - }, - "file_name": { - "type": "keyword" - }, - "file_created_at": { - "type": "date", # Store as date type - "format": "epoch_millis" # Specify a millisecond-level Unix timestamp - }, - "document_id": { - "type": "keyword" - }, - "knowledge_id": { - "type": "keyword" - }, - "sort_id": { - "type": "long" # sort field - }, - "status": { - "type": "integer" - } - } - }, - Field.VECTOR.value: { - "type": "dense_vector", - "dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency - "index": True, - "similarity": "cosine" - } - } - } - } - print(index_mapping) - self._client.indices.create(index=self._collection_name, body=index_mapping) - - -class ElasticSearchVectorFactory(ABC): - def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector: - collection_name = f"Vector_index_{knowledge.id}_Node" - - # Use regular Elasticsearch with config values - config_dict = { - "host": os.getenv("ELASTICSEARCH_HOST", "127.0.0.1"), - "port": os.getenv("ELASTICSEARCH_PORT", 9200), - "username": os.getenv("ELASTICSEARCH_USERNAME", "elastic"), - "password": os.getenv("ELASTICSEARCH_PASSWORD", "elastic"), - } - - # Common configuration - config_dict.update( - { - "ca_certs": str(os.getenv("ELASTICSEARCH_CA_CERTS")) if os.getenv("ELASTICSEARCH_CA_CERTS") else None, - "verify_certs": os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true", - "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)), - "retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true", - "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)), - } - ) - - if knowledge.embedding and knowledge.reranker: - return ElasticSearchVector( - index_name=collection_name, - config=ElasticSearchConfig(**config_dict), - embedding_config=knowledge.embedding.api_keys[0], - reranker_config=knowledge.reranker.api_keys[0] - ) - else: - if knowledge.embedding is None: - raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") - if knowledge.reranker is None: - raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}") - - diff --git a/app/core/rag/vdb/field.py b/app/core/rag/vdb/field.py deleted file mode 100644 index 86d39060..00000000 --- a/app/core/rag/vdb/field.py +++ /dev/null @@ -1,16 +0,0 @@ -from enum import StrEnum, auto - - -class Field(StrEnum): - CONTENT_KEY = "page_content" - METADATA_KEY = "metadata" - GROUP_KEY = "group_id" - VECTOR = auto() - # Sparse Vector aims to support full text search - SPARSE_VECTOR = auto() - TEXT_KEY = "text" - PRIMARY_KEY = "id" - DOC_ID = "metadata.doc_id" - DOCUMENT_ID = "metadata.document_id" - KNOWLEDGE_ID = "metadata.knowledge_id" - SORT_ID = "metadata.sort_id" diff --git a/app/core/rag/vdb/vector_base.py b/app/core/rag/vdb/vector_base.py deleted file mode 100644 index df3ac7d8..00000000 --- a/app/core/rag/vdb/vector_base.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any - -from app.core.rag.models.chunk import DocumentChunk - - -class BaseVector(ABC): - def __init__(self, collection_name: str): - self._collection_name = collection_name - - @abstractmethod - def get_type(self) -> str: - raise NotImplementedError - - @abstractmethod - def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): - raise NotImplementedError - - @abstractmethod - def add_texts(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): - raise NotImplementedError - - @abstractmethod - def text_exists(self, id: str) -> bool: - raise NotImplementedError - - @abstractmethod - def delete_by_ids(self, ids: list[str]): - raise NotImplementedError - - def get_ids_by_metadata_field(self, key: str, value: str): - raise NotImplementedError - - @abstractmethod - def delete_by_metadata_field(self, key: str, value: str): - raise NotImplementedError - - @abstractmethod - def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: - raise NotImplementedError - - @abstractmethod - def search_by_full_text(self, query: str, **kwargs: Any) -> list[DocumentChunk]: - raise NotImplementedError - - @abstractmethod - def delete(self): - raise NotImplementedError - - def _filter_duplicate_texts(self, chunks: list[DocumentChunk]) -> list[DocumentChunk]: - for chunk in chunks.copy(): - if chunk.metadata and "doc_id" in chunk.metadata: - doc_id = chunk.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - chunks.remove(chunk) - - return chunks - - def _get_uuids(self, chunks: list[DocumentChunk]) -> list[str]: - return [chunk.metadata["doc_id"] for chunk in chunks if chunk.metadata and "doc_id" in chunk.metadata] - - @property - def collection_name(self): - return self._collection_name diff --git a/app/core/rag_utils/README.md b/app/core/rag_utils/README.md deleted file mode 100644 index 7e3dc5e4..00000000 --- a/app/core/rag_utils/README.md +++ /dev/null @@ -1,116 +0,0 @@ -# RAG Chunk 分析工具 - -这个模块提供了对 RAG chunk 内容进行分析的工具函数,包括: - -## 功能模块 - -### 1. chunk_summary.py - Chunk 摘要生成 -- `generate_chunk_summary(chunks, max_chunks=10)`: 为给定的 chunk 列表生成简洁摘要 -- 使用 LLM 提取核心信息和关键要点 -- 摘要长度控制在 100-150 字 - -### 2. chunk_tags.py - 标签提取 -- `extract_chunk_tags(chunks, max_tags=10, max_chunks=10)`: 从 chunk 中提取关键标签 -- `extract_chunk_tags_with_frequency(chunks, max_tags=10)`: 提取标签并统计频率 -- 使用 LLM 识别核心概念和专业术语 -- 自动过滤无意义词汇 - -### 3. chunk_insight.py - 洞察分析 -- `generate_chunk_insight(chunks, max_chunks=15)`: 生成深度洞察报告 -- `classify_chunk_domain(chunk)`: 对 chunk 进行领域分类 -- `analyze_domain_distribution(chunks, max_chunks=20)`: 分析领域分布 -- 提供内容的主题、特点和价值分析 - -## 使用示例 - -```python -from app.core.rag_utils import ( - generate_chunk_summary, - extract_chunk_tags, - generate_chunk_insight -) - -# 示例 chunk 数据 -chunks = [ - "机器学习是人工智能的一个重要分支...", - "深度学习使用神经网络进行特征学习...", - # ... -] - -# 生成摘要 -summary = await generate_chunk_summary(chunks, max_chunks=10) -print(f"摘要: {summary}") - -# 提取标签 -tags = await extract_chunk_tags(chunks, max_tags=10) -print(f"标签: {tags}") - -# 生成洞察 -insight = await generate_chunk_insight(chunks, max_chunks=15) -print(f"洞察: {insight}") -``` - -## API 接口 - -在 `memory_dashboard_controller.py` 中提供了两个对外接口: - -### 1. GET /dashboard/chunk_summary_tag -获取 chunk 总结和提取的标签 - -**参数:** -- `end_user_id` (必填): 宿主ID -- `limit` (可选, 默认15): 返回的chunk数量 -- `max_tags` (可选, 默认10): 最大标签数量 - -**返回:** -```json -{ - "code": 200, - "msg": "chunk摘要和标签获取成功", - "data": { - "summary": "chunk内容的总结...", - "tags": [ - {"tag": "机器学习", "frequency": 5}, - {"tag": "深度学习", "frequency": 3} - ] - } -} -``` - -### 2. GET /dashboard/chunk_insight -获取 chunk 的洞察内容 - -**参数:** -- `end_user_id` (必填): 宿主ID -- `limit` (可选, 默认15): 返回的chunk数量 - -**返回:** -```json -{ - "code": 200, - "msg": "chunk洞察获取成功", - "data": { - "insight": "该知识库主要聚焦于技术领域(60%)..." - } -} -``` - -## 技术特点 - -1. **异步处理**: 所有函数都是异步的,支持高并发 -2. **LLM 驱动**: 使用大语言模型进行智能分析 -3. **可配置**: 支持自定义处理的 chunk 数量和标签数量 -4. **错误处理**: 完善的异常处理和日志记录 -5. **模块化设计**: 每个功能独立,易于维护和扩展 - -## 依赖 - -- `app.core.memory.utils.llm_utils`: LLM 客户端 -- `app.core.logging_config`: 日志配置 -- `pydantic`: 数据验证和结构化输出 - -## 注意事项 - -1. 所有函数都需要在异步上下文中调用(使用 `await`) -2. 处理大量 chunk 时建议设置合理的 `max_chunks` 参数以控制 token 消耗 -3. LLM 调用可能需要一定时间,建议在前端显示加载状态 diff --git a/app/core/rag_utils/__init__.py b/app/core/rag_utils/__init__.py deleted file mode 100644 index d5a8ce1c..00000000 --- a/app/core/rag_utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -RAG chunk analysis utilities. -""" - -from .chunk_summary import generate_chunk_summary -from .chunk_tags import extract_chunk_tags, extract_chunk_persona -from .chunk_insight import generate_chunk_insight - -__all__ = [ - "generate_chunk_summary", - "extract_chunk_tags", - "extract_chunk_persona", - "generate_chunk_insight", -] diff --git a/app/core/rag_utils/chunk_insight.py b/app/core/rag_utils/chunk_insight.py deleted file mode 100644 index 2c96160e..00000000 --- a/app/core/rag_utils/chunk_insight.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Generate insights from RAG chunks. - -This module provides functionality to analyze chunk content and generate insights using LLM. -""" - -import asyncio -from typing import List, Dict, Any -from collections import Counter -from pydantic import BaseModel, Field - -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.logging_config import get_business_logger - -business_logger = get_business_logger() - - -class ChunkInsight(BaseModel): - """Pydantic model for chunk insight.""" - insight: str = Field(..., description="对chunk内容的深度洞察分析") - - -class DomainClassification(BaseModel): - """Pydantic model for domain classification.""" - domain: str = Field( - ..., - description="内容所属的领域分类", - examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"] - ) - - -async def classify_chunk_domain(chunk: str) -> str: - """ - Classify a chunk into a specific domain. - - Args: - chunk: Chunk content string - - Returns: - Domain name - """ - try: - llm_client = get_llm_client() - - prompt = f"""请将以下文本内容归类到最合适的领域中。 - -可选领域及其关键词: -- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等 -- 商业:市场、销售、管理、财务、投资、创业、营销、战略等 -- 教育:学习、课程、培训、教学、知识、技能、考试、研究等 -- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等 -- 娱乐:游戏、电影、音乐、体育、艺术、文化等 -- 健康:医疗、养生、运动、心理、保健、疾病等 -- 其他:无法归入以上类别的内容 - -文本内容: {chunk[:500]}... - -请直接返回最合适的领域名称。""" - - messages = [ - {"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"}, - {"role": "user", "content": prompt} - ] - - classification = await llm_client.response_structured( - messages=messages, - response_model=DomainClassification - ) - - return classification.domain if classification else "其他" - - except Exception as e: - business_logger.error(f"分类chunk领域失败: {str(e)}") - return "其他" - - -async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]: - """ - Analyze the domain distribution of chunks. - - Args: - chunks: List of chunk content strings - max_chunks: Maximum number of chunks to analyze - - Returns: - Dictionary of domain -> percentage - """ - if not chunks: - return {} - - try: - # 限制分析的chunk数量 - chunks_to_analyze = chunks[:max_chunks] - - # 为每个chunk分类 - domain_counts = Counter() - for chunk in chunks_to_analyze: - domain = await classify_chunk_domain(chunk) - domain_counts[domain] += 1 - - # 计算百分比 - total = sum(domain_counts.values()) - domain_distribution = { - domain: count / total - for domain, count in domain_counts.items() - } - - # 按百分比降序排序 - return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True)) - - except Exception as e: - business_logger.error(f"分析领域分布失败: {str(e)}") - return {} - - -async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str: - """ - Generate insights from the given chunks. - - Args: - chunks: List of chunk content strings - max_chunks: Maximum number of chunks to analyze - - Returns: - A comprehensive insight report - """ - if not chunks: - business_logger.warning("没有提供chunk内容用于生成洞察") - return "暂无足够数据生成洞察报告" - - try: - # 1. 分析领域分布 - domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks) - - # 2. 统计基本信息 - total_chunks = len(chunks) - avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0 - - # 3. 构建洞察prompt - prompt_parts = [] - - if domain_dist: - top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]]) - prompt_parts.append(f"- 内容领域分布: {top_domains}") - - prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}字") - - # 添加部分chunk内容作为参考 - sample_chunks = chunks[:5] - sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)]) - prompt_parts.append(f"\n内容示例:\n{sample_content}") - - system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。 - -重要规则: -1. 报告需要将所有要点流畅地串联成一个段落 -2. 语言风格要专业、客观,同时易于理解 -3. 不要添加任何额外的解释或标题,直接输出报告内容 -4. 基于提供的数据和示例内容进行分析,不要编造信息 -5. 重点关注内容的主题、特点和价值 -6. 报告长度控制在150-200字 - -例如,如果输入是: -- 内容领域分布: 技术(60%), 商业(25%), 教育(15%) -- 内容规模: 共50个知识片段,平均长度320字 -内容示例: [示例内容...] - -你的输出应该类似: -"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段,平均每个片段约320字,内容详实。从示例来看,内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。" -""" - - user_prompt = "\n".join(prompt_parts) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} - ] - - # 调用LLM生成洞察 - llm_client = get_llm_client() - response = await llm_client.chat(messages=messages) - - insight = response.content.strip() - business_logger.info(f"成功生成chunk洞察,分析了 {min(len(chunks), max_chunks)} 个片段") - - return insight - - except Exception as e: - business_logger.error(f"生成chunk洞察失败: {str(e)}") - return "洞察生成失败" - - -if __name__ == "__main__": - # 测试代码 - test_chunks = [ - "Python是一种高级编程语言,以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。", - "机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。", - "深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。", - "自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。", - "数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。" - ] - - print("开始生成chunk洞察...") - insight = asyncio.run(generate_chunk_insight(test_chunks)) - print(f"\n生成的洞察:\n{insight}") diff --git a/app/core/rag_utils/chunk_summary.py b/app/core/rag_utils/chunk_summary.py deleted file mode 100644 index 971d6907..00000000 --- a/app/core/rag_utils/chunk_summary.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -Generate summary for RAG chunks. - -This module provides functionality to summarize chunk content using LLM. -""" - -import asyncio -from typing import List, Dict, Any -from pydantic import BaseModel, Field - -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.logging_config import get_business_logger - -business_logger = get_business_logger() - - -class ChunkSummary(BaseModel): - """Pydantic model for chunk summary.""" - summary: str = Field(..., description="简洁的chunk内容摘要") - - -async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str: - """ - Generate a summary for the given chunks. - - Args: - chunks: List of chunk content strings - max_chunks: Maximum number of chunks to process (default: 10) - - Returns: - A concise summary of the chunks - """ - if not chunks: - business_logger.warning("没有提供chunk内容用于生成摘要") - return "暂无内容" - - try: - # 限制处理的chunk数量,避免token过多 - chunks_to_process = chunks[:max_chunks] - - # 合并chunk内容 - combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]) - - # 构建prompt - system_prompt = ( - "你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n" - "- 摘要长度控制在100-150字;\n" - "- 提取核心信息和关键要点;\n" - "- 使用客观、清晰的语言;\n" - "- 避免冗余和重复;\n" - "- 如果内容涉及多个主题,按重要性排序呈现。" - ) - - user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}" - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - # 调用LLM生成摘要 - llm_client = get_llm_client() - response = await llm_client.chat(messages=messages) - - summary = response.content.strip() - business_logger.info(f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段") - - return summary - - except Exception as e: - business_logger.error(f"生成chunk摘要失败: {str(e)}") - return "摘要生成失败" - - -async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]: - """ - Generate summaries for multiple chunk lists in batch. - - Args: - chunks_list: List of chunk lists - - Returns: - List of summaries - """ - tasks = [generate_chunk_summary(chunks) for chunks in chunks_list] - return await asyncio.gather(*tasks) - - -if __name__ == "__main__": - # 测试代码 - test_chunks = [ - "这是第一段测试内容,讲述了关于机器学习的基础知识。", - "第二段内容介绍了深度学习的应用场景和发展历史。", - "第三段讨论了自然语言处理技术的最新进展。" - ] - - print("开始生成chunk摘要...") - summary = asyncio.run(generate_chunk_summary(test_chunks)) - print(f"\n生成的摘要:\n{summary}") diff --git a/app/core/rag_utils/chunk_tags.py b/app/core/rag_utils/chunk_tags.py deleted file mode 100644 index 5d633be9..00000000 --- a/app/core/rag_utils/chunk_tags.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Extract tags from RAG chunks. - -This module provides functionality to extract meaningful tags from chunk content using LLM. -""" - -import asyncio -from collections import Counter -from typing import List, Tuple -from pydantic import BaseModel, Field - -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.logging_config import get_business_logger - -business_logger = get_business_logger() - - -class ExtractedTags(BaseModel): - """Pydantic model for extracted tags.""" - tags: List[str] = Field(..., description="从文本中提取的关键标签列表") - - -class ExtractedPersona(BaseModel): - """Pydantic model for extracted persona.""" - personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等") - - -async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]: - """ - Extract meaningful tags from the given chunks. - - Args: - chunks: List of chunk content strings - max_tags: Maximum number of tags to return (default: 10) - max_chunks: Maximum number of chunks to process (default: 10) - - Returns: - List of tuples (tag, frequency), sorted by frequency in descending order - """ - if not chunks: - business_logger.warning("没有提供chunk内容用于提取标签") - return [] - - try: - # 限制处理的chunk数量 - chunks_to_process = chunks[:max_chunks] - - # 构建prompt - system_prompt = ( - "你是一位专业的文本分析专家,擅长从文本中提取关键标签。请遵循以下规则:\n\n" - "1. **提取核心概念**: 识别文本中最重要的名词、专业术语、主题词;\n" - "2. **过滤无意义词**: 排除过于宽泛的词(如'内容'、'信息'、'数据');\n" - "3. **保持具体性**: 优先选择具体的、有代表性的词语;\n" - "4. **标签数量**: 提取5-15个最具代表性的标签;\n" - "5. **去重合并**: 语义相近的标签只保留一个最核心的。\n\n" - "标签应该是名词或名词短语,能够准确概括文本的核心内容。" - ) - - llm_client = get_llm_client() - - # 为每个chunk单独提取标签,然后统计频率 - all_tags = [] - for chunk in chunks_to_process: - single_chunk_prompt = f"请从以下文本中提取关键标签:\n\n{chunk}" - single_messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": single_chunk_prompt}, - ] - - try: - single_response = await llm_client.response_structured( - messages=single_messages, - response_model=ExtractedTags - ) - all_tags.extend(single_response.tags) - except Exception as e: - business_logger.warning(f"处理单个chunk时出错: {str(e)}") - continue - - # 统计标签频率 - tag_counter = Counter(all_tags) - - # 获取最常见的标签,限制数量 - most_common_tags = tag_counter.most_common(max_tags) - - business_logger.info(f"成功提取 {len(most_common_tags)} 个标签,处理了 {len(chunks_to_process)} 个片段") - - return most_common_tags - - except Exception as e: - business_logger.error(f"提取chunk标签失败: {str(e)}") - return [] - - -async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 10) -> List[Tuple[str, int]]: - """ - Extract tags with actual frequency calculation across all chunks. - - This is an alias for extract_chunk_tags for backward compatibility. - - Args: - chunks: List of chunk content strings - max_tags: Maximum number of tags to return - - Returns: - List of tuples (tag, frequency), sorted by frequency - """ - return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks)) - - -async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]: - """ - Extract persona (人物形象) from the given chunks. - - Args: - chunks: List of chunk content strings - max_personas: Maximum number of personas to return (default: 5) - max_chunks: Maximum number of chunks to process (default: 20) - - Returns: - List of persona strings like "产品设计师", "旅行爱好者", "摄影发烧友" - """ - if not chunks: - business_logger.warning("没有提供chunk内容用于提取人物形象") - return [] - - try: - # 限制处理的chunk数量 - chunks_to_process = chunks[:max_chunks] - - # 合并chunk内容 - combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]) - - # 构建prompt - system_prompt = ( - "你是一位专业的人物画像分析专家,擅长从文本中提取人物形象标签。请遵循以下规则:\n\n" - "1. **职业身份**: 识别职业、专业领域(如'产品设计师'、'软件工程师'、'创业者');\n" - "2. **兴趣爱好**: 提取核心兴趣和爱好(如'旅行爱好者'、'摄影发烧友'、'咖啡控');\n" - "3. **生活方式**: 概括生活态度和习惯(如'极简主义者'、'户外探险家'、'阅读爱好者');\n" - "4. **个性特征**: 提炼显著的性格特点(如'思考者'、'行动派'、'完美主义者');\n" - "5. **数量控制**: 提取3-8个最具代表性的人物形象标签;\n" - "6. **简洁明确**: 每个标签应该是简短的名词或名词短语(2-6个字)。\n\n" - "人物形象标签应该能够准确刻画这个人的核心特征和身份定位。" - ) - - user_prompt = f"请从以下文本中提取人物形象标签:\n\n{combined_content}" - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - # 调用LLM提取人物形象 - llm_client = get_llm_client() - structured_response = await llm_client.response_structured( - messages=messages, - response_model=ExtractedPersona - ) - - # 去重并限制数量 - personas = list(dict.fromkeys(structured_response.personas))[:max_personas] - - business_logger.info(f"成功提取 {len(personas)} 个人物形象,处理了 {len(chunks_to_process)} 个片段") - - return personas - - except Exception as e: - business_logger.error(f"提取人物形象失败: {str(e)}") - return [] - - -if __name__ == "__main__": - # 测试代码 - test_chunks = [ - "我是一名产品设计师,平时喜欢旅行和摄影。周末经常去户外徒步,探索新的风景。", - "最近在学习咖啡拉花,已经能做出简单的图案了。每天早上都会给自己冲一杯手冲咖啡。", - "喜欢阅读各类书籍,尤其是设计和心理学相关的。记录生活是我的习惯,用镜头捕捉美好瞬间。" - ] - - print("开始提取chunk标签...") - tags = asyncio.run(extract_chunk_tags(test_chunks)) - print(f"\n提取的标签:") - for tag, freq in tags: - print(f"- {tag} (频率: {freq})") - - print("\n" + "="*50) - print("开始提取人物形象...") - personas = asyncio.run(extract_chunk_persona(test_chunks)) - print(f"\n提取的人物形象:") - for persona in personas: - print(f"- {persona}") diff --git a/app/core/response_utils.py b/app/core/response_utils.py deleted file mode 100644 index 127dcee4..00000000 --- a/app/core/response_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -import time -from typing import Any, Optional - - -def success(data: Optional[Any] = None, msg: str = "OK") -> dict: - return { - "code": 0, - "msg": msg, - "data": data if data is not None else {}, - "error": "", - "time": int(time.time() * 1000), - } - - -def fail(code: int, msg: str, error: str = "", data: Optional[Any] = None) -> dict: - return { - "code": code, - "msg": msg, - "data": data if data is not None else {}, - "error": error, - "time": int(time.time() * 1000), - } \ No newline at end of file diff --git a/app/core/security.py b/app/core/security.py deleted file mode 100644 index 19ac7c40..00000000 --- a/app/core/security.py +++ /dev/null @@ -1,126 +0,0 @@ -from datetime import datetime, timedelta, timezone -from typing import Any, Union, Optional -import uuid - -from jose import jwt, JWTError -from passlib.context import CryptContext - -from app.core.config import settings - -pwd_context = CryptContext(schemes=["scrypt"], deprecated="auto") - - -def create_access_token(subject: Union[str, Any], expires_delta: timedelta = None) -> tuple[str, str]: - """创建访问token - - Args: - subject: token主体(通常是用户名) - expires_delta: 过期时间间隔 - - Returns: - tuple: (token字符串, token_id) - """ - now = datetime.now(timezone.utc) - if expires_delta: - expire = now + expires_delta - else: - expire = now + timedelta( - minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES - ) - - token_id = str(uuid.uuid4()) - to_encode = { - "exp": expire, - "iat": now, # Issued at time - "sub": str(subject), - "type": "access", - "jti": token_id # JWT ID - } - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) - return encoded_jwt, token_id - - -def create_refresh_token(subject: Union[str, Any], expires_delta: timedelta = None) -> tuple[str, str]: - """创建刷新token - - Args: - subject: token主体(通常是用户名) - expires_delta: 过期时间间隔 - - Returns: - tuple: (token字符串, token_id) - """ - now = datetime.now(timezone.utc) - if expires_delta: - expire = now + expires_delta - else: - expire = now + timedelta( - days=settings.REFRESH_TOKEN_EXPIRE_DAYS - ) - - token_id = str(uuid.uuid4()) - to_encode = { - "exp": expire, - "iat": now, # Issued at time - "sub": str(subject), - "type": "refresh", - "jti": token_id # JWT ID - } - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) - return encoded_jwt, token_id - - -def verify_token(token: str, token_type: str = "access") -> Union[str, None]: - try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) - userId: str = payload.get("sub") - token_type_in_payload: str = payload.get("type") - - if userId is None or token_type_in_payload != token_type: - return None - return userId - except JWTError: - return None - - -def get_token_id(token: str) -> Optional[str]: - """从token中提取token ID - - Args: - token: JWT token字符串 - - Returns: - token ID或None - """ - try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) - return payload.get("jti") - except JWTError: - return None - - -def get_token_expiry(token: str) -> Optional[datetime]: - """从token中提取过期时间 - - Args: - token: JWT token字符串 - - Returns: - 过期时间或None - """ - try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) - exp_timestamp = payload.get("exp") - if exp_timestamp: - return datetime.fromtimestamp(exp_timestamp, tz=timezone.utc) - return None - except JWTError: - return None - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) - - -def get_password_hash(password: str) -> str: - return pwd_context.hash(password) diff --git a/app/core/sensitive_filter.py b/app/core/sensitive_filter.py deleted file mode 100644 index 9348325d..00000000 --- a/app/core/sensitive_filter.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -敏感信息过滤器 -用于在日志和异常消息中过滤敏感数据 -""" -import re -from typing import Any, Dict, List, Set, Union - - -class SensitiveDataFilter: - """敏感数据过滤器""" - - # 是否启用过滤(从配置读取) - _enabled: bool = None - - @classmethod - def is_enabled(cls) -> bool: - """检查过滤器是否启用""" - if cls._enabled is None: - from app.core.config import settings - cls._enabled = settings.ENABLE_SENSITIVE_DATA_FILTER - return cls._enabled - - # 敏感字段关键词(不区分大小写) - SENSITIVE_KEYS: Set[str] = { - "password", - "passwd", - "pwd", - "token", - "access_token", - "refresh_token", - "token_id", - "secret", - "api_key", - "apikey", - "authorization", - "auth", - "private_key", - "secret_key", - "session_id", - "sessionid", - "csrf_token", - "credit_card", - "card_number", - "cvv", - "ssn", - } - - # 敏感数据的正则模式 - SENSITIVE_PATTERNS: List[tuple] = [ - # Email地址 - (re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'), "[EMAIL]"), - # 手机号(中国11位) - (re.compile(r'\b1[3-9]\d{9}\b'), "[PHONE]"), - # 信用卡号(15-19位数字) - (re.compile(r'\b\d{15,19}\b'), "[CARD]"), - # JWT Token (格式: xxx.yyy.zzz) - 必须以eyJ开头,包含至少两个点 - (re.compile(r'\beyJ[A-Za-z0-9_-]+\.eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+'), "[TOKEN]"), - # JWT Token 部分匹配(只有header和payload,没有signature) - (re.compile(r'\beyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+(?:\.[A-Za-z0-9_-]*)?'), "[TOKEN]"), - # UUID格式的token或ID - (re.compile(r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b', re.IGNORECASE), "[UUID]"), - # API密钥格式(32位以上的字母数字组合) - (re.compile(r'\b[A-Za-z0-9]{32,}\b'), "[API_KEY]"), - ] - - # 替换文本 - REDACTED_TEXT = "***REDACTED***" - - @classmethod - def filter_dict(cls, data: Dict[str, Any], deep: bool = True) -> Dict[str, Any]: - """ - 过滤字典中的敏感数据 - - Args: - data: 要过滤的字典 - deep: 是否深度过滤嵌套字典 - - Returns: - 过滤后的字典副本 - """ - if not cls.is_enabled() or not isinstance(data, dict): - return data - - filtered = {} - for key, value in data.items(): - # 检查键名是否为敏感字段 - if cls._is_sensitive_key(key): - filtered[key] = cls.REDACTED_TEXT - elif isinstance(value, dict) and deep: - # 递归过滤嵌套字典 - filtered[key] = cls.filter_dict(value, deep=True) - elif isinstance(value, list) and deep: - # 过滤列表中的字典 - filtered[key] = [ - cls.filter_dict(item, deep=True) if isinstance(item, dict) else item - for item in value - ] - elif isinstance(value, str): - # 过滤字符串中的敏感模式 - filtered[key] = cls.filter_string(value) - else: - filtered[key] = value - - return filtered - - @classmethod - def filter_string(cls, text: str) -> str: - """ - 过滤字符串中的敏感数据 - - Args: - text: 要过滤的字符串 - - Returns: - 过滤后的字符串 - """ - if not cls.is_enabled() or not isinstance(text, str): - return text - - filtered_text = text - for pattern, replacement in cls.SENSITIVE_PATTERNS: - filtered_text = pattern.sub(replacement, filtered_text) - - return filtered_text - - @classmethod - def filter_message(cls, message: str, context: Dict[str, Any] = None) -> tuple: - """ - 过滤异常消息和上下文 - - Args: - message: 异常消息 - context: 异常上下文字典 - - Returns: - (过滤后的消息, 过滤后的上下文) - """ - filtered_message = cls.filter_string(message) - filtered_context = cls.filter_dict(context) if context else {} - - return filtered_message, filtered_context - - @classmethod - def filter_log_record(cls, record: Dict[str, Any]) -> Dict[str, Any]: - """ - 过滤日志记录 - - Args: - record: 日志记录字典 - - Returns: - 过滤后的日志记录 - """ - filtered = record.copy() - - # 过滤消息 - if "message" in filtered: - filtered["message"] = cls.filter_string(str(filtered["message"])) - - # 过滤额外字段 - if "extra" in filtered and isinstance(filtered["extra"], dict): - filtered["extra"] = cls.filter_dict(filtered["extra"]) - - # 过滤异常信息 - if "exc_info" in filtered and filtered["exc_info"]: - # 不过滤堆栈跟踪,但过滤异常消息 - pass - - return filtered - - @classmethod - def _is_sensitive_key(cls, key: str) -> bool: - """ - 检查键名是否为敏感字段 - - Args: - key: 字段名 - - Returns: - 是否为敏感字段 - """ - key_lower = key.lower() - return any(sensitive_key in key_lower for sensitive_key in cls.SENSITIVE_KEYS) - - @classmethod - def sanitize_for_display(cls, value: Any, max_length: int = 100) -> str: - """ - 清理数据用于显示(用于日志或错误消息) - - Args: - value: 要清理的值 - max_length: 最大长度 - - Returns: - 清理后的字符串 - """ - if value is None: - return "None" - - # 转换为字符串 - str_value = str(value) - - # 过滤敏感信息 - filtered_value = cls.filter_string(str_value) - - # 截断过长的内容 - if len(filtered_value) > max_length: - filtered_value = filtered_value[:max_length] + "..." - - return filtered_value diff --git a/app/core/share_utils.py b/app/core/share_utils.py deleted file mode 100644 index 684f9877..00000000 --- a/app/core/share_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import secrets -import string -import hashlib - - -def generate_share_token(length: int = 16) -> str: - """生成唯一的分享 token - - Args: - length: token 长度,默认 16 - - Returns: - 随机字符串,包含大小写字母和数字 - """ - alphabet = string.ascii_letters + string.digits - return ''.join(secrets.choice(alphabet) for _ in range(length)) - - -def hash_password(password: str) -> str: - """加密密码 - - Args: - password: 明文密码 - - Returns: - 密码哈希(使用 SHA-256) - """ - # 使用 SHA-256 + salt - salt = secrets.token_hex(16) - pwd_hash = hashlib.sha256((password + salt).encode()).hexdigest() - return f"{salt}${pwd_hash}" - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """验证密码 - - Args: - plain_password: 明文密码 - hashed_password: 密码哈希 - - Returns: - 是否匹配 - """ - try: - salt, pwd_hash = hashed_password.split('$') - computed_hash = hashlib.sha256((plain_password + salt).encode()).hexdigest() - return computed_hash == pwd_hash - except (ValueError, AttributeError): - return False - - -def build_share_url(share_token: str, base_url: str = None) -> str: - """构建分享 URL - - Args: - share_token: 分享 token - base_url: 基础 URL,如果为 None 则使用相对路径 - - Returns: - 完整的分享 URL - """ - if base_url: - return f"{base_url.rstrip('/')}/public/share/{share_token}" - return f"/public/share/{share_token}" - - -def generate_embed_code(share_token: str, width: str = "100%", height: str = "600px", base_url: str = None) -> dict: - """生成嵌入代码 - - Args: - share_token: 分享 token - width: iframe 宽度 - height: iframe 高度 - base_url: 基础 URL - - Returns: - 包含 iframe_code 和 preview_url 的字典 - """ - preview_url = build_share_url(share_token, base_url) - - iframe_code = f'''<iframe - src="{preview_url}" - width="{width}" - height="{height}" - frameborder="0" - allowfullscreen> -</iframe>''' - - return { - "iframe_code": iframe_code, - "preview_url": preview_url, - "width": width, - "height": height - } diff --git a/app/core/storage_strategy.py b/app/core/storage_strategy.py deleted file mode 100644 index 337108af..00000000 --- a/app/core/storage_strategy.py +++ /dev/null @@ -1,198 +0,0 @@ -""" -Storage strategy interface and concrete implementations for file upload system. -""" -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Dict, Any -import uuid - -from app.core.upload_enums import UploadContext -from app.core.upload_policies import UploadPolicy, get_upload_policy -from app.core.config import settings - - -class StorageStrategy(ABC): - """Abstract base class for storage strategies.""" - - @abstractmethod - def get_storage_path( - self, - tenant_id: uuid.UUID, - file_id: uuid.UUID, - file_extension: str, - metadata: Dict[str, Any] - ) -> Path: - """ - Generate the storage path for a file. - - Args: - tenant_id: The tenant ID - file_id: The unique file ID - file_extension: The file extension (e.g., ".jpg") - metadata: Additional metadata that may influence path generation - - Returns: - Path object representing the file storage location - """ - pass - - @abstractmethod - def get_upload_policy(self) -> UploadPolicy: - """ - Get the upload policy for this storage strategy. - - Returns: - UploadPolicy object with constraints and rules - """ - pass - - -class AvatarStorageStrategy(StorageStrategy): - """Storage strategy for user avatar files.""" - - def get_storage_path( - self, - tenant_id: uuid.UUID, - file_id: uuid.UUID, - file_extension: str, - metadata: Dict[str, Any] - ) -> Path: - """ - Generate storage path for avatar files. - Path format: {GENERIC_FILE_PATH}/avatars/{tenant_id}/{file_id}{extension} - """ - base_path = Path(settings.GENERIC_FILE_PATH) - return base_path / "avatars" / str(tenant_id) / f"{file_id}{file_extension}" - - def get_upload_policy(self) -> UploadPolicy: - """Get upload policy for avatar context.""" - return get_upload_policy(UploadContext.AVATAR) - - -class AppIconStorageStrategy(StorageStrategy): - """Storage strategy for application icon files.""" - - def get_storage_path( - self, - tenant_id: uuid.UUID, - file_id: uuid.UUID, - file_extension: str, - metadata: Dict[str, Any] - ) -> Path: - """ - Generate storage path for app icon files. - Path format: {GENERIC_FILE_PATH}/app_icons/{tenant_id}/{file_id}{extension} - """ - base_path = Path(settings.GENERIC_FILE_PATH) - return base_path / "app_icons" / str(tenant_id) / f"{file_id}{file_extension}" - - def get_upload_policy(self) -> UploadPolicy: - """Get upload policy for app_icon context.""" - return get_upload_policy(UploadContext.APP_ICON) - - -class KnowledgeBaseStorageStrategy(StorageStrategy): - """Storage strategy for knowledge base files.""" - - def get_storage_path( - self, - tenant_id: uuid.UUID, - file_id: uuid.UUID, - file_extension: str, - metadata: Dict[str, Any] - ) -> Path: - """ - Generate storage path for knowledge base files. - Path format: {GENERIC_FILE_PATH}/knowledge_base/{tenant_id}/{kb_id}/{file_id}{extension} - - If kb_id is provided in metadata, it will be included in the path for compatibility - with existing knowledge base file structure. - """ - base_path = Path(settings.GENERIC_FILE_PATH) - kb_id = metadata.get("kb_id") - - if kb_id: - # Include kb_id in path for compatibility with existing structure - return base_path / "knowledge_base" / str(tenant_id) / str(kb_id) / f"{file_id}{file_extension}" - else: - # Default path without kb_id - return base_path / "knowledge_base" / str(tenant_id) / f"{file_id}{file_extension}" - - def get_upload_policy(self) -> UploadPolicy: - """Get upload policy for knowledge_base context.""" - return get_upload_policy(UploadContext.KNOWLEDGE_BASE) - - -class TempStorageStrategy(StorageStrategy): - """Storage strategy for temporary files.""" - - def get_storage_path( - self, - tenant_id: uuid.UUID, - file_id: uuid.UUID, - file_extension: str, - metadata: Dict[str, Any] - ) -> Path: - """ - Generate storage path for temporary files. - Path format: {GENERIC_FILE_PATH}/temp/{tenant_id}/{file_id}{extension} - """ - base_path = Path(settings.GENERIC_FILE_PATH) - return base_path / "temp" / str(tenant_id) / f"{file_id}{file_extension}" - - def get_upload_policy(self) -> UploadPolicy: - """Get upload policy for temp context.""" - return get_upload_policy(UploadContext.TEMP) - - -class AttachmentStorageStrategy(StorageStrategy): - """Storage strategy for attachment files.""" - - def get_storage_path( - self, - tenant_id: uuid.UUID, - file_id: uuid.UUID, - file_extension: str, - metadata: Dict[str, Any] - ) -> Path: - """ - Generate storage path for attachment files. - Path format: {GENERIC_FILE_PATH}/attachments/{tenant_id}/{file_id}{extension} - """ - base_path = Path(settings.GENERIC_FILE_PATH) - return base_path / "attachments" / str(tenant_id) / f"{file_id}{file_extension}" - - def get_upload_policy(self) -> UploadPolicy: - """Get upload policy for attachment context.""" - return get_upload_policy(UploadContext.ATTACHMENT) - - -class StrategyFactory: - """Factory class for creating storage strategies based on upload context.""" - - _strategies = { - UploadContext.AVATAR: AvatarStorageStrategy, - UploadContext.APP_ICON: AppIconStorageStrategy, - UploadContext.KNOWLEDGE_BASE: KnowledgeBaseStorageStrategy, - UploadContext.TEMP: TempStorageStrategy, - UploadContext.ATTACHMENT: AttachmentStorageStrategy, - } - - @classmethod - def get_strategy(cls, context: UploadContext) -> StorageStrategy: - """ - Get the appropriate storage strategy for the given context. - - Args: - context: The upload context - - Returns: - An instance of the appropriate StorageStrategy - - Raises: - ValueError: If no strategy is defined for the given context - """ - strategy_class = cls._strategies.get(context) - if strategy_class is None: - raise ValueError(f"No storage strategy defined for context: {context}") - return strategy_class() diff --git a/app/core/transaction_monitor.py b/app/core/transaction_monitor.py deleted file mode 100644 index f6ec3a8a..00000000 --- a/app/core/transaction_monitor.py +++ /dev/null @@ -1,230 +0,0 @@ -# app/core/transaction_monitor.py -""" -事务监控模块 - -提供事务持续时间监控、长事务检测和告警功能。 -""" - -import time -import threading -from typing import Optional, Callable, Dict, Any -from contextlib import contextmanager -from app.core.logging_config import get_logger - -logger = get_logger(__name__) - - -class TransactionMonitor: - """ - 事务监控器 - - 功能: - - 监控事务持续时间 - - 检测长事务 - - 记录事务统计信息 - - 发出长事务告警 - """ - - # 默认长事务阈值(秒) - DEFAULT_LONG_TRANSACTION_THRESHOLD = 5.0 - - # 警告阈值(秒) - DEFAULT_WARNING_THRESHOLD = 2.0 - - def __init__( - self, - long_transaction_threshold: float = DEFAULT_LONG_TRANSACTION_THRESHOLD, - warning_threshold: float = DEFAULT_WARNING_THRESHOLD, - enable_monitoring: bool = True - ): - """ - 初始化事务监控器 - - Args: - long_transaction_threshold: 长事务阈值(秒),超过此时间视为长事务 - warning_threshold: 警告阈值(秒),超过此时间发出警告 - enable_monitoring: 是否启用监控 - """ - self.long_transaction_threshold = long_transaction_threshold - self.warning_threshold = warning_threshold - self.enable_monitoring = enable_monitoring - - # 事务统计 - self._stats = { - "total_transactions": 0, - "long_transactions": 0, - "warning_transactions": 0, - "total_duration": 0.0, - "max_duration": 0.0, - "min_duration": float('inf') - } - - # 线程本地存储,用于跟踪当前事务 - self._local = threading.local() - - @contextmanager - def monitor_transaction( - self, - transaction_name: str = "unnamed", - context: Optional[Dict[str, Any]] = None - ): - """ - 监控事务执行 - - 使用示例: - with monitor.monitor_transaction("create_user"): - # 执行事务操作 - pass - - Args: - transaction_name: 事务名称,用于日志记录 - context: 事务上下文信息(如 user_id, tenant_id 等) - """ - if not self.enable_monitoring: - yield - return - - # 记录开始时间 - start_time = time.time() - context = context or {} - - # 存储到线程本地 - self._local.transaction_name = transaction_name - self._local.start_time = start_time - self._local.context = context - - logger.debug( - "transaction_started", - transaction_name=transaction_name, - **context - ) - - try: - yield - finally: - # 计算持续时间 - duration = time.time() - start_time - - # 更新统计 - self._update_stats(duration) - - # 检查是否为长事务 - self._check_transaction_duration( - transaction_name, - duration, - context - ) - - logger.debug( - "transaction_completed", - transaction_name=transaction_name, - duration_seconds=round(duration, 3), - **context - ) - - def _update_stats(self, duration: float): - """更新事务统计信息""" - self._stats["total_transactions"] += 1 - self._stats["total_duration"] += duration - self._stats["max_duration"] = max(self._stats["max_duration"], duration) - self._stats["min_duration"] = min(self._stats["min_duration"], duration) - - if duration >= self.long_transaction_threshold: - self._stats["long_transactions"] += 1 - elif duration >= self.warning_threshold: - self._stats["warning_transactions"] += 1 - - def _check_transaction_duration( - self, - transaction_name: str, - duration: float, - context: Dict[str, Any] - ): - """ - 检查事务持续时间并发出告警 - - Args: - transaction_name: 事务名称 - duration: 事务持续时间(秒) - context: 事务上下文 - """ - if duration >= self.long_transaction_threshold: - # 长事务告警 - logger.warning( - f"Long transaction detected: {transaction_name} took {round(duration, 3)}s " - f"(threshold: {self.long_transaction_threshold}s). " - f"Consider breaking down the transaction or moving non-critical operations outside. " - f"Context: {context}" - ) - elif duration >= self.warning_threshold: - # 警告级别 - logger.info( - f"Slow transaction detected: {transaction_name} took {round(duration, 3)}s " - f"(threshold: {self.warning_threshold}s). " - f"Monitor this transaction for potential optimization. " - f"Context: {context}" - ) - - def get_stats(self) -> Dict[str, Any]: - """ - 获取事务统计信息 - - Returns: - 包含统计信息的字典 - """ - if self._stats["total_transactions"] == 0: - avg_duration = 0.0 - else: - avg_duration = self._stats["total_duration"] / self._stats["total_transactions"] - - return { - **self._stats, - "avg_duration": round(avg_duration, 3), - "long_transaction_rate": ( - self._stats["long_transactions"] / self._stats["total_transactions"] - if self._stats["total_transactions"] > 0 else 0.0 - ), - "warning_transaction_rate": ( - self._stats["warning_transactions"] / self._stats["total_transactions"] - if self._stats["total_transactions"] > 0 else 0.0 - ) - } - - def reset_stats(self): - """重置统计信息""" - self._stats = { - "total_transactions": 0, - "long_transactions": 0, - "warning_transactions": 0, - "total_duration": 0.0, - "max_duration": 0.0, - "min_duration": float('inf') - } - logger.info("transaction_stats_reset") - - def print_stats(self): - """打印统计信息(用于调试)""" - stats = self.get_stats() - print("\n" + "=" * 60) - print("Transaction Statistics") - print("=" * 60) - print(f"Total Transactions: {stats['total_transactions']}") - print(f"Long Transactions: {stats['long_transactions']} ({stats['long_transaction_rate']:.1%})") - print(f"Warning Transactions: {stats['warning_transactions']} ({stats['warning_transaction_rate']:.1%})") - print(f"Average Duration: {stats['avg_duration']:.3f}s") - print(f"Max Duration: {stats['max_duration']:.3f}s") - print(f"Min Duration: {stats['min_duration']:.3f}s") - print("=" * 60 + "\n") - - -# 全局事务监控器实例 -transaction_monitor = TransactionMonitor( - long_transaction_threshold=5.0, # 5秒 - warning_threshold=2.0, # 2秒 - enable_monitoring=True -) - - -def get_transaction_monitor() -> TransactionMonitor: - """获取全局事务监控器实例""" - return transaction_monitor diff --git a/app/core/uow.py b/app/core/uow.py deleted file mode 100644 index a71dc35b..00000000 --- a/app/core/uow.py +++ /dev/null @@ -1,265 +0,0 @@ -""" -Unit of Work Pattern Implementation -Manages database transactions and coordinates multiple repositories. - -事务边界管理: -- 使用 with 语句明确事务边界 -- 所有数据库操作必须在 with 块内执行 -- 必须显式调用 commit() 提交事务 -- 异常会自动触发回滚 - -长事务监控: -- 自动监控事务持续时间 -- 检测并告警长事务(默认 > 5秒) -- 提供事务性能统计 -""" -from abc import ABC, abstractmethod -from typing import Callable, TypeVar, Generic, Optional, Dict, Any -from sqlalchemy.orm import Session -import time - -from app.repositories.generic_file_repository import GenericFileRepository -from app.repositories.user_repository import UserRepository -from app.repositories.workspace_repository import WorkspaceRepository -from app.repositories.workspace_invite_repository import WorkspaceInviteRepository -from app.repositories.tenant_repository import TenantRepository -from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository -from app.core.logging_config import get_logger - -logger = get_logger(__name__) - -T = TypeVar('T') - - -class IUnitOfWork(ABC): - """工作单元接口""" - - files: GenericFileRepository - users: UserRepository - workspaces: WorkspaceRepository - workspace_invites: WorkspaceInviteRepository - tenants: TenantRepository - model_configs: ModelConfigRepository - model_api_keys: ModelApiKeyRepository - - @abstractmethod - def __enter__(self): - """进入上下文""" - pass - - @abstractmethod - def __exit__(self, exc_type, exc_val, exc_tb): - """退出上下文""" - pass - - @abstractmethod - def commit(self): - """提交事务""" - pass - - @abstractmethod - def rollback(self): - """回滚事务""" - pass - - -class SqlAlchemyUnitOfWork(IUnitOfWork): - """ - SQLAlchemy 工作单元实现 - - 事务边界说明: - - __enter__: 开始事务 (创建新的 session) - - __exit__: 结束事务 (自动回滚异常,关闭 session) - - commit(): 显式提交事务 - - rollback(): 显式回滚事务 - - 长事务监控: - - 自动记录事务开始时间 - - 在事务结束时计算持续时间 - - 超过阈值时发出告警 - - 使用示例: - with uow: - # 事务开始 - user = uow.users.create_user(data) - workspace = uow.workspaces.create_workspace(data) - # 所有操作在同一事务中 - uow.commit() - # 事务提交 - # 事务结束,session 关闭 - """ - - # 长事务阈值(秒) - LONG_TRANSACTION_THRESHOLD = 5.0 - WARNING_THRESHOLD = 2.0 - - def __init__( - self, - session_factory: Callable[[], Session], - transaction_name: Optional[str] = None, - context: Optional[Dict[str, Any]] = None, - enable_monitoring: bool = True - ): - self.session_factory = session_factory - self._session: Session = None - self._transaction_active = False - self._transaction_name = transaction_name if transaction_name is not None else "unnamed" - self._context = context or {} - self._enable_monitoring = enable_monitoring - self._start_time: Optional[float] = None - - def __enter__(self): - """ - 进入事务上下文 - 创建新的数据库 session 并开始事务 - 同时开始监控事务持续时间 - """ - self._session = self.session_factory() - self._transaction_active = True - - # 记录事务开始时间 - if self._enable_monitoring: - self._start_time = time.time() - logger.debug( - "transaction_started", - transaction_name=self._transaction_name, - **self._context - ) - - # 初始化所有仓储,共享同一个 session - # 确保所有仓储操作在同一事务中 - self.files = GenericFileRepository(self._session) - self.users = UserRepository(self._session) - self.workspaces = WorkspaceRepository(self._session) - self.workspace_invites = WorkspaceInviteRepository(self._session) - self.tenants = TenantRepository(self._session) - - # Note: ModelConfigRepository and ModelApiKeyRepository use static methods - # They don't need session in constructor, but we provide access to the session - self.model_configs = ModelConfigRepository - self.model_api_keys = ModelApiKeyRepository - self.session = self._session # Provide direct access to session for static method repositories - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - 退出事务上下文 - - 如果发生异常: - - 自动回滚事务 - - 关闭 session - - 如果没有异常: - - 仅关闭 session (需要显式调用 commit) - - 同时检查事务持续时间并发出告警 - """ - try: - if exc_type is not None: - # 异常发生,自动回滚 - self.rollback() - finally: - # 检查事务持续时间 - if self._enable_monitoring and self._start_time is not None: - duration = time.time() - self._start_time - self._check_transaction_duration(duration) - - logger.debug( - "transaction_completed", - transaction_name=self._transaction_name, - duration_seconds=round(duration, 3), - **self._context - ) - - # 无论如何都要关闭 session - self._session.close() - self._transaction_active = False - - def commit(self): - """ - 显式提交事务 - - 注意: 必须在 with 块内调用 - 提交后事务仍然活跃,可以继续操作 - """ - if not self._transaction_active: - raise RuntimeError("Cannot commit: transaction is not active") - - logger.debug("Committing transaction") - self._session.commit() - logger.debug("Transaction committed successfully") - - def rollback(self): - """ - 显式回滚事务 - - 注意: 必须在 with 块内调用 - 回滚后事务仍然活跃,可以继续操作 - """ - if not self._transaction_active: - raise RuntimeError("Cannot rollback: transaction is not active") - - logger.debug("Rolling back transaction") - self._session.rollback() - logger.debug("Transaction rolled back successfully") - - def _check_transaction_duration(self, duration: float): - """ - 检查事务持续时间并发出告警 - - Args: - duration: 事务持续时间(秒) - """ - if duration >= self.LONG_TRANSACTION_THRESHOLD: - # 长事务告警 - logger.warning( - f"Long transaction detected: {self._transaction_name} took {round(duration, 3)}s " - f"(threshold: {self.LONG_TRANSACTION_THRESHOLD}s). " - f"Consider breaking down the transaction or moving non-critical operations outside. " - f"Context: {self._context}" - ) - elif duration >= self.WARNING_THRESHOLD: - # 警告级别 - logger.info( - f"Slow transaction detected: {self._transaction_name} took {round(duration, 3)}s " - f"(threshold: {self.WARNING_THRESHOLD}s). " - f"Monitor this transaction for potential optimization. " - f"Context: {self._context}" - ) - - def execute_in_transaction(self, func: Callable[[IUnitOfWork], T]) -> T: - """ - 在事务中执行函数,自动管理事务边界 - - 这是一个便捷方法,用于明确事务边界: - - 自动开始事务 - - 执行函数 - - 自动提交事务 - - 异常时自动回滚 - - Args: - func: 接受 UoW 作为参数的函数 - - Returns: - 函数的返回值 - - Example: - def create_user_and_workspace(uow): - user = uow.users.create_user(user_data) - workspace = uow.workspaces.create_workspace(ws_data) - return user, workspace - - result = uow.execute_in_transaction(create_user_and_workspace) - """ - logger.debug("Starting transaction execution") - with self: - try: - result = func(self) - self.commit() - logger.debug("Transaction execution completed successfully") - return result - except Exception as e: - logger.error(f"Transaction execution failed: {str(e)}") - # Rollback is automatic in __exit__ - raise diff --git a/app/core/upload_enums.py b/app/core/upload_enums.py deleted file mode 100644 index dd29d520..00000000 --- a/app/core/upload_enums.py +++ /dev/null @@ -1,10 +0,0 @@ -from enum import Enum - - -class UploadContext(str, Enum): - """上传上下文枚举,定义文件上传的目的和分类""" - AVATAR = "avatar" - APP_ICON = "app_icon" - KNOWLEDGE_BASE = "knowledge_base" - TEMP = "temp" - ATTACHMENT = "attachment" diff --git a/app/core/upload_policies.py b/app/core/upload_policies.py deleted file mode 100644 index 56443a45..00000000 --- a/app/core/upload_policies.py +++ /dev/null @@ -1,80 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional -from app.core.upload_enums import UploadContext - - -@dataclass -class UploadPolicy: - """上传策略,定义文件大小限制、允许的文件类型等规则""" - max_file_size: int # 最大文件大小(字节) - allowed_extensions: List[str] # 允许的文件扩展名列表 - allowed_mime_types: List[str] # 允许的 MIME 类型列表 - require_authentication: bool = True # 是否需要认证 - enable_virus_scan: bool = False # 是否启用病毒扫描 - enable_compression: bool = False # 是否启用压缩 - auto_delete_after_days: Optional[int] = None # 自动删除天数(None 表示不自动删除) - - -# 各上下文的上传策略配置 -UPLOAD_POLICIES = { - UploadContext.AVATAR: UploadPolicy( - max_file_size=5 * 1024 * 1024, # 5MB - allowed_extensions=[".jpg", ".jpeg", ".png", ".gif", ".webp"], - allowed_mime_types=["image/jpeg", "image/png", "image/gif", "image/webp"], - require_authentication=True, - enable_compression=True, - ), - UploadContext.APP_ICON: UploadPolicy( - max_file_size=2 * 1024 * 1024, # 2MB - allowed_extensions=[".jpg", ".jpeg", ".png", ".svg"], - allowed_mime_types=["image/jpeg", "image/png", "image/svg+xml"], - require_authentication=True, - enable_compression=True, - ), - UploadContext.KNOWLEDGE_BASE: UploadPolicy( - max_file_size=50 * 1024 * 1024, # 50MB - allowed_extensions=[".pdf", ".doc", ".docx", ".txt", ".md", ".xlsx", ".csv"], - allowed_mime_types=[ - "application/pdf", - "application/msword", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "text/plain", - "text/markdown", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "text/csv", - ], - require_authentication=True, - enable_virus_scan=True, - ), - UploadContext.TEMP: UploadPolicy( - max_file_size=10 * 1024 * 1024, # 10MB - allowed_extensions=[], # 允许所有类型 - allowed_mime_types=[], # 允许所有类型 - require_authentication=True, - auto_delete_after_days=7, - ), - UploadContext.ATTACHMENT: UploadPolicy( - max_file_size=20 * 1024 * 1024, # 20MB - allowed_extensions=[], # 允许所有类型 - allowed_mime_types=[], # 允许所有类型 - require_authentication=True, - ), -} - - -def get_upload_policy(context: UploadContext) -> UploadPolicy: - """ - 根据上传上下文获取对应的上传策略 - - Args: - context: 上传上下文 - - Returns: - 对应的上传策略 - - Raises: - ValueError: 如果上下文不存在对应的策略 - """ - if context not in UPLOAD_POLICIES: - raise ValueError(f"未定义上传上下文 '{context}' 的策略") - return UPLOAD_POLICIES[context] diff --git a/app/core/validators/__init__.py b/app/core/validators/__init__.py deleted file mode 100644 index a53b6b71..00000000 --- a/app/core/validators/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Validators for file upload system. -""" -from app.core.validators.file_validator import FileValidator, ValidationResult - -__all__ = ["FileValidator", "ValidationResult"] diff --git a/app/core/validators/file_validator.py b/app/core/validators/file_validator.py deleted file mode 100644 index f5de46ed..00000000 --- a/app/core/validators/file_validator.py +++ /dev/null @@ -1,357 +0,0 @@ -""" -File validator for generic file upload system. -Validates file size, type, content, and upload policies. -""" -import mimetypes -from typing import Optional, List -from dataclasses import dataclass -from fastapi import UploadFile - -from app.core.upload_policies import UploadPolicy -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - - -# Magic numbers for common file types (first few bytes of file) -MAGIC_NUMBERS = { - # Images - b'\xFF\xD8\xFF': ['.jpg', '.jpeg'], - b'\x89PNG\r\n\x1a\n': ['.png'], - b'GIF87a': ['.gif'], - b'GIF89a': ['.gif'], - b'RIFF': ['.webp'], # Note: WEBP has additional checks needed - b'<svg': ['.svg'], - b'<?xml': ['.svg'], - - # Documents - b'%PDF': ['.pdf'], - b'PK\x03\x04': ['.docx', '.xlsx', '.zip'], # ZIP-based formats - b'\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1': ['.doc', '.xls'], # MS Office old format - - # Text files (no specific magic number, will be validated differently) -} - - -@dataclass -class ValidationResult: - """Result of a validation operation.""" - is_valid: bool - error_message: Optional[str] = None - error_code: Optional[int] = None - - -class FileUploadError(BusinessException): - """Base exception for file upload errors.""" - pass - - -class FileSizeExceededError(FileUploadError): - """Exception raised when file size exceeds the limit.""" - def __init__(self, max_size: int, actual_size: int): - super().__init__( - f"文件大小 {actual_size} 字节超过限制 {max_size} 字节", - code=BizCode.BAD_REQUEST - ) - self.max_size = max_size - self.actual_size = actual_size - - -class FileTypeNotAllowedError(FileUploadError): - """Exception raised when file type is not allowed.""" - def __init__(self, file_type: str, allowed_types: List[str]): - allowed_str = ', '.join(allowed_types) if allowed_types else '任意类型' - super().__init__( - f"文件类型 '{file_type}' 不在允许列表中: {allowed_str}", - code=BizCode.BAD_REQUEST - ) - self.file_type = file_type - self.allowed_types = allowed_types - - -class EmptyFileError(FileUploadError): - """Exception raised when file is empty.""" - def __init__(self): - super().__init__( - "文件内容为空,无法上传", - code=BizCode.BAD_REQUEST - ) - - -class FileContentMismatchError(FileUploadError): - """Exception raised when file content doesn't match its extension.""" - def __init__(self, extension: str): - super().__init__( - f"文件内容与扩展名 '{extension}' 不匹配,可能存在文件类型伪装", - code=BizCode.BAD_REQUEST - ) - self.extension = extension - - -class FileValidator: - """ - Validator for file uploads. - Validates file size, type, content, and upload policies. - """ - - def __init__(self): - """Initialize the file validator.""" - # Initialize mimetypes - mimetypes.init() - - def validate_file_size(self, file_size: int, max_size: int) -> ValidationResult: - """ - Validate that file size does not exceed the maximum allowed size. - - Args: - file_size: Size of the file in bytes - max_size: Maximum allowed size in bytes - - Returns: - ValidationResult indicating if validation passed - """ - if file_size == 0: - return ValidationResult( - is_valid=False, - error_message="文件大小为 0 字节,无法上传空文件", - error_code=BizCode.BAD_REQUEST - ) - - if file_size > max_size: - return ValidationResult( - is_valid=False, - error_message=f"文件大小 {file_size} 字节超过限制 {max_size} 字节", - error_code=BizCode.BAD_REQUEST - ) - - return ValidationResult(is_valid=True) - - def validate_file_type( - self, - file_extension: str, - allowed_extensions: List[str], - mime_type: Optional[str] = None, - allowed_mime_types: Optional[List[str]] = None - ) -> ValidationResult: - """ - Validate file type based on extension and MIME type. - - Args: - file_extension: File extension (e.g., '.jpg') - allowed_extensions: List of allowed extensions (empty list means all allowed) - mime_type: MIME type of the file (optional) - allowed_mime_types: List of allowed MIME types (empty list means all allowed) - - Returns: - ValidationResult indicating if validation passed - """ - # Normalize extension to lowercase and ensure it starts with a dot - if not file_extension.startswith('.'): - file_extension = f'.{file_extension}' - file_extension = file_extension.lower() - - # If allowed_extensions is empty, all extensions are allowed - if allowed_extensions: - # Normalize allowed extensions - normalized_allowed = [ext.lower() if ext.startswith('.') else f'.{ext.lower()}' - for ext in allowed_extensions] - - if file_extension not in normalized_allowed: - return ValidationResult( - is_valid=False, - error_message=f"文件扩展名 '{file_extension}' 不在允许列表中: {', '.join(normalized_allowed)}", - error_code=BizCode.BAD_REQUEST - ) - - # Validate MIME type if provided - if mime_type and allowed_mime_types: - mime_type_lower = mime_type.lower() - allowed_mime_lower = [mt.lower() for mt in allowed_mime_types] - - if mime_type_lower not in allowed_mime_lower: - return ValidationResult( - is_valid=False, - error_message=f"文件 MIME 类型 '{mime_type}' 不在允许列表中: {', '.join(allowed_mime_types)}", - error_code=BizCode.BAD_REQUEST - ) - - return ValidationResult(is_valid=True) - - def validate_file_content( - self, - file_content: bytes, - file_extension: str - ) -> ValidationResult: - """ - Validate file content by checking magic numbers (file signatures). - This helps prevent file type spoofing. - - Args: - file_content: First bytes of the file content (at least 16 bytes recommended) - file_extension: Expected file extension - - Returns: - ValidationResult indicating if validation passed - """ - if not file_content: - return ValidationResult( - is_valid=False, - error_message="文件内容为空", - error_code=BizCode.BAD_REQUEST - ) - - # Normalize extension - if not file_extension.startswith('.'): - file_extension = f'.{file_extension}' - file_extension = file_extension.lower() - - # For text-based files, we skip magic number validation - text_extensions = ['.txt', '.md', '.csv', '.json', '.xml', '.html', '.css', '.js'] - if file_extension in text_extensions: - return ValidationResult(is_valid=True) - - # Check magic numbers for binary files - content_matched = False - for magic_bytes, extensions in MAGIC_NUMBERS.items(): - if file_content.startswith(magic_bytes): - # Special handling for WEBP (needs additional check) - if magic_bytes == b'RIFF' and len(file_content) >= 12: - if file_content[8:12] == b'WEBP': - extensions = ['.webp'] - else: - continue - - # Check if the detected type matches the extension - if file_extension in extensions: - content_matched = True - break - else: - # Content doesn't match extension - return ValidationResult( - is_valid=False, - error_message=f"文件内容与扩展名 '{file_extension}' 不匹配,检测到的类型为: {', '.join(extensions)}", - error_code=BizCode.BAD_REQUEST - ) - - # If we checked for magic numbers but didn't find a match, it might be an issue - # However, for some file types (like .docx, .xlsx which are ZIP files), - # we allow them through if they match the ZIP signature - if not content_matched and file_extension in ['.docx', '.xlsx', '.zip']: - if file_content.startswith(b'PK\x03\x04'): - content_matched = True - - # For file types we don't have magic numbers for, we'll allow them through - # This is a pragmatic approach - we validate what we can - return ValidationResult(is_valid=True) - - def validate_upload_policy( - self, - file: UploadFile, - policy: UploadPolicy - ) -> ValidationResult: - """ - Validate a file against an upload policy. - This is a comprehensive validation that checks size, type, and content. - - Args: - file: The uploaded file - policy: The upload policy to validate against - - Returns: - ValidationResult indicating if validation passed - """ - # Get file extension from filename - filename = file.filename or "" - file_extension = "" - if "." in filename: - file_extension = "." + filename.rsplit(".", 1)[1].lower() - - # Get file size - file.file.seek(0, 2) # Seek to end - file_size = file.file.tell() - file.file.seek(0) # Reset to beginning - - # Validate file size - size_result = self.validate_file_size(file_size, policy.max_file_size) - if not size_result.is_valid: - return size_result - - # Get MIME type - mime_type = file.content_type - - # Validate file type (extension and MIME type) - type_result = self.validate_file_type( - file_extension, - policy.allowed_extensions, - mime_type, - policy.allowed_mime_types - ) - if not type_result.is_valid: - return type_result - - # Read first bytes for content validation (read up to 16 bytes for magic number check) - file_content = file.file.read(16) - file.file.seek(0) # Reset to beginning - - # Validate file content (magic numbers) - content_result = self.validate_file_content(file_content, file_extension) - if not content_result.is_valid: - return content_result - - return ValidationResult(is_valid=True) - - def validate_and_raise( - self, - file: UploadFile, - policy: UploadPolicy - ) -> None: - """ - Validate a file against a policy and raise an exception if validation fails. - This is a convenience method for use in services. - - Args: - file: The uploaded file - policy: The upload policy to validate against - - Raises: - FileSizeExceededError: If file size exceeds limit - FileTypeNotAllowedError: If file type is not allowed - EmptyFileError: If file is empty - FileContentMismatchError: If file content doesn't match extension - """ - result = self.validate_upload_policy(file, policy) - - if not result.is_valid: - # Determine which specific error to raise based on the error message - if "大小" in result.error_message and "超过" in result.error_message: - # Extract sizes from the validation result - filename = file.filename or "" - file_extension = "" - if "." in filename: - file_extension = "." + filename.rsplit(".", 1)[1].lower() - - file.file.seek(0, 2) - file_size = file.file.tell() - file.file.seek(0) - - raise FileSizeExceededError(policy.max_file_size, file_size) - - elif "为 0 字节" in result.error_message or "内容为空" in result.error_message: - raise EmptyFileError() - - elif "不匹配" in result.error_message: - filename = file.filename or "" - file_extension = "" - if "." in filename: - file_extension = "." + filename.rsplit(".", 1)[1].lower() - raise FileContentMismatchError(file_extension) - - elif "扩展名" in result.error_message or "MIME 类型" in result.error_message: - filename = file.filename or "" - file_extension = "" - if "." in filename: - file_extension = "." + filename.rsplit(".", 1)[1].lower() - raise FileTypeNotAllowedError(file_extension, policy.allowed_extensions) - - else: - # Generic error - raise FileUploadError(result.error_message, code=result.error_code) diff --git a/app/core/workflow/__init__.py b/app/core/workflow/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/db.py b/app/db.py deleted file mode 100644 index 895e81c2..00000000 --- a/app/db.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.declarative import declarative_base -from app.core.config import settings - -SQLALCHEMY_DATABASE_URL = f"postgresql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}" - -engine = create_engine(SQLALCHEMY_DATABASE_URL) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -Base = declarative_base() - -# Dependency to get a DB session -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() diff --git a/app/dependencies.py b/app/dependencies.py deleted file mode 100644 index c2c4d99a..00000000 --- a/app/dependencies.py +++ /dev/null @@ -1,459 +0,0 @@ -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer -from sqlalchemy.orm import Session -from jose import jwt, JWTError -import uuid -from functools import wraps - -from app.db import get_db, SessionLocal -from app.schemas import token_schema -from app.core.config import settings -from app.core.security import get_token_id -from app.repositories import user_repository, tenant_repository -from app.repositories import workspace_repository -from app.models.user_model import User -from app.models.tenant_model import Tenants -from app.models.workspace_model import Workspace -from app.services.session_service import SessionService -from app.core.logging_config import get_auth_logger, get_security_logger -from app.core.uow import SqlAlchemyUnitOfWork, IUnitOfWork -from app.core.exceptions import PermissionDeniedException - -# 获取专用日志器 -auth_logger = get_auth_logger() -security_logger = get_security_logger() - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - -async def get_current_user( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) -) -> User: - """ - 获取当前认证用户 - """ - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - - try: - auth_logger.debug("开始解析JWT token") - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) - user_id: str = payload.get("sub") - - if user_id is None: - auth_logger.warning("JWT token中缺少用户ID") - raise credentials_exception - - token_data = token_schema.TokenData(userId=user_id) - auth_logger.debug(f"JWT解析成功,用户ID: {user_id}") - - except JWTError as e: - auth_logger.warning(f"JWT解析失败: {str(e)}") - raise credentials_exception - - # 检查单点登录黑名单和用户token失效 - try: - auth_logger.debug("检查单点登录黑名单") - token_id = get_token_id(token) - session_service = SessionService() - - if await session_service.is_token_blacklisted(token_id): - auth_logger.warning(f"Token已被列入黑名单: {token_id}") - raise credentials_exception - - # 检查用户是否重置了密码(所有旧token失效) - invalidation_time_str = await session_service.get_user_token_invalidation_time(user_id) - if invalidation_time_str: - from datetime import datetime, timezone - invalidation_time = datetime.fromisoformat(invalidation_time_str) - token_issued_at = datetime.fromtimestamp(payload.get("iat", 0), tz=timezone.utc) if payload.get("iat") else None - - if token_issued_at and token_issued_at < invalidation_time: - auth_logger.warning(f"Token在密码重置前签发,已失效: user_id={user_id}") - raise credentials_exception - - auth_logger.debug("单点登录检查通过") - - except HTTPException: - raise - except Exception as e: - auth_logger.error(f"检查token有效性时发生错误: {str(e)}") - raise credentials_exception - - try: - auth_logger.debug(f"查询用户信息: {token_data.userId}") - user = user_repository.get_user_by_id(db, user_id=token_data.userId) - - if user is None: - auth_logger.warning(f"用户不存在: {token_data.userId}") - raise credentials_exception - if not user.is_active: - auth_logger.warning(f"用户已被停用: {user.username} (ID: {user.id})") - raise credentials_exception - - auth_logger.info(f"用户认证成功: {user.username} (ID: {user.id})") - return user - - except Exception as e: - auth_logger.error(f"查询用户信息时发生错误: {str(e)}") - raise credentials_exception - - -async def get_current_tenant( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -) -> Tenants: - """ - 获取当前用户的租户 - 由于每个用户只属于一个租户,直接返回用户的租户 - """ - auth_logger.debug(f"获取用户 {current_user.username} 的租户信息") - - try: - # 直接从用户模型获取租户 - if current_user.tenant: - auth_logger.info(f"用户 {current_user.username} 的租户: {current_user.tenant.name}") - return current_user.tenant - else: - auth_logger.warning(f"用户 {current_user.username} 没有关联的租户") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="用户没有关联的租户" - ) - - except HTTPException: - raise - except Exception as e: - auth_logger.error(f"获取租户信息时发生错误: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="获取租户信息失败" - ) - - -async def get_user_tenants( - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -) -> list[Tenants]: - """ - 获取当前用户所属的所有租户 - 由于每个用户只属于一个租户,返回包含该租户的列表 - """ - auth_logger.debug(f"获取用户 {current_user.username} 的所有租户") - - try: - if current_user.tenant: - tenants = [current_user.tenant] - auth_logger.info(f"用户 {current_user.username} 属于 1 个租户") - return tenants - else: - auth_logger.info(f"用户 {current_user.username} 没有关联的租户") - return [] - - except Exception as e: - auth_logger.error(f"获取用户租户列表时发生错误: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="获取租户列表失败" - ) - - -async def get_current_superuser( - current_user: User = Depends(get_current_user) -) -> User: - """ - 检查当前用户是否为超级管理员 - """ - auth_logger.debug(f"检查用户 {current_user.username} 是否为超级管理员") - - if not current_user.is_superuser: - auth_logger.warning(f"用户 {current_user.username} 尝试访问超管功能但不是超级管理员") - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="只有超级管理员才能执行此操作" - ) - - auth_logger.info(f"超级管理员 {current_user.username} 访问超管功能") - return current_user - - -# ---------------------- -# Workspace Access Guard -# ---------------------- - -# async def require_workspace_access( -# workspace_id: uuid.UUID, -# db: Session = Depends(get_db), -# current_user: User = Depends(get_current_user), -# ) -> Workspace: -# """ -# 校验当前用户对指定工作空间的访问权限: -# - 工作空间必须存在 -# - 超级管理员且与工作空间同租户可访问 -# - 普通用户必须是该工作空间成员 - -# 返回工作空间对象以便后续使用;无权限时抛出 HTTPException。 -# """ -# auth_logger.debug(f"校验工作空间访问权限: workspace_id={workspace_id}, user={current_user.id}") - -# # 1) 工作空间存在性 -# workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) -# if not workspace: -# auth_logger.warning(f"工作空间不存在: {workspace_id}") -# raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workspace not found") - -# # 2) 超级管理员(同租户)直接放行 -# if current_user.is_superuser: -# if workspace.tenant_id == current_user.tenant_id: -# auth_logger.debug(f"超管同租户访问放行: user={current_user.id}, workspace={workspace_id}") -# return workspace -# # 超管跨租户访问不允许 -# auth_logger.warning( -# f"超管跨租户访问被拒: user_tenant={current_user.tenant_id}, workspace_tenant={workspace.tenant_id}" -# ) -# raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") - -# # 3) 普通用户需要是成员 -# member = workspace_repository.get_member_in_workspace( -# db=db, user_id=current_user.id, workspace_id=workspace_id -# ) -# if not member: -# auth_logger.warning(f"非成员访问被拒: user={current_user.id}, workspace={workspace_id}") -# raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") - -# auth_logger.debug(f"成员访问通过: user={current_user.id}, workspace={workspace_id}") -# return workspace - - -# # 针对创建应用的请求体(包含 workspace_id)提供便捷校验 -# from app.schemas.app_schema import AppCreate - -# async def require_workspace_access_for_app_create( -# payload: AppCreate, -# db: Session = Depends(get_db), -# current_user: User = Depends(get_current_user), -# ) -> Workspace: -# return await require_workspace_access(payload.workspace_id, db, current_user) - - -# ---------------------- -# Decorator (@) version -# ---------------------- - -def _check_workspace_access_sync(db: Session, user: User, workspace_id: uuid.UUID) -> Workspace: - """同步校验版本,供装饰器在同步端点中调用 - 使用权限服务""" - auth_logger.debug(f"同步校验工作空间访问权限: workspace_id={workspace_id}, user={user.id}") - - # 1) 工作空间存在性 - workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) - if not workspace: - auth_logger.warning(f"工作空间不存在: {workspace_id}") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workspace not found") - - # 2) 超级用户跳过成员检查,直接验证租户 - if user.is_superuser: - if user.tenant_id == workspace.tenant_id: - auth_logger.debug(f"超级用户访问同租户工作空间: workspace_id={workspace_id}, user={user.id}") - return workspace - else: - auth_logger.warning(f"超级用户尝试访问其他租户工作空间: workspace_id={workspace_id}, user={user.id}") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") - - # 3) 普通用户使用权限服务检查访问权限 - from app.core.permissions import permission_service, Subject, Resource, Action - from app.core.permissions.policies import WorkspaceMemberPolicy, SameTenantSuperuserPolicy - - # Check if user is a member - member = workspace_repository.get_member_in_workspace( - db=db, user_id=user.id, workspace_id=workspace_id - ) - workspace_memberships = {workspace_id} if member else set() - - subject = Subject.from_user(user, workspace_memberships=workspace_memberships) - resource = Resource.from_workspace(workspace) - - # Add workspace member policy - temp_service = permission_service - if member: - temp_service.add_policy(WorkspaceMemberPolicy(allowed_actions={Action.READ, Action.UPDATE, Action.MANAGE})) - temp_service.add_policy(SameTenantSuperuserPolicy()) - - try: - permission_service.require_permission( - subject, - Action.READ, - resource, - error_message="Forbidden" - ) - return workspace - except PermissionDeniedException: - auth_logger.warning(f"工作空间访问被拒绝: workspace_id={workspace_id}, user={user.id}") - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") - - -def workspace_access_guard(get_workspace_id_from_body: bool = False): - """ - @ 装饰器:在端点进入前执行工作空间访问校验。 - 要求端点函数签名包含: - - db: Session = Depends(get_db) - - user 或 current_user: User = Depends(get_current_user) - - workspace_id: uuid.UUID (query/path 参数)或 payload: AppCreate(body,含 workspace_id) - - 支持同步和异步函数。 - """ - import asyncio - - def _decorator(func): - # 检查函数是否是异步的 - if asyncio.iscoroutinefunction(func): - @wraps(func) - async def _async_wrapper(*args, **kwargs): - db: Session = kwargs.get("db") - user: User = kwargs.get("user") or kwargs.get("current_user") - - if get_workspace_id_from_body: - payload = kwargs.get("payload") - if not payload or not hasattr(payload, "workspace_id"): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id missing in body") - workspace_id = payload.workspace_id - else: - workspace_id = kwargs.get("workspace_id") - if workspace_id is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id is required") - - _check_workspace_access_sync(db, user, workspace_id) - return await func(*args, **kwargs) - return _async_wrapper - else: - @wraps(func) - def _sync_wrapper(*args, **kwargs): - db: Session = kwargs.get("db") - user: User = kwargs.get("user") or kwargs.get("current_user") - - if get_workspace_id_from_body: - payload = kwargs.get("payload") - if not payload or not hasattr(payload, "workspace_id"): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id missing in body") - workspace_id = payload.workspace_id - else: - workspace_id = kwargs.get("workspace_id") - if workspace_id is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id is required") - - _check_workspace_access_sync(db, user, workspace_id) - return func(*args, **kwargs) - return _sync_wrapper - - return _decorator - - -def get_uow() -> IUnitOfWork: - """ - 获取工作单元实例 - - Returns: - IUnitOfWork: 工作单元实例 - """ - return SqlAlchemyUnitOfWork(SessionLocal) - - -def cur_workspace_access_guard(): - """ - @ 装饰器:在端点进入前执行工作空间访问校验。 - 要求端点函数签名包含: - - db: Session = Depends(get_db) - - current_user: User = Depends(get_current_user) - - 支持同步和异步函数。 - """ - import asyncio - import inspect - - def _decorator(func): - # 检查函数是否是异步的 - if asyncio.iscoroutinefunction(func): - @wraps(func) - async def _async_wrapper(*args, **kwargs): - db: Session = kwargs.get("db") - user: User = kwargs.get("current_user") - workspace_id = user.current_workspace_id - if workspace_id is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id is required") - _check_workspace_access_sync(db, user, workspace_id) - return await func(*args, **kwargs) - return _async_wrapper - else: - @wraps(func) - def _sync_wrapper(*args, **kwargs): - db: Session = kwargs.get("db") - user: User = kwargs.get("current_user") - workspace_id = user.current_workspace_id - if workspace_id is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id is required") - _check_workspace_access_sync(db, user, workspace_id) - return func(*args, **kwargs) - return _sync_wrapper - - return _decorator - -class ShareTokenData: - """分享 token 数据""" - def __init__(self, user_id: str, share_token: str): - self.user_id = user_id - self.share_token = share_token - - -async def get_share_user_id( - token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) -) -> ShareTokenData: - """ - 从分享访问 token 中获取用户 ID 和 share_token - - 这个函数用于公开分享的接口,验证访问 token 并返回用户信息 - 不需要验证用户是否存在或激活,只需要验证 token 的有效性和 share_token 是否有效 - - Returns: - ShareTokenData: 包含 user_id 和 share_token - """ - from app.services.auth_service import decode_access_token - from app.services.release_share_service import ReleaseShareService - from app.core.exceptions import BusinessException - - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - - try: - auth_logger.debug("开始解析分享访问 token") - - # 解码 token 获取 user_id 和 share_token - payload = decode_access_token(token) - user_id = payload["user_id"] - share_token = payload["share_token"] - - auth_logger.debug(f"Token 解析成功,用户ID: {user_id}, share_token: {share_token}") - - # 验证 share_token 是否有效 - service = ReleaseShareService(db) - share_info = service.get_shared_release_info(share_token=share_token) - - if not share_info: - auth_logger.warning(f"分享 token 无效: {share_token}") - raise credentials_exception - - auth_logger.info(f"分享访问验证成功: user_id={user_id}, share_token={share_token}") - return ShareTokenData(user_id=user_id, share_token=share_token) - - except BusinessException as e: - auth_logger.warning(f"分享访问验证失败: {str(e)}") - raise credentials_exception - except Exception as e: - auth_logger.error(f"验证分享访问 token 时发生错误: {str(e)}") - raise credentials_exception - diff --git a/app/main.py b/app/main.py deleted file mode 100644 index e3a571d7..00000000 --- a/app/main.py +++ /dev/null @@ -1,382 +0,0 @@ -import os -from dotenv import load_dotenv -from fastapi import FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from app.core.config import settings -from contextlib import asynccontextmanager -from fastapi.responses import JSONResponse -from app.core.response_utils import fail -from app.core.logging_config import LoggingConfig, get_logger -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode, HTTP_MAPPING -from app.controllers import ( - model_controller, - task_controller, - test_controller, - user_controller, - auth_controller, - workspace_controller, - setup_controller, - file_controller, - document_controller, - knowledge_controller, - chunk_controller, - knowledgeshare_controller, - app_controller, - upload_controller, - memory_agent_controller, - memory_storage_controller, - memory_dashboard_controller, - multi_agent_controller, -) - -from fastapi import FastAPI, APIRouter - - -app = FastAPI(title="Data Config API", version="1.0.0") -router = APIRouter(prefix="/memory", tags=["Memory"]) - -# 管理端 API (JWT 认证) -from app.controllers import manager_router - -# 服务端 API (API Key 认证) -from app.controllers.service import service_router - -# Initialize logging system -LoggingConfig.setup_logging() -logger = get_logger(__name__) - -@asynccontextmanager -async def lifespan(app: FastAPI): - """使用 FastAPI lifespan 替代 on_event 处理启动/关闭事件""" - # 应用启动事件 - - # 检查是否需要自动升级数据库 - if settings.DB_AUTO_UPGRADE: - logger.info("开始自动升级数据库...") - try: - import subprocess - result = subprocess.run( - ["alembic", "upgrade", "head"], - capture_output=True, - text=True, - check=True - ) - logger.info(f"数据库升级成功: {result.stdout}") - except subprocess.CalledProcessError as e: - logger.error(f"数据库升级失败: {e.stderr}") - raise RuntimeError(f"数据库升级失败: {e.stderr}") - except Exception as e: - logger.error(f"运行数据库升级时出错: {str(e)}") - raise - else: - logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)") - - logger.info("应用程序启动完成") - yield - # 应用关闭事件 - logger.info("应用程序正在关闭") - -app = FastAPI( - title="redbera-mem", - description="redbera-mem", - version="1.0.0", - lifespan=lifespan, -) - -# Enable CORS for frontend access with environment-extendable origins -default_origins = [ - settings.WEB_URL -] -allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o}) - -app.add_middleware( - CORSMiddleware, - allow_origins=allowed_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -logger.info("FastAPI应用程序启动") - - -@app.get("/", tags=["General"]) -def read_root(): - """ - A simple health check endpoint. - """ - logger.debug("健康检查端点被访问") - return {"message": "FastAPI is running"} - - -# 生命周期事件由 lifespan 管理,无需 on_event - - -# 注册路由 -# 管理端 API (JWT 认证) -app.include_router(manager_router, prefix="/api") - -# 服务端 API (API Key 认证) -app.include_router(service_router, prefix="/v1") - - -logger.info("所有路由已注册完成") - - -# Import additional exception types for specific handling -from app.core.exceptions import ( - ValidationException, - ResourceNotFoundException, - PermissionDeniedException, - AuthenticationException, - AuthorizationException, - FileUploadException -) -from app.core.sensitive_filter import SensitiveDataFilter -import traceback - - -# 处理验证异常 -@app.exception_handler(ValidationException) -async def validation_exception_handler(request: Request, exc: ValidationException): - """处理验证异常""" - # 过滤敏感信息 - filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) - - logger.warning( - f"Validation error: {filtered_message}", - extra={ - "path": request.url.path, - "method": request.method, - "context": filtered_context, - "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, - "cause": str(exc.cause) if exc.cause else None - }, - exc_info=exc.cause is not None - ) - biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.VALIDATION_FAILED - status_code = HTTP_MAPPING.get(biz_code, 400) - return JSONResponse( - status_code=status_code, - content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) - ) - - -# 处理资源不存在异常 -@app.exception_handler(ResourceNotFoundException) -async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException): - """处理资源不存在异常""" - # 过滤敏感信息 - filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) - - logger.info( - f"Resource not found: {filtered_message}", - extra={ - "path": request.url.path, - "method": request.method, - "context": filtered_context, - "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, - "cause": str(exc.cause) if exc.cause else None - } - ) - biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_NOT_FOUND - status_code = HTTP_MAPPING.get(biz_code, 404) - return JSONResponse( - status_code=status_code, - content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) - ) - - -# 处理权限拒绝异常 -@app.exception_handler(PermissionDeniedException) -async def permission_denied_handler(request: Request, exc: PermissionDeniedException): - """处理权限拒绝异常""" - # 过滤敏感信息 - filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) - - logger.warning( - f"Permission denied: {filtered_message}", - extra={ - "path": request.url.path, - "method": request.method, - "user": getattr(request.state, "user_id", None), - "context": filtered_context, - "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, - "cause": str(exc.cause) if exc.cause else None - } - ) - biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN - status_code = HTTP_MAPPING.get(biz_code, 403) - return JSONResponse( - status_code=status_code, - content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) - ) - - -# 处理认证异常 -@app.exception_handler(AuthenticationException) -async def authentication_exception_handler(request: Request, exc: AuthenticationException): - """处理认证异常""" - # 过滤敏感信息 - filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) - - logger.warning( - f"Authentication error: {filtered_message}", - extra={ - "path": request.url.path, - "method": request.method, - "context": filtered_context, - "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, - "cause": str(exc.cause) if exc.cause else None - } - ) - biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.UNAUTHORIZED - status_code = HTTP_MAPPING.get(biz_code, 401) - return JSONResponse( - status_code=status_code, - content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) - ) - - -# 处理授权异常 -@app.exception_handler(AuthorizationException) -async def authorization_exception_handler(request: Request, exc: AuthorizationException): - """处理授权异常""" - # 过滤敏感信息 - filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) - - logger.warning( - f"Authorization error: {filtered_message}", - extra={ - "path": request.url.path, - "method": request.method, - "context": filtered_context, - "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, - "cause": str(exc.cause) if exc.cause else None - } - ) - biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN - status_code = HTTP_MAPPING.get(biz_code, 403) - return JSONResponse( - status_code=status_code, - content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) - ) - - -# 处理文件上传异常 -@app.exception_handler(FileUploadException) -async def file_upload_exception_handler(request: Request, exc: FileUploadException): - """处理文件上传异常""" - # 过滤敏感信息 - filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) - - logger.error( - f"File upload error: {filtered_message}", - extra={ - "path": request.url.path, - "method": request.method, - "context": filtered_context, - "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, - "cause": str(exc.cause) if exc.cause else None - }, - exc_info=exc.cause is not None - ) - biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_READ_ERROR - status_code = HTTP_MAPPING.get(biz_code, 500) - return JSONResponse( - status_code=status_code, - content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) - ) - - -# 业务异常统一处理(使用业务错误码) -@app.exception_handler(BusinessException) -async def business_exception_handler(request: Request, exc: BusinessException): - """处理通用业务异常""" - # 过滤敏感信息 - filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) - - logger.error( - f"Business error: {filtered_message}", - extra={ - "path": request.url.path, - "method": request.method, - "context": filtered_context, - "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, - "cause": str(exc.cause) if exc.cause else None - }, - exc_info=exc.cause is not None - ) - raw_code = exc.code - if isinstance(raw_code, BizCode): - biz_code = raw_code - elif isinstance(raw_code, int): - try: - biz_code = BizCode(raw_code) - except ValueError: - biz_code = BizCode.BAD_REQUEST - else: - biz_code = BizCode.BAD_REQUEST - - status_code = HTTP_MAPPING.get(biz_code, 400) - return JSONResponse( - status_code=status_code, - content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) - ) - - -# 统一异常处理:将HTTPException转换为统一响应结构 -@app.exception_handler(HTTPException) -async def http_exception_handler(request: Request, exc: HTTPException): - """处理HTTP异常""" - # 过滤敏感信息 - filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail)) - - logger.warning( - f"HTTP exception: {filtered_detail}", - extra={ - "path": request.url.path, - "method": request.method, - "status_code": exc.status_code - } - ) - return JSONResponse( - status_code=exc.status_code, - content=fail(code=exc.status_code, msg=filtered_detail, error=filtered_detail) - ) - - -# 捕获未处理的异常,返回统一错误结构 -@app.exception_handler(Exception) -async def unhandled_exception_handler(request: Request, exc: Exception): - """处理未捕获的异常""" - # 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息) - logger.error( - f"Unhandled exception: {exc}", - extra={ - "path": request.url.path, - "method": request.method, - "exception_type": type(exc).__name__, - "traceback": traceback.format_exc() - }, - exc_info=True - ) - - # 生产环境隐藏详细错误信息 - environment = os.getenv("ENVIRONMENT", "development") - if environment == "production": - message = "服务器内部错误,请稍后重试" - else: - # 开发环境也要过滤敏感信息 - message = SensitiveDataFilter.filter_string(str(exc)) - - return JSONResponse( - status_code=500, - content=fail(code=BizCode.INTERNAL_ERROR.value, msg=message, error=message) - ) - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py deleted file mode 100644 index 238e4d47..00000000 --- a/app/models/__init__.py +++ /dev/null @@ -1,52 +0,0 @@ -from .tenant_model import Tenants -from .user_model import User -from .workspace_model import Workspace, WorkspaceMember, WorkspaceRole -from .knowledge_model import Knowledge -from .document_model import Document -from .file_model import File -from .generic_file_model import GenericFile -from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey -from .knowledgeshare_model import KnowledgeShare -from .app_model import App -from .agent_app_config_model import AgentConfig -from .app_release_model import AppRelease -from .memory_increment_model import MemoryIncrement -from .end_user_model import EndUser -from .appshare_model import AppShare -from .release_share_model import ReleaseShare -from .conversation_model import Conversation, Message -from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType -from .data_config_model import DataConfig -from .multi_agent_model import MultiAgentConfig, AgentInvocation - -__all__ = [ - "Tenants", - "User", - "Workspace", - "WorkspaceMember", - "WorkspaceRole", - "Knowledge", - "Document", - "File", - "GenericFile", - "ModelConfig", - "ModelProvider", - "ModelType", - "ModelApiKey", - "KnowledgeShare", - "App", - "AgentConfig", - "AppRelease", - "MemoryIncrement", - "EndUser", - "AppShare", - "ReleaseShare", - "Conversation", - "Message", - "ApiKey", - "ApiKeyLog", - "ApiKeyType", - "DataConfig", - "MultiAgentConfig", - "AgentInvocation" -] diff --git a/app/models/agent_app_config_model.py b/app/models/agent_app_config_model.py deleted file mode 100644 index 373de92c..00000000 --- a/app/models/agent_app_config_model.py +++ /dev/null @@ -1,44 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey -from sqlalchemy.dialects.postgresql import UUID, JSON -from sqlalchemy.orm import relationship -from app.db import Base - - -class AgentConfig(Base): - __tablename__ = "agent_configs" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - - # 一对一关联到 App - app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, unique=True, index=True) - - # Agent 行为配置 - system_prompt = Column(Text, nullable=True, comment="系统提示词") - default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID") - - # 结构化配置(直接存储 JSON) - model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)") - knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置") - memory = Column(JSON, nullable=True, comment="记忆配置") - variables = Column(JSON, default=list, nullable=True, comment="变量配置") - tools = Column(JSON, default=dict, nullable=True, comment="工具配置") - - # 多 Agent 相关字段 - agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") - agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等") - parent_agent_id = Column(UUID(as_uuid=True), ForeignKey("agent_configs.id"), comment="父 Agent ID") - capabilities = Column(JSON, default=list, comment="Agent 能力列表") - - # 状态与时间戳 - is_active = Column(Boolean, default=True, nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # 关系 - app = relationship("App", back_populates="agent_config") - parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents") - - def __repr__(self): - return f"<AgentConfig(id={self.id}, app_id={self.app_id})>" \ No newline at end of file diff --git a/app/models/api_key_model.py b/app/models/api_key_model.py deleted file mode 100644 index 70f17b1d..00000000 --- a/app/models/api_key_model.py +++ /dev/null @@ -1,90 +0,0 @@ -"""API Key 数据模型""" -import datetime -import uuid -from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text -from sqlalchemy.dialects.postgresql import UUID, JSONB -from sqlalchemy.orm import relationship - -from app.db import Base - - -class ApiKeyType(StrEnum): - """API Key 类型""" - APP = "app" # 应用 API Key - RAG = "rag" # RAG API Key - MEMORY = "memory" # Memory API Key - GENERAL = "general" # 通用 API Key - - -class ApiKey(Base): - """API Key 表""" - __tablename__ = "api_keys" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - - # 基本信息 - name = Column(String(255), nullable=False, comment="API Key 名称") - description = Column(Text, comment="描述") - key_prefix = Column(String(20), nullable=False, comment="Key 前缀") - key_hash = Column(String(255), nullable=False, unique=True, index=True, comment="Key 哈希值") - - # 类型和权限 - type = Column(String(50), nullable=False, index=True, comment="API Key 类型") - scopes = Column(JSONB, nullable=False, default=list, comment="权限范围列表") - - # 关联资源 - workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间") - resource_id = Column(UUID(as_uuid=True), index=True, comment="关联资源ID") - resource_type = Column(String(50), comment="资源类型") - - # 限制和配额 - rate_limit = Column(Integer, default=100, comment="速率限制(请求/分钟)") - quota_limit = Column(Integer, comment="配额限制(总请求数)") - quota_used = Column(Integer, default=0, comment="已使用配额") - - # 有效期 - expires_at = Column(DateTime, comment="过期时间") - - # 状态 - is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") - last_used_at = Column(DateTime, comment="最后使用时间") - usage_count = Column(Integer, default=0, comment="使用次数") - - # 审计 - created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="创建者") - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - - # 关系 - workspace = relationship("Workspace", back_populates="api_keys") - creator = relationship("User", foreign_keys=[created_by]) - logs = relationship("ApiKeyLog", back_populates="api_key", cascade="all, delete-orphan") - - -class ApiKeyLog(Base): - """API Key 使用日志表""" - __tablename__ = "api_key_logs" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - - api_key_id = Column(UUID(as_uuid=True), ForeignKey("api_keys.id", ondelete="CASCADE"), nullable=False, index=True, comment="API Key ID") - - # 请求信息 - endpoint = Column(String(255), nullable=False, comment="请求端点") - method = Column(String(10), nullable=False, comment="HTTP 方法") - ip_address = Column(String(50), comment="IP 地址") - user_agent = Column(Text, comment="User Agent") - - # 响应信息 - status_code = Column(Integer, comment="响应状态码") - response_time = Column(Integer, comment="响应时间(毫秒)") - - # Token 使用 - tokens_used = Column(Integer, comment="使用的 Token 数") - - # 时间 - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True, comment="创建时间") - - # 关系 - api_key = relationship("ApiKey", back_populates="logs") diff --git a/app/models/app_model.py b/app/models/app_model.py deleted file mode 100644 index 7897eb62..00000000 --- a/app/models/app_model.py +++ /dev/null @@ -1,115 +0,0 @@ -import datetime -from enum import StrEnum -from re import LOCALE -import uuid -from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey -from sqlalchemy.dialects.postgresql import UUID, JSON -from sqlalchemy.orm import relationship -from app.db import Base - -class IconType(StrEnum): - """图标类型枚举""" - LOCALE = "locale" - REMOTE = "remote" - -# 可见性:private | workspace | public -class AppVisibility(StrEnum): - """可见性枚举""" - PRIVATE = "private" - WORKSPACE = "workspace" - PUBLIC = "public" - -# 应用类型:agent | workflow | multi_agent -class AppType(StrEnum): - """应用类型枚举""" - AGENT = "agent" - WORKFLOW = "workflow" - MULTI_AGENT = "multi_agent" - - -# 应用状态:draft | active | archived -class AppStatus(StrEnum): - """应用状态枚举""" - DRAFT = "draft" - ACTIVE = "active" - ARCHIVED = "archived" - - -class App(Base): - __tablename__ = "apps" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - workspace_id = Column(UUID(as_uuid=True), nullable=False, comment="workspaces.id") - created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id") - - name = Column(String, index=True, nullable=False) - description = Column(String, nullable=True) - icon = Column(String, nullable=True) - icon_type = Column(String, nullable=True) - - # 应用类型:agent | workflow 等 - type = Column(String, index=True, nullable=False) - - # 可见性:private | workspace | public - visibility = Column(String, default="workspace") - - # 状态:draft | active | archived - status = Column(String, default="draft") - - # 标签或扩展元数据 - tags = Column(JSON, default=list) - - # 当前已发布版本指针(发布后指向快照,不受编辑影响) - current_release_id = Column( - UUID(as_uuid=True), - ForeignKey("app_releases.id", use_alter=True, name="fk_apps_current_release_id"), - nullable=True, - index=True, - ) - - is_active = Column(Boolean, default=True, nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # 一对一:Agent 配置(仅当 type=agent 时有效) - agent_config = relationship( - "AgentConfig", - back_populates="app", - uselist=False, - cascade="all, delete-orphan", - ) - - # 一对一:多 Agent 配置(仅当 type=multi_agent 时有效) - multi_agent_config = relationship( - "MultiAgentConfig", - back_populates="app", - uselist=False, - cascade="all, delete-orphan", - ) - - # 发布版本关联 - current_release = relationship("AppRelease", foreign_keys=[current_release_id]) - # 指定外键以避免与 current_release_id 造成歧义 - releases = relationship( - "AppRelease", - back_populates="app", - cascade="all, delete-orphan", - foreign_keys="AppRelease.app_id", - ) - - # 会话关联 - conversations = relationship( - "Conversation", - back_populates="app", - cascade="all, delete-orphan" - ) - - # 与 EndUser 的反向关系 - end_users = relationship( - "EndUser", - back_populates="app", - cascade="all, delete-orphan", - ) - - def __repr__(self): - return f"<App(id={self.id}, name={self.name}, type={self.type})>" \ No newline at end of file diff --git a/app/models/app_release_model.py b/app/models/app_release_model.py deleted file mode 100644 index 8119f3df..00000000 --- a/app/models/app_release_model.py +++ /dev/null @@ -1,68 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, UniqueConstraint -from sqlalchemy.dialects.postgresql import UUID, JSON -from sqlalchemy.orm import relationship -from app.db import Base -from app.models.app_model import IconType - - -class AppRelease(Base): - __tablename__ = "app_releases" - __table_args__ = ( - UniqueConstraint("app_id", "version", name="uq_app_release_app_version"), - ) - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, index=True) - - # 版本号(按应用内递增) - version = Column(Integer, nullable=False, default=1, index=True) - # 版本号,显示用 - version_name = Column(String, nullable=False) - # 版本说明 - release_notes = Column(String, nullable=True, comment="版本说明") - - # 基础信息快照(发布时冻结) - name = Column(String, nullable=False) - description = Column(String, nullable=True) - icon = Column(String, nullable=True) - icon_type = Column(String, nullable=True) - type = Column(String, nullable=False) - visibility = Column(String, default="private") - - # 类型特定配置快照(针对 agent/workflow 等统一存放) - config = Column(JSON, default=dict) - - # 便于查询的索引字段(例如 agent 的默认模型) - default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True) - - # 发布信息 - published_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id") - published_at = Column(DateTime, default=datetime.datetime.now) - - is_active = Column(Boolean, default=True, nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # 关系: 指定外键,避免与 App.current_release_id 引起歧义 - app = relationship("App", back_populates="releases", foreign_keys=[app_id]) - - # 发布人关系 - 使用 primaryjoin 明确指定关联条件 - publisher = relationship( - "User", - primaryjoin="AppRelease.published_by == User.id", - foreign_keys=[published_by], - lazy="joined", - viewonly=True # 只读关系,不会尝试更新 - ) - - @property - def publisher_name(self) -> str: - """发布人名称""" - if self.publisher: - return self.publisher.username or self.publisher.email or "未知用户" - return "未知用户" - - def __repr__(self): - return f"<AppRelease(id={self.id}, app_id={self.app_id}, version={self.version})>" \ No newline at end of file diff --git a/app/models/appshare_model.py b/app/models/appshare_model.py deleted file mode 100644 index 57ea59bc..00000000 --- a/app/models/appshare_model.py +++ /dev/null @@ -1,28 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, DateTime, ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base -from sqlalchemy.orm import relationship - - -class AppShare(Base): - """应用分享模型 - - 记录应用从一个工作空间分享到另一个工作空间的关系 - """ - __tablename__ = "app_shares" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - source_app_id = Column(UUID(as_uuid=True), ForeignKey('apps.id', ondelete='CASCADE'), nullable=False, comment="源应用ID") - source_workspace_id = Column(UUID(as_uuid=True), ForeignKey('workspaces.id'), nullable=False, comment="源工作空间ID") - target_workspace_id = Column(UUID(as_uuid=True), ForeignKey('workspaces.id'), nullable=False, comment="目标工作空间ID") - shared_by = Column(UUID(as_uuid=True), ForeignKey('users.id'), nullable=False, comment="分享者用户ID") - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now) - - # Relationships - source_app = relationship("App", foreign_keys=[source_app_id], backref="shares") - source_workspace = relationship("Workspace", foreign_keys=[source_workspace_id]) - target_workspace = relationship("Workspace", foreign_keys=[target_workspace_id]) - shared_user = relationship("User", backref="app_shares") diff --git a/app/models/conversation_model.py b/app/models/conversation_model.py deleted file mode 100644 index e7f9e8c4..00000000 --- a/app/models/conversation_model.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -会话和消息模型 -""" -import uuid -import datetime -from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean, Integer, Text, JSON -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship - -from app.db import Base - - -class Conversation(Base): - """会话表 - - 会话类型说明: - - 草稿会话 (is_draft=True): 使用应用的当前草稿配置,用于开发和测试 - - 发布会话 (is_draft=False): 使用应用的当前发布版本配置,用于生产环境 - - 工作空间隔离: - - 每个会话属于一个工作空间(workspace_id) - - 同一个应用在不同工作空间有独立的会话记录 - - 支持应用分享后的会话隔离 - """ - __tablename__ = "conversations" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - - # 关联信息 - app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="应用ID") - workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, comment="工作空间ID") - user_id = Column(String, nullable=True, comment="用户ID(外部系统)") - - # 会话信息 - title = Column(String(255), comment="会话标题") - summary = Column(Text, comment="会话摘要") - - # 会话类型:True=草稿会话(使用草稿配置),False=发布会话(使用发布配置) - is_draft = Column(Boolean, default=True, nullable=False, comment="是否为草稿会话") - - # 配置快照:保存创建会话时的完整配置,用于审计和问题追溯 - config_snapshot = Column(JSON, comment="配置快照(Agent配置、模型配置等)") - - # 统计信息 - message_count = Column(Integer, default=0, comment="消息数量") - - # 状态 - is_active = Column(Boolean, default=True, nullable=False, comment="是否活跃") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - - # 关联关系 - app = relationship("App", back_populates="conversations") - workspace = relationship("Workspace") - messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan") - - -class Message(Base): - """消息表""" - __tablename__ = "messages" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - - # 关联信息 - conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"), nullable=False, comment="会话ID") - - # 消息内容 - role = Column(String(20), nullable=False, comment="角色: user/assistant/system") - content = Column(Text, nullable=False, comment="消息内容") - - # 元数据(避免使用 metadata 保留字) - meta_data = Column(JSON, comment="消息元数据(如模型、token使用等)") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - - # 关联关系 - conversation = relationship("Conversation", back_populates="messages") diff --git a/app/models/data_config_model.py b/app/models/data_config_model.py deleted file mode 100644 index 9f27562c..00000000 --- a/app/models/data_config_model.py +++ /dev/null @@ -1,71 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - - -class DataConfig(Base): - """数据配置表 - 用于存储记忆系统的配置参数""" - __tablename__ = "data_config" - - # 主键 - config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID") - - # 基本信息 - config_name = Column(String, nullable=False, comment="配置名称") - config_desc = Column(String, nullable=True, comment="配置描述") - - # 组织信息 - workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID") - group_id = Column(String, nullable=True, comment="组ID") - user_id = Column(String, nullable=True, comment="用户ID") - apply_id = Column(String, nullable=True, comment="应用ID") - - # 模型选择(从workspace继承) - llm_id = Column(String, nullable=True, comment="LLM模型配置ID") - embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") - rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") - llm = Column(String, nullable=True, comment="LLM模型配置ID") - - # 记忆萃取引擎配置 - enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") - enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧") - deep_retrieval = Column(Boolean, default=True, comment="深度检索开关") - - # 阈值配置 (0-1 之间的浮点数) - t_type_strict = Column(Float, default=0.8, comment="类型严格阈值") - t_name_strict = Column(Float, default=0.8, comment="名称严格阈值") - t_overall = Column(Float, default=0.8, comment="综合阈值") - - # 状态配置 - state = Column(Boolean, default=False, comment="配置使用状态") - - # 分块策略 - chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略") - - # 剪枝配置 - pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝") - pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound") - pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)") - - # 自我反思配置 - enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思") - iteration_period = Column(String, default="3", comment="反思迭代周期") - reflexion_range = Column(String, default="retrieval", comment="反思范围:部分/全部") - baseline = Column(String, default="time", comment="基线:时间/事实/时间和事实") - - # 遗忘引擎配置 - statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3") - include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文") - max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量") - lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数") - lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数") - offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - - def __repr__(self): - return f"<DataConfig(config_id={self.config_id}, config_name={self.config_name})>" diff --git a/app/models/document_model.py b/app/models/document_model.py deleted file mode 100644 index 44012a56..00000000 --- a/app/models/document_model.py +++ /dev/null @@ -1,28 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, Integer, String, JSON, DateTime, ForeignKey, Float -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - -class Document(Base): - __tablename__ = "documents" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - kb_id = Column(UUID(as_uuid=True), nullable=False, comment="knowledges.id") - created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id") - file_id = Column(UUID(as_uuid=True), nullable=False, comment="files.id") - file_name = Column(String, index=True, nullable=False, comment="file name") - file_ext = Column(String, index=True, nullable=False, comment="file extension") - file_size = Column(Integer, default=0, comment="file size(byte)") - file_meta = Column(JSON, nullable=False, default={}) - parser_id = Column(String, index=True, nullable=False, comment="default parser ID") - parser_config = Column(JSON, nullable=False, default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"}, comment="default parser config") - chunk_num = Column(Integer, default=0, comment="chunk num") - progress = Column(Float, default=0) - progress_msg = Column(String, default="", comment="process message") - process_begin_at = Column(DateTime, default=datetime.datetime.now) - process_duration = Column(Float, default=0) - run = Column(Integer, default=0, comment="start to run processing or cancel.(1: run it; 2: cancel)") - status = Column(Integer, default=1, comment="is it validate(0: wasted, 1: validate)") - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now) \ No newline at end of file diff --git a/app/models/end_user_model.py b/app/models/end_user_model.py deleted file mode 100644 index a2c02f84..00000000 --- a/app/models/end_user_model.py +++ /dev/null @@ -1,24 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, String, DateTime, ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship -from app.db import Base - -class EndUser(Base): - __tablename__ = "end_users" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) - app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False) - # end_user_id = Column(String, nullable=False, index=True) - other_id = Column(String, nullable=True) # Store original user_id - other_name = Column(String, default="", nullable=False) - other_address = Column(String, default="", nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # 与 App 的反向关系 - app = relationship( - "App", - back_populates="end_users" - ) \ No newline at end of file diff --git a/app/models/file_model.py b/app/models/file_model.py deleted file mode 100644 index 842e3dc8..00000000 --- a/app/models/file_model.py +++ /dev/null @@ -1,17 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, Integer, String, DateTime, ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - -class File(Base): - __tablename__ = "files" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - kb_id = Column(UUID(as_uuid=True), nullable=False, comment="knowledges.id") - created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id") - parent_id = Column(UUID(as_uuid=True), nullable=True, default=None, comment="parent folder id") - file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /") - file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf") - file_size = Column(Integer, default=0, comment="file size(byte)") - created_at = Column(DateTime, default=datetime.datetime.now) \ No newline at end of file diff --git a/app/models/generic_file_model.py b/app/models/generic_file_model.py deleted file mode 100644 index 5e3a08d7..00000000 --- a/app/models/generic_file_model.py +++ /dev/null @@ -1,52 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, Integer, String, DateTime, Boolean, Index, JSON -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - - -class GenericFile(Base): - """ - 通用文件模型,支持多种上传上下文(头像、应用图标、知识库文件、临时文件等) - """ - __tablename__ = "generic_files" - - # 主键和租户信息 - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="文件唯一标识") - tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True, comment="租户ID") - created_by = Column(UUID(as_uuid=True), nullable=False, index=True, comment="创建者用户ID") - - # 文件基本信息 - file_name = Column(String, nullable=False, comment="原始文件名") - file_ext = Column(String, nullable=False, index=True, comment="文件扩展名") - file_size = Column(Integer, nullable=False, comment="文件大小(字节)") - mime_type = Column(String, nullable=True, comment="MIME类型") - - # 上传上下文 - context = Column(String, nullable=False, index=True, comment="上传上下文(avatar/app_icon/knowledge_base/temp/attachment)") - - # 存储信息 - storage_path = Column(String, nullable=False, comment="文件存储路径") - - # 元数据(JSON格式,存储业务相关信息) - file_metadata = Column(JSON, nullable=True, default={}, comment="业务元数据") - - # 状态和访问控制 - status = Column(String, default="active", index=True, comment="文件状态(active/processing/deleted)") - is_public = Column(Boolean, default=False, comment="是否公开访问") - access_url = Column(String, nullable=True, comment="访问URL") - - # 引用计数(用于判断文件是否可以删除) - reference_count = Column(Integer, default=0, comment="引用计数") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - deleted_at = Column(DateTime, nullable=True, comment="删除时间(软删除)") - - # 复合索引 - __table_args__ = ( - Index('idx_tenant_context', 'tenant_id', 'context'), - Index('idx_tenant_status', 'tenant_id', 'status'), - Index('idx_created_at', 'created_at'), - ) diff --git a/app/models/knowledge_model.py b/app/models/knowledge_model.py deleted file mode 100644 index bdb97678..00000000 --- a/app/models/knowledge_model.py +++ /dev/null @@ -1,69 +0,0 @@ -import datetime -import uuid -import enum -from sqlalchemy import Column, Integer, String, JSON, DateTime, ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base -from sqlalchemy.orm import relationship - - -class KnowledgeType(enum.StrEnum): - General = "General" - Web = "Web" - ThirdParty = "Third-party" - FOLDER = "Folder" - - -class ParserType(enum.StrEnum): - NAIVE = "naive" - QA = "qa" - MANUAL = "manual" - TABLE = "table" - PRESENTATION = "presentation" - LAWS = "laws" - PAPER = "paper" - RESUME = "resume" - BOOK = "book" - ONE = "one" - AUDIO = "audio" - EMAIL = "email" - TAG = "tag" - KG = "knowledge_graph" - - -class PermissionType(enum.StrEnum): - Private = "Private" - Share = "Share" - -class Knowledge(Base): - __tablename__ = "knowledges" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - workspace_id = Column(UUID(as_uuid=True), nullable=False, comment="workspaces.id") - created_by = Column(UUID(as_uuid=True), ForeignKey('users.id'), nullable=False, comment="users.id") - parent_id = Column(UUID(as_uuid=True), nullable=True, default=None, comment="parent folder id when type is Folder") - name = Column(String, index=True, nullable=False, comment="KB name") - description = Column(String, comment="KB description") - avatar = Column(String, comment="avatar url") - type = Column(String, default="General", comment="Type:General|Web|Third-party|Folder") - permission_id = Column(String, default="Private", comment="permission ID:Private|Share") - embedding_id = Column(UUID(as_uuid=True), ForeignKey('model_configs.id', ondelete="SET NULL"), nullable=True, comment="default embedding model ID") - reranker_id = Column(UUID(as_uuid=True), ForeignKey('model_configs.id', ondelete="SET NULL"), nullable=True, comment="default reranker model ID") - llm_id = Column(UUID(as_uuid=True), ForeignKey('model_configs.id', ondelete="SET NULL"), nullable=True, comment="default llm model ID") - image2text_id = Column(UUID(as_uuid=True), ForeignKey('model_configs.id', ondelete="SET NULL"), nullable=True, comment="default image2text model ID") - doc_num = Column(Integer, default=0, comment="doc num") - chunk_num = Column(Integer, default=0, comment="chunk num") - parser_id = Column(String, index=True, default="naive", comment="default parser ID") - parser_config = Column(JSON, nullable=False, - default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"}, - comment="default parser config") - status = Column(Integer, index=True, default=1, comment="is it validate(0: disable, 1: enable, 2:Soft-delete)") - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now) - - # Relationships - created_user = relationship("User", backref="created_user") - embedding = relationship("ModelConfig", foreign_keys=[embedding_id], uselist=False, backref="embedding") - reranker = relationship("ModelConfig", foreign_keys=[reranker_id], uselist=False, backref="reranker") - llm = relationship("ModelConfig", foreign_keys=[llm_id], uselist=False, backref="llm") - image2text = relationship("ModelConfig", foreign_keys=[image2text_id], uselist=False, backref="image2text") diff --git a/app/models/knowledgeshare_model.py b/app/models/knowledgeshare_model.py deleted file mode 100644 index d285cb37..00000000 --- a/app/models/knowledgeshare_model.py +++ /dev/null @@ -1,24 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, Integer, String, JSON, DateTime, ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base -from sqlalchemy.orm import relationship - - -class KnowledgeShare(Base): - __tablename__ = "knowledge_shares" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - source_kb_id = Column(UUID(as_uuid=True), nullable=False, comment="source knowledges.id") - source_workspace_id = Column(UUID(as_uuid=True), nullable=False, comment="source workspaces.id") - target_kb_id = Column(UUID(as_uuid=True), ForeignKey('knowledges.id'), nullable=False, comment="target knowledges.id") - target_workspace_id = Column(UUID(as_uuid=True), ForeignKey('workspaces.id'), nullable=False, comment="target workspaces.id") - shared_by = Column(UUID(as_uuid=True), ForeignKey('users.id'), nullable=False, comment="shared users.id") - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now) - - # Relationships - target_kb = relationship("Knowledge", backref="target_kb") - target_workspace = relationship("Workspace", backref="target_workspace") - shared_user = relationship("User", backref="shared_user") diff --git a/app/models/memory_increment_model.py b/app/models/memory_increment_model.py deleted file mode 100644 index 53ba069d..00000000 --- a/app/models/memory_increment_model.py +++ /dev/null @@ -1,18 +0,0 @@ -import uuid -import datetime -from sqlalchemy import Column, ForeignKey, Integer, Date, DateTime -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship -from app.db import Base - -class MemoryIncrement(Base): - __tablename__ = "memory_increments" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), index=True, nullable=False) - total_num = Column(Integer, default=0, nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # 与 App 的关系(指向映射类名,而非表名) - workspace = relationship("Workspace", back_populates="memory_increments") diff --git a/app/models/models_model.py b/app/models/models_model.py deleted file mode 100644 index e5215018..00000000 --- a/app/models/models_model.py +++ /dev/null @@ -1,104 +0,0 @@ -import datetime -import uuid -from enum import StrEnum -from typing import Optional, List -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum -from sqlalchemy.dialects.postgresql import UUID, JSON -from sqlalchemy.orm import relationship -from app.db import Base - - -class ModelType(StrEnum): - """模型类型枚举""" - LLM = "llm" - CHAT = "chat" - EMBEDDING = "embedding" - RERANK = "rerank" - - -class ModelProvider(StrEnum): - """模型提供商枚举""" - OPENAI = "openai" - # ANTHROPIC = "anthropic" - # GOOGLE = "google" - # BAIDU = "baidu" - DASHSCOPE = "dashscope" - # ZHIPU = "zhipu" - # MOONSHOT = "moonshot" - # DEEPSEEK = "deepseek" - OLLAMA = "ollama" - XINFERENCE = "xinference" - GPUSTACK = "gpustack" - BEDROCK = "bedrock" - - -class ModelConfig(Base): - """模型配置表""" - __tablename__ = "model_configs" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - name = Column(String, nullable=False, comment="模型显示名称") - type = Column(String, nullable=False, index=True, comment="模型类型") - description = Column(String, comment="模型描述") - - # 模型配置参数 - config = Column(JSON, comment="模型配置参数") - # - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。 - # - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。 - # - presence_penalty : 对新出现的主题进行惩罚,鼓励模型谈论已经提到过的话题。 - # - frequency_penalty : 对高频词进行惩罚,降低重复相同词语的可能性。 - # - stop 或 stop_sequences : 一个或多个字符串序列,当模型生成这些序列时会停止输出。 - # - 特定于提供商的参数 : 比如某些模型可能支持的 stream (流式输出) 开关、 seed (随机种子) 等。 - - # # 模型能力参数 - # max_tokens = Column(String, comment="最大token数") - # context_length = Column(String, comment="上下文长度") - - # 状态管理 - is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") - is_public = Column(Boolean, default=False, nullable=False, comment="是否公开") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - - # 关联关系 - api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan") - - def __repr__(self): - return f"<ModelConfig(id={self.id}, name={self.name}, type={self.type})>" - - -class ModelApiKey(Base): - """模型API密钥表""" - __tablename__ = "model_api_keys" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=False, comment="模型配置ID") - - # API Key 信息 - model_name = Column(String, nullable=False, comment="模型实际名称") - provider = Column(String, nullable=False, comment="API Key提供商") - api_key = Column(String, nullable=False, comment="API密钥") - api_base = Column(String, comment="API基础URL") - - # 配置参数 - config = Column(JSON, comment="API Key特定配置") - - # 使用统计 - usage_count = Column(String, default="0", comment="使用次数") - last_used_at = Column(DateTime, comment="最后使用时间") - - # 状态管理 - is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") - priority = Column(String, default="1", comment="优先级") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - - # 关联关系 - model_config = relationship("ModelConfig", back_populates="api_keys") - - def __repr__(self): - return f"<ModelApiKey(id={self.id}, model_name={self.model_name}, provider={self.provider}, model_config_id={self.model_config_id})>" diff --git a/app/models/multi_agent_model.py b/app/models/multi_agent_model.py deleted file mode 100644 index 061ecffa..00000000 --- a/app/models/multi_agent_model.py +++ /dev/null @@ -1,143 +0,0 @@ -"""多 Agent 相关数据模型""" -import datetime -import uuid -from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text, ForeignKey -from sqlalchemy.dialects.postgresql import UUID, JSON -from sqlalchemy.orm import relationship - -from app.db import Base - - -class MultiAgentConfig(Base): - """多 Agent 配置表""" - __tablename__ = "multi_agent_configs" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - - # 关联应用 - app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, unique=True, index=True, comment="关联应用") - - # 主 Agent (存储发布版本 ID) - master_agent_id = Column(UUID(as_uuid=True), ForeignKey("app_releases.id"), nullable=False, comment="主 Agent 发布版本 ID") - master_agent_name = Column(String(100), comment="主 Agent 名称") - - # 协作模式 - orchestration_mode = Column( - String(20), - nullable=False, - default="conditional", - comment="协作模式: sequential|parallel|conditional|loop" - ) - - # 子 Agent 列表 - sub_agents = Column( - JSON, - nullable=False, - default=list, - comment="子 Agent 列表: [{'agent_id': 'uuid', 'name': '...', 'role': '...', 'priority': 1}]" - ) - - # 路由规则 - routing_rules = Column( - JSON, - comment="路由规则: [{'condition': '...', 'target_agent_id': 'uuid', 'priority': 1}]" - ) - - # 执行配置 - execution_config = Column( - JSON, - nullable=False, - default=dict, - comment="执行配置: {'max_iterations': 5, 'timeout': 60, 'parallel_limit': 3}" - ) - - # 结果整合策略 - aggregation_strategy = Column( - String(20), - nullable=False, - default="merge", - comment="结果整合策略: merge|vote|priority|custom" - ) - - # 状态 - is_active = Column(Boolean, default=True, nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # 关系 - app = relationship("App") - master_agent_release = relationship("AppRelease", foreign_keys=[master_agent_id]) - - def __repr__(self): - return f"<MultiAgentConfig(id={self.id}, app_id={self.app_id}, mode={self.orchestration_mode})>" - - -class AgentInvocation(Base): - """Agent 调用记录表""" - __tablename__ = "agent_invocations" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - - # 调用关系 - caller_agent_id = Column( - UUID(as_uuid=True), - ForeignKey("agent_configs.id"), - nullable=False, - index=True, - comment="调用者 Agent ID" - ) - callee_agent_id = Column( - UUID(as_uuid=True), - ForeignKey("agent_configs.id"), - nullable=False, - index=True, - comment="被调用者 Agent ID" - ) - - # 关联信息 - conversation_id = Column( - UUID(as_uuid=True), - index=True, - comment="关联会话 ID(不使用外键约束,避免循环依赖)" - ) - parent_invocation_id = Column( - UUID(as_uuid=True), - ForeignKey("agent_invocations.id"), - index=True, - comment="父调用 ID(用于追踪调用链)" - ) - - # 输入输出 - input_message = Column(Text, nullable=False, comment="输入消息") - output_message = Column(Text, comment="输出消息") - context = Column(JSON, comment="上下文信息") - - # 状态 - status = Column( - String(20), - nullable=False, - default="pending", - index=True, - comment="状态: pending|running|completed|failed" - ) - error_message = Column(Text, comment="错误信息") - - # 性能指标 - started_at = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True) - completed_at = Column(DateTime) - elapsed_time = Column(Float, comment="耗时(秒)") - token_usage = Column(JSON, comment="Token 使用情况") - - # 元数据 - meta_data = Column(JSON, comment="额外元数据") - - created_at = Column(DateTime, default=datetime.datetime.now) - - # 关系 - caller = relationship("AgentConfig", foreign_keys=[caller_agent_id]) - callee = relationship("AgentConfig", foreign_keys=[callee_agent_id]) - # conversation 不使用 relationship,避免外键约束问题 - parent_invocation = relationship("AgentInvocation", remote_side=[id], backref="child_invocations") - - def __repr__(self): - return f"<AgentInvocation(id={self.id}, caller={self.caller_agent_id}, callee={self.callee_agent_id}, status={self.status})>" diff --git a/app/models/release_share_model.py b/app/models/release_share_model.py deleted file mode 100644 index 13b41e1c..00000000 --- a/app/models/release_share_model.py +++ /dev/null @@ -1,47 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, UniqueConstraint -from sqlalchemy.dialects.postgresql import UUID, JSON -from sqlalchemy.orm import relationship -from app.db import Base - - -class ReleaseShare(Base): - """应用发布版本分享配置""" - __tablename__ = "release_shares" - __table_args__ = ( - UniqueConstraint("release_id", name="uq_release_share_release_id"), - ) - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - release_id = Column(UUID(as_uuid=True), ForeignKey("app_releases.id", ondelete="CASCADE"), nullable=False, unique=True, index=True) - app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id", ondelete="CASCADE"), nullable=False, index=True) - - # 分享配置 - is_enabled = Column(Boolean, default=True, nullable=False, comment="是否启用公开分享") - share_token = Column(String, nullable=False, unique=True, index=True, comment="公开访问的唯一标识") - - # 访问控制 - require_password = Column(Boolean, default=False, nullable=False, comment="是否需要密码访问") - password_hash = Column(String, nullable=True, comment="访问密码哈希") - - # 嵌入配置 - allow_embed = Column(Boolean, default=False, nullable=False, comment="是否允许嵌入") - embed_domains = Column(JSON, default=list, comment="允许嵌入的域名白名单") - - # 统计数据 - view_count = Column(Integer, default=0, nullable=False, comment="访问次数") - last_accessed_at = Column(DateTime, nullable=True, comment="最后访问时间") - - # 元数据 - created_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="创建者") - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # 关系 - release = relationship("AppRelease", backref="share") - app = relationship("App") - creator = relationship("User") - - def __repr__(self): - return f"<ReleaseShare(id={self.id}, release_id={self.release_id}, share_token={self.share_token})>" diff --git a/app/models/retrieval_info.py b/app/models/retrieval_info.py deleted file mode 100644 index 335f27db..00000000 --- a/app/models/retrieval_info.py +++ /dev/null @@ -1,13 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, DateTime, Text -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - -class RetrievalInfo(Base): - __tablename__ = "retrieval_info" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) - host_id = Column(UUID(as_uuid=True), nullable=False) - retrieve_info = Column(Text, default="", nullable=True) - created_at = Column(DateTime, default=datetime.datetime.now) diff --git a/app/models/tenant_model.py b/app/models/tenant_model.py deleted file mode 100644 index fd3d9a31..00000000 --- a/app/models/tenant_model.py +++ /dev/null @@ -1,23 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, String, DateTime, Boolean -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship -from app.db import Base - - -class Tenants(Base): - __tablename__ = "tenants" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - name = Column(String, index=True, nullable=False) - description = Column(String, nullable=True) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - is_active = Column(Boolean, default=True) - - # Relationship to users - one tenant has many users - users = relationship("User", back_populates="tenant") - - # Relationship to workspaces owned by the tenant - owned_workspaces = relationship("Workspace", back_populates="tenant") diff --git a/app/models/user_model.py b/app/models/user_model.py deleted file mode 100644 index 89971a3a..00000000 --- a/app/models/user_model.py +++ /dev/null @@ -1,30 +0,0 @@ -import datetime -import uuid -from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship -from app.db import Base - -class User(Base): - __tablename__ = "users" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - username = Column(String, unique=True, index=True, nullable=False) - email = Column(String, unique=True, index=True, nullable=False) - hashed_password = Column(String, nullable=False) - is_active = Column(Boolean, default=True, nullable=False) - is_superuser = Column(Boolean, default=False, nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空 - - current_workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=True) # 当前工作空间ID,可为空 - - # Foreign key to tenant - each user belongs to exactly one tenant - tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False) - - # Relationship to workspace memberships - users collaborate in workspaces through membership - workspaces = relationship("WorkspaceMember", back_populates="user") - - # Relationship to tenant - one-to-one relationship - tenant = relationship("Tenants", back_populates="users") diff --git a/app/models/workspace_model.py b/app/models/workspace_model.py deleted file mode 100644 index abb5adeb..00000000 --- a/app/models/workspace_model.py +++ /dev/null @@ -1,70 +0,0 @@ -import datetime -from enum import StrEnum -import uuid -from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship -from app.db import Base - -class WorkspaceRole(StrEnum): - manager = "manager" - member = "member" - -class InviteStatus(StrEnum): - pending = "pending" - accepted = "accepted" - revoked = "revoked" - expired = "expired" - -class Workspace(Base): - __tablename__ = "workspaces" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - name = Column(String, index=True, nullable=False) - icon = Column(String, nullable=True) - iconType = Column(String, nullable=True) - description = Column(String, nullable=True) - tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False) # belongs to tenant - storage_type = Column(String, nullable=True) - llm = Column(String, nullable=True) - embedding = Column(String, nullable=True) - rerank = Column(String, nullable=True) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - is_active = Column(Boolean, default=True) - - # Relationships - tenant = relationship("Tenants", back_populates="owned_workspaces") # belongs to tenant - members = relationship("WorkspaceMember", back_populates="workspace") # users collaborate through membership - api_keys = relationship("ApiKey", back_populates="workspace", cascade="all, delete-orphan") # API Keys - memory_increments = relationship("MemoryIncrement", back_populates="workspace") - -class WorkspaceMember(Base): - __tablename__ = "workspace_members" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) - workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False) - role = Column(String, nullable=False) - is_active = Column(Boolean, default=True) - user = relationship("User", back_populates="workspaces") - workspace = relationship("Workspace", back_populates="members") - -class WorkspaceInvite(Base): - __tablename__ = "workspace_invites" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False) - email = Column(String, nullable=False, index=True) - role = Column(String, nullable=False) # WorkspaceRole: manager or member - token_hash = Column(String, nullable=False, unique=True, index=True) - status = Column(String, nullable=False, default=InviteStatus.pending) # InviteStatus - expires_at = Column(DateTime, nullable=False) - accepted_at = Column(DateTime, nullable=True) - created_by_user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) - created_at = Column(DateTime, default=datetime.datetime.now) - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - # Relationships - workspace = relationship("Workspace") - created_by = relationship("User", foreign_keys=[created_by_user_id]) diff --git a/app/repositories/__init__.py b/app/repositories/__init__.py deleted file mode 100644 index bb509219..00000000 --- a/app/repositories/__init__.py +++ /dev/null @@ -1,171 +0,0 @@ -# -*- coding: utf-8 -*- -"""仓储模块 - -本模块提供统一的数据访问层,包括PostgreSQL和Neo4j的仓储实现。 - -Classes: - RepositoryFactory: 仓储工厂,统一管理所有数据库的仓储实例 -""" - -from typing import Optional -from sqlalchemy.orm import Session - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.dialog_repository import DialogRepository -from app.repositories.neo4j.statement_repository import StatementRepository -from app.repositories.neo4j.entity_repository import EntityRepository -from app.repositories.user_repository import UserRepository -from app.repositories.workspace_repository import WorkspaceRepository -from app.repositories.app_repository import AppRepository - - -class RepositoryFactory: - """仓储工厂 - 统一管理所有数据库的仓储 - - 这个工厂类提供了获取各种仓储实例的统一接口。 - 支持Neo4j图数据库和PostgreSQL关系数据库的仓储。 - - Attributes: - neo4j_connector: Neo4j连接器实例(可选) - db_session: SQLAlchemy数据库会话(可选) - - Example: - >>> # 创建工厂实例 - >>> factory = RepositoryFactory( - ... neo4j_connector=Neo4jConnector(), - ... db_session=db_session - ... ) - >>> - >>> # 获取Neo4j仓储 - >>> dialog_repo = factory.get_dialog_repository() - >>> statement_repo = factory.get_statement_repository() - >>> - >>> # 获取PostgreSQL仓储 - >>> knowledge_repo = factory.get_knowledge_repository() - """ - - def __init__( - self, - neo4j_connector: Optional[Neo4jConnector] = None, - db_session: Optional[Session] = None - ): - """初始化仓储工厂 - - Args: - neo4j_connector: Neo4j连接器实例(可选) - db_session: SQLAlchemy数据库会话(可选) - """ - self.neo4j_connector = neo4j_connector - self.db_session = db_session - - # ==================== Neo4j 仓储 ==================== - - def get_dialog_repository(self) -> DialogRepository: - """获取对话仓储 - - Returns: - DialogRepository: 对话仓储实例 - - Raises: - ValueError: 如果Neo4j连接器未初始化 - """ - if not self.neo4j_connector: - raise ValueError("Neo4j connector not initialized") - return DialogRepository(self.neo4j_connector) - - def get_statement_repository(self) -> StatementRepository: - """获取陈述句仓储 - - Returns: - StatementRepository: 陈述句仓储实例 - - Raises: - ValueError: 如果Neo4j连接器未初始化 - """ - if not self.neo4j_connector: - raise ValueError("Neo4j connector not initialized") - return StatementRepository(self.neo4j_connector) - - def get_entity_repository(self) -> EntityRepository: - """获取实体仓储 - - Returns: - EntityRepository: 实体仓储实例 - - Raises: - ValueError: 如果Neo4j连接器未初始化 - """ - if not self.neo4j_connector: - raise ValueError("Neo4j connector not initialized") - return EntityRepository(self.neo4j_connector) - - # ==================== PostgreSQL 仓储 ==================== - # 注意:现有的PostgreSQL仓储保持不变,这里只是提供统一的访问接口 - # 部分仓储(如knowledge_repository、document_repository)使用函数式接口 - # 部分仓储(如user_repository、workspace_repository)使用类接口 - - def get_user_repository(self) -> UserRepository: - """获取用户仓储 - - Returns: - UserRepository: 用户仓储实例 - - Raises: - ValueError: 如果数据库会话未初始化 - """ - if not self.db_session: - raise ValueError("Database session not initialized") - return UserRepository(self.db_session) - - def get_workspace_repository(self) -> WorkspaceRepository: - """获取工作空间仓储 - - Returns: - WorkspaceRepository: 工作空间仓储实例 - - Raises: - ValueError: 如果数据库会话未初始化 - """ - if not self.db_session: - raise ValueError("Database session not initialized") - return WorkspaceRepository(self.db_session) - - def get_app_repository(self) -> AppRepository: - """获取应用仓储 - - Returns: - AppRepository: 应用仓储实例 - - Raises: - ValueError: 如果数据库会话未初始化 - """ - if not self.db_session: - raise ValueError("Database session not initialized") - return AppRepository(self.db_session) - - def get_db_session(self) -> Session: - """获取数据库会话 - - 用于访问函数式仓储(如knowledge_repository、document_repository) - - Returns: - Session: SQLAlchemy数据库会话 - - Raises: - ValueError: 如果数据库会话未初始化 - - Example: - >>> factory = RepositoryFactory(db_session=session) - >>> db = factory.get_db_session() - >>> # 使用函数式仓储 - >>> from app.repositories import knowledge_repository - >>> knowledges = knowledge_repository.get_knowledges_paginated(db, [], 1, 10) - """ - if not self.db_session: - raise ValueError("Database session not initialized") - return self.db_session - - -__all__ = [ - 'RepositoryFactory', -] diff --git a/app/repositories/api_key_repository.py b/app/repositories/api_key_repository.py deleted file mode 100644 index ceeb99cd..00000000 --- a/app/repositories/api_key_repository.py +++ /dev/null @@ -1,138 +0,0 @@ -"""API Key Repository""" -from sqlalchemy.orm import Session -from sqlalchemy import select, func, and_ -from typing import Optional, List, Tuple -import uuid -import datetime - -from app.models.api_key_model import ApiKey, ApiKeyLog -from app.schemas import api_key_schema - - -class ApiKeyRepository: - """API Key 数据访问层""" - - @staticmethod - def create(db: Session, api_key_data: dict) -> ApiKey: - """创建 API Key""" - api_key = ApiKey(**api_key_data) - db.add(api_key) - db.flush() - return api_key - - @staticmethod - def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ApiKey]: - """根据 ID 获取 API Key""" - return db.get(ApiKey, api_key_id) - - @staticmethod - def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]: - """根据哈希值获取 API Key""" - stmt = select(ApiKey).where(ApiKey.key_hash == key_hash) - return db.scalars(stmt).first() - - @staticmethod - def list_by_workspace( - db: Session, - workspace_id: uuid.UUID, - query: api_key_schema.ApiKeyQuery - ) -> Tuple[List[ApiKey], int]: - """列出工作空间的 API Keys""" - stmt = select(ApiKey).where(ApiKey.workspace_id == workspace_id) - - # 过滤条件 - if query.type: - stmt = stmt.where(ApiKey.type == query.type) - if query.is_active is not None: - stmt = stmt.where(ApiKey.is_active == query.is_active) - if query.resource_id: - stmt = stmt.where(ApiKey.resource_id == query.resource_id) - - # 总数 - count_stmt = select(func.count()).select_from(stmt.subquery()) - total = db.execute(count_stmt).scalar() - - # 分页 - stmt = stmt.order_by(ApiKey.created_at.desc()) - stmt = stmt.offset((query.page - 1) * query.pagesize).limit(query.pagesize) - - items = db.scalars(stmt).all() - return list(items), total - - @staticmethod - def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey: - """更新 API Key""" - api_key = db.get(ApiKey, api_key_id) - if api_key: - for key, value in update_data.items(): - if value is not None: - setattr(api_key, key, value) - api_key.updated_at = datetime.datetime.now() - db.flush() - return api_key - - @staticmethod - def delete(db: Session, api_key_id: uuid.UUID) -> bool: - """删除 API Key""" - api_key = db.get(ApiKey, api_key_id) - if api_key: - db.delete(api_key) - db.flush() - return True - return False - - @staticmethod - def update_usage(db: Session, api_key_id: uuid.UUID) -> bool: - """更新使用统计""" - api_key = db.get(ApiKey, api_key_id) - if api_key: - api_key.usage_count += 1 - api_key.quota_used += 1 - api_key.last_used_at = datetime.datetime.now() - db.flush() - return True - return False - - @staticmethod - def get_stats(db: Session, api_key_id: uuid.UUID) -> dict: - """获取使用统计""" - api_key = db.get(ApiKey, api_key_id) - if not api_key: - return {} - - # 今日请求数 - today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - today_count_stmt = select(func.count()).select_from(ApiKeyLog).where( - and_( - ApiKeyLog.api_key_id == api_key_id, - ApiKeyLog.created_at >= today_start - ) - ) - requests_today = db.execute(today_count_stmt).scalar() or 0 - - # 平均响应时间 - avg_time_stmt = select(func.avg(ApiKeyLog.response_time)).where( - ApiKeyLog.api_key_id == api_key_id - ) - avg_response_time = db.execute(avg_time_stmt).scalar() - - return { - "total_requests": api_key.usage_count, - "requests_today": requests_today, - "quota_used": api_key.quota_used, - "quota_limit": api_key.quota_limit, - "last_used_at": api_key.last_used_at, - "avg_response_time": float(avg_response_time) if avg_response_time else None - } - - -class ApiKeyLogRepository: - """API Key 日志数据访问层""" - - @staticmethod - def create(db: Session, log_data: dict) -> ApiKeyLog: - """创建日志""" - log = ApiKeyLog(**log_data) - db.add(log) - db.flush() - return log diff --git a/app/repositories/app_repository.py b/app/repositories/app_repository.py deleted file mode 100644 index 5630238d..00000000 --- a/app/repositories/app_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid - -from app.models.app_model import App - -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class AppRepository: - def __init__(self, db: Session): - self.db = db - - def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]: - """根据工作空间ID查询应用""" - try: - apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() - db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用") - return apps - except Exception as e: - db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}") - raise - -def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]: - """根据工作空间ID查询应用""" - repo = AppRepository(db) - return repo.get_apps_by_workspace_id(workspace_id) diff --git a/app/repositories/base_repository.py b/app/repositories/base_repository.py deleted file mode 100644 index a62404ec..00000000 --- a/app/repositories/base_repository.py +++ /dev/null @@ -1,108 +0,0 @@ -# -*- coding: utf-8 -*- -"""基础仓储接口模块 - -本模块定义了通用的仓储接口,适用于所有数据库类型(PostgreSQL、Neo4j等)。 -遵循仓储模式(Repository Pattern),提供统一的数据访问抽象。 - -Classes: - BaseRepository: 基础仓储接口,定义CRUD操作的抽象方法 -""" - -from abc import ABC, abstractmethod -from typing import Generic, TypeVar, List, Optional, Dict, Any - -T = TypeVar('T') - - -class BaseRepository(ABC, Generic[T]): - """基础仓储接口 - 适用于所有数据库类型 - - 这是一个抽象基类,定义了所有仓储必须实现的基本CRUD操作。 - 使用泛型T来支持不同的实体类型。 - - Type Parameters: - T: 实体类型,通常是Pydantic模型或ORM模型 - - Methods: - create: 创建新实体 - get_by_id: 根据ID获取实体 - update: 更新现有实体 - delete: 删除实体 - find: 根据条件查询实体列表 - """ - - @abstractmethod - async def create(self, entity: T) -> T: - """创建实体 - - Args: - entity: 要创建的实体对象 - - Returns: - T: 创建后的实体对象(可能包含生成的ID等) - - Raises: - Exception: 创建失败时抛出异常 - """ - pass - - @abstractmethod - async def get_by_id(self, entity_id: str) -> Optional[T]: - """根据ID获取实体 - - Args: - entity_id: 实体的唯一标识符 - - Returns: - Optional[T]: 找到的实体对象,如果不存在则返回None - - Raises: - Exception: 查询失败时抛出异常 - """ - pass - - @abstractmethod - async def update(self, entity: T) -> T: - """更新实体 - - Args: - entity: 要更新的实体对象(必须包含ID) - - Returns: - T: 更新后的实体对象 - - Raises: - Exception: 更新失败时抛出异常 - """ - pass - - @abstractmethod - async def delete(self, entity_id: str) -> bool: - """删除实体 - - Args: - entity_id: 要删除的实体ID - - Returns: - bool: 删除成功返回True,否则返回False - - Raises: - Exception: 删除失败时抛出异常 - """ - pass - - @abstractmethod - async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]: - """查询实体列表 - - Args: - filters: 查询条件字典,键为字段名,值为期望的值 - limit: 返回结果的最大数量,默认100 - - Returns: - List[T]: 符合条件的实体列表 - - Raises: - Exception: 查询失败时抛出异常 - """ - pass diff --git a/app/repositories/data_config_repository.py b/app/repositories/data_config_repository.py deleted file mode 100644 index d1d1af90..00000000 --- a/app/repositories/data_config_repository.py +++ /dev/null @@ -1,408 +0,0 @@ -# -*- coding: utf-8 -*- -"""数据配置Repository模块 - -本模块提供data_config表的数据访问层,包括SQL查询构建和Neo4j Cypher查询。 -从 app.core.memory.src.data_config_api.sql_queries 迁移而来。 - -Classes: - DataConfigRepository: 数据配置仓储类,提供CRUD操作和查询构建 -""" - -from typing import Dict, Tuple, List -from sqlalchemy.orm import Session - -from app.schemas.memory_storage_schema import ( - ConfigParamsCreate, - ConfigParamsDelete, - ConfigUpdate, - ConfigUpdateExtracted, - ConfigUpdateForget, - ConfigKey, -) -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - -# 表名常量 -TABLE_NAME = "data_config" - - -class DataConfigRepository: - """数据配置Repository - - 提供data_config表的数据访问方法,包括: - - SQL查询构建(PostgreSQL) - - Neo4j Cypher查询常量 - """ - - # ==================== Neo4j Cypher 查询常量 ==================== - - # Dialogue count by group - SEARCH_FOR_DIALOGUE = """ - MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num - """ - - # Chunk count by group - SEARCH_FOR_CHUNK = """ - MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num - """ - - # Statement count by group - SEARCH_FOR_STATEMENT = """ - MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num - """ - - # ExtractedEntity count by group - SEARCH_FOR_ENTITY = """ - MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num - """ - - # All counts by label and total - SEARCH_FOR_ALL = """ - OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count - UNION ALL - OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count - UNION ALL - OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count - UNION ALL - OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count - UNION ALL - OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count - """ - - # Extracted entity details within group/app/user - SEARCH_FOR_DETIALS = """ - MATCH (n:ExtractedEntity) - WHERE n.group_id = $group_id - RETURN n.entity_idx AS entity_idx, - n.connect_strength AS connect_strength, - n.description AS description, - n.entity_type AS entity_type, - n.name AS name, - n.fact_summary AS fact_summary, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, - n.id AS id - """ - - # Edges between extracted entities within group/app/user - SEARCH_FOR_EDGES = """ - MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) - WHERE n.group_id = $group_id - RETURN - r.group_id AS group_id, - r.apply_id AS apply_id, - r.user_id AS user_id, - elementId(r) AS rel_id, - startNode(r).id AS source_id, - endNode(r).id AS target_id, - r.predicate AS predicate, - r.statement_id AS statement_id, - r.statement AS statement - """ - - # Entity graph within group (source node, edge, target node) - SEARCH_FOR_ENTITY_GRAPH = """ - MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) - WHERE n.group_id = $group_id - RETURN - { - entity_idx: n.entity_idx, - connect_strength: n.connect_strength, - description: n.description, - entity_type: n.entity_type, - name: n.name, - fact_summary: n.fact_summary, - id: n.id - } AS sourceNode, - { - rel_id: elementId(r), - source_id: startNode(r).id, - target_id: endNode(r).id, - predicate: r.predicate, - statement_id: r.statement_id, - statement: r.statement - } AS edge, - { - entity_idx: m.entity_idx, - connect_strength: m.connect_strength, - description: m.description, - entity_type: m.entity_type, - name: m.name, - fact_summary: m.fact_summary, - id: m.id - } AS targetNode - """ - - # ==================== SQL 查询构建方法 ==================== - - @staticmethod - def build_insert(params: ConfigParamsCreate) -> Tuple[str, Dict]: - """构建插入语句(PostgreSQL 命名参数) - - Args: - params: 配置参数创建模型 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - """ - db_logger.debug(f"构建插入语句: config_name={params.config_name}, workspace_id={params.workspace_id}") - - columns = [ - "config_name", - "config_desc", - "workspace_id", - "llm_id", - "embedding_id", - "rerank_id", - "created_at", - ] - placeholders = [ - "%(config_name)s", - "%(config_desc)s", - "%(workspace_id)s::uuid", - "%(llm_id)s", - "%(embedding_id)s", - "%(rerank_id)s", - "timezone('Asia/Shanghai', now())", - ] - query = f"INSERT INTO {TABLE_NAME} (" + ",".join(columns) + ") VALUES (" + ",".join(placeholders) + ")" - # 将 UUID 转换为字符串 - workspace_id_str = str(params.workspace_id) if params.workspace_id else None - params_dict = { - "config_name": params.config_name, - "config_desc": params.config_desc, - "workspace_id": workspace_id_str, - "llm_id": params.llm_id, - "embedding_id": params.embedding_id, - "rerank_id": params.rerank_id, - } - return query, params_dict - - @staticmethod - def build_update(update: ConfigUpdate) -> Tuple[str, Dict]: - """构建基础配置更新语句(PostgreSQL 命名参数) - - Args: - update: 配置更新模型 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - - Raises: - ValueError: 没有字段需要更新时抛出 - """ - db_logger.debug(f"构建更新语句: config_id={update.config_id}") - - key_where = "config_id = %(config_id)s" - set_fields: List[str] = [] - params: Dict = { - "config_id": update.config_id, - } - - mapping = { - "config_name": "config_name", - "config_desc": "config_desc", - } - - for api_field, db_col in mapping.items(): - value = getattr(update, api_field) - if value is not None: - set_fields.append(f"{db_col} = %({api_field})s") - params[api_field] = value - - set_fields.append("updated_at = timezone('Asia/Shanghai', now())") - if not set_fields: - raise ValueError("No fields to update") - query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" - return query, params - - - @staticmethod - def build_update_extracted(update: ConfigUpdateExtracted) -> Tuple[str, Dict]: - """构建记忆萃取引擎配置更新语句(PostgreSQL 命名参数) - - Args: - update: 萃取配置更新模型 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - - Raises: - ValueError: 没有字段需要更新时抛出 - """ - db_logger.debug(f"构建萃取配置更新语句: config_id={update.config_id}") - - key_where = "config_id = %(config_id)s" - set_fields: List[str] = [] - params: Dict = { - "config_id": update.config_id, - } - - mapping = { - # 模型选择 - "llm_id": "llm", - "embedding_id": "embedding", - "rerank_id": "rerank", - # 记忆萃取引擎 - "enable_llm_dedup_blockwise": "enable_llm_dedup_blockwise", - "enable_llm_disambiguation": "enable_llm_disambiguation", - "deep_retrieval": "deep_retrieval", - "t_type_strict": "t_type_strict", - "t_name_strict": "t_name_strict", - "t_overall": "t_overall", - "state": "state", - "chunker_strategy": "chunker_strategy", - # 句子提取 - "statement_granularity": "statement_granularity", - "include_dialogue_context": "include_dialogue_context", - "max_context": "max_context", - # 剪枝配置 - "pruning_enabled": "pruning_enabled", - "pruning_scene": "pruning_scene", - "pruning_threshold": "pruning_threshold", - # 自我反思配置 - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - } - - for api_field, db_col in mapping.items(): - value = getattr(update, api_field) - if value is not None: - set_fields.append(f"{db_col} = %({api_field})s") - params[api_field] = value - - set_fields.append("updated_at = timezone('Asia/Shanghai', now())") - if not set_fields: - raise ValueError("No fields to update") - query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" - return query, params - - @staticmethod - def build_update_forget(update: ConfigUpdateForget) -> Tuple[str, Dict]: - """构建遗忘引擎配置更新语句(PostgreSQL 命名参数) - - Args: - update: 遗忘配置更新模型 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - - Raises: - ValueError: 没有字段需要更新时抛出 - """ - db_logger.debug(f"构建遗忘配置更新语句: config_id={update.config_id}") - - key_where = "config_id = %(config_id)s" - set_fields: List[str] = [] - params: Dict = { - "config_id": update.config_id, - } - - mapping = { - # 遗忘引擎 - "lambda_time": "lambda_time", - "lambda_mem": "lambda_mem", - # 由于 PostgreSQL 中 OFFSET 是保留字,需使用双引号包裹列名 - "offset": '"offset"', - } - - for api_field, db_col in mapping.items(): - value = getattr(update, api_field) - if value is not None: - set_fields.append(f"{db_col} = %({api_field})s") - params[api_field] = value - - set_fields.append("updated_at = timezone('Asia/Shanghai', now())") - if not set_fields: - raise ValueError("No fields to update") - query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" - return query, params - - @staticmethod - def build_select_extracted(key: ConfigKey) -> Tuple[str, Dict]: - """构建萃取配置查询语句,通过主键查询某条配置(PostgreSQL 命名参数) - - Args: - key: 配置键模型 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - """ - db_logger.debug(f"构建萃取配置查询语句: config_id={key.config_id}") - # f"SELECT statement_granularity, include_dialogue_context, max_context, " - - query = ( - f"SELECT llm_id, embedding_id, rerank_id, " - f"enable_llm_dedup_blockwise, enable_llm_disambiguation, deep_retrieval, " - f"t_type_strict, t_name_strict, t_overall, chunker_strategy, " - f"statement_granularity, include_dialogue_context, max_context, " - f"pruning_enabled, pruning_scene, pruning_threshold, " - f"enable_self_reflexion, iteration_period, reflexion_range, baseline " - f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s" - ) - params = {"config_id": key.config_id} - return query, params - - @staticmethod - def build_select_forget(key: ConfigKey) -> Tuple[str, Dict]: - """构建遗忘配置查询语句,通过主键查询某条配置(PostgreSQL 命名参数) - - Args: - key: 配置键模型 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - """ - db_logger.debug(f"构建遗忘配置查询语句: config_id={key.config_id}") - - query = ( - f"SELECT lambda_time, lambda_mem, \"offset\" " # 用双引号包裹保留字别名 - f"FROM {TABLE_NAME} WHERE config_id = %(config_id)s" - ) - params = {"config_id": key.config_id} - return query, params - - @staticmethod - def build_select_all(workspace_id = None) -> Tuple[str, Dict]: - """构建查询所有配置参数的语句(PostgreSQL 命名参数) - - Args: - workspace_id: 工作空间ID(UUID或字符串),用于过滤查询结果 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - """ - db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") - - if workspace_id: - # 将 UUID 转换为字符串以便在 SQL 中使用 - workspace_id_str = str(workspace_id) if workspace_id else None - query = f"SELECT * FROM {TABLE_NAME} WHERE workspace_id = %(workspace_id)s::uuid ORDER BY updated_at DESC NULLS LAST" - params = {"workspace_id": workspace_id_str} - else: - query = f"SELECT * FROM {TABLE_NAME} ORDER BY updated_at DESC NULLS LAST" - params = {} - return query, params - - @staticmethod - def build_delete(key: ConfigParamsDelete) -> Tuple[str, Dict]: - """构建删除语句,通过配置ID删除(PostgreSQL 命名参数) - - Args: - key: 配置删除模型 - - Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) - """ - db_logger.debug(f"构建删除语句: config_id={key.config_id}") - - query = ( - f"DELETE FROM {TABLE_NAME} WHERE config_id = %(config_id)s" - ) - params = {"config_id": key.config_id} - return query, params diff --git a/app/repositories/document_repository.py b/app/repositories/document_repository.py deleted file mode 100644 index 52e46bdb..00000000 --- a/app/repositories/document_repository.py +++ /dev/null @@ -1,153 +0,0 @@ -import uuid -import datetime -from sqlalchemy.orm import Session -from app.models.document_model import Document -from app.schemas import document_schema -from app.core.logging_config import get_db_logger - -# Obtain a dedicated logger for the database -db_logger = get_db_logger() - - -def get_documents_paginated( - db: Session, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - """ - Paged query document (with filtering and sorting) - """ - db_logger.debug(f"Query documents in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}") - - try: - query = db.query(Document) - - # Apply filter conditions - for filter_cond in filters: - query = query.filter(filter_cond) - - # Calculate the total count (for pagination) - total = query.count() - db_logger.debug(f"Total number of document queries: {total}") - - # sort - if orderby: - order_attr = getattr(Document, orderby, None) - if order_attr is not None: - if desc: - query = query.order_by(order_attr.desc()) - else: - query = query.order_by(order_attr.asc()) - db_logger.debug(f"sort: {orderby}, desc={desc}") - - # pagination - items = query.offset((page - 1) * pagesize).limit(pagesize).all() - db_logger.info(f"The document paging query has been successful: total={total}, Number of current page={len(items)}") - - return total, [document_schema.Document.model_validate(item) for item in items] - except Exception as e: - db_logger.error(f"Querying document pagination failed: page={page}, pagesize={pagesize} - {str(e)}") - raise - - -def create_document(db: Session, document: document_schema.DocumentCreate) -> Document: - db_logger.debug(f"Create a document record: file_name={document.file_name}") - - try: - db_document = Document(**document.model_dump()) - db.add(db_document) - db.commit() - db_logger.info(f"Document record created successfully: {document.file_name} (ID: {db_document.id})") - return db_document - except Exception as e: - db_logger.error(f"Failed to create a document record: title={document.file_name} - {str(e)}") - db.rollback() - raise - - -def get_document_by_id(db: Session, document_id: uuid.UUID) -> Document | None: - db_logger.debug(f"Query documents based on ID: document_id={document_id}") - - try: - document = db.query(Document).filter(Document.id == document_id).first() - if document: - db_logger.debug(f"Document query successful: {document.file_name} (ID: {document_id})") - else: - db_logger.debug(f"Document does not exist: document_id={document_id}") - return document - except Exception as e: - db_logger.error(f"Failed to query the document based on the ID: document_id={document_id} - {str(e)}") - raise - - -def reset_documents_progress_by_kb_id(db: Session, kb_id: uuid.UUID) -> int: - """ - Reset the processing progress of all documents under the specified knowledge base - - Args: - db: database session - kb_id: Knowledge Base ID - - Returns: - int: Number of updated documents - """ - db_logger.debug(f"Reset the processing progress of all documents under the specified knowledge base: kb_id={kb_id}") - try: - # Build update conditions - filters = [ - Document.kb_id == kb_id - ] - - # Build updated data - update_data = { - Document.chunk_num: 0, - Document.progress: 0, - Document.progress_msg: "Pending", - Document.process_duration: 0, - Document.run: 0, # Reset run status - Document.updated_at: datetime.datetime.now() - } - - # Perform batch update - result = db.query(Document).filter(*filters).update( - update_data, - synchronize_session=False - ) - - # commit transaction - db.commit() - db_logger.debug(f"Successfully reset the processing progress of all documents under the specified knowledge base: kb_id: {kb_id}") - return result - - except Exception as e: - db.rollback() - db_logger.error(f"Failed to reset the processing progress of all documents under the specified knowledge base: kb_id={kb_id} - {str(e)}") - raise - - - -def delete_document_by_id(db: Session, document_id: uuid.UUID): - db_logger.debug(f"Delete document record: document_id={document_id}") - - try: - # First, query the document information for logging purposes - document = db.query(Document).filter(Document.id == document_id).first() - if document: - file_name = document.file_name - else: - file_name = "unknown" - - result = db.query(Document).filter(Document.id == document_id).delete() - db.commit() - - if result > 0: - db_logger.info(f"Document record deleted successfully: {file_name} (ID: {document_id})") - else: - db_logger.warning(f"The document record does not exist, and cannot be deleted: document_id={document_id}") - except Exception as e: - db_logger.error(f"Failed to delete document record: document_id={document_id} - {str(e)}") - db.rollback() - raise diff --git a/app/repositories/end_user_repository.py b/app/repositories/end_user_repository.py deleted file mode 100644 index 9005fda0..00000000 --- a/app/repositories/end_user_repository.py +++ /dev/null @@ -1,105 +0,0 @@ -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid - -from app.models.end_user_model import EndUser - -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class EndUserRepository: - def __init__(self, db: Session): - self.db = db - - def get_end_users_by_app_id(self, app_id: uuid.UUID) -> List[EndUser]: - """根据应用ID查询宿主""" - try: - end_users = ( - self.db.query(EndUser) - .filter(EndUser.app_id == app_id) - .all() - ) - db_logger.info(f"成功查询应用 {app_id} 下的 {len(end_users)} 个宿主") - return end_users - except Exception as e: - self.db.rollback() - db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}") - raise - - def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: - """根据 end_user_id 查询宿主""" - try: - end_user = ( - self.db.query(EndUser) - .filter(EndUser.id == end_user_id) - .first() - ) - if end_user: - db_logger.info(f"成功查询到宿主 {end_user_id}") - else: - db_logger.info(f"未找到宿主 {end_user_id}") - return end_user - except Exception as e: - self.db.rollback() - db_logger.error(f"查询宿主 {end_user_id} 时出错: {str(e)}") - raise - - def get_or_create_end_user( - self, - app_id: uuid.UUID, - other_id: str, - original_user_id: Optional[str] = None - ) -> EndUser: - """获取或创建终端用户 - - Args: - app_id: 应用ID - other_id: 第三方ID - original_user_id: 原始用户ID (存储到 other_id) - """ - try: - # 尝试查找现有用户 - end_user = ( - self.db.query(EndUser) - .filter( - EndUser.app_id == app_id, - EndUser.other_id == other_id - ) - .first() - ) - - if end_user: - db_logger.debug(f"找到现有终端用户: 应用ID {app_id}、第三方ID {other_id}") - return end_user - - # 创建新用户 - end_user = EndUser( - app_id=app_id, - other_id=other_id - ) - self.db.add(end_user) - self.db.commit() - self.db.refresh(end_user) - - db_logger.info(f"创建新终端用户: (other_id: {other_id}) for app {app_id}") - return end_user - - except Exception as e: - self.db.rollback() - db_logger.error(f"获取或创建终端用户时出错: {str(e)}") - raise - -def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]: - """根据应用ID查询宿主(返回 EndUser ORM 列表)""" - repo = EndUserRepository(db) - end_users = repo.get_end_users_by_app_id(app_id) - return end_users - -def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: - """根据 end_user_id 查询对应宿主""" - repo = EndUserRepository(db) - end_user = repo.get_end_user_by_id(end_user_id) - return end_user \ No newline at end of file diff --git a/app/repositories/file_repository.py b/app/repositories/file_repository.py deleted file mode 100644 index 49d21b74..00000000 --- a/app/repositories/file_repository.py +++ /dev/null @@ -1,121 +0,0 @@ -import uuid -from sqlalchemy.orm import Session -from app.models.file_model import File -from app.schemas import file_schema -from app.core.logging_config import get_db_logger - -# Obtain a dedicated logger for the database -db_logger = get_db_logger() - - -def get_files_paginated( - db: Session, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - """ - Paged query file (with filtering and sorting) - """ - db_logger.debug(f"Query file in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}") - - try: - query = db.query(File) - - # Apply filter conditions - for filter_cond in filters: - query = query.filter(filter_cond) - - # Calculate the total count (for pagination) - total = query.count() - db_logger.debug(f"Total number of file queries: {total}") - - # sort - if orderby: - order_attr = getattr(File, orderby, None) - if order_attr is not None: - if desc: - query = query.order_by(order_attr.desc()) - else: - query = query.order_by(order_attr.asc()) - db_logger.debug(f"sort: {orderby}, desc={desc}") - - # pagination - items = query.offset((page - 1) * pagesize).limit(pagesize).all() - db_logger.info(f"The file paging query has been successful: total={total}, Number of current page={len(items)}") - - return total, [file_schema.File.model_validate(item) for item in items] - except Exception as e: - db_logger.error(f"Querying file pagination failed: page={page}, pagesize={pagesize} - {str(e)}") - raise - - -def create_file(db: Session, file: file_schema.FileCreate) -> File: - db_logger.debug(f"Create a file record: filename={file.file_name}") - - try: - db_file = File(**file.model_dump()) - db.add(db_file) - db.commit() - db_logger.info(f"File record created successfully: {file.file_name} (ID: {db_file.id})") - return db_file - except Exception as e: - db_logger.error(f"Failed to create a file record: filename={file.file_name} - {str(e)}") - db.rollback() - raise - - -def get_file_by_id(db: Session, file_id: uuid.UUID) -> File | None: - db_logger.debug(f"Query file based on ID: file_id={file_id}") - - try: - file = db.query(File).filter(File.id == file_id).first() - if file: - db_logger.debug(f"File query successful: {file.file_name} (ID: {file_id})") - else: - db_logger.debug(f"File does not exist: file_id={file_id}") - return file - except Exception as e: - db_logger.error(f"Failed to query the file based on the ID: file_id={file_id} - {str(e)}") - raise - - -def get_files_by_parent_id(db: Session, parent_id: uuid.UUID | None) -> list | None: - db_logger.debug(f"Query file based on folder ID: parent_id={parent_id}") - - try: - query = db.query(File) - if parent_id: - query = query.filter(File.parent_id == parent_id) - files = query.all() - db_logger.debug(f"Folder query file successful: parent_id={parent_id}, file_num={len(files)}") - return files - except Exception as e: - db_logger.error(f"Failed to query files based on folder ID: parent_id={parent_id} - {str(e)}") - raise - - -def delete_file_by_id(db: Session, file_id: uuid.UUID): - db_logger.debug(f"Delete file record: file_id={file_id}") - - try: - # First, query the file information for logging purposes - file = db.query(File).filter(File.id == file_id).first() - if file: - filename = file.file_name - else: - filename = "unknown" - - result = db.query(File).filter(File.id == file_id).delete() - db.commit() - - if result > 0: - db_logger.info(f"File record deleted successfully: {filename} (ID: {file_id})") - else: - db_logger.warning(f"The file record does not exist, and cannot be deleted: file_id={file_id}") - except Exception as e: - db_logger.error(f"Failed to delete file record: file_id={file_id} - {str(e)}") - db.rollback() - raise diff --git a/app/repositories/generic_file_repository.py b/app/repositories/generic_file_repository.py deleted file mode 100644 index e24eed41..00000000 --- a/app/repositories/generic_file_repository.py +++ /dev/null @@ -1,243 +0,0 @@ -""" -Generic File Repository -Handles database operations for generic file uploads. -""" -import uuid -from typing import Optional, List, Tuple, Dict, Any -from datetime import datetime -from sqlalchemy.orm import Session -from sqlalchemy import and_, or_, func - -from app.models.generic_file_model import GenericFile -from app.core.upload_enums import UploadContext -from app.core.logging_config import get_db_logger - -# Get database logger -db_logger = get_db_logger() - - -class GenericFileRepository: - """Repository for generic file operations""" - - def __init__(self, db: Session): - self.db = db - - def create_file(self, file_data: Dict[str, Any]) -> GenericFile: - """ - Create a new file record in the database. - - Args: - file_data: Dictionary containing file information - - Returns: - GenericFile: Created file record - - Raises: - Exception: If database operation fails - """ - db_logger.debug(f"Creating file record: filename={file_data.get('file_name')}") - - try: - db_file = GenericFile(**file_data) - self.db.add(db_file) - self.db.flush() - db_logger.info(f"File record created successfully: {file_data.get('file_name')} (ID: {db_file.id})") - return db_file - except Exception as e: - db_logger.error(f"Failed to create file record: filename={file_data.get('file_name')} - {str(e)}") - raise - - def get_file_by_id(self, file_id: uuid.UUID) -> Optional[GenericFile]: - """ - Get a file by its ID. - - Args: - file_id: UUID of the file - - Returns: - Optional[GenericFile]: File record if found, None otherwise - """ - db_logger.debug(f"Querying file by ID: file_id={file_id}") - - try: - file = self.db.query(GenericFile).filter( - and_( - GenericFile.id == file_id, - GenericFile.deleted_at.is_(None) - ) - ).first() - - if file: - db_logger.debug(f"File found: {file.file_name} (ID: {file_id})") - else: - db_logger.debug(f"File not found: file_id={file_id}") - - return file - except Exception as e: - db_logger.error(f"Failed to query file by ID: file_id={file_id} - {str(e)}") - raise - - def update_file(self, file_id: uuid.UUID, update_data: Dict[str, Any]) -> Optional[GenericFile]: - """ - Update file metadata. - - Args: - file_id: UUID of the file to update - update_data: Dictionary containing fields to update - - Returns: - Optional[GenericFile]: Updated file record if found, None otherwise - """ - db_logger.debug(f"Updating file: file_id={file_id}") - - try: - file = self.get_file_by_id(file_id) - if not file: - db_logger.debug(f"File not found for update: file_id={file_id}") - return None - - # Update allowed fields - for field, value in update_data.items(): - if hasattr(file, field) and field not in ['id', 'created_by', 'created_at', 'tenant_id']: - setattr(file, field, value) - - # Update timestamp - file.updated_at = datetime.now() - - self.db.flush() - db_logger.info(f"File updated successfully: {file.file_name} (ID: {file_id})") - return file - except Exception as e: - db_logger.error(f"Failed to update file: file_id={file_id} - {str(e)}") - raise - - def delete_file(self, file_id: uuid.UUID) -> bool: - """ - Soft delete a file by setting deleted_at timestamp. - - Args: - file_id: UUID of the file to delete - - Returns: - bool: True if file was deleted, False if not found - """ - db_logger.debug(f"Soft deleting file: file_id={file_id}") - - try: - file = self.get_file_by_id(file_id) - if not file: - db_logger.debug(f"File not found for deletion: file_id={file_id}") - return False - - # Soft delete by setting deleted_at - file.deleted_at = datetime.now() - file.status = "deleted" - file.updated_at = datetime.now() - - self.db.flush() - db_logger.info(f"File soft deleted successfully: {file.file_name} (ID: {file_id})") - return True - except Exception as e: - db_logger.error(f"Failed to delete file: file_id={file_id} - {str(e)}") - raise - - def get_files_by_context( - self, - context: UploadContext, - tenant_id: uuid.UUID, - page: int = 1, - pagesize: int = 20, - status: Optional[str] = "active", - created_by: Optional[uuid.UUID] = None - ) -> Tuple[int, List[GenericFile]]: - """ - Get files by context with pagination. - - Args: - context: Upload context (avatar, app_icon, etc.) - tenant_id: Tenant ID for isolation - page: Page number (1-indexed) - pagesize: Number of items per page - status: File status filter (default: "active") - created_by: Optional filter by creator user ID - - Returns: - Tuple[int, List[GenericFile]]: Total count and list of files - """ - db_logger.debug( - f"Querying files by context: context={context}, tenant_id={tenant_id}, " - f"page={page}, pagesize={pagesize}, status={status}" - ) - - try: - query = self.db.query(GenericFile).filter( - and_( - GenericFile.context == context, - GenericFile.tenant_id == tenant_id, - GenericFile.deleted_at.is_(None) - ) - ) - - # Apply status filter - if status: - query = query.filter(GenericFile.status == status) - - # Apply creator filter - if created_by: - query = query.filter(GenericFile.created_by == created_by) - - # Get total count - total = query.count() - db_logger.debug(f"Total files found: {total}") - - # Apply pagination and ordering - files = query.order_by(GenericFile.created_at.desc()).offset((page - 1) * pagesize).limit(pagesize).all() - - db_logger.info( - f"Files query successful: context={context}, total={total}, " - f"returned={len(files)}" - ) - - return total, files - except Exception as e: - db_logger.error( - f"Failed to query files by context: context={context}, " - f"tenant_id={tenant_id} - {str(e)}" - ) - raise - - -# Convenience functions for backward compatibility -def create_file(db: Session, file_data: Dict[str, Any]) -> GenericFile: - """Create a new file record""" - return GenericFileRepository(db).create_file(file_data) - - -def get_file_by_id(db: Session, file_id: uuid.UUID) -> Optional[GenericFile]: - """Get a file by its ID""" - return GenericFileRepository(db).get_file_by_id(file_id) - - -def update_file(db: Session, file_id: uuid.UUID, update_data: Dict[str, Any]) -> Optional[GenericFile]: - """Update file metadata""" - return GenericFileRepository(db).update_file(file_id, update_data) - - -def delete_file(db: Session, file_id: uuid.UUID) -> bool: - """Soft delete a file""" - return GenericFileRepository(db).delete_file(file_id) - - -def get_files_by_context( - db: Session, - context: UploadContext, - tenant_id: uuid.UUID, - page: int = 1, - pagesize: int = 20, - status: Optional[str] = "active", - created_by: Optional[uuid.UUID] = None -) -> Tuple[int, List[GenericFile]]: - """Get files by context with pagination""" - return GenericFileRepository(db).get_files_by_context( - context, tenant_id, page, pagesize, status, created_by - ) diff --git a/app/repositories/knowledge_repository.py b/app/repositories/knowledge_repository.py deleted file mode 100644 index 73f7a494..00000000 --- a/app/repositories/knowledge_repository.py +++ /dev/null @@ -1,211 +0,0 @@ -import uuid -from sqlalchemy.orm import Session -from app.models.knowledge_model import Knowledge -from app.schemas import knowledge_schema -from app.core.logging_config import get_db_logger - -# Obtain a dedicated logger for the database -db_logger = get_db_logger() - - -def get_knowledges_paginated( - db: Session, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - """ - Paged query knowledge base (with filtering and sorting) - """ - db_logger.debug(f"Query knowledge base in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}") - - try: - query = db.query(Knowledge) - - # Apply filter conditions - for filter_cond in filters: - query = query.filter(filter_cond) - - # Calculate the total count (for pagination) - total = query.count() - db_logger.debug(f"Total number of knowledge base queries: {total}") - - # sort - if orderby: - order_attr = getattr(Knowledge, orderby, None) - if order_attr is not None: - if desc: - query = query.order_by(order_attr.desc()) - else: - query = query.order_by(order_attr.asc()) - db_logger.debug(f"sort: {orderby}, desc={desc}") - - # pagination - items = query.offset((page - 1) * pagesize).limit(pagesize).all() - db_logger.info(f"The knowledge base paging query has been successful: total={total}, Number of current page={len(items)}") - - return total, [knowledge_schema.Knowledge.model_validate(item) for item in items] - except Exception as e: - db_logger.error(f"Querying knowledge base pagination failed: page={page}, pagesize={pagesize} - {str(e)}") - raise - - -def get_chunded_knowledgeids( - db: Session, - filters: list -) -> list: - """ - Query the list of vectorized knowledge base IDs - Return: list[UUID] - List of knowledge base IDs - """ - db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}") - - try: - # Only query the id field - query = db.query(Knowledge.id) - - # Apply filter conditions - for filter_cond in filters: - query = query.filter(filter_cond) - - # Get all IDs - items = query.all() - db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}") - - # Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column - return [item[0] for item in items] - except Exception as e: - db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}") - raise - - -def create_knowledge(db: Session, knowledge: knowledge_schema.KnowledgeCreate) -> Knowledge: - db_logger.debug(f"Create a knowledge base record: name={knowledge.name}") - - try: - db_knowledge = Knowledge(**knowledge.model_dump()) - db.add(db_knowledge) - db.commit() - db_logger.info(f"knowledge base record created successfully: {knowledge.name} (ID: {db_knowledge.id})") - return db_knowledge - except Exception as e: - db_logger.error(f"Failed to create a knowledge base record: name={knowledge.name} - {str(e)}") - db.rollback() - raise - - -def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | None: - db_logger.debug(f"Query knowledge base based on ID: knowledge_id={knowledge_id}") - - try: - knowledge = db.query(Knowledge).filter(Knowledge.id == knowledge_id).first() - if knowledge: - db_logger.debug(f"knowledge base query successful: {knowledge.name} (ID: {knowledge_id})") - else: - db_logger.debug(f"knowledge base does not exist: knowledge_id={knowledge_id}") - return knowledge - except Exception as e: - db_logger.error(f"Failed to query the knowledge base based on the ID: knowledge_id={knowledge_id} - {str(e)}") - raise - - -def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None: - db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}") - - try: - knowledge = db.query(Knowledge).filter(Knowledge.name == name).filter(Knowledge.workspace_id == workspace_id).first() - if knowledge: - db_logger.debug(f"knowledge base query successful: {name} (ID: {knowledge.id})") - else: - db_logger.debug(f"knowledge base does not exist: name={name}, workspace_id={workspace_id}") - return knowledge - except Exception as e: - db_logger.error(f"Failed to query the knowledge base based on the name and workspace_id: name={name}, workspace_id={workspace_id} - {str(e)}") - raise - - -def delete_knowledge_by_id(db: Session, knowledge_id: uuid.UUID): - db_logger.debug(f"Delete knowledge base record: knowledge_id={knowledge_id}") - - try: - # First, query the knowledge base information for logging purposes - knowledge = db.query(Knowledge).filter(Knowledge.id == knowledge_id).first() - if knowledge: - knowledge_name = knowledge.name - else: - knowledge_name = "unknown" - - result = db.query(Knowledge).filter(Knowledge.id == knowledge_id).delete() - db.commit() - - if result > 0: - db_logger.info(f"knowledge base record deleted successfully: {knowledge_name} (ID: {knowledge_id})") - else: - db_logger.warning(f"The knowledge base record does not exist, and cannot be deleted: knowledge_id={knowledge_id}") - except Exception as e: - db_logger.error(f"Failed to delete knowledge base record: knowledge_id={knowledge_id} - {str(e)}") - db.rollback() - raise - - -def get_total_doc_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: - """ - 根据workspace_id查询knowledges表所有doc_num的总和 - """ - db_logger.debug(f"Query total doc_num by workspace_id: workspace_id={workspace_id}") - - try: - from sqlalchemy import func - result = db.query(func.sum(Knowledge.doc_num)).filter( - Knowledge.workspace_id == workspace_id, - Knowledge.status == 1 - ).scalar() - - total = result if result is not None else 0 - db_logger.info(f"Total doc_num query successful: workspace_id={workspace_id}, total={total}") - return total - except Exception as e: - db_logger.error(f"Failed to query total doc_num: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_total_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: - """ - 根据workspace_id查询knowledges表所有chunk_num的总和 - """ - db_logger.debug(f"Query total chunk_num by workspace_id: workspace_id={workspace_id}") - - try: - from sqlalchemy import func - result = db.query(func.sum(Knowledge.chunk_num)).filter( - Knowledge.workspace_id == workspace_id, - Knowledge.status == 1 - ).scalar() - - total = result if result is not None else 0 - db_logger.info(f"Total chunk_num query successful: workspace_id={workspace_id}, total={total}") - return total - except Exception as e: - db_logger.error(f"Failed to query total chunk_num: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: - """ - 根据workspace_id查询knowledges表所有不同id的数量(知识库总数) - """ - db_logger.debug(f"Query total knowledge base count by workspace_id: workspace_id={workspace_id}") - - try: - count = db.query(Knowledge).filter( - Knowledge.workspace_id == workspace_id, - Knowledge.status == 1 - ).count() - - db_logger.info(f"Total knowledge base count query successful: workspace_id={workspace_id}, count={count}") - return count - except Exception as e: - db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}") - raise diff --git a/app/repositories/knowledgeshare_repository.py b/app/repositories/knowledgeshare_repository.py deleted file mode 100644 index e4976b8d..00000000 --- a/app/repositories/knowledgeshare_repository.py +++ /dev/null @@ -1,142 +0,0 @@ -import uuid -from sqlalchemy.orm import Session -from app.models.knowledgeshare_model import KnowledgeShare -from app.schemas import knowledgeshare_schema -from app.core.logging_config import get_db_logger -from sqlalchemy.orm import joinedload -from sqlalchemy import or_ - -# Obtain a dedicated logger for the database -db_logger = get_db_logger() - - -def get_knowledgeshares_paginated( - db: Session, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - """ - Paged query knowledge base sharing (with filtering and sorting) - """ - db_logger.debug( - f"Query knowledge base sharing in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}") - - try: - query = db.query(KnowledgeShare) - - # Apply filter conditions - for filter_cond in filters: - query = query.filter(filter_cond) - - # Calculate the total count (for pagination) - total = query.count() - db_logger.debug(f"Total number of knowledge base sharing queries: {total}") - - # sort - if orderby: - order_attr = getattr(KnowledgeShare, orderby, None) - if order_attr is not None: - if desc: - query = query.order_by(order_attr.desc()) - else: - query = query.order_by(order_attr.asc()) - db_logger.debug(f"sort: {orderby}, desc={desc}") - - # pagination - items = query.offset((page - 1) * pagesize).limit(pagesize).all() - db_logger.info(f"The knowledge base sharing paging query has been successful: total={total}, Number of current page={len(items)}") - - return total, [knowledgeshare_schema.KnowledgeShare.model_validate(item) for item in items] - except Exception as e: - db_logger.error(f"Querying knowledge base sharing pagination failed: page={page}, pagesize={pagesize} - {str(e)}") - raise - - -def get_source_kb_ids_by_target_kb_id( - db: Session, - filters: list -) -> list: - """ - Query the original knowledge base ID list by sharing the knowledge base - Return: list[UUID] - List of knowledge base IDs - """ - db_logger.debug( - f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}") - - try: - # Only query the id field - query = db.query(KnowledgeShare.source_kb_id) - - # Apply filter conditions - for filter_cond in filters: - query = query.filter(filter_cond) - - # Get all IDs - items = query.all() - db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}") - - # Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column - return [item[0] for item in items] - except Exception as e: - db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}") - raise - - -def create_knowledgeshare(db: Session, knowledgeshare: knowledgeshare_schema.KnowledgeShareCreate) -> KnowledgeShare: - db_logger.debug(f"Create a knowledge base sharing record: source_kb_id={knowledgeshare.source_kb_id}") - - try: - db_knowledgeshare = KnowledgeShare(**knowledgeshare.model_dump()) - db.add(db_knowledgeshare) - db.commit() - db_logger.info(f"knowledge base sharing record created successfully: (ID: {db_knowledgeshare.id})") - return db_knowledgeshare - except Exception as e: - db_logger.error(f"Failed to create a knowledge base sharing record: source_kb_id={knowledgeshare.source_kb_id} - {str(e)}") - db.rollback() - raise - - -def get_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID) -> KnowledgeShare | None: - db_logger.debug(f"Query knowledge base sharing based on ID: knowledgeshare_id={knowledgeshare_id}") - - try: - knowledgeshare = db.query(KnowledgeShare).filter( - or_( - KnowledgeShare.id == knowledgeshare_id, - KnowledgeShare.target_kb_id == knowledgeshare_id - ) - ).first() - if knowledgeshare: - db_logger.debug(f"knowledge base sharing query successful: (ID: {knowledgeshare_id})") - else: - db_logger.debug(f"knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}") - return knowledgeshare - except Exception as e: - db_logger.error(f"Failed to query the knowledge base sharing based on the ID: knowledgeshare_id={knowledgeshare_id} - {str(e)}") - raise - - -def delete_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID): - db_logger.debug(f"Delete knowledge base sharing record: knowledgeshare_id={knowledgeshare_id}") - - try: - result = db.query(KnowledgeShare).filter( - or_( - KnowledgeShare.id == knowledgeshare_id, - KnowledgeShare.target_kb_id == knowledgeshare_id - ) - ).delete() - db.commit() - - if result > 0: - db_logger.info(f"knowledge base sharing record deleted successfully: (ID: {knowledgeshare_id})") - else: - db_logger.warning(f"The knowledge base sharing record does not exist, and cannot be deleted: knowledgeshare_id={knowledgeshare_id}") - except Exception as e: - db_logger.error(f"Failed to delete knowledge base sharing record: knowledgeshare_id={knowledgeshare_id} - {str(e)}") - db.rollback() - raise diff --git a/app/repositories/memory_increment_repository.py b/app/repositories/memory_increment_repository.py deleted file mode 100644 index 37396fbd..00000000 --- a/app/repositories/memory_increment_repository.py +++ /dev/null @@ -1,110 +0,0 @@ -from sqlalchemy import func -from sqlalchemy.orm import Session, aliased -from typing import List, Optional -import uuid -import datetime - -from app.models.memory_increment_model import MemoryIncrement - -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class MemoryIncrementRepository: - def __init__(self, db: Session): - self.db = db - - def get_memory_increments_by_workspace_id(self, workspace_id: uuid.UUID, limit: int) -> List[MemoryIncrement]: - """根据工作空间ID查询内存增量:通过 MemoryIncrement 关联查询 MemoryIncrement 列表""" - try: - # 使用窗口函数按日期分区并排序 - subquery = ( - self.db.query( - MemoryIncrement, - func.row_number().over( - partition_by=func.date(MemoryIncrement.created_at), # 按日期分区 - order_by=MemoryIncrement.created_at.desc() # 按时间戳升序排序 - ).label('row_num') - ) - .filter(MemoryIncrement.workspace_id == workspace_id) - .subquery() - ) - - memory_increment_alias = aliased(MemoryIncrement, subquery) - - memory_increments = ( - self.db.query(memory_increment_alias) - .filter(subquery.c.row_num == 1) # 只取每个日期的第一条(最新的) - .order_by(memory_increment_alias.created_at.asc()) # 按时间戳降序排序 - .limit(limit) - .all() - ) - db_logger.info(f"成功查询工作空间 {workspace_id} 下的内存增量") - return memory_increments - except Exception as e: - db_logger.error(f"查询工作空间 {workspace_id} 下内存增量时出错: {str(e)}") - raise - - def get_latest_memory_increment_by_workspace_id(self, workspace_id: uuid.UUID) -> Optional[MemoryIncrement]: - """根据工作空间ID查询最新的内存增量记录""" - try: - memory_increment = ( - self.db.query(MemoryIncrement) - .filter(MemoryIncrement.workspace_id == workspace_id) - .order_by(MemoryIncrement.created_at.desc(), MemoryIncrement.id.desc()) - .first() - ) - if memory_increment: - db_logger.info(f"成功查询工作空间 {workspace_id} 下的最新内存增量") - else: - db_logger.warning(f"未找到工作空间 {workspace_id} 下的内存增量记录") - return memory_increment - except Exception as e: - db_logger.error(f"查询工作空间 {workspace_id} 下最新内存增量时出错: {str(e)}") - raise - - def write_memory_increment( - self, - workspace_id: uuid.UUID, - total_num: int - ) -> MemoryIncrement: - """写入内存增量""" - try: - memory_increment = MemoryIncrement( - workspace_id=workspace_id, - total_num=total_num, - created_at=datetime.datetime.now(), - updated_at=datetime.datetime.now() - ) - self.db.add(memory_increment) - self.db.commit() - self.db.refresh(memory_increment) - db_logger.info(f"成功写入内存增量: workspace_id={workspace_id}, total_num={total_num}") - return memory_increment - except Exception as e: - db_logger.error(f"写入内存增量失败: workspace_id={workspace_id}, total_num={total_num} - {str(e)}") - raise - - -def get_memory_increments_by_workspace_id(db: Session, workspace_id: uuid.UUID, limit: int) -> List[MemoryIncrement]: - """根据工作空间ID查询内存增量(返回 MemoryIncrement ORM 列表)""" - repo = MemoryIncrementRepository(db) - memory_increments = repo.get_memory_increments_by_workspace_id(workspace_id, limit) - return memory_increments - -def write_memory_increment( - db: Session, - workspace_id: uuid.UUID, - total_num: int -) -> MemoryIncrement: - """写入内存增量""" - repo = MemoryIncrementRepository(db) - memory_increment = repo.write_memory_increment(workspace_id, total_num) - return memory_increment - -def get_latest_memory_increment_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> Optional[MemoryIncrement]: - """根据工作空间ID查询最新的内存增量记录""" - repo = MemoryIncrementRepository(db) - return repo.get_latest_memory_increment_by_workspace_id(workspace_id) \ No newline at end of file diff --git a/app/repositories/model_repository.py b/app/repositories/model_repository.py deleted file mode 100644 index 20c1af40..00000000 --- a/app/repositories/model_repository.py +++ /dev/null @@ -1,386 +0,0 @@ -from sqlalchemy.orm import Session, joinedload -from sqlalchemy import and_, or_, func, desc -from typing import List, Optional, Dict, Any, Tuple -import uuid - -from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider -from app.schemas.model_schema import ( - ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, - ModelConfigQuery -) -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class ModelConfigRepository: - """模型配置Repository""" - - @staticmethod - def get_by_id(db: Session, model_id: uuid.UUID) -> Optional[ModelConfig]: - """根据ID获取模型配置""" - db_logger.debug(f"根据ID查询模型配置: model_id={model_id}") - - try: - model = db.query(ModelConfig).options( - joinedload(ModelConfig.api_keys) - ).filter(ModelConfig.id == model_id).first() - - if model: - db_logger.debug(f"模型配置查询成功: {model.name} (ID: {model_id})") - else: - db_logger.debug(f"模型配置不存在: model_id={model_id}") - return model - except Exception as e: - db_logger.error(f"根据ID查询模型配置失败: model_id={model_id} - {str(e)}") - raise - - @staticmethod - def get_by_name(db: Session, name: str) -> Optional[ModelConfig]: - """根据名称获取模型配置""" - db_logger.debug(f"根据名称查询模型配置: name={name}") - - try: - model = db.query(ModelConfig).filter(ModelConfig.name == name).first() - if model: - db_logger.debug(f"模型配置查询成功: {model.name}") - return model - except Exception as e: - db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}") - raise - - @staticmethod - def search_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]: - """按名称模糊匹配获取模型配置列表 - - Args: - name: 模型名称关键词(模糊匹配) - limit: 返回数量上限 - Returns: - 模型配置列表 - """ - db_logger.debug(f"按名称模糊查询模型配置: name~{name}, limit={limit}") - try: - models = ( - db.query(ModelConfig) - .filter(ModelConfig.name.ilike(f"%{name}%")) - .order_by(ModelConfig.name) - .limit(limit) - .all() - ) - db_logger.debug(f"模糊查询成功: 返回数量={len(models)}") - return models - except Exception as e: - db_logger.error(f"按名称模糊查询模型配置失败: name~{name} - {str(e)}") - raise - - @staticmethod - def get_list(db: Session, query: ModelConfigQuery) -> Tuple[List[ModelConfig], int]: - """获取模型配置列表""" - db_logger.debug(f"查询模型配置列表: {query.dict()}") - - try: - # 构建查询条件 - filters = [] - - # 支持多个 type 值(使用 IN 查询) - if query.type: - filters.append(ModelConfig.type.in_(query.type)) - - if query.is_active is not None: - filters.append(ModelConfig.is_active == query.is_active) - - if query.is_public is not None: - filters.append(ModelConfig.is_public == query.is_public) - - if query.search: - # 搜索逻辑需要join ModelApiKey表来搜索model_name - search_filter = or_( - ModelConfig.name.ilike(f"%{query.search}%"), - # ModelConfig.description.ilike(f"%{query.search}%") - ) - filters.append(search_filter) - - # 构建基础查询 - base_query = db.query(ModelConfig).options( - joinedload(ModelConfig.api_keys) - ) - - # 如果需要按provider筛选,需要join ModelApiKey表 - if query.provider: - base_query = base_query.join(ModelApiKey).filter( - ModelApiKey.provider == query.provider - ).distinct() - - if filters: - base_query = base_query.filter(and_(*filters)) - - # 获取总数 - total = base_query.count() - - # 分页查询 - models = base_query.order_by(desc(ModelConfig.updated_at)).offset( - (query.page - 1) * query.pagesize - ).limit(query.pagesize).all() - - db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}") - return models, total - - except Exception as e: - db_logger.error(f"查询模型配置列表失败: {str(e)}") - raise - - @staticmethod - def get_by_type(db: Session, model_type: ModelType, is_active: bool = True) -> List[ModelConfig]: - """根据类型获取模型配置""" - db_logger.debug(f"根据类型查询模型配置: type={model_type}, is_active={is_active}") - - try: - query = db.query(ModelConfig).options( - joinedload(ModelConfig.api_keys) - ).filter(ModelConfig.type == model_type) - - if is_active: - query = query.filter(ModelConfig.is_active == True) - - models = query.order_by(ModelConfig.name).all() - db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}") - return models - - except Exception as e: - db_logger.error(f"根据类型查询模型配置失败: type={model_type} - {str(e)}") - raise - - @staticmethod - def create(db: Session, model_data: dict) -> ModelConfig: - """创建模型配置""" - db_logger.debug(f"创建模型配置: {model_data.get('name')}") - - try: - db_model = ModelConfig(**model_data) - db.add(db_model) - - db_logger.info(f"模型配置已添加到会话: {db_model.name}") - return db_model - - except Exception as e: - db.rollback() - db_logger.error(f"创建模型配置失败: {model_data.get('name')} - {str(e)}") - raise - - @staticmethod - def update(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> Optional[ModelConfig]: - """更新模型配置""" - db_logger.debug(f"更新模型配置: model_id={model_id}") - - try: - db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first() - if not db_model: - db_logger.warning(f"模型配置不存在: model_id={model_id}") - return None - - # 更新字段 - update_data = model_data.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(db_model, field, value) - - db.commit() - db.refresh(db_model) - - db_logger.info(f"模型配置更新成功: {db_model.name} (ID: {model_id})") - return db_model - - except Exception as e: - db.rollback() - db_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}") - raise - - @staticmethod - def delete(db: Session, model_id: uuid.UUID) -> bool: - """删除模型配置""" - db_logger.debug(f"删除模型配置: model_id={model_id}") - - try: - db_model = db.query(ModelConfig).filter(ModelConfig.id == model_id).first() - if not db_model: - db_logger.warning(f"模型配置不存在: model_id={model_id}") - return False - - db.delete(db_model) - db.commit() - - db_logger.info(f"模型配置删除成功: model_id={model_id}") - return True - - except Exception as e: - db.rollback() - db_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}") - raise - - @staticmethod - def get_stats(db: Session) -> Dict[str, Any]: - """获取模型统计信息""" - db_logger.debug("获取模型统计信息") - - try: - # 总数统计 - total_models = db.query(ModelConfig).count() - active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count() - - # 按类型统计 - llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count() - embedding_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.EMBEDDING).count() - rerank_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.RERANK).count() - - # 按提供商统计 - 现在从ModelApiKey表获取 - provider_stats = {} - provider_results = db.query( - ModelApiKey.provider, func.count(func.distinct(ModelApiKey.model_config_id)) - ).group_by(ModelApiKey.provider).all() - - for provider, count in provider_results: - provider_stats[provider.value] = count - - stats = { - "total_models": total_models, - "active_models": active_models, - "llm_count": llm_count, - "embedding_count": embedding_count, - "rerank_count": rerank_count, - "provider_stats": provider_stats - } - - db_logger.debug(f"模型统计信息获取成功: {stats}") - return stats - - except Exception as e: - db_logger.error(f"获取模型统计信息失败: {str(e)}") - raise - - -class ModelApiKeyRepository: - """模型API Key Repository""" - - @staticmethod - def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ModelApiKey]: - """根据ID获取API Key""" - db_logger.debug(f"根据ID查询API Key: api_key_id={api_key_id}") - - try: - api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first() - if api_key: - db_logger.debug(f"API Key查询成功: {api_key.model_name} (ID: {api_key_id})") - return api_key - except Exception as e: - db_logger.error(f"根据ID查询API Key失败: api_key_id={api_key_id} - {str(e)}") - raise - - @staticmethod - def get_by_model_config(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]: - """根据模型配置ID获取API Key列表""" - db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}") - - try: - query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id) - - if is_active: - query = query.filter(ModelApiKey.is_active == True) - - api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all() - db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}") - return api_keys - - except Exception as e: - db_logger.error(f"根据模型配置ID查询API Key失败: model_config_id={model_config_id} - {str(e)}") - raise - - @staticmethod - def create(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey: - """创建API Key""" - db_logger.debug(f"创建API Key: {api_key_data.provider}") - - try: - db_api_key = ModelApiKey(**api_key_data.dict()) - db.add(db_api_key) - - db_logger.info(f"API Key已添加到会话: {db_api_key.provider}") - return db_api_key - - except Exception as e: - db.rollback() - db_logger.error(f"创建API Key失败: {api_key_data.provider} - {str(e)}") - raise - - @staticmethod - def update(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> Optional[ModelApiKey]: - """更新API Key""" - db_logger.debug(f"更新API Key: api_key_id={api_key_id}") - - try: - db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first() - if not db_api_key: - db_logger.warning(f"API Key不存在: api_key_id={api_key_id}") - return None - - # 更新字段 - update_data = api_key_data.dict(exclude_unset=True) - for field, value in update_data.items(): - setattr(db_api_key, field, value) - - db.commit() - db.refresh(db_api_key) - - db_logger.info(f"API Key更新成功: {db_api_key.model_name} (ID: {api_key_id})") - return db_api_key - - except Exception as e: - db.rollback() - db_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}") - raise - - @staticmethod - def delete(db: Session, api_key_id: uuid.UUID) -> bool: - """删除API Key""" - db_logger.debug(f"删除API Key: api_key_id={api_key_id}") - - try: - db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first() - if not db_api_key: - db_logger.warning(f"API Key不存在: api_key_id={api_key_id}") - return False - - db.delete(db_api_key) - db.commit() - - db_logger.info(f"API Key删除成功: api_key_id={api_key_id}") - return True - - except Exception as e: - db.rollback() - db_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}") - raise - - @staticmethod - def update_usage(db: Session, api_key_id: uuid.UUID) -> bool: - """更新API Key使用统计""" - db_logger.debug(f"更新API Key使用统计: api_key_id={api_key_id}") - - try: - db_api_key = db.query(ModelApiKey).filter(ModelApiKey.id == api_key_id).first() - if not db_api_key: - return False - - # 更新使用次数和最后使用时间 - current_count = int(db_api_key.usage_count or "0") - db_api_key.usage_count = str(current_count + 1) - db_api_key.last_used_at = func.now() - - db.commit() - db_logger.debug(f"API Key使用统计更新成功: api_key_id={api_key_id}") - return True - - except Exception as e: - db.rollback() - db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}") - raise \ No newline at end of file diff --git a/app/repositories/neo4j/__init__.py b/app/repositories/neo4j/__init__.py deleted file mode 100644 index 7f9e2ed8..00000000 --- a/app/repositories/neo4j/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# -*- coding: utf-8 -*- -"""Neo4j仓储模块 - -本模块包含Neo4j图数据库的仓储实现,用于管理知识图谱的节点和边。 - -Modules: - neo4j_connector: Neo4j数据库连接器 - base_neo4j_repository: Neo4j仓储基类 - dialog_repository: 对话仓储 - statement_repository: 陈述句仓储 - entity_repository: 实体仓储 - cypher_queries: Cypher查询语句 - graph_search: 图搜索功能 - graph_saver: 图数据保存功能 - add_nodes: 添加节点功能 - add_edges: 添加边功能 - create_indexes: 创建索引功能 -""" - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository -from app.repositories.neo4j.dialog_repository import DialogRepository -from app.repositories.neo4j.statement_repository import StatementRepository -from app.repositories.neo4j.entity_repository import EntityRepository - -__all__ = [ - 'Neo4jConnector', - 'BaseNeo4jRepository', - 'DialogRepository', - 'StatementRepository', - 'EntityRepository', -] diff --git a/app/repositories/neo4j/add_edges.py b/app/repositories/neo4j/add_edges.py deleted file mode 100644 index 1d4c050b..00000000 --- a/app/repositories/neo4j/add_edges.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import List, Optional -import hashlib -from datetime import datetime -from uuid import uuid4 -from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE, MEMORY_SUMMARY_STATEMENT_EDGE_SAVE -from app.core.memory.models.message_models import Chunk -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.models.graph_models import MemorySummaryNode - -async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnector) -> Optional[List[str]]: - """Add edges between chunk nodes and their statement nodes in Neo4j. - - Args: - chunks: List of Chunk objects containing the statements - connector: Neo4j connector instance - - Returns: - List of created edge UUIDs or None if failed - """ - if not chunks: - print("No chunks provided to create edges") - return [] - - try: - # Build edges deterministically per (chunk, statement) pair - edges: List[dict] = [] - for chunk in chunks: - for stmt in getattr(chunk, "statements", []) or []: - stable_edge_id = hashlib.sha1(f"{chunk.id}|{stmt.id}".encode("utf-8")).hexdigest() - edge = { - "id": stable_edge_id, - "source": chunk.id, - "target": stmt.id, - "group_id": getattr(stmt, 'group_id', None), - "user_id":getattr(stmt, 'user_id', None), - "apply_id": getattr(stmt, 'apply_id', None), - "run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None), - "created_at": getattr(stmt, 'created_at', None), - "expired_at": getattr(stmt, 'expired_at', None), - # "created_at": getattr(statement, 'created_at', None), - # "expired_at": None # Set to None or appropriate default - } - edges.append(edge) - - if not edges: - print("No statements found in chunks to create edges") - return [] - - # Execute the query to create edges - result = await connector.execute_query( - CHUNK_STATEMENT_EDGE_SAVE, - chunk_statement_edges=edges - ) - created_uuids = [record.get("uuid") for record in result] if result else [] - print(f"Successfully created {len(created_uuids)} chunk-statement edges") - return created_uuids - except Exception as e: - print(f"Error creating chunk-statement edges: {e}") - return None - -async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: - """Create edges from MemorySummary to Statements via their chunk_ids. - - For each summary and each chunk_id in it, this links the summary to all statements - contained in that chunk using DERIVED_FROM_STATEMENT. This supports queries like - summary -> statement -> entity with minimal hops. - - Args: - summaries: List of MemorySummaryNode objects - connector: Neo4j connector instance - - Returns: - List of created edge elementIds or None if failed - """ - if not summaries: - return [] - - try: - edges: List[dict] = [] - for s in summaries: - for chunk_id in getattr(s, "chunk_ids", []) or []: - edges.append({ - "summary_id": s.id, - "chunk_id": chunk_id, - "group_id": s.group_id, - "run_id": s.run_id, - "created_at": s.created_at.isoformat() if s.created_at else None, - "expired_at": s.expired_at.isoformat() if s.expired_at else None, - }) - - if not edges: - return [] - - result = await connector.execute_query( - MEMORY_SUMMARY_STATEMENT_EDGE_SAVE, - edges=edges - ) - created = [record.get("uuid") for record in result] if result else [] - return created - except Exception: - return None diff --git a/app/repositories/neo4j/add_nodes.py b/app/repositories/neo4j/add_nodes.py deleted file mode 100644 index d339879f..00000000 --- a/app/repositories/neo4j/add_nodes.py +++ /dev/null @@ -1,215 +0,0 @@ -from typing import List, Optional - -from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE -from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - -async def delete_all_nodes(group_id: str, connector: Neo4jConnector): - """Delete all nodes in the database.""" - result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n") - print(f"All group_id: {group_id} node and edge deleted successfully") - return result - -async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: - """Add dialogue nodes to Neo4j database. - - Args: - dialogues: List of DialogueNode objects to save - connector: Neo4j connector instance - - Returns: - List of created node UUIDs or None if failed - """ - if not dialogues: - print("No dialogues to save") - return [] - - try: - # Flatten DialogueNode objects to match Cypher expected fields - flattened_dialogues = [] - for dialogue in dialogues: - flattened_dialogues.append({ - "id": dialogue.id, - "group_id": dialogue.group_id, - "user_id": dialogue.user_id, - "apply_id": dialogue.apply_id, - "run_id": dialogue.run_id, - "ref_id": dialogue.ref_id, - "name": dialogue.name, - "created_at": dialogue.created_at.isoformat() if dialogue.created_at else None, - "expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None, - "content": dialogue.content, - "dialog_embedding": dialogue.dialog_embedding - }) - - result = await connector.execute_query( - DIALOGUE_NODE_SAVE, - dialogues=flattened_dialogues - ) - - created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") - return created_uuids - - except Exception as e: - print(f"Error creating dialogue nodes: {e}") - return None - - -async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]: - """Add statement nodes to Neo4j database. - - Args: - statements: List of StatementNode objects to save - connector: Neo4j connector instance - - Returns: - List of created node UUIDs or None if failed - """ - if not statements: - print("No statements to save") - return [] - - try: - # Flatten StatementNode objects to only include primitive types - flattened_statements = [] - for statement in statements: - flattened_statement = { - "id": statement.id, - "name": statement.name, - "group_id": statement.group_id, - "user_id": statement.user_id, - "apply_id": statement.apply_id, - "run_id": statement.run_id, - "chunk_id": statement.chunk_id, - # "created_at": statement.created_at.isoformat(), - "created_at": statement.created_at.isoformat() if statement.created_at else None, - "expired_at": statement.expired_at.isoformat() if statement.expired_at else None, - "stmt_type": statement.stmt_type, - "temporal_info": statement.temporal_info.value, - "statement": statement.statement, - "connect_strength": statement.connect_strength, - "chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None, - # "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None, - # "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None, - "valid_at": statement.valid_at.isoformat() if statement.valid_at else None, - "invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None, - # "triplet_extraction_info": json.dumps({ - # "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [], - # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] - # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), - "statement_embedding": statement.statement_embedding if statement.statement_embedding else None - } - flattened_statements.append(flattened_statement) - - result = await connector.execute_query( - STATEMENT_NODE_SAVE, - statements=flattened_statements - ) - - created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} statement nodes") - return created_uuids - - except Exception as e: - print(f"Error creating statement nodes: {e}") - return None - -async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: - """Add chunk nodes to Neo4j in batch. - - Args: - chunks: List of ChunkNode objects to add - connector: Neo4j connector instance - - Returns: - List of created chunk UUIDs or None if failed - """ - if not chunks: - print("No chunk nodes to add") - return [] - - try: - # Convert chunk nodes to dictionaries for the query - flattened_chunks = [] - for chunk in chunks: - # Flatten metadata properties to avoid Neo4j Map type issues - metadata = chunk.metadata if chunk.metadata else {} - flattened_chunk = { - "id": chunk.id, - "name": chunk.name, - "group_id": chunk.group_id, - "user_id": chunk.user_id, - "apply_id": chunk.apply_id, - "run_id": chunk.run_id, - "created_at": chunk.created_at.isoformat() if chunk.created_at else None, - "expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None, - "dialog_id": chunk.dialog_id, - "content": chunk.content, - "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, - "sequence_number": chunk.sequence_number, - "start_index": metadata.get("start_index"), - "end_index": metadata.get("end_index") - } - flattened_chunks.append(flattened_chunk) - - result = await connector.execute_query( - CHUNK_NODE_SAVE, - chunks=flattened_chunks - ) - - created_uuids = [record["uuid"] for record in result] - print(f"Successfully created {len(created_uuids)} chunk nodes") - return created_uuids - - except Exception as e: - print(f"Error creating chunk nodes: {e}") - return None - - - -async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[List[str]]: - """Add memory summary nodes to Neo4j in batch. - - Args: - summaries: List of MemorySummaryNode objects to add - connector: Neo4j connector instance - - Returns: - List of created summary node ids or None if failed - """ - if not summaries: - print("No memory summary nodes to add") - return [] - - try: - flattened = [] - for s in summaries: - flattened.append({ - "id": s.id, - "name": s.name, - "group_id": s.group_id, - "user_id": s.user_id, - "apply_id": s.apply_id, - "run_id": s.run_id, - "created_at": s.created_at.isoformat() if s.created_at else None, - "expired_at": s.expired_at.isoformat() if s.expired_at else None, - "dialog_id": s.dialog_id, - "chunk_ids": s.chunk_ids, - "content": s.content, - "summary_embedding": s.summary_embedding if s.summary_embedding else None, - "config_id": s.config_id, # 添加 config_id - }) - - result = await connector.execute_query( - MEMORY_SUMMARY_NODE_SAVE, - summaries=flattened - ) - created_ids = [record.get("uuid") for record in result] - return created_ids - except Exception: - return None - - diff --git a/app/repositories/neo4j/base_neo4j_repository.py b/app/repositories/neo4j/base_neo4j_repository.py deleted file mode 100644 index 51a90078..00000000 --- a/app/repositories/neo4j/base_neo4j_repository.py +++ /dev/null @@ -1,175 +0,0 @@ -# -*- coding: utf-8 -*- -"""Neo4j仓储基类模块 - -本模块提供Neo4j仓储的基类实现,封装了通用的Neo4j节点操作。 - -Classes: - BaseNeo4jRepository: Neo4j仓储基类,实现通用的CRUD操作 -""" - -from typing import List, Optional, Dict, Any, TypeVar -from app.repositories.base_repository import BaseRepository -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -T = TypeVar('T') - - -class BaseNeo4jRepository(BaseRepository[T]): - """Neo4j仓储基类 - 实现通用的Neo4j节点操作 - - 这个基类封装了Neo4j节点的通用CRUD操作,子类只需要实现 - 特定的映射逻辑和业务查询方法。 - - Attributes: - connector: Neo4j连接器实例 - node_label: 节点标签(如"Dialogue", "Statement"等) - - Type Parameters: - T: 实体类型,通常是Pydantic模型 - """ - - def __init__(self, connector: Neo4jConnector, node_label: str): - """初始化Neo4j仓储 - - Args: - connector: Neo4j连接器实例 - node_label: 节点标签,用于Cypher查询 - """ - self.connector = connector - self.node_label = node_label - - async def create(self, entity: T) -> T: - """创建节点 - - 将实体对象转换为Neo4j节点并保存到数据库。 - - Args: - entity: 要创建的实体对象 - - Returns: - T: 创建后的实体对象 - - Example: - >>> dialog = DialogueNode(id="123", name="对话1", ...) - >>> created = await repository.create(dialog) - """ - query = f""" - CREATE (n:{self.node_label} $props) - RETURN n - """ - result = await self.connector.execute_query( - query, - props=entity.model_dump() - ) - return entity - - async def get_by_id(self, entity_id: str) -> Optional[T]: - """根据ID获取节点 - - Args: - entity_id: 节点ID - - Returns: - Optional[T]: 找到的实体对象,如果不存在则返回None - """ - query = f""" - MATCH (n:{self.node_label} {{id: $id}}) - RETURN n - """ - result = await self.connector.execute_query(query, id=entity_id) - if result: - return self._map_to_entity(result[0]) - return None - - async def update(self, entity: T) -> T: - """更新节点 - - 更新现有节点的属性。使用SET +=语法合并属性。 - - Args: - entity: 要更新的实体对象(必须包含id字段) - - Returns: - T: 更新后的实体对象 - """ - query = f""" - MATCH (n:{self.node_label} {{id: $id}}) - SET n += $props - RETURN n - """ - await self.connector.execute_query( - query, - id=entity.id, - props=entity.model_dump() - ) - return entity - - async def delete(self, entity_id: str) -> bool: - """删除节点 - - 删除指定ID的节点。使用DETACH DELETE同时删除相关的边。 - - Args: - entity_id: 要删除的节点ID - - Returns: - bool: 删除成功返回True,否则返回False - """ - query = f""" - MATCH (n:{self.node_label} {{id: $id}}) - DETACH DELETE n - RETURN count(n) as deleted - """ - result = await self.connector.execute_query(query, id=entity_id) - return result[0]['deleted'] > 0 if result else False - - async def find(self, filters: Dict[str, Any], limit: int = 100) -> List[T]: - """查询节点 - - 根据过滤条件查询节点列表。 - - Args: - filters: 查询条件字典,键为属性名,值为期望的值 - limit: 返回结果的最大数量 - - Returns: - List[T]: 符合条件的实体列表 - - Example: - >>> results = await repository.find( - ... {"group_id": "group_123", "user_id": "user_456"}, - ... limit=50 - ... ) - """ - # 构建查询条件 - where_clauses = [f"n.{key} = ${key}" for key in filters.keys()] - where_str = " AND ".join(where_clauses) if where_clauses else "1=1" - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_str} - RETURN n - LIMIT $limit - """ - results = await self.connector.execute_query( - query, - limit=limit, - **filters - ) - return [self._map_to_entity(r) for r in results] - - def _map_to_entity(self, node_data: Dict) -> T: - """将节点数据映射为实体对象 - - 这是一个抽象方法,子类必须实现具体的映射逻辑。 - - Args: - node_data: 从Neo4j查询返回的节点数据字典 - - Returns: - T: 映射后的实体对象 - - Raises: - NotImplementedError: 如果子类未实现此方法 - """ - raise NotImplementedError("Subclasses must implement _map_to_entity method") diff --git a/app/repositories/neo4j/create_indexes.py b/app/repositories/neo4j/create_indexes.py deleted file mode 100644 index 55dead1b..00000000 --- a/app/repositories/neo4j/create_indexes.py +++ /dev/null @@ -1,332 +0,0 @@ -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - -async def create_fulltext_indexes(): - """Create full-text indexes for keyword search with BM25 scoring.""" - connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Full-Text Indexes (for keyword search)") - print("=" * 70) - - # 创建 Statements 索引 - await connector.execute_query(""" - CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] - OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: statementsFulltext") - - # # 创建 Dialogues 索引 - # await connector.execute_query(""" - # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] - # OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - # """) - - # 创建 Entities 索引 - await connector.execute_query(""" - CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] - OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: entitiesFulltext") - - # 创建 Chunks 索引 - await connector.execute_query(""" - CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] - OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: chunksFulltext") - - # 创建 MemorySummary 索引 - await connector.execute_query(""" - CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] - OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - print("✓ Created: summariesFulltext") - - print("\nFull-text indexes created successfully with BM25 support.") - except Exception as e: - print(f"✗ Error creating full-text indexes: {e}") - finally: - await connector.close() - - -async def create_vector_indexes(): - """Create vector indexes for fast embedding similarity search. - - Vector indexes provide 10-100x faster similarity search compared to manual cosine calculation. - This is critical for performance - reduces embedding search from ~1.4s to ~0.05-0.2s! - """ - connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Vector Indexes (for embedding search)") - print("=" * 70) - print("Note: Adjust vector.dimensions if using different embedding model") - print(" Current setting: 1024 dimensions (for bge-m3)") - print() - - # Statement embedding index - await connector.execute_query(""" - CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS - FOR (s:Statement) - ON s.statement_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: statement_embedding_index") - - # Chunk embedding index - await connector.execute_query(""" - CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS - FOR (c:Chunk) - ON c.chunk_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: chunk_embedding_index") - - # Entity name embedding index - await connector.execute_query(""" - CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS - FOR (e:ExtractedEntity) - ON e.name_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: entity_embedding_index") - - # Memory summary embedding index - await connector.execute_query(""" - CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS - FOR (m:MemorySummary) - ON m.summary_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: summary_embedding_index") - - # Dialogue embedding index (optional) - await connector.execute_query(""" - CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS - FOR (d:Dialogue) - ON d.dialog_embedding - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """) - print("✓ Created: dialogue_embedding_index") - - print("\nVector indexes created successfully!") - print("\nExpected performance improvement:") - print(" Before: ~1.4s for embedding search") - print(" After: ~0.05-0.2s for embedding search (10-30x faster!)") - - except Exception as e: - print(f"✗ Error creating vector indexes: {e}") - finally: - await connector.close() - - -async def create_config_id_indexes(): - """Create indexes on config_id fields for improved query performance. - - These indexes enable fast filtering of nodes by configuration ID, - which is essential for configuration isolation and multi-tenant scenarios. - """ - connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Config ID Indexes") - print("=" * 70) - - # Dialogue.config_id index - await connector.execute_query(""" - CREATE INDEX dialogue_config_id_index IF NOT EXISTS - FOR (d:Dialogue) ON (d.config_id) - """) - print("✓ Created: dialogue_config_id_index") - - # Statement.config_id index - await connector.execute_query(""" - CREATE INDEX statement_config_id_index IF NOT EXISTS - FOR (s:Statement) ON (s.config_id) - """) - print("✓ Created: statement_config_id_index") - - # ExtractedEntity.config_id index - await connector.execute_query(""" - CREATE INDEX entity_config_id_index IF NOT EXISTS - FOR (e:ExtractedEntity) ON (e.config_id) - """) - print("✓ Created: entity_config_id_index") - - # MemorySummary.config_id index - await connector.execute_query(""" - CREATE INDEX summary_config_id_index IF NOT EXISTS - FOR (m:MemorySummary) ON (m.config_id) - """) - print("✓ Created: summary_config_id_index") - - print("\nConfig ID indexes created successfully!") - print("These indexes enable fast filtering by configuration ID.") - - except Exception as e: - print(f"✗ Error creating config_id indexes: {e}") - finally: - await connector.close() - - -async def create_unique_constraints(): - """Create uniqueness constraints for core node identifiers. - - Ensures concurrent MERGE operations remain safe and prevents duplicates. - """ - connector = Neo4jConnector() - try: - print("\n" + "=" * 70) - print("Creating Unique Constraints") - print("=" * 70) - - # Dialogue.id unique - await connector.execute_query( - """ - CREATE CONSTRAINT dialog_id_unique IF NOT EXISTS - FOR (d:Dialogue) REQUIRE d.id IS UNIQUE - """ - ) - print("✓ Created: dialog_id_unique") - - # Statement.id unique - await connector.execute_query( - """ - CREATE CONSTRAINT statement_id_unique IF NOT EXISTS - FOR (s:Statement) REQUIRE s.id IS UNIQUE - """ - ) - print("✓ Created: statement_id_unique") - - # Chunk.id unique - await connector.execute_query( - """ - CREATE CONSTRAINT chunk_id_unique IF NOT EXISTS - FOR (c:Chunk) REQUIRE c.id IS UNIQUE - """ - ) - print("✓ Created: chunk_id_unique") - - print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.") - except Exception as e: - print(f"✗ Error creating unique constraints: {e}") - finally: - await connector.close() - - -async def create_all_indexes(): - """Create all indexes and constraints in one go.""" - print("\n" + "=" * 70) - print("Neo4j Index & Constraint Setup") - print("=" * 70) - print("This will create:") - print(" 1. Full-text indexes (for keyword/BM25 search)") - print(" 2. Vector indexes (for embedding similarity search)") - print(" 3. Config ID indexes (for configuration isolation)") - print(" 4. Unique constraints (for data integrity)") - print("=" * 70) - - await create_fulltext_indexes() - await create_vector_indexes() - await create_config_id_indexes() - await create_unique_constraints() - - print("\n" + "=" * 70) - print("✓ All indexes and constraints created successfully!") - print("=" * 70) - print("\nTo verify, run in Neo4j Browser:") - print(" SHOW INDEXES") - print(" SHOW CONSTRAINTS") - print() - - -async def check_indexes(): - """Check what indexes currently exist.""" - connector = Neo4jConnector() - - try: - print("\n" + "=" * 70) - print("Checking Existing Indexes") - print("=" * 70) - - query = "SHOW INDEXES" - result = await connector.execute_query(query) - - fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT'] - vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR'] - range_indexes = [idx for idx in result if idx.get('type') == 'RANGE'] - - print(f"\nFull-text indexes: {len(fulltext_indexes)}") - for idx in fulltext_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nVector indexes: {len(vector_indexes)}") - for idx in vector_indexes: - print(f" ✓ {idx.get('name')}") - - print(f"\nRange indexes (including config_id): {len(range_indexes)}") - for idx in range_indexes: - print(f" ✓ {idx.get('name')}") - - if not vector_indexes: - print("\n⚠️ WARNING: No vector indexes found!") - print(" Embedding search will be VERY SLOW (~1.4s)") - print(" Run: python create_indexes.py") - - # Check for config_id indexes - config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')] - if len(config_id_indexes) < 4: - print("\n⚠️ WARNING: Not all config_id indexes found!") - print(f" Expected 4, found {len(config_id_indexes)}") - print(" Run: python create_indexes.py config_id") - - print("=" * 70) - - finally: - await connector.close() - - -if __name__ == "__main__": - import asyncio - import sys - - if len(sys.argv) > 1: - command = sys.argv[1] - if command == "check": - asyncio.run(check_indexes()) - elif command == "fulltext": - asyncio.run(create_fulltext_indexes()) - elif command == "vector": - asyncio.run(create_vector_indexes()) - elif command == "config_id": - asyncio.run(create_config_id_indexes()) - elif command == "constraints": - asyncio.run(create_unique_constraints()) - else: - print(f"Unknown command: {command}") - print("\nUsage:") - print(" python create_indexes.py # Create all indexes") - print(" python create_indexes.py check # Check existing indexes") - print(" python create_indexes.py fulltext # Create only full-text indexes") - print(" python create_indexes.py vector # Create only vector indexes") - print(" python create_indexes.py config_id # Create only config_id indexes") - print(" python create_indexes.py constraints # Create only constraints") - else: - asyncio.run(create_all_indexes()) - diff --git a/app/repositories/neo4j/cypher_queries.py b/app/repositories/neo4j/cypher_queries.py deleted file mode 100644 index 1f9943f8..00000000 --- a/app/repositories/neo4j/cypher_queries.py +++ /dev/null @@ -1,684 +0,0 @@ - -DIALOGUE_NODE_SAVE = """ - UNWIND $dialogues AS dialogue - MERGE (n:Dialogue {id: dialogue.id}) - SET n.uuid = coalesce(n.uuid, dialogue.id), - n.group_id = dialogue.group_id, - n.user_id = dialogue.user_id, - n.apply_id = dialogue.apply_id, - n.run_id = dialogue.run_id, - n.ref_id = dialogue.ref_id, - n.created_at = dialogue.created_at, - n.expired_at = dialogue.expired_at, - n.content = dialogue.content, - n.dialog_embedding = dialogue.dialog_embedding - RETURN n.id AS uuid -""" - -STATEMENT_NODE_SAVE = """ -UNWIND $statements AS statement -MERGE (s:Statement {id: statement.id}) -SET s += { - id: statement.id, - group_id: statement.group_id, - user_id: statement.user_id, - apply_id: statement.apply_id, - chunk_id: statement.chunk_id, - run_id: statement.run_id, - created_at: statement.created_at, - expired_at: statement.expired_at, - stmt_type: statement.stmt_type, - temporal_info: statement.temporal_info, - relevence_info: statement.relevence_info, - statement: statement.statement, - valid_at: statement.valid_at, - invalid_at: statement.invalid_at, - statement_embedding: statement.statement_embedding -} -RETURN s.id AS uuid -""" - -CHUNK_NODE_SAVE = """ -UNWIND $chunks AS chunk -MERGE (c:Chunk {id: chunk.id}) -SET c += { - id: chunk.id, - name: chunk.name, - group_id: chunk.group_id, - user_id: chunk.user_id, - apply_id: chunk.apply_id, - run_id: chunk.run_id, - created_at: chunk.created_at, - expired_at: chunk.expired_at, - dialog_id: chunk.dialog_id, - content: chunk.content, - chunk_embedding: chunk.chunk_embedding, - sequence_number: chunk.sequence_number, - start_index: chunk.start_index, - end_index: chunk.end_index -} -RETURN c.id AS uuid -""" -# bug修改点 - -EXTRACTED_ENTITY_NODE_SAVE = """ -// Upsert entity nodes safely: preserve existing non-empty fields when incoming is empty -UNWIND $entities AS entity -MERGE (e:ExtractedEntity {id: entity.id}) -SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END, - e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END, - e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END, - e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END, - e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END, - e.created_at = CASE - WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at) - THEN entity.created_at ELSE e.created_at END, - e.expired_at = CASE - WHEN entity.expired_at IS NOT NULL AND (e.expired_at IS NULL OR entity.expired_at > e.expired_at) - THEN entity.expired_at ELSE e.expired_at END, - e.entity_idx = CASE WHEN e.entity_idx IS NULL OR e.entity_idx = 0 THEN entity.entity_idx ELSE e.entity_idx END, - e.entity_type = CASE WHEN entity.entity_type IS NOT NULL AND entity.entity_type <> '' THEN entity.entity_type ELSE e.entity_type END, - e.description = CASE - WHEN entity.description IS NOT NULL AND entity.description <> '' - AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description)) - THEN entity.description ELSE e.description END, - e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END, - e.aliases = CASE - WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0 - THEN CASE WHEN e.aliases IS NULL THEN entity.aliases ELSE e.aliases + entity.aliases END - ELSE e.aliases END, - e.name_embedding = CASE - WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding - ELSE e.name_embedding END, - e.fact_summary = CASE - WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> '' - AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary)) - THEN entity.fact_summary ELSE e.fact_summary END, - e.connect_strength = CASE - WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength - ELSE CASE - WHEN e.connect_strength = 'strong' AND entity.connect_strength = 'weak' THEN 'both' - WHEN e.connect_strength = 'weak' AND entity.connect_strength = 'strong' THEN 'both' - WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength - ELSE e.connect_strength - END - END -RETURN e.id AS uuid -""" - -# Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships -ENTITY_RELATIONSHIP_SAVE = """ -UNWIND $relationships AS rel -// Match entities by stable id within group, do not constrain by run_id -MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id}) -MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id}) -// Avoid duplicate edges across runs for the same endpoints -MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object) -SET r.predicate = rel.predicate, - r.statement_id = rel.statement_id, - r.value = rel.value, - r.statement = rel.statement, - r.valid_at = rel.valid_at, - r.invalid_at = rel.invalid_at, - r.created_at = rel.created_at, - r.expired_at = rel.expired_at, - r.run_id = rel.run_id, - r.group_id = rel.group_id -RETURN elementId(r) AS uuid -""" - -# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代 - -# 保存弱关系实体,设置 e.is_weak = true;不维护 e.relations 聚合字段 -WEAK_ENTITY_NODE_SAVE = """ -UNWIND $weak_entities AS entity -MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id}) -SET e += { - name: entity.name, - group_id: entity.group_id, - run_id: entity.run_id, - description: entity.description, - chunk_id: entity.chunk_id, - dialog_id: entity.dialog_id -} -// Independent weak flag,仅标记弱关系,不再维护 relations 聚合字段 -SET e.is_weak = true -RETURN e.id AS id -""" - -# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true,不维护 e.relations 字段 -SAVE_STRONG_TRIPLE_ENTITIES = """ -UNWIND $items AS item -MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id}) -SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id} -// Independent strong flag -SET s.is_strong = true -MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id}) -SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id} -// Independent strong flag -SET o.is_strong = true -""" - - -DIALOGUE_STATEMENT_EDGE_SAVE = """ - UNWIND $dialogue_statement_edges AS edge - // 支持按 uuid 或 ref_id 连接到 Dialogue,避免因来源 ID 不一致而断链 - MATCH (dialogue:Dialogue) - WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source - MATCH (statement:Statement {id: edge.target}) - // 仅按端点去重,关系属性可更新 - MERGE (dialogue)-[e:MENTIONS]->(statement) - SET e.uuid = edge.id, - e.group_id = edge.group_id, - e.created_at = edge.created_at, - e.expired_at = edge.expired_at - RETURN e.uuid AS uuid -""" - -# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代 - - -CHUNK_STATEMENT_EDGE_SAVE = """ - UNWIND $chunk_statement_edges AS edge - MATCH (statement:Statement {id: edge.source, run_id: edge.run_id}) - MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id}) - MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement) - SET e.group_id = edge.group_id, - e.run_id = edge.run_id, - e.created_at = edge.created_at, - e.expired_at = edge.expired_at - RETURN e.id AS uuid -""" - -STATEMENT_ENTITY_EDGE_SAVE = """ -UNWIND $relationships AS rel -// Statement nodes are per-run; keep run_id constraint on statements -// Statement nodes are per-run; keep run_id constraint on statements -MATCH (statement:Statement {id: rel.source, run_id: rel.run_id}) -// Entities are shared across runs within a group; do not constrain by run_id -MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id}) -// Avoid duplicate edges across runs for same endpoints -MERGE (statement)-[r:REFERENCES_ENTITY]->(entity) -SET r.group_id = rel.group_id, - r.run_id = rel.run_id, - r.created_at = rel.created_at, - r.expired_at = rel.expired_at, - r.connect_strength = rel.connect_strength -RETURN elementId(r) AS uuid -""" - -ENTITY_EMBEDDING_SEARCH = """ -CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding) -YIELD node AS e, score -WHERE e.name_embedding IS NOT NULL - AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, - e.name AS name, - e.group_id AS group_id, - e.entity_type AS entity_type, - score -ORDER BY score DESC -LIMIT $limit -""" -# Embedding-based search: cosine similarity on Statement.statement_embedding -STATEMENT_EMBEDDING_SEARCH = """ -CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding) -YIELD node AS s, score -WHERE s.statement_embedding IS NOT NULL - AND ($group_id IS NULL OR s.group_id = $group_id) -RETURN s.id AS id, - s.statement AS statement, - s.group_id AS group_id, - s.chunk_id AS chunk_id, - s.created_at AS created_at, - s.expired_at AS expired_at, - s.valid_at AS valid_at, - s.invalid_at AS invalid_at, - score -ORDER BY score DESC -LIMIT $limit -""" - -# Embedding-based search: cosine similarity on Chunk.chunk_embedding -CHUNK_EMBEDDING_SEARCH = """ -CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding) -YIELD node AS c, score -WHERE c.chunk_embedding IS NOT NULL - AND ($group_id IS NULL OR c.group_id = $group_id) -RETURN c.id AS chunk_id, - c.group_id AS group_id, - c.content AS content, - c.dialog_id AS dialog_id, - score -ORDER BY score DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score -WHERE ($group_id IS NULL OR s.group_id = $group_id) -OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) -OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) -RETURN s.id AS id, - s.statement AS statement, - s.group_id AS group_id, - s.chunk_id AS chunk_id, - s.created_at AS created_at, - s.expired_at AS expired_at, - s.valid_at AS valid_at, - s.invalid_at AS invalid_at, - c.id AS chunk_id_from_rel, - collect(DISTINCT e.id) AS entity_ids, - score -ORDER BY score DESC -LIMIT $limit -""" -# 查询实体名称包含指定字符串的实体 -SEARCH_ENTITIES_BY_NAME = """ -CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score -WHERE ($group_id IS NULL OR e.group_id = $group_id) -OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) -OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) -RETURN e.id AS id, - e.name AS name, - e.group_id AS group_id, - e.entity_type AS entity_type, - e.apply_id AS apply_id, - e.user_id AS user_id, - e.created_at AS created_at, - e.expired_at AS expired_at, - e.entity_idx AS entity_idx, - e.statement_id AS statement_id, - e.description AS description, - e.aliases AS aliases, - e.name_embedding AS name_embedding, - e.fact_summary AS fact_summary, - e.connect_strength AS connect_strength, - collect(DISTINCT s.id) AS statement_ids, - collect(DISTINCT c.id) AS chunk_ids, - score -ORDER BY score DESC -LIMIT $limit -""" - -SEARCH_CHUNKS_BY_CONTENT = """ -CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score -WHERE ($group_id IS NULL OR c.group_id = $group_id) -OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement) -OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) -RETURN c.id AS chunk_id, - c.group_id AS group_id, - c.content AS content, - c.dialog_id AS dialog_id, - c.sequence_number AS sequence_number, - collect(DISTINCT s.id) AS statement_ids, - collect(DISTINCT e.id) AS entity_ids, - score -ORDER BY score DESC -LIMIT $limit -""" - -# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用 - -# # 同组group_id下按“精确名字或别名+可选类型一致”来检索 -# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """ -# UNWIND $rows AS row -# MATCH (e:ExtractedEntity) -# WHERE e.group_id = row.group_id -# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name))) -# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type) -# RETURN row.id AS incoming_id, -# e.id AS id, -# e.name AS name, -# e.group_id AS group_id, -# e.entity_idx AS entity_idx, -# e.entity_type AS entity_type, -# e.description AS description, -# e.statement_id AS statement_id, -# e.aliases AS aliases, -# e.name_embedding AS name_embedding, -# e.fact_summary AS fact_summary, -# e.connect_strength AS connect_strength, -# e.created_at AS created_at, -# e.expired_at AS expired_at -# """ -# # 同组group_id下按name contains召回补充 -# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """ -# UNWIND $rows AS row -# MATCH (e:ExtractedEntity) -# WHERE e.group_id = row.group_id -# AND toLower(e.name) CONTAINS toLower(row.name) -# RETURN row.id AS incoming_id, -# e.id AS id, -# e.name AS name, -# e.group_id AS group_id, -# e.entity_idx AS entity_idx, -# e.entity_type AS entity_type, -# e.description AS description, -# e.statement_id AS statement_id, -# e.aliases AS aliases, -# e.name_embedding AS name_embedding, -# e.fact_summary AS fact_summary, -# e.connect_strength AS connect_strength, -# e.created_at AS created_at, -# e.expired_at AS expired_at -# """ - -SEARCH_DIALOGUE_BY_DIALOG_ID = """ -MATCH (d:Dialogue) -WHERE ($group_id IS NULL OR d.group_id = $group_id) - AND d.id = $dialog_id -RETURN d.id AS dialog_id, - d.group_id AS group_id, - d.content AS content, - d.created_at AS created_at, - d.expired_at AS expired_at -ORDER BY d.created_at DESC -LIMIT $limit -""" - -SEARCH_CHUNK_BY_CHUNK_ID = """ -MATCH (c:Chunk) -WHERE ($group_id IS NULL OR c.group_id = $group_id) - AND c.id = $chunk_id -RETURN c.id AS chunk_id, - c.group_id AS group_id, - c.content AS content, - c.dialog_id AS dialog_id, - c.created_at AS created_at, - c.expired_at AS expired_at, - c.sequence_number AS sequence_number -ORDER BY c.created_at DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_BY_TEMPORAL = """ -MATCH (s:Statement) -WHERE ($group_id IS NULL OR s.group_id = $group_id) - AND ($apply_id IS NULL OR s.apply_id = $apply_id) - AND ($user_id IS NULL OR s.user_id = $user_id) - AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date)) - AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date))) - OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) - AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date))))) -RETURN s.id AS id, - s.statement AS statement, - s.group_id AS group_id, - s.apply_id AS apply_id, - s.user_id AS user_id, - s.chunk_id AS chunk_id, - s.created_at AS created_at, - s.valid_at AS valid_at, - s.invalid_at AS invalid_at, - collect(DISTINCT s.id) AS statement_ids -ORDER BY datetime(s.created_at) DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """ -CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score -WHERE ($group_id IS NULL OR s.group_id = $group_id) - AND ($apply_id IS NULL OR s.apply_id = $apply_id) - AND ($user_id IS NULL OR s.user_id = $user_id) - AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date))) - AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date)))) - OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) - AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date))))) -OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) -OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) -RETURN s.id AS id, - s.statement AS statement, - s.group_id AS group_id, - s.apply_id AS apply_id, - s.user_id AS user_id, - s.chunk_id AS chunk_id, - s.created_at AS created_at, - s.valid_at AS valid_at, - s.invalid_at AS invalid_at, - c.id AS chunk_id_from_rel, - collect(DISTINCT e.id) AS entity_ids, - score -ORDER BY s.created_at DESC, score DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_BY_CREATED_AT = """ -MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) - AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at)) -RETURN n.id AS id, - n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, - n.chunk_id AS chunk_id, - n.created_at AS created_at, - n.valid_at AS valid_at, - n.invalid_at AS invalid_at, - collect(DISTINCT n.id) AS statement_ids -ORDER BY n.created_at DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_BY_VALID_AT = """ -MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) - AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at)) -RETURN n.id AS id, - n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, - n.chunk_id AS chunk_id, - n.created_at AS created_at, - n.valid_at AS valid_at, - n.invalid_at AS invalid_at, - collect(DISTINCT n.id) AS statement_ids -ORDER BY n.valid_at DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_G_CREATED_AT = """ -MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) - AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at)) -RETURN n.id AS id, - n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, - n.chunk_id AS chunk_id, - n.created_at AS created_at, - n.valid_at AS valid_at, - n.invalid_at AS invalid_at, - collect(DISTINCT n.id) AS statement_ids -ORDER BY n.created_at DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_L_CREATED_AT = """ -MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) - AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at)) -RETURN n.id AS id, - n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, - n.chunk_id AS chunk_id, - n.created_at AS created_at, - n.valid_at AS valid_at, - n.invalid_at AS invalid_at, - collect(DISTINCT n.id) AS statement_ids -ORDER BY n.created_at DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_G_VALID_AT = """ -MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) - AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at)) -RETURN n.id AS id, - n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, - n.chunk_id AS chunk_id, - n.created_at AS created_at, - n.valid_at AS valid_at, - n.invalid_at AS invalid_at, - collect(DISTINCT n.id) AS statement_ids -ORDER BY n.valid_at DESC -LIMIT $limit -""" - -SEARCH_STATEMENTS_L_VALID_AT = """ -MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) - AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at)) -RETURN n.id AS id, - n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, - n.chunk_id AS chunk_id, - n.created_at AS created_at, - n.valid_at AS valid_at, - n.invalid_at AS invalid_at, - collect(DISTINCT n.id) AS statement_ids -ORDER BY n.valid_at DESC -LIMIT $limit -""" - -# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用 - -# # 同组group_id下按“精确名字或别名+可选类型一致”来检索 -# SECOND_LAYER_CANDIDATE_MATCH_BATCH = """ -# UNWIND $rows AS row -# MATCH (e:ExtractedEntity) -# WHERE e.group_id = row.group_id -# AND (toLower(e.name) = toLower(row.name) OR any(a IN e.aliases WHERE toLower(a) = toLower(row.name))) -# AND (row.entity_type IS NULL OR e.entity_type = row.entity_type) -# RETURN row.id AS incoming_id, -# e.id AS id, -# e.name AS name, -# e.group_id AS group_id, -# e.entity_idx AS entity_idx, -# e.entity_type AS entity_type, -# e.description AS description, -# e.statement_id AS statement_id, -# e.aliases AS aliases, -# e.name_embedding AS name_embedding, -# e.fact_summary AS fact_summary, -# e.connect_strength AS connect_strength, -# e.created_at AS created_at, -# e.expired_at AS expired_at -# """ -# # 同组group_id下按name contains召回补充 -# SECOND_LAYER_CANDIDATE_CONTAINS_BATCH = """ -# UNWIND $rows AS row -# MATCH (e:ExtractedEntity) -# WHERE e.group_id = row.group_id -# AND toLower(e.name) CONTAINS toLower(row.name) -# RETURN row.id AS incoming_id, -# e.id AS id, -# e.name AS name, -# e.group_id AS group_id, -# e.entity_idx AS entity_idx, -# e.entity_type AS entity_type, -# e.description AS description, -# e.statement_id AS statement_id, -# e.aliases AS aliases, -# e.name_embedding AS name_embedding, -# e.fact_summary AS fact_summary, -# e.connect_strength AS connect_strength, -# e.created_at AS created_at, -# e.expired_at AS expired_at -# """ - -# 根据id修改句子的invalid_at的值 -UPDATE_STATEMENT_INVALID_AT = """ -MATCH (n:Statement {group_id: $group_id, id: $id}) -SET n.invalid_at = $new_invalid_at -""" - -# MemorySummary keyword search using fulltext index -SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """ -CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score -WHERE ($group_id IS NULL OR m.group_id = $group_id) -OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement) -RETURN m.id AS id, - m.name AS name, - m.group_id AS group_id, - m.dialog_id AS dialog_id, - m.chunk_ids AS chunk_ids, - m.content AS content, - m.created_at AS created_at, - score -ORDER BY score DESC -LIMIT $limit -""" - -# Embedding-based search: cosine similarity on MemorySummary.summary_embedding -MEMORY_SUMMARY_EMBEDDING_SEARCH = """ -CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding) -YIELD node AS m, score -WHERE m.summary_embedding IS NOT NULL - AND ($group_id IS NULL OR m.group_id = $group_id) -RETURN m.id AS id, - m.name AS name, - m.group_id AS group_id, - m.dialog_id AS dialog_id, - m.chunk_ids AS chunk_ids, - m.content AS content, - m.created_at AS created_at, - score -ORDER BY score DESC -LIMIT $limit -""" - -MEMORY_SUMMARY_NODE_SAVE = """ -UNWIND $summaries AS summary -MERGE (m:MemorySummary {id: summary.id}) -SET m += { - id: summary.id, - name: summary.name, - group_id: summary.group_id, - user_id: summary.user_id, - apply_id: summary.apply_id, - run_id: summary.run_id, - created_at: summary.created_at, - expired_at: summary.expired_at, - dialog_id: summary.dialog_id, - chunk_ids: summary.chunk_ids, - content: summary.content, - summary_embedding: summary.summary_embedding, - config_id: summary.config_id -} -RETURN m.id AS uuid -""" - -MEMORY_SUMMARY_STATEMENT_EDGE_SAVE = """ -UNWIND $edges AS e -MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id}) -MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id}) -MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id}) -MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s) -SET r.group_id = e.group_id, - r.run_id = e.run_id, - r.created_at = e.created_at, - r.expired_at = e.expired_at -RETURN elementId(r) AS uuid -""" diff --git a/app/repositories/neo4j/dialog_repository.py b/app/repositories/neo4j/dialog_repository.py deleted file mode 100644 index ccb3d94c..00000000 --- a/app/repositories/neo4j/dialog_repository.py +++ /dev/null @@ -1,185 +0,0 @@ -# -*- coding: utf-8 -*- -"""对话仓储模块 - -本模块提供对话节点的数据访问功能。 - -Classes: - DialogRepository: 对话仓储,管理DialogueNode的CRUD操作 -""" - -from typing import List, Optional, Dict -from datetime import datetime - -from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository -from app.core.memory.models.graph_models import DialogueNode -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - -class DialogRepository(BaseNeo4jRepository[DialogueNode]): - """对话仓储 - - 管理对话节点的创建、查询、更新和删除操作。 - 提供按group_id、user_id、ref_id等条件查询对话的方法。 - - Attributes: - connector: Neo4j连接器实例 - node_label: 节点标签,固定为"Dialogue" - """ - - def __init__(self, connector: Neo4jConnector): - """初始化对话仓储 - - Args: - connector: Neo4j连接器实例 - """ - super().__init__(connector, "Dialogue") - - def _map_to_entity(self, node_data: Dict) -> DialogueNode: - """将节点数据映射为对话实体 - - Args: - node_data: 从Neo4j查询返回的节点数据字典 - - Returns: - DialogueNode: 对话实体对象 - """ - # 从查询结果中提取节点数据 - n = node_data.get('n', node_data) - - # 处理datetime字段 - if isinstance(n.get('created_at'), str): - n['created_at'] = datetime.fromisoformat(n['created_at']) - if n.get('expired_at') and isinstance(n['expired_at'], str): - n['expired_at'] = datetime.fromisoformat(n['expired_at']) - - return DialogueNode(**n) - - async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]: - """根据group_id查询对话 - - Args: - group_id: 组ID - limit: 返回结果的最大数量 - - Returns: - List[DialogueNode]: 对话列表 - """ - return await self.find({"group_id": group_id}, limit=limit) - - async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]: - """根据user_id查询对话 - - Args: - user_id: 用户ID - limit: 返回结果的最大数量 - - Returns: - List[DialogueNode]: 对话列表 - """ - return await self.find({"user_id": user_id}, limit=limit) - - async def find_by_ref_id(self, ref_id: str) -> Optional[DialogueNode]: - """根据ref_id查询对话 - - ref_id是外部对话系统的引用ID,通常是唯一的。 - - Args: - ref_id: 引用ID - - Returns: - Optional[DialogueNode]: 找到的对话,如果不存在则返回None - """ - results = await self.find({"ref_id": ref_id}, limit=1) - return results[0] if results else None - - async def find_by_group_and_user( - self, - group_id: str, - user_id: str, - limit: int = 100 - ) -> List[DialogueNode]: - """根据group_id和user_id查询对话 - - Args: - group_id: 组ID - user_id: 用户ID - limit: 返回结果的最大数量 - - Returns: - List[DialogueNode]: 对话列表 - """ - return await self.find( - {"group_id": group_id, "user_id": user_id}, - limit=limit - ) - - async def find_recent_dialogs( - self, - group_id: str, - days: int = 7, - limit: int = 100 - ) -> List[DialogueNode]: - """查询最近的对话 - - Args: - group_id: 组ID - days: 查询最近多少天的对话 - limit: 返回结果的最大数量 - - Returns: - List[DialogueNode]: 对话列表,按创建时间倒序排列 - """ - query = f""" - MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id - AND n.created_at >= datetime() - duration({{days: $days}}) - RETURN n - ORDER BY n.created_at DESC - LIMIT $limit - """ - results = await self.connector.execute_query( - query, - group_id=group_id, - days=days, - limit=limit - ) - return [self._map_to_entity(r) for r in results] - - async def find_by_config_id( - self, - config_id: str, - limit: int = 100 - ) -> List[DialogueNode]: - """根据config_id查询对话 - - Args: - config_id: 配置ID - limit: 返回结果的最大数量 - - Returns: - List[DialogueNode]: 对话列表 - """ - return await self.find({"config_id": config_id}, limit=limit) - - async def find_by_config_and_group( - self, - config_id: str, - group_id: str, - limit: int = 100 - ) -> List[DialogueNode]: - """根据config_id和group_id查询对话 - - 支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。 - - Args: - config_id: 配置ID - group_id: 组ID - limit: 返回结果的最大数量 - - Returns: - List[DialogueNode]: 对话列表 - """ - return await self.find( - {"config_id": config_id, "group_id": group_id}, - limit=limit - ) diff --git a/app/repositories/neo4j/entity_repository.py b/app/repositories/neo4j/entity_repository.py deleted file mode 100644 index ef2e5170..00000000 --- a/app/repositories/neo4j/entity_repository.py +++ /dev/null @@ -1,339 +0,0 @@ -# -*- coding: utf-8 -*- -"""实体仓储模块 - -本模块提供实体节点的数据访问功能。 - -Classes: - EntityRepository: 实体仓储,管理ExtractedEntityNode的CRUD操作 -""" - -from typing import List, Optional, Dict -from datetime import datetime - -from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository -from app.core.memory.models.graph_models import ExtractedEntityNode -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - -class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]): - """实体仓储 - - 管理实体节点的创建、查询、更新和删除操作。 - 提供按类型、名称、向量相似度等条件查询实体的方法。 - - Attributes: - connector: Neo4j连接器实例 - node_label: 节点标签,固定为"ExtractedEntity" - """ - - def __init__(self, connector: Neo4jConnector): - """初始化实体仓储 - - Args: - connector: Neo4j连接器实例 - """ - super().__init__(connector, "ExtractedEntity") - - def _map_to_entity(self, node_data: Dict) -> ExtractedEntityNode: - """将节点数据映射为实体对象 - - Args: - node_data: 从Neo4j查询返回的节点数据字典 - - Returns: - ExtractedEntityNode: 实体对象 - """ - # 从查询结果中提取节点数据 - n = node_data.get('n', node_data) - - # 处理datetime字段 - if isinstance(n.get('created_at'), str): - n['created_at'] = datetime.fromisoformat(n['created_at']) - if n.get('expired_at') and isinstance(n['expired_at'], str): - n['expired_at'] = datetime.fromisoformat(n['expired_at']) - - return ExtractedEntityNode(**n) - - async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]: - """根据实体类型查询 - - Args: - entity_type: 实体类型(如"Person", "Organization"等) - limit: 返回结果的最大数量 - - Returns: - List[ExtractedEntityNode]: 实体列表 - """ - return await self.find({"entity_type": entity_type}, limit=limit) - - async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[ExtractedEntityNode]: - """根据group_id查询实体 - - Args: - group_id: 组ID - limit: 返回结果的最大数量 - - Returns: - List[ExtractedEntityNode]: 实体列表 - """ - return await self.find({"group_id": group_id}, limit=limit) - - async def find_by_name( - self, - name: str, - group_id: Optional[str] = None, - limit: int = 100 - ) -> List[ExtractedEntityNode]: - """根据名称查询实体 - - 支持模糊匹配(CONTAINS)。 - - Args: - name: 实体名称 - group_id: 可选的组ID过滤 - limit: 返回结果的最大数量 - - Returns: - List[ExtractedEntityNode]: 实体列表 - """ - where_clause = "n.name CONTAINS $name" - if group_id: - where_clause += " AND n.group_id = $group_id" - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_clause} - RETURN n - LIMIT $limit - """ - - params = {"name": name, "limit": limit} - if group_id: - params["group_id"] = group_id - - results = await self.connector.execute_query(query, **params) - return [self._map_to_entity(r) for r in results] - - async def find_related_entities( - self, - entity_id: str, - relation_type: Optional[str] = None, - limit: int = 100 - ) -> List[ExtractedEntityNode]: - """查询相关实体 - - 查询与指定实体有关系的其他实体。 - - Args: - entity_id: 实体ID - relation_type: 可选的关系类型过滤 - limit: 返回结果的最大数量 - - Returns: - List[ExtractedEntityNode]: 相关实体列表 - """ - if relation_type: - query = """ - MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO {relation_type: $relation_type}]->(e2:ExtractedEntity) - RETURN e2 as n - LIMIT $limit - """ - results = await self.connector.execute_query( - query, - entity_id=entity_id, - relation_type=relation_type, - limit=limit - ) - else: - query = """ - MATCH (e1:ExtractedEntity {id: $entity_id})-[r:RELATES_TO]->(e2:ExtractedEntity) - RETURN e2 as n - LIMIT $limit - """ - results = await self.connector.execute_query( - query, - entity_id=entity_id, - limit=limit - ) - - return [self._map_to_entity(r) for r in results] - - async def search_by_embedding( - self, - embedding: List[float], - group_id: Optional[str] = None, - limit: int = 10, - min_score: float = 0.7 - ) -> List[Dict]: - """基于向量相似度搜索实体 - - 使用余弦相似度计算查询向量与实体名称向量的相似度。 - - Args: - embedding: 查询向量 - group_id: 可选的组ID过滤 - limit: 返回结果的最大数量 - min_score: 最小相似度分数阈值 - - Returns: - List[Dict]: 包含实体和相似度分数的字典列表 - 每个字典包含: entity (ExtractedEntityNode), score (float) - """ - where_clause = "n.name_embedding IS NOT NULL" - if group_id: - where_clause += " AND n.group_id = $group_id" - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_clause} - WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score - WHERE score > $min_score - RETURN n, score - ORDER BY score DESC - LIMIT $limit - """ - - params = { - "embedding": embedding, - "min_score": min_score, - "limit": limit - } - if group_id: - params["group_id"] = group_id - - results = await self.connector.execute_query(query, **params) - - return [ - { - "entity": self._map_to_entity(r), - "score": r.get("score", 0.0) - } - for r in results - ] - - async def find_by_statement_id(self, statement_id: str) -> List[ExtractedEntityNode]: - """根据陈述句ID查询实体 - - 查询从指定陈述句中提取的所有实体。 - - Args: - statement_id: 陈述句ID - - Returns: - List[ExtractedEntityNode]: 实体列表 - """ - return await self.find({"statement_id": statement_id}) - - async def find_strong_entities( - self, - group_id: str, - limit: int = 100 - ) -> List[ExtractedEntityNode]: - """查询强连接的实体 - - Args: - group_id: 组ID - limit: 返回结果的最大数量 - - Returns: - List[ExtractedEntityNode]: 强连接的实体列表 - """ - return await self.find( - {"group_id": group_id, "connect_strength": "Strong"}, - limit=limit - ) - - async def get_entity_count_by_type(self, group_id: str) -> Dict[str, int]: - """统计各类型实体的数量 - - Args: - group_id: 组ID - - Returns: - Dict[str, int]: 实体类型到数量的映射 - """ - query = """ - MATCH (n:ExtractedEntity {group_id: $group_id}) - RETURN n.entity_type as entity_type, count(n) as count - ORDER BY count DESC - """ - results = await self.connector.execute_query(query, group_id=group_id) - return {r["entity_type"]: r["count"] for r in results} - - async def find_by_config_id( - self, - config_id: str, - limit: int = 100 - ) -> List[ExtractedEntityNode]: - """根据config_id查询实体 - - Args: - config_id: 配置ID - limit: 返回结果的最大数量 - - Returns: - List[ExtractedEntityNode]: 实体列表 - """ - return await self.find({"config_id": config_id}, limit=limit) - - async def search_by_embedding_with_config( - self, - embedding: List[float], - config_id: Optional[str] = None, - group_id: Optional[str] = None, - limit: int = 10, - min_score: float = 0.7 - ) -> List[Dict]: - """基于向量相似度搜索实体,可选择按config_id过滤 - - 使用余弦相似度计算查询向量与实体名称向量的相似度。 - 支持按config_id过滤结果,确保只返回使用特定配置处理的实体。 - - Args: - embedding: 查询向量 - config_id: 可选的配置ID过滤 - group_id: 可选的组ID过滤 - limit: 返回结果的最大数量 - min_score: 最小相似度分数阈值 - - Returns: - List[Dict]: 包含实体和相似度分数的字典列表 - 每个字典包含: entity (ExtractedEntityNode), score (float) - """ - # 构建查询条件 - where_clauses = ["n.name_embedding IS NOT NULL"] - params = { - "embedding": embedding, - "min_score": min_score, - "limit": limit - } - - if config_id: - where_clauses.append("n.config_id = $config_id") - params["config_id"] = config_id - - if group_id: - where_clauses.append("n.group_id = $group_id") - params["group_id"] = group_id - - where_str = " AND ".join(where_clauses) - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_str} - WITH n, gds.similarity.cosine(n.name_embedding, $embedding) AS score - WHERE score > $min_score - RETURN n, score - ORDER BY score DESC - LIMIT $limit - """ - - results = await self.connector.execute_query(query, **params) - - return [ - { - "entity": self._map_to_entity(r), - "score": r.get("score", 0.0) - } - for r in results - ] diff --git a/app/repositories/neo4j/graph_saver.py b/app/repositories/neo4j/graph_saver.py deleted file mode 100644 index 13215e0f..00000000 --- a/app/repositories/neo4j/graph_saver.py +++ /dev/null @@ -1,216 +0,0 @@ -from typing import List - -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.add_nodes import add_dialogue_nodes, add_statement_nodes, add_chunk_nodes -from app.repositories.neo4j.cypher_queries import ( - STATEMENT_ENTITY_EDGE_SAVE, - ENTITY_RELATIONSHIP_SAVE, - EXTRACTED_ENTITY_NODE_SAVE, - CHUNK_STATEMENT_EDGE_SAVE, - STATEMENT_ENTITY_EDGE_SAVE, - ENTITY_RELATIONSHIP_SAVE, - EXTRACTED_ENTITY_NODE_SAVE, -) -from app.core.memory.models.graph_models import ( - DialogueNode, - ChunkNode, - StatementChunkEdge, - StatementEntityEdge, - StatementNode, - ExtractedEntityNode, - EntityEntityEdge, -) - -async def save_entities_and_relationships( - entity_nodes: List[ExtractedEntityNode], - entity_entity_edges: List[EntityEntityEdge], - connector: Neo4jConnector -): - """Save entities and their relationships using graph models""" - all_entities = [entity.model_dump() for entity in entity_nodes] - all_relationships = [] - - for edge in entity_entity_edges: - relationship = { - 'source_id': edge.source, - 'target_id': edge.target, - 'predicate': edge.relation_type, - 'statement_id': edge.source_statement_id, - 'value': edge.relation_value, - 'statement': edge.statement, - 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, - 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, - 'created_at': edge.created_at.isoformat(), - 'expired_at': edge.expired_at.isoformat(), - 'run_id': edge.run_id, - 'group_id': edge.group_id, - 'user_id': edge.user_id, - 'apply_id': edge.apply_id, - } - all_relationships.append(relationship) - - # Save entities - if all_entities: - entity_uuids = await connector.execute_query(EXTRACTED_ENTITY_NODE_SAVE, entities=all_entities) - if entity_uuids: - print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") - else: - print("Failed to save entity nodes to Neo4j") - else: - print("No entity nodes to save") - - # Create relationships - if all_relationships: - relationship_uuids = await connector.execute_query(ENTITY_RELATIONSHIP_SAVE, relationships=all_relationships) - if relationship_uuids: - print(f"Successfully saved {len(relationship_uuids)} entity relationships (edges) to Neo4j") - else: - print("Failed to save entity relationships to Neo4j") - else: - print("No entity relationships to save") - - -async def save_chunk_nodes( - chunk_nodes: List[ChunkNode], - connector: Neo4jConnector -): - """Save chunk nodes using graph models""" - if not chunk_nodes: - print("No chunk nodes to save") - return - - chunk_uuids = await add_chunk_nodes(chunk_nodes, connector) - if chunk_uuids: - print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") - else: - print("Failed to save chunk nodes to Neo4j") - - -async def save_statement_chunk_edges( - statement_chunk_edges: List[StatementChunkEdge], - connector: Neo4jConnector -): - """Save statement-chunk edges using graph models""" - if not statement_chunk_edges: - return - - all_sc_edges = [] - for edge in statement_chunk_edges: - all_sc_edges.append({ - "id": edge.id, - "source": edge.source, - "target": edge.target, - "group_id": edge.group_id, - "user_id": edge.user_id, - "apply_id": edge.apply_id, - "run_id": edge.run_id, - "created_at": edge.created_at.isoformat() if edge.created_at else None, - "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, - }) - - try: - await connector.execute_query( - CHUNK_STATEMENT_EDGE_SAVE, - chunk_statement_edges=all_sc_edges - ) - except Exception: - pass - - -async def save_statement_entity_edges( - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector -): - """Save statement-entity edges using graph models""" - if not statement_entity_edges: - print("No statement-entity edges to save") - return - - all_se_edges = [] - for edge in statement_entity_edges: - edge_data = { - "source": edge.source, - "target": edge.target, - "group_id": edge.group_id, - "user_id": edge.user_id, - "apply_id": edge.apply_id, - "run_id": edge.run_id, - "connect_strength": edge.connect_strength, - "created_at": edge.created_at.isoformat() if edge.created_at else None, - "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, - } - all_se_edges.append(edge_data) - - if all_se_edges: - try: - await connector.execute_query( - STATEMENT_ENTITY_EDGE_SAVE, - relationships=all_se_edges - ) - except Exception: - pass - - -async def save_dialog_and_statements_to_neo4j( - dialogue_nodes: List[DialogueNode], - chunk_nodes: List[ChunkNode], - statement_nodes: List[StatementNode], - entity_nodes: List[ExtractedEntityNode], - entity_edges: List[EntityEntityEdge], - statement_chunk_edges: List[StatementChunkEdge], - statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector -) -> bool: - """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. - - Args: - dialogue_nodes: List of DialogueNode objects to save - chunk_nodes: List of ChunkNode objects to save - statement_nodes: List of StatementNode objects to save - entity_nodes: List of ExtractedEntityNode objects to save - entity_edges: List of EntityEntityEdge objects to save - statement_chunk_edges: List of StatementChunkEdge objects to save - statement_entity_edges: List of StatementEntityEdge objects to save - connector: Neo4j connector instance - - Returns: - bool: True if successful, False otherwise - """ - try: - # Save all dialogue nodes in batch - dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector) - if dialogue_uuids: - print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") - else: - print("Failed to save dialogues to Neo4j") - return False - - # Save all chunk nodes in batch - await save_chunk_nodes(chunk_nodes, connector) - - # Save all statement nodes in batch - if statement_nodes: - statement_uuids = await add_statement_nodes(statement_nodes, connector) - if statement_uuids: - print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") - else: - print("Failed to save statement nodes to Neo4j") - return False - else: - print("No statement nodes to save") - - # Save entities and relationships - await save_entities_and_relationships(entity_nodes, entity_edges, connector) - print("Successfully saved entities and relationships to Neo4j") - - # Save new edges - await save_statement_chunk_edges(statement_chunk_edges, connector) - await save_statement_entity_edges(statement_entity_edges, connector) - - return True - - except Exception as e: - print(f"Neo4j integration error: {e}") - print("Continuing without database storage...") - return False diff --git a/app/repositories/neo4j/graph_search.py b/app/repositories/neo4j/graph_search.py deleted file mode 100644 index ab2b28ac..00000000 --- a/app/repositories/neo4j/graph_search.py +++ /dev/null @@ -1,584 +0,0 @@ -from typing import Any, Dict, List, Optional -import asyncio - -# 使用新的仓储层 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.cypher_queries import ( - SEARCH_STATEMENTS_BY_KEYWORD, - SEARCH_ENTITIES_BY_NAME, - SEARCH_CHUNKS_BY_CONTENT, - STATEMENT_EMBEDDING_SEARCH, - CHUNK_EMBEDDING_SEARCH, - ENTITY_EMBEDDING_SEARCH, - SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, - MEMORY_SUMMARY_EMBEDDING_SEARCH, - SEARCH_STATEMENTS_BY_TEMPORAL, - SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, - SEARCH_DIALOGUE_BY_DIALOG_ID, - SEARCH_CHUNK_BY_CHUNK_ID, - SEARCH_STATEMENTS_BY_CREATED_AT, - SEARCH_STATEMENTS_BY_VALID_AT, - SEARCH_STATEMENTS_G_CREATED_AT, - SEARCH_STATEMENTS_L_CREATED_AT, - SEARCH_STATEMENTS_G_VALID_AT, - SEARCH_STATEMENTS_L_VALID_AT, -) - - -async def search_graph( - connector: Neo4jConnector, - q: str, - group_id: Optional[str] = None, - limit: int = 50, - include: List[str] = None, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Search across Statements, Entities, Chunks, and Summaries using a free-text query. - - OPTIMIZED: Runs all queries in parallel using asyncio.gather() - - - Statements: matches s.statement CONTAINS q - - Entities: matches e.name CONTAINS q - - Chunks: matches s.content CONTAINS q (from Statement nodes) - - Summaries: matches ms.content CONTAINS q - - Args: - connector: Neo4j connector - q: Query text - group_id: Optional group filter - limit: Max results per category - include: List of categories to search (default: all) - - Returns: - Dictionary with search results per category - """ - if include is None: - include = ["statements", "chunks", "entities", "summaries"] - - # Prepare tasks for parallel execution - tasks = [] - task_keys = [] - - if "statements" in include: - tasks.append(connector.execute_query( - SEARCH_STATEMENTS_BY_KEYWORD, - q=q, - group_id=group_id, - limit=limit, - )) - task_keys.append("statements") - - if "entities" in include: - tasks.append(connector.execute_query( - SEARCH_ENTITIES_BY_NAME, - q=q, - group_id=group_id, - limit=limit, - )) - task_keys.append("entities") - - if "chunks" in include: - tasks.append(connector.execute_query( - SEARCH_CHUNKS_BY_CONTENT, - q=q, - group_id=group_id, - limit=limit, - )) - task_keys.append("chunks") - - if "summaries" in include: - tasks.append(connector.execute_query( - SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, - q=q, - group_id=group_id, - limit=limit, - )) - task_keys.append("summaries") - - # Execute all queries in parallel - task_results = await asyncio.gather(*tasks, return_exceptions=True) - - # Build results dictionary - results = {} - for key, result in zip(task_keys, task_results): - if isinstance(result, Exception): - results[key] = [] - else: - results[key] = result - - return results - - -async def search_graph_by_embedding( - connector: Neo4jConnector, - embedder_client, - query_text: str, - group_id: Optional[str] = None, - limit: int = 50, - include: List[str] = ["statements", "chunks", "entities","summaries"], -) -> Dict[str, List[Dict[str, Any]]]: - """ - Embedding-based semantic search across Statements, Chunks, and Entities. - - OPTIMIZED: Runs all queries in parallel using asyncio.gather() - - - Computes query embedding with the provided embedder_client - - Ranks by cosine similarity in Cypher - - Filters by group_id if provided - - Returns up to 'limit' per included type - """ - import time - - # Get embedding for the query - embed_start = time.time() - embeddings = await embedder_client.response([query_text]) - embed_time = time.time() - embed_start - print(f"[PERF] Embedding generation took: {embed_time:.4f}s") - - if not embeddings or not embeddings[0]: - return {"statements": [], "chunks": [], "entities": [], "summaries": []} - embedding = embeddings[0] - - # Prepare tasks for parallel execution - tasks = [] - task_keys = [] - - # Statements (embedding) - if "statements" in include: - tasks.append(connector.execute_query( - STATEMENT_EMBEDDING_SEARCH, - embedding=embedding, - group_id=group_id, - limit=limit, - )) - task_keys.append("statements") - - # Chunks (embedding) - if "chunks" in include: - tasks.append(connector.execute_query( - CHUNK_EMBEDDING_SEARCH, - embedding=embedding, - group_id=group_id, - limit=limit, - )) - task_keys.append("chunks") - - # Entities - if "entities" in include: - tasks.append(connector.execute_query( - ENTITY_EMBEDDING_SEARCH, - embedding=embedding, - group_id=group_id, - limit=limit, - )) - task_keys.append("entities") - - # Memory summaries - if "summaries" in include: - tasks.append(connector.execute_query( - MEMORY_SUMMARY_EMBEDDING_SEARCH, - embedding=embedding, - group_id=group_id, - limit=limit, - )) - task_keys.append("summaries") - - # Execute all queries in parallel - query_start = time.time() - task_results = await asyncio.gather(*tasks, return_exceptions=True) - query_time = time.time() - query_start - print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") - - # Build results dictionary - results: Dict[str, List[Dict[str, Any]]] = { - "statements": [], - "chunks": [], - "entities": [], - "summaries": [], - } - - for key, result in zip(task_keys, task_results): - if isinstance(result, Exception): - results[key] = [] - else: - results[key] = result - - return results -async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 - connector: Neo4jConnector, - group_id: str, - entities: List[Dict[str, Any]], - use_contains_fallback: bool = True, - batch_size: int = 500, - max_concurrency: int = 5, -) -> Dict[str, List[Dict[str, Any]]]: - """ - 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries): - - 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选; - - 保留并发控制与返回结构(incoming_id -> [db_entity_props...]); - - 若提供 `entity_type`,在本地对返回结果做类型过滤; - - `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。 - - 返回:incoming_id -> [db_entity_props...] - """ - - if not entities: - return {} - - sem = asyncio.Semaphore(max_concurrency) - - async def _query_by_name(incoming: Dict[str, Any]) -> tuple[str, List[Dict[str, Any]]]: - async with sem: - inc_id = incoming.get("id") or "__unknown__" - name = (incoming.get("name") or "").strip() - if not name: - return inc_id, [] - try: - # 全文索引按名称检索(包含 CONTAINS 语义) - rows = await connector.execute_query( - SEARCH_ENTITIES_BY_NAME, - q=name, - group_id=group_id, - limit=100, - ) - except Exception: - rows = [] - - # 可选本地类型过滤(若输入实体提供类型) - typ = incoming.get("entity_type") - if typ: - try: - rows = [r for r in rows if (r.get("entity_type") == typ)] - except Exception: - pass - - # 注入 incoming_id 以保持兼容下游合并逻辑 - for r in rows: - r["incoming_id"] = inc_id - - # 简单的降级:若为空且允许 fallback,可按小写名再次查询 - if use_contains_fallback and not rows and name: - try: - rows = await connector.execute_query( - SEARCH_ENTITIES_BY_NAME, - q=name.lower(), - group_id=group_id, - limit=100, - ) - for r in rows: - r["incoming_id"] = inc_id - except Exception: - pass - - return inc_id, rows - - tasks = [_query_by_name(e) for e in entities] - results = await asyncio.gather(*tasks, return_exceptions=True) - - merged: Dict[str, List[Dict[str, Any]]] = {} - for res in results: - if isinstance(res, Exception): - # 静默跳过单条失败 - continue - inc_id, rows = res - inc_id = inc_id or "__unknown__" - merged.setdefault(inc_id, []) - existing_ids = {x.get("id") for x in merged[inc_id]} - for rec in rows: - if rec.get("id") not in existing_ids: - merged[inc_id].append(rec) - return merged - - -async def search_graph_by_keyword_temporal( - connector: Neo4jConnector, - query_text: str, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 50, -) -> Dict[str, List[Any]]: - """ - Temporal keyword search across Statements. - - - Matches statements containing query_text created between start_date and end_date - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - if not query_text: - print(f"query_text不能为空") - return {"statements": []} - statements = await connector.execute_query( - SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, - q=query_text, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - valid_date=valid_date, - invalid_date=invalid_date, - limit=limit, - ) - print(f"查询结果为:\n{statements}") - - return {"statements": statements} - - -async def search_graph_by_temporal( - connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 10, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Statements. - - - Matches statements created between start_date and end_date - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - statements = await connector.execute_query( - SEARCH_STATEMENTS_BY_TEMPORAL, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - valid_date=valid_date, - invalid_date=invalid_date, - limit=limit, - ) - - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - return {"statements": statements} - - -async def search_graph_by_dialog_id( - connector: Neo4jConnector, - dialog_id: str, - group_id: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Dialogues. - - - Matches dialogues with dialog_id - - Optionally filters by group_id - - Returns up to 'limit' dialogues - """ - if not dialog_id: - print(f"dialog_id不能为空") - return {"dialogues": []} - - dialogues = await connector.execute_query( - SEARCH_DIALOGUE_BY_DIALOG_ID, - group_id=group_id, - dialog_id=dialog_id, - limit=limit, - ) - return {"dialogues": dialogues} - - -async def search_graph_by_chunk_id( - connector: Neo4jConnector, - chunk_id : str, - group_id: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - if not chunk_id: - print(f"chunk_id不能为空") - return {"chunks": []} - chunks = await connector.execute_query( - SEARCH_CHUNK_BY_CHUNK_ID, - group_id=group_id, - chunk_id=chunk_id, - limit=limit, - ) - return {"chunks": chunks} - - -async def search_graph_by_created_at( - connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - created_at: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Statements. - - - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - statements = await connector.execute_query( - SEARCH_STATEMENTS_BY_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - created_at=created_at, - limit=limit, - ) - - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - return {"statements": statements} - -async def search_graph_by_valid_at( - connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - valid_at: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Statements. - - - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - statements = await connector.execute_query( - SEARCH_STATEMENTS_BY_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - valid_at=valid_at, - limit=limit, - ) - - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - return {"statements": statements} - -async def search_graph_g_created_at( - connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - created_at: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Statements. - - - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - statements = await connector.execute_query( - SEARCH_STATEMENTS_G_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - created_at=created_at, - limit=limit, - ) - - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - return {"statements": statements} - -async def search_graph_g_valid_at( - connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - valid_at: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Statements. - - - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - statements = await connector.execute_query( - SEARCH_STATEMENTS_G_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - valid_at=valid_at, - limit=limit, - ) - - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - return {"statements": statements} - -async def search_graph_l_created_at( - connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - created_at: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Statements. - - - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - statements = await connector.execute_query( - SEARCH_STATEMENTS_L_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - created_at=created_at, - limit=limit, - ) - - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - return {"statements": statements} - -async def search_graph_l_valid_at( - connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, - valid_at: Optional[str] = None, - limit: int = 1, -) -> Dict[str, List[Dict[str, Any]]]: - """ - Temporal search across Statements. - - - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id - - Returns up to 'limit' statements - """ - statements = await connector.execute_query( - SEARCH_STATEMENTS_L_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, - valid_at=valid_at, - limit=limit, - ) - - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - return {"statements": statements} diff --git a/app/repositories/neo4j/neo4j_connector.py b/app/repositories/neo4j/neo4j_connector.py deleted file mode 100644 index 642661d4..00000000 --- a/app/repositories/neo4j/neo4j_connector.py +++ /dev/null @@ -1,114 +0,0 @@ -# -*- coding: utf-8 -*- -"""Neo4j连接器模块 - -本模块提供Neo4j图数据库的连接和查询功能。 -从 app/core/memory/src/database/neo4j_connector.py 迁移而来。 - -Classes: - Neo4jConnector: Neo4j数据库连接器,提供异步查询接口 -""" - -import os -from typing import Any, List, Dict - -from neo4j import AsyncGraphDatabase, basic_auth - -from app.core.config import settings - - -class Neo4jConnector: - """Neo4j数据库连接器 - - 提供与Neo4j图数据库的连接和查询功能。 - 使用异步驱动程序以支持高并发操作。 - - Attributes: - driver: Neo4j异步驱动程序实例 - - Methods: - close: 关闭数据库连接 - execute_query: 执行Cypher查询 - delete_group: 删除指定组的所有数据 - """ - - def __init__(self): - """初始化Neo4j连接器 - - 从配置文件和环境变量中读取连接信息。 - - Raises: - RuntimeError: 如果NEO4J_PASSWORD环境变量未设置 - """ - # 从全局配置和环境变量获取 Neo4j 配置 - uri = settings.NEO4J_URI - username = settings.NEO4J_USERNAME - password = settings.NEO4J_PASSWORD - - if not password: - raise RuntimeError( - "NEO4J_PASSWORD is not set. Create a .env with NEO4J_PASSWORD or export it before running." - ) - self.driver = AsyncGraphDatabase.driver( - uri, - auth=basic_auth(username, password) - ) - - async def close(self): - """关闭数据库连接 - - 释放数据库连接资源。应在应用程序关闭时调用。 - """ - await self.driver.close() - - async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: - """执行Cypher查询 - - Args: - query: Cypher查询语句 - **kwargs: 查询参数,将作为参数传递给Cypher查询 - - Returns: - List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典 - - Example: - >>> connector = Neo4jConnector() - >>> results = await connector.execute_query( - ... "MATCH (n:Person {name: $name}) RETURN n", - ... name="Alice" - ... ) - """ - result = await self.driver.execute_query( - query, - database="neo4j", - **kwargs - ) - records, summary, keys = result - return [record.data() for record in records] - - async def delete_group(self, group_id: str): - """删除指定组的所有数据 - - 删除所有属于指定group_id的节点和边。 - 这是一个危险操作,会永久删除数据。 - - Args: - group_id: 要删除的组ID - - Example: - >>> connector = Neo4jConnector() - >>> await connector.delete_group("group_123") - Group group_123 deleted. - """ - # 删除节点(DETACH DELETE会同时删除相关的边) - await self.driver.execute_query( - "MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n", - database="neo4j", - group_id=group_id - ) - # 删除独立的边(如果有的话) - await self.driver.execute_query( - "MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r", - database="neo4j", - group_id=group_id - ) - print(f"Group {group_id} deleted.") diff --git a/app/repositories/neo4j/statement_repository.py b/app/repositories/neo4j/statement_repository.py deleted file mode 100644 index 816bf06e..00000000 --- a/app/repositories/neo4j/statement_repository.py +++ /dev/null @@ -1,319 +0,0 @@ -# -*- coding: utf-8 -*- -"""陈述句仓储模块 - -本模块提供陈述句节点的数据访问功能。 - -Classes: - StatementRepository: 陈述句仓储,管理StatementNode的CRUD操作 -""" - -from typing import List, Optional, Dict -from datetime import datetime - -from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository -from app.core.memory.models.graph_models import StatementNode -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.utils.data.ontology import TemporalInfo - - -class StatementRepository(BaseNeo4jRepository[StatementNode]): - """陈述句仓储 - - 管理陈述句节点的创建、查询、更新和删除操作。 - 提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。 - - Attributes: - connector: Neo4j连接器实例 - node_label: 节点标签,固定为"Statement" - """ - - def __init__(self, connector: Neo4jConnector): - """初始化陈述句仓储 - - Args: - connector: Neo4j连接器实例 - """ - super().__init__(connector, "Statement") - - def _map_to_entity(self, node_data: Dict) -> StatementNode: - """将节点数据映射为陈述句实体 - - Args: - node_data: 从Neo4j查询返回的节点数据字典 - - Returns: - StatementNode: 陈述句实体对象 - """ - # 从查询结果中提取节点数据 - n = node_data.get('n', node_data) - - # 处理datetime字段 - if isinstance(n.get('created_at'), str): - n['created_at'] = datetime.fromisoformat(n['created_at']) - if n.get('expired_at') and isinstance(n['expired_at'], str): - n['expired_at'] = datetime.fromisoformat(n['expired_at']) - if n.get('valid_at') and isinstance(n['valid_at'], str): - n['valid_at'] = datetime.fromisoformat(n['valid_at']) - if n.get('invalid_at') and isinstance(n['invalid_at'], str): - n['invalid_at'] = datetime.fromisoformat(n['invalid_at']) - - # 处理temporal_info字段 - if isinstance(n.get('temporal_info'), dict): - n['temporal_info'] = TemporalInfo(**n['temporal_info']) - elif not n.get('temporal_info'): - # 如果没有temporal_info,创建一个默认的 - n['temporal_info'] = TemporalInfo() - - return StatementNode(**n) - - async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]: - """根据chunk_id查询陈述句 - - Args: - chunk_id: 分块ID - - Returns: - List[StatementNode]: 陈述句列表 - """ - return await self.find({"chunk_id": chunk_id}) - - async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[StatementNode]: - """根据group_id查询陈述句 - - Args: - group_id: 组ID - limit: 返回结果的最大数量 - - Returns: - List[StatementNode]: 陈述句列表 - """ - return await self.find({"group_id": group_id}, limit=limit) - - async def search_by_embedding( - self, - embedding: List[float], - group_id: Optional[str] = None, - limit: int = 10, - min_score: float = 0.7 - ) -> List[Dict]: - """基于向量相似度搜索陈述句 - - 使用余弦相似度计算查询向量与陈述句向量的相似度。 - - Args: - embedding: 查询向量 - group_id: 可选的组ID过滤 - limit: 返回结果的最大数量 - min_score: 最小相似度分数阈值 - - Returns: - List[Dict]: 包含陈述句和相似度分数的字典列表 - 每个字典包含: statement (StatementNode), score (float) - """ - # 构建查询条件 - where_clause = "n.statement_embedding IS NOT NULL" - if group_id: - where_clause += " AND n.group_id = $group_id" - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_clause} - WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score - WHERE score > $min_score - RETURN n, score - ORDER BY score DESC - LIMIT $limit - """ - - params = { - "embedding": embedding, - "min_score": min_score, - "limit": limit - } - if group_id: - params["group_id"] = group_id - - results = await self.connector.execute_query(query, **params) - - return [ - { - "statement": self._map_to_entity(r), - "score": r.get("score", 0.0) - } - for r in results - ] - - async def search_by_keyword( - self, - keyword: str, - group_id: Optional[str] = None, - limit: int = 50 - ) -> List[StatementNode]: - """基于关键词搜索陈述句 - - Args: - keyword: 搜索关键词 - group_id: 可选的组ID过滤 - limit: 返回结果的最大数量 - - Returns: - List[StatementNode]: 陈述句列表 - """ - where_clause = "n.statement CONTAINS $keyword" - if group_id: - where_clause += " AND n.group_id = $group_id" - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_clause} - RETURN n - LIMIT $limit - """ - - params = {"keyword": keyword, "limit": limit} - if group_id: - params["group_id"] = group_id - - results = await self.connector.execute_query(query, **params) - return [self._map_to_entity(r) for r in results] - - async def find_by_temporal_range( - self, - group_id: str, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - limit: int = 100 - ) -> List[StatementNode]: - """根据时间范围查询陈述句 - - 查询在指定时间范围内有效的陈述句。 - - Args: - group_id: 组ID - start_date: 开始日期(可选) - end_date: 结束日期(可选) - limit: 返回结果的最大数量 - - Returns: - List[StatementNode]: 陈述句列表 - """ - where_clauses = ["n.group_id = $group_id"] - params = {"group_id": group_id, "limit": limit} - - if start_date: - where_clauses.append("n.valid_at >= $start_date") - params["start_date"] = start_date.isoformat() - - if end_date: - where_clauses.append("(n.invalid_at IS NULL OR n.invalid_at <= $end_date)") - params["end_date"] = end_date.isoformat() - - where_str = " AND ".join(where_clauses) - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_str} - RETURN n - ORDER BY n.created_at DESC - LIMIT $limit - """ - - results = await self.connector.execute_query(query, **params) - return [self._map_to_entity(r) for r in results] - - async def find_strong_statements( - self, - group_id: str, - limit: int = 100 - ) -> List[StatementNode]: - """查询强连接的陈述句 - - Args: - group_id: 组ID - limit: 返回结果的最大数量 - - Returns: - List[StatementNode]: 强连接的陈述句列表 - """ - return await self.find( - {"group_id": group_id, "connect_strength": "Strong"}, - limit=limit - ) - - async def find_by_config_id( - self, - config_id: str, - limit: int = 100 - ) -> List[StatementNode]: - """根据config_id查询陈述句 - - Args: - config_id: 配置ID - limit: 返回结果的最大数量 - - Returns: - List[StatementNode]: 陈述句列表 - """ - return await self.find({"config_id": config_id}, limit=limit) - - async def search_by_embedding_with_config( - self, - embedding: List[float], - config_id: Optional[str] = None, - group_id: Optional[str] = None, - limit: int = 10, - min_score: float = 0.7 - ) -> List[Dict]: - """基于向量相似度搜索陈述句,可选择按config_id过滤 - - 使用余弦相似度计算查询向量与陈述句向量的相似度。 - 支持按config_id过滤结果,确保只返回使用特定配置处理的陈述句。 - - Args: - embedding: 查询向量 - config_id: 可选的配置ID过滤 - group_id: 可选的组ID过滤 - limit: 返回结果的最大数量 - min_score: 最小相似度分数阈值 - - Returns: - List[Dict]: 包含陈述句和相似度分数的字典列表 - 每个字典包含: statement (StatementNode), score (float) - """ - # 构建查询条件 - where_clauses = ["n.statement_embedding IS NOT NULL"] - params = { - "embedding": embedding, - "min_score": min_score, - "limit": limit - } - - if config_id: - where_clauses.append("n.config_id = $config_id") - params["config_id"] = config_id - - if group_id: - where_clauses.append("n.group_id = $group_id") - params["group_id"] = group_id - - where_str = " AND ".join(where_clauses) - - query = f""" - MATCH (n:{self.node_label}) - WHERE {where_str} - WITH n, gds.similarity.cosine(n.statement_embedding, $embedding) AS score - WHERE score > $min_score - RETURN n, score - ORDER BY score DESC - LIMIT $limit - """ - - results = await self.connector.execute_query(query, **params) - - return [ - { - "statement": self._map_to_entity(r), - "score": r.get("score", 0.0) - } - for r in results - ] diff --git a/app/repositories/release_share_repository.py b/app/repositories/release_share_repository.py deleted file mode 100644 index 714b04e5..00000000 --- a/app/repositories/release_share_repository.py +++ /dev/null @@ -1,59 +0,0 @@ -import uuid -from typing import Optional -from sqlalchemy.orm import Session -from sqlalchemy import select -from app.models import ReleaseShare - - -class ReleaseShareRepository: - """发布版本分享仓储""" - - def __init__(self, db: Session): - self.db = db - - def create(self, release_share: ReleaseShare) -> ReleaseShare: - """创建分享配置""" - self.db.add(release_share) - self.db.commit() - self.db.refresh(release_share) - return release_share - - def get_by_id(self, share_id: uuid.UUID) -> Optional[ReleaseShare]: - """根据 ID 获取分享配置""" - return self.db.get(ReleaseShare, share_id) - - def get_by_release_id(self, release_id: uuid.UUID) -> Optional[ReleaseShare]: - """根据发布版本 ID 获取分享配置""" - stmt = select(ReleaseShare).where(ReleaseShare.release_id == release_id) - return self.db.scalars(stmt).first() - - def get_by_share_token(self, share_token: str) -> Optional[ReleaseShare]: - """根据分享 token 获取分享配置""" - stmt = select(ReleaseShare).where(ReleaseShare.share_token == share_token) - return self.db.scalars(stmt).first() - - def update(self, release_share: ReleaseShare) -> ReleaseShare: - """更新分享配置""" - self.db.commit() - self.db.refresh(release_share) - return release_share - - def delete(self, release_share: ReleaseShare) -> None: - """删除分享配置""" - self.db.delete(release_share) - self.db.commit() - - def token_exists(self, share_token: str) -> bool: - """检查 token 是否已存在""" - stmt = select(ReleaseShare.id).where(ReleaseShare.share_token == share_token) - return self.db.scalars(stmt).first() is not None - - def increment_view_count(self, share_id: uuid.UUID) -> None: - """增加访问次数(异步更新,不阻塞)""" - from datetime import datetime - stmt = select(ReleaseShare).where(ReleaseShare.id == share_id) - share = self.db.scalars(stmt).first() - if share: - share.view_count += 1 - share.last_accessed_at = datetime.now() - self.db.commit() diff --git a/app/repositories/tenant_repository.py b/app/repositories/tenant_repository.py deleted file mode 100644 index 97e422bd..00000000 --- a/app/repositories/tenant_repository.py +++ /dev/null @@ -1,167 +0,0 @@ -import uuid -from sqlalchemy.orm import Session, joinedload -from sqlalchemy import and_, or_, func -from typing import List, Optional - -from app.models.tenant_model import Tenants -from app.models.user_model import User -from app.schemas.tenant_schema import TenantCreate, TenantUpdate - - -class TenantRepository: - """租户数据访问层""" - - def __init__(self, db: Session): - self.db = db - - def create_tenant(self, tenant_data: TenantCreate) -> Tenants: - """创建租户""" - db_tenant = Tenants( - name=tenant_data.name, - id=uuid.uuid4(), - description=tenant_data.description, - is_active=tenant_data.is_active - ) - self.db.add(db_tenant) - self.db.flush() - return db_tenant - - def get_tenant_by_id(self, tenant_id: uuid.UUID) -> Optional[Tenants]: - """根据ID获取租户""" - return self.db.query(Tenants).filter(Tenants.id == tenant_id).first() - - def get_tenant_by_name(self, name: str) -> Optional[Tenants]: - """根据名称获取租户""" - return self.db.query(Tenants).filter(Tenants.name == name).first() - - def get_tenants( - self, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> List[Tenants]: - """获取租户列表""" - query = self.db.query(Tenants) - - if is_active is not None: - query = query.filter(Tenants.is_active == is_active) - - if search: - query = query.filter( - or_( - Tenants.name.ilike(f"%{search}%"), - Tenants.description.ilike(f"%{search}%") - ) - ) - - return query.offset(skip).limit(limit).all() - - def count_tenants( - self, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> int: - """统计租户数量""" - query = self.db.query(func.count(Tenants.id)) - - if is_active is not None: - query = query.filter(Tenants.is_active == is_active) - - if search: - query = query.filter( - or_( - Tenants.name.ilike(f"%{search}%"), - Tenants.description.ilike(f"%{search}%") - ) - ) - - return query.scalar() - - def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]: - """更新租户""" - db_tenant = self.get_tenant_by_id(tenant_id) - if not db_tenant: - return None - - for field, value in tenant_data.dict(exclude_unset=True).items(): - setattr(db_tenant, field, value) - - self.db.flush() - return db_tenant - - def delete_tenant(self, tenant_id: uuid.UUID) -> bool: - """删除租户""" - db_tenant = self.get_tenant_by_id(tenant_id) - if not db_tenant: - return False - - self.db.delete(db_tenant) - return True - - def get_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> List[User]: - """获取租户下的所有用户""" - query = self.db.query(User).filter(User.tenant_id == tenant_id) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - - return query.all() - - def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]: - """获取用户所属的租户""" - user = self.db.query(User).filter(User.id == user_id).first() - if not user or not user.tenant_id: - return None - - return self.get_tenant_by_id(user.tenant_id) - - def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: - """将用户分配给租户""" - user = self.db.query(User).filter(User.id == user_id).first() - if not user: - return False - - # 验证租户存在 - tenant = self.get_tenant_by_id(tenant_id) - if not tenant: - return False - - user.tenant_id = tenant_id - self.db.flush() - return True - - def count_tenant_users(self, tenant_id: uuid.UUID, is_active: Optional[bool] = None) -> int: - """统计租户下的用户数量""" - query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - - return query.scalar() - - -# 便利函数,保持向后兼容 -def create_tenant(db: Session, tenant_data: TenantCreate) -> Tenants: - """创建租户""" - return TenantRepository(db).create_tenant(tenant_data) - -def get_tenant_by_id(db: Session, tenant_id: uuid.UUID) -> Optional[Tenants]: - """根据ID获取租户""" - return TenantRepository(db).get_tenant_by_id(tenant_id) - -def get_tenant_by_name(db: Session, name: str) -> Optional[Tenants]: - """根据名称获取租户""" - return TenantRepository(db).get_tenant_by_name(name) - -def get_tenants(db: Session, skip: int = 0, limit: int = 100) -> List[Tenants]: - """获取租户列表""" - return TenantRepository(db).get_tenants(skip=skip, limit=limit) - -def get_user_tenant(db: Session, user_id: uuid.UUID) -> Optional[Tenants]: - """获取用户所属的租户""" - return TenantRepository(db).get_user_tenant(user_id) - -def get_tenant_users(db: Session, tenant_id: uuid.UUID) -> List[User]: - """获取租户下的所有用户""" - return TenantRepository(db).get_tenant_users(tenant_id) \ No newline at end of file diff --git a/app/repositories/user_repository.py b/app/repositories/user_repository.py deleted file mode 100644 index ffdd6ec1..00000000 --- a/app/repositories/user_repository.py +++ /dev/null @@ -1,322 +0,0 @@ -from sqlalchemy.orm import Session, joinedload -from sqlalchemy import and_, or_, func -from typing import List, Optional -import uuid - -from app.models.user_model import User -from app.models.tenant_model import Tenants -from app.schemas.user_schema import UserCreate, UserUpdate -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class UserRepository: - """用户数据访问层""" - - def __init__(self, db: Session): - self.db = db - - def get_user_by_id(self, user_id: uuid.UUID) -> Optional[User]: - """根据ID获取用户""" - db_logger.debug(f"根据ID查询用户: user_id={user_id}") - - try: - user = self.db.query(User).options(joinedload(User.tenant)).filter(User.id == user_id).first() - if user: - db_logger.debug(f"用户查询成功: {user.username} (ID: {user_id})") - else: - db_logger.debug(f"用户不存在: user_id={user_id}") - return user - except Exception as e: - db_logger.error(f"根据ID查询用户失败: user_id={user_id} - {str(e)}") - raise - - def get_user_by_email(self, email: str) -> Optional[User]: - """根据邮箱获取用户""" - db_logger.debug(f"根据邮箱查询用户: email={email}") - - try: - user = self.db.query(User).options(joinedload(User.tenant)).filter(User.email == email).first() - if user: - db_logger.debug(f"用户查询成功: {user.username} (email: {email})") - else: - db_logger.debug(f"用户不存在: email={email}") - return user - except Exception as e: - db_logger.error(f"根据邮箱查询用户失败: email={email} - {str(e)}") - raise - - def get_user_by_username(self, username: str) -> Optional[User]: - """根据用户名获取用户""" - db_logger.debug(f"根据用户名查询用户: username={username}") - - try: - user = self.db.query(User).options(joinedload(User.tenant)).filter(User.username == username).first() - if user: - db_logger.debug(f"用户查询成功: {user.username} (ID: {user.id})") - else: - db_logger.debug(f"用户不存在: username={username}") - return user - except Exception as e: - db_logger.error(f"根据用户名查询用户失败: username={username} - {str(e)}") - raise - - def get_superuser(self) -> Optional[User]: - """获取超级用户""" - db_logger.debug("查询超级用户") - - try: - user = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).first() - if user: - db_logger.debug(f"超级用户查询成功: {user.username}") - else: - db_logger.debug("超级用户不存在") - return user - except Exception as e: - db_logger.error(f"查询超级用户失败: {str(e)}") - raise - def check_superuser_only(self) -> bool: - """检查是否只有一个超级用户""" - db_logger.debug("检查是否只有一个超级用户") - - try: - count = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).count() - return count == 1 - except Exception as e: - db_logger.error(f"检查超级用户数量失败: {str(e)}") - raise - - def create_user( - self, - user_data: UserCreate, - hashed_password: str, - tenant_id: Optional[uuid.UUID] = None, - is_superuser: bool = False - ) -> User: - """创建用户""" - db_logger.debug(f"创建用户记录: username={user_data.username}, email={user_data.email}, is_superuser={is_superuser}") - - try: - db_user = User( - username=user_data.username, - email=user_data.email, - hashed_password=hashed_password, - tenant_id=tenant_id, - is_superuser=is_superuser, - ) - self.db.add(db_user) - self.db.flush() - db_logger.info(f"用户记录创建成功: {user_data.username} (email: {user_data.email})") - return db_user - except Exception as e: - db_logger.error(f"创建用户记录失败: username={user_data.username} - {str(e)}") - raise - - def update_user(self, user_id: uuid.UUID, user_data: UserUpdate) -> Optional[User]: - """更新用户""" - db_logger.debug(f"更新用户: user_id={user_id}") - - try: - user = self.get_user_by_id(user_id) - if not user: - db_logger.debug(f"用户不存在: user_id={user_id}") - return None - - for field, value in user_data.dict(exclude_unset=True).items(): - setattr(user, field, value) - - self.db.flush() - db_logger.info(f"用户更新成功: {user.username}") - return user - except Exception as e: - db_logger.error(f"更新用户失败: user_id={user_id} - {str(e)}") - raise - - def delete_user(self, user_id: uuid.UUID) -> bool: - """删除用户""" - db_logger.debug(f"删除用户: user_id={user_id}") - - try: - user = self.get_user_by_id(user_id) - if not user: - db_logger.debug(f"用户不存在: user_id={user_id}") - return False - - self.db.delete(user) - self.db.flush() - db_logger.info(f"用户删除成功: {user.username}") - return True - except Exception as e: - db_logger.error(f"删除用户失败: user_id={user_id} - {str(e)}") - raise - - def get_users_by_tenant( - self, - tenant_id: uuid.UUID, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> List[User]: - """获取租户下的用户列表""" - db_logger.debug(f"查询租户用户: tenant_id={tenant_id}") - - try: - query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - - if search: - query = query.filter( - or_( - User.username.ilike(f"%{search}%"), - User.email.ilike(f"%{search}%") - ) - ) - - users = query.offset(skip).limit(limit).all() - db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}") - return users - except Exception as e: - db_logger.error(f"查询租户用户失败: tenant_id={tenant_id} - {str(e)}") - raise - - def count_users_by_tenant( - self, - tenant_id: uuid.UUID, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> int: - """统计租户下的用户数量""" - try: - query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - - if search: - query = query.filter( - or_( - User.username.ilike(f"%{search}%"), - User.email.ilike(f"%{search}%") - ) - ) - - return query.scalar() - except Exception as e: - db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}") - raise - - def get_superusers_by_tenant( - self, - tenant_id: uuid.UUID, - is_active: Optional[bool] = True - ) -> List[User]: - """获取租户下的超管用户列表""" - db_logger.debug(f"查询租户超管用户: tenant_id={tenant_id}") - - try: - query = self.db.query(User).options(joinedload(User.tenant)).filter( - and_( - User.tenant_id == tenant_id, - User.is_superuser == True - ) - ) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - - users = query.all() - db_logger.debug(f"租户超管用户查询成功: tenant_id={tenant_id}, count={len(users)}") - return users - except Exception as e: - db_logger.error(f"查询租户超管用户失败: tenant_id={tenant_id} - {str(e)}") - raise - - def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: - """将用户分配给租户""" - db_logger.debug(f"分配用户到租户: user_id={user_id}, tenant_id={tenant_id}") - - try: - user = self.get_user_by_id(user_id) - if not user: - db_logger.debug(f"用户不存在: user_id={user_id}") - return False - - # 验证租户存在 - tenant = self.db.query(Tenants).filter(Tenants.id == tenant_id).first() - if not tenant: - db_logger.debug(f"租户不存在: tenant_id={tenant_id}") - return False - - user.tenant_id = tenant_id - self.db.flush() - db_logger.info(f"用户分配成功: user={user.username}, tenant={tenant.name}") - return True - except Exception as e: - db_logger.error(f"分配用户到租户失败: user_id={user_id}, tenant_id={tenant_id} - {str(e)}") - raise - - def get_users_without_tenant( - self, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None - ) -> List[User]: - """获取没有租户的用户列表""" - try: - query = self.db.query(User).filter(User.tenant_id.is_(None)) - - if is_active is not None: - query = query.filter(User.is_active == is_active) - - return query.offset(skip).limit(limit).all() - except Exception as e: - db_logger.error(f"查询无租户用户失败: {str(e)}") - raise - - -# 便利函数,保持向后兼容 -def get_user_by_id(db: Session, user_id: uuid.UUID) -> Optional[User]: - """根据ID获取用户""" - return UserRepository(db).get_user_by_id(user_id) - -def get_user_by_email(db: Session, email: str) -> Optional[User]: - """根据邮箱获取用户""" - return UserRepository(db).get_user_by_email(email) - -def get_user_by_username(db: Session, username: str) -> Optional[User]: - """根据用户名获取用户""" - return UserRepository(db).get_user_by_username(username) - -def get_superuser(db: Session) -> Optional[User]: - """获取超级用户""" - return UserRepository(db).get_superuser() - -def check_superuser_only(db: Session) -> Optional[User]: - """检查是否只有一个超级用户""" - return UserRepository(db).check_superuser_only() - -def create_user( - db: Session, - user: UserCreate, - hashed_password: str, - tenant_id: Optional[uuid.UUID] = None, - is_superuser: bool = False -) -> User: - """创建用户(函数式接口)""" - repo = UserRepository(db) - return repo.create_user(user, hashed_password, tenant_id, is_superuser) - - -def get_superusers_by_tenant( - db: Session, - tenant_id: uuid.UUID, - is_active: Optional[bool] = True -) -> List[User]: - """获取租户下的超管用户列表(函数式接口)""" - repo = UserRepository(db) - return repo.get_superusers_by_tenant(tenant_id, is_active) diff --git a/app/repositories/workspace_invite_repository.py b/app/repositories/workspace_invite_repository.py deleted file mode 100644 index 73a9418d..00000000 --- a/app/repositories/workspace_invite_repository.py +++ /dev/null @@ -1,134 +0,0 @@ -from sqlalchemy.orm import Session -from sqlalchemy import and_, or_ -from typing import List, Optional -import datetime -import uuid - -from app.models.workspace_model import WorkspaceInvite, InviteStatus -from app.schemas.workspace_schema import WorkspaceInviteCreate - - -class WorkspaceInviteRepository: - def __init__(self, db: Session): - self.db = db - - def create_invite( - self, - workspace_id: uuid.UUID, - invite_data: WorkspaceInviteCreate, - token_hash: str, - created_by_user_id: uuid.UUID - ) -> WorkspaceInvite: - """创建工作空间邀请""" - expires_at = datetime.datetime.now() + datetime.timedelta(days=invite_data.expires_in_days) - - db_invite = WorkspaceInvite( - workspace_id=workspace_id, - email=invite_data.email, - role=invite_data.role, - token_hash=token_hash, - status=InviteStatus.pending, - expires_at=expires_at, - created_by_user_id=created_by_user_id - ) - - self.db.add(db_invite) - self.db.commit() - self.db.refresh(db_invite) - return db_invite - - def get_invite_by_token_hash(self, token_hash: str) -> Optional[WorkspaceInvite]: - """根据令牌哈希获取邀请""" - return self.db.query(WorkspaceInvite).filter( - WorkspaceInvite.token_hash == token_hash - ).first() - - def get_invite_by_id(self, invite_id: uuid.UUID) -> Optional[WorkspaceInvite]: - """根据ID获取邀请""" - return self.db.query(WorkspaceInvite).filter( - WorkspaceInvite.id == invite_id - ).first() - - def get_workspace_invites( - self, - workspace_id: uuid.UUID, - status: Optional[InviteStatus] = None, - limit: int = 50, - offset: int = 0 - ) -> List[WorkspaceInvite]: - """获取工作空间的邀请列表""" - query = self.db.query(WorkspaceInvite).filter( - WorkspaceInvite.workspace_id == workspace_id - ) - - if status: - query = query.filter(WorkspaceInvite.status == status) - - return query.order_by(WorkspaceInvite.created_at.desc()).offset(offset).limit(limit).all() - - def get_pending_invite_by_email_and_workspace( - self, - email: str, - workspace_id: uuid.UUID - ) -> Optional[WorkspaceInvite]: - """获取指定邮箱在指定工作空间的待处理邀请""" - return self.db.query(WorkspaceInvite).filter( - and_( - WorkspaceInvite.email == email, - WorkspaceInvite.workspace_id == workspace_id, - WorkspaceInvite.status == InviteStatus.pending - ) - ).first() - - def update_invite_status( - self, - invite_id: uuid.UUID, - status: InviteStatus, - accepted_at: Optional[datetime.datetime] = None - ) -> Optional[WorkspaceInvite]: - """更新邀请状态""" - invite = self.get_invite_by_id(invite_id) - if invite: - invite.status = status - if accepted_at: - invite.accepted_at = accepted_at - invite.updated_at = datetime.datetime.now() - self.db.commit() - self.db.refresh(invite) - return invite - - def revoke_invite(self, invite_id: uuid.UUID) -> Optional[WorkspaceInvite]: - """撤销邀请""" - return self.update_invite_status(invite_id, InviteStatus.revoked) - - def expire_old_invites(self) -> int: - """将过期的邀请标记为已过期""" - now = datetime.datetime.now() - expired_count = self.db.query(WorkspaceInvite).filter( - and_( - WorkspaceInvite.status == InviteStatus.pending, - WorkspaceInvite.expires_at < now - ) - ).update( - { - WorkspaceInvite.status: InviteStatus.expired, - WorkspaceInvite.updated_at: now - } - ) - self.db.commit() - return expired_count - - def count_workspace_invites( - self, - workspace_id: uuid.UUID, - status: Optional[InviteStatus] = None - ) -> int: - """统计工作空间邀请数量""" - query = self.db.query(WorkspaceInvite).filter( - WorkspaceInvite.workspace_id == workspace_id - ) - - if status: - query = query.filter(WorkspaceInvite.status == status) - - return query.count() \ No newline at end of file diff --git a/app/repositories/workspace_repository.py b/app/repositories/workspace_repository.py deleted file mode 100644 index 106830be..00000000 --- a/app/repositories/workspace_repository.py +++ /dev/null @@ -1,383 +0,0 @@ -from sqlalchemy.orm import Session, joinedload -from app.models.user_model import User -from typing import List, Optional -import uuid -from app.models.workspace_model import Workspace, WorkspaceMember, WorkspaceRole -from app.schemas.workspace_schema import WorkspaceCreate, WorkspaceUpdate -from app.core.logging_config import get_db_logger - -# 获取数据库专用日志器 -db_logger = get_db_logger() - - -class WorkspaceRepository: - """工作空间数据访问层""" - - def __init__(self, db: Session): - self.db = db - - def create_workspace(self, workspace_data: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace: - """创建工作空间""" - db_logger.debug(f"创建工作空间记录: name={workspace_data.name}, tenant_id={tenant_id}") - - try: - db_workspace = Workspace( - name=workspace_data.name, - description=workspace_data.description, - icon=workspace_data.icon, - iconType=workspace_data.iconType, - storage_type=workspace_data.storage_type, - llm=workspace_data.llm, - embedding=workspace_data.embedding, - rerank=workspace_data.rerank, - tenant_id=tenant_id - ) - self.db.add(db_workspace) - self.db.flush() - db_logger.info(f"工作空间记录创建成功: {workspace_data.name} (ID: {db_workspace.id}), storage_type: {workspace_data.storage_type}") - return db_workspace - except Exception as e: - db_logger.error(f"创建工作空间记录失败: name={workspace_data.name} - {str(e)}") - raise - - def get_workspace_by_id(self, workspace_id: uuid.UUID) -> Optional[Workspace]: - """根据ID获取工作空间""" - db_logger.debug(f"根据ID查询工作空间: workspace_id={workspace_id}") - - try: - workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first() - if workspace: - db_logger.debug(f"工作空间查询成功: {workspace.name} (ID: {workspace_id})") - else: - db_logger.debug(f"工作空间不存在: workspace_id={workspace_id}") - return workspace - except Exception as e: - db_logger.error(f"根据ID查询工作空间失败: workspace_id={workspace_id} - {str(e)}") - raise - - def get_workspace_models_configs(self, workspace_id: uuid.UUID) -> Optional[dict]: - """根据workspace_id获取模型配置(llm, embedding, rerank) - - Args: - workspace_id: 工作空间ID - - Returns: - 包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None - """ - db_logger.debug(f"查询工作空间模型配置: workspace_id={workspace_id}") - - try: - workspace = self.db.query(Workspace).filter(Workspace.id == workspace_id).first() - if workspace: - configs = { - "llm": workspace.llm, - "embedding": workspace.embedding, - "rerank": workspace.rerank - } - db_logger.debug( - f"工作空间模型配置查询成功: workspace_id={workspace_id}, " - f"llm={configs['llm']}, embedding={configs['embedding']}, rerank={configs['rerank']}" - ) - return configs - else: - db_logger.debug(f"工作空间不存在: workspace_id={workspace_id}") - return None - except Exception as e: - db_logger.error(f"查询工作空间模型配置失败: workspace_id={workspace_id} - {str(e)}") - raise - - def get_workspaces_by_user(self, user_id: uuid.UUID) -> List[Workspace]: - """获取用户参与的所有工作空间(包括用户创建的和作为成员的)""" - db_logger.debug(f"查询用户参与的工作空间: user_id={user_id}") - - try: - # 首先获取用户信息以获取 tenant_id - from app.models.user_model import User - user = self.db.query(User).filter(User.id == user_id).first() - if not user: - db_logger.warning(f"用户不存在: user_id={user_id}") - return [] - - if user.is_superuser: - # 超级用户获取对应tenantid所有工作空间 - workspaces = ( - self.db.query(Workspace) - .filter(Workspace.tenant_id == user.tenant_id) - .filter(Workspace.is_active == True) - .order_by(Workspace.updated_at.desc()) - .all() - ) - db_logger.debug(f"超用户查询所有工作空间: user_id={user_id}, 数量={len(workspaces)}") - return workspaces - - # 获取用户作为成员的工作空间 - member_workspaces = ( - self.db.query(Workspace) - .join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id) - .filter(WorkspaceMember.user_id == user_id) - .filter(Workspace.is_active == True) - .order_by(Workspace.updated_at.desc()) - .all() - ) - - db_logger.debug(f"用户工作空间查询成功: user_id={user_id}, 数量={len(member_workspaces)}") - return member_workspaces - except Exception as e: - db_logger.error(f"查询用户工作空间失败: user_id={user_id} - {str(e)}") - raise - - def get_workspaces_by_tenant(self, tenant_id: uuid.UUID) -> List[Workspace]: - """获取租户的所有工作空间""" - db_logger.debug(f"查询租户的工作空间: tenant_id={tenant_id}") - - try: - workspaces = ( - self.db.query(Workspace) - .filter(Workspace.tenant_id == tenant_id) - .filter(Workspace.is_active == True) - .all() - ) - db_logger.debug(f"租户工作空间查询成功: tenant_id={tenant_id}, 数量={len(workspaces)}") - return workspaces - except Exception as e: - db_logger.error(f"查询租户工作空间失败: tenant_id={tenant_id} - {str(e)}") - raise - - def add_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole = WorkspaceRole.member) -> WorkspaceMember: - """添加工作空间成员""" - db_logger.debug(f"添加工作空间成员: user_id={user_id}, workspace_id={workspace_id}, role={role}") - - try: - db_member = WorkspaceMember( - user_id=user_id, - workspace_id=workspace_id, - role=role - ) - self.db.add(db_member) - self.db.flush() - db_logger.info(f"工作空间成员添加成功: user_id={user_id}, workspace_id={workspace_id}, role={role}") - return db_member - except Exception as e: - db_logger.error(f"添加工作空间成员失败: user_id={user_id}, workspace_id={workspace_id} - {str(e)}") - raise - - def get_member(self, user_id: uuid.UUID, workspace_id: uuid.UUID) -> Optional[WorkspaceMember]: - """获取工作空间成员""" - db_logger.debug(f"查询工作空间成员: user_id={user_id}, workspace_id={workspace_id}") - - try: - member = self.db.query(WorkspaceMember).filter( - WorkspaceMember.user_id == user_id, - WorkspaceMember.workspace_id == workspace_id, - WorkspaceMember.is_active == True, - ).first() - if member: - db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}") - else: - db_logger.debug(f"工作空间成员不存在: user_id={user_id}, workspace_id={workspace_id}") - return member - except Exception as e: - db_logger.error(f"查询工作空间成员失败: user_id={user_id}, workspace_id={workspace_id} - {str(e)}") - raise - - def get_members_by_workspace(self, workspace_id: uuid.UUID) -> List[WorkspaceMember]: - """按工作空间获取成员列表,并预加载 user 与 workspace 关系""" - db_logger.debug(f"查询工作空间的成员列表: workspace_id={workspace_id}") - try: - members = ( - self.db.query(WorkspaceMember) - .join(User, WorkspaceMember.user_id == User.id) - .options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace)) - .filter(WorkspaceMember.workspace_id == workspace_id) - .filter(WorkspaceMember.is_active == True) - .filter(User.is_active == True) - .all() - ) - db_logger.debug(f"成员列表查询成功: workspace_id={workspace_id}, 数量={len(members)}") - return members - except Exception as e: - db_logger.error(f"查询成员列表失败: workspace_id={workspace_id} - {str(e)}") - raise - - def get_member_by_id(self, member_id: uuid.UUID) -> WorkspaceMember: - """按成员ID获取工作空间成员,并预加载 user 与 workspace 关系""" - db_logger.debug(f"查询成员的工作空间: member_id={member_id}") - try: - member = ( - self.db.query(WorkspaceMember) - .join(User, WorkspaceMember.user_id == User.id) - .options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace)) - .filter(WorkspaceMember.id == member_id) - .filter(WorkspaceMember.is_active == True) - .filter(User.is_active == True) - .first() - ) - if member: - db_logger.debug(f"成员查询成功: member_id={member_id}, workspace_id={member.workspace_id}, role={member.role}") - else: - db_logger.debug(f"成员不存在: member_id={member_id}") - return member - except Exception as e: - db_logger.error(f"查询成员列表失败: member_id={member_id} - {str(e)}") - raise - - def update_member_role(self, workspace_id: uuid.UUID, user_id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]: - try: - member = self.db.query(WorkspaceMember).filter( - WorkspaceMember.workspace_id == workspace_id, - WorkspaceMember.user_id == user_id, - WorkspaceMember.is_active == True, - ).first() - if not member: - return None - member.role = role - self.db.commit() - self.db.refresh(member) - return member - except Exception as e: - db_logger.error(f"更新成员角色失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}") - raise - - def deactivate_member(self, workspace_id: uuid.UUID, user_id: uuid.UUID) -> Optional[WorkspaceMember]: - try: - member = self.db.query(WorkspaceMember).filter( - WorkspaceMember.workspace_id == workspace_id, - WorkspaceMember.user_id == user_id, - WorkspaceMember.is_active == True, - ).first() - if not member: - return None - member.is_active = False - self.db.commit() - self.db.refresh(member) - return member - except Exception as e: - db_logger.error(f"删除成员失败: workspace_id={workspace_id}, user_id={user_id} - {str(e)}") - raise - - def delete_member_by_id(self, member_id: uuid.UUID) -> Optional[WorkspaceMember]: - try: - member = self.db.query(WorkspaceMember).filter( - WorkspaceMember.id == member_id, - WorkspaceMember.is_active == True, - ).first() - if not member: - return None - member.is_active = False - self.db.commit() - self.db.refresh(member) - return member - except Exception as e: - db_logger.error(f"删除成员失败: id={member_id} - {str(e)}") - raise - - def update_member_role_by_id(self, id: uuid.UUID, role: WorkspaceRole) -> Optional[WorkspaceMember]: - try: - member = self.db.query(WorkspaceMember).filter( - WorkspaceMember.id == id, - WorkspaceMember.is_active == True, - ).first() - if not member: - return None - member.role = role - self.db.commit() - self.db.refresh(member) - return member - except Exception as e: - db_logger.error(f"更新成员角色失败: id={id} - {str(e)}") - raise - -# 保持向后兼容的函数 -def get_workspace_by_id(db: Session, workspace_id: uuid.UUID) -> Workspace | None: - repo = WorkspaceRepository(db) - return repo.get_workspace_by_id(workspace_id) - - -def get_workspaces_by_user(db: Session, user_id: uuid.UUID) -> List[Workspace]: - repo = WorkspaceRepository(db) - return repo.get_workspaces_by_user(user_id) - - -def get_workspaces_by_tenant(db: Session, tenant_id: uuid.UUID) -> List[Workspace]: - repo = WorkspaceRepository(db) - return repo.get_workspaces_by_tenant(tenant_id) - - -def get_member_in_workspace(db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID) -> WorkspaceMember | None: - repo = WorkspaceRepository(db) - return repo.get_member(user_id, workspace_id) - - -def create_workspace(db: Session, workspace: WorkspaceCreate, tenant_id: uuid.UUID) -> Workspace: - repo = WorkspaceRepository(db) - return repo.create_workspace(workspace, tenant_id) - - -def add_member_to_workspace( - db: Session, user_id: uuid.UUID, workspace_id: uuid.UUID, role: WorkspaceRole -) -> WorkspaceMember: - repo = WorkspaceRepository(db) - return repo.add_member(workspace_id, user_id, role) - - -def get_members_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[WorkspaceMember]: - repo = WorkspaceRepository(db) - return repo.get_members_by_workspace(workspace_id) - -def get_member_by_id(db: Session, member_id: uuid.UUID) -> WorkspaceMember | None: - repo = WorkspaceRepository(db) - return repo.get_member_by_id(member_id) - -def update_member_role_in_workspace( - db: Session, - user_id: uuid.UUID, - workspace_id: uuid.UUID, - role: WorkspaceRole, -) -> Optional[WorkspaceMember]: - repo = WorkspaceRepository(db) - return repo.update_member_role(workspace_id, user_id, role) - -def remove_member_from_workspace( - db: Session, - user_id: uuid.UUID, - workspace_id: uuid.UUID, -) -> Optional[WorkspaceMember]: - repo = WorkspaceRepository(db) - return repo.deactivate_member(workspace_id, user_id) - -def remove_member_from_workspace_by_id( - db: Session, - member_id: uuid.UUID, -) -> Optional[WorkspaceMember]: - repo = WorkspaceRepository(db) - return repo.delete_member_by_id(member_id) - - -def update_member_role_by_id( - db: Session, - id: uuid.UUID, - role: WorkspaceRole, -) -> Optional[WorkspaceMember]: - repo = WorkspaceRepository(db) - return repo.update_member_role_by_id(id, role) - - -def get_workspace_models_configs(db: Session, workspace_id: uuid.UUID) -> Optional[dict]: - """根据workspace_id获取模型配置(llm, embedding, rerank) - - Args: - db: 数据库会话 - workspace_id: 工作空间ID - - Returns: - 包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None - - Example: - >>> configs = get_workspace_models_configs(db, workspace_id) - >>> if configs: - >>> print(f"LLM: {configs['llm']}") - >>> print(f"Embedding: {configs['embedding']}") - >>> print(f"Rerank: {configs['rerank']}") - """ - repo = WorkspaceRepository(db) - return repo.get_workspace_models_configs(workspace_id) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py deleted file mode 100644 index 208adc68..00000000 --- a/app/schemas/__init__.py +++ /dev/null @@ -1,108 +0,0 @@ -from .item_schema import Item -from .user_schema import User, UserCreate, UserUpdate -from .workspace_schema import Workspace, WorkspaceCreate, WorkspaceMember, WorkspaceMemberCreate -from .token_schema import Token, TokenData -from .knowledge_schema import Knowledge, KnowledgeCreate, KnowledgeUpdate -from .document_schema import Document, DocumentCreate, DocumentUpdate -from .file_schema import File, FileCreate, FileUpdate -from .tenant_schema import Tenant, TenantCreate, TenantUpdate -from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve -from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate -from .app_schema import ( - DraftRunRequest, - DraftRunResponse, - DraftRunStreamChunk, - App, - AppCreate, - AppUpdate, - AgentConfig, - AgentConfigCreate, - AgentConfigUpdate, - AppRelease, - ModelParameters, - KnowledgeRetrievalConfig, - MemoryConfig, - ToolConfig, - VariableDefinition, -) -from .conversation_schema import ( - Conversation, - ConversationCreate, - ConversationWithMessages, - Message, - MessageCreate, - ChatRequest, - ChatResponse, -) -from .multi_agent_schema import ( - SubAgentConfig, - RoutingRule, - ExecutionConfig, - MultiAgentConfigCreate, - MultiAgentConfigUpdate, - MultiAgentConfigSchema, - MultiAgentRunRequest, - MultiAgentRunResponse, - SubAgentResult, -) - -__all__ = [ - "Item", - "User", - "UserCreate", - "UserUpdate", - "Workspace", - "WorkspaceCreate", - "WorkspaceMember", - "WorkspaceMemberCreate", - "Token", - "Knowledge", - "KnowledgeCreate", - "KnowledgeUpdate", - "Document", - "DocumentCreate", - "DocumentUpdate", - "File", - "FileCreate", - "FileUpdate", - "Tenant", - "TenantCreate", - "TenantUpdate", - "ChunkCreate", - "ChunkUpdate", - "ChunkRetrieve", - "KnowledgeShare", - "KnowledgeShareCreate", - "DraftRunRequest", - "DraftRunResponse", - "DraftRunStreamChunk", - "App", - "AppCreate", - "AppUpdate", - "AgentConfig", - "AgentConfigCreate", - "AgentConfigUpdate", - "AppRelease", - "ModelParameters", - "KnowledgeRetrievalConfig", - "MemoryConfig", - "ToolConfig", - "VariableDefinition", - "Conversation", - "ConversationCreate", - "ConversationWithMessages", - "Message", - "MessageCreate", - "ChatRequest", - "ChatResponse", - # Multi-Agent Schemas - "SubAgentConfig", - "RoutingRule", - "ExecutionConfig", - "MultiAgentConfigCreate", - "MultiAgentConfigUpdate", - "MultiAgentConfigSchema", - "MultiAgentRunRequest", - "MultiAgentRunResponse", - "SubAgentResult", -] diff --git a/app/schemas/api_key_schema.py b/app/schemas/api_key_schema.py deleted file mode 100644 index 36c0d457..00000000 --- a/app/schemas/api_key_schema.py +++ /dev/null @@ -1,104 +0,0 @@ -"""API Key Schema""" -from pydantic import BaseModel, Field, ConfigDict -from typing import Optional, List -import datetime -import uuid - -from app.models.api_key_model import ApiKeyType - - -class ApiKeyCreate(BaseModel): - """创建 API Key""" - name: str = Field(..., description="API Key 名称", max_length=255) - description: Optional[str] = Field(None, description="描述") - type: ApiKeyType = Field(..., description="API Key 类型") - scopes: List[str] = Field(default_factory=list, description="权限范围列表") - resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID") - resource_type: Optional[str] = Field(None, description="资源类型") - rate_limit: Optional[int] = Field(100, description="速率限制(请求/分钟)", ge=1) - quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) - expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") - - -class ApiKeyUpdate(BaseModel): - """更新 API Key""" - name: Optional[str] = Field(None, description="API Key 名称", max_length=255) - description: Optional[str] = Field(None, description="描述") - scopes: Optional[List[str]] = Field(None, description="权限范围列表") - rate_limit: Optional[int] = Field(None, description="速率限制(请求/分钟)", ge=1) - quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) - is_active: Optional[bool] = Field(None, description="是否激活") - expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") - - -class ApiKeyResponse(BaseModel): - """API Key 响应(创建时返回,包含明文 Key)""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - name: str - description: Optional[str] - api_key: str = Field(..., description="API Key 明文(仅创建时返回)") - key_prefix: str - type: str - scopes: List[str] - resource_id: Optional[uuid.UUID] - resource_type: Optional[str] - rate_limit: int - quota_limit: Optional[int] - expires_at: Optional[datetime.datetime] - created_at: datetime.datetime - - -class ApiKey(BaseModel): - """API Key 信息(不包含明文 Key)""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - name: str - description: Optional[str] - key_prefix: str - type: str - scopes: List[str] - resource_id: Optional[uuid.UUID] - resource_type: Optional[str] - rate_limit: int - quota_limit: Optional[int] - quota_used: int - expires_at: Optional[datetime.datetime] - is_active: bool - last_used_at: Optional[datetime.datetime] - usage_count: int - workspace_id: uuid.UUID - created_by: uuid.UUID - created_at: datetime.datetime - updated_at: datetime.datetime - - -class ApiKeyStats(BaseModel): - """API Key 使用统计""" - total_requests: int = Field(..., description="总请求数") - requests_today: int = Field(..., description="今日请求数") - quota_used: int = Field(..., description="已使用配额") - quota_limit: Optional[int] = Field(None, description="配额限制") - last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间") - avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)") - - -class ApiKeyQuery(BaseModel): - """API Key 查询参数""" - type: Optional[ApiKeyType] = Field(None, description="API Key 类型") - is_active: Optional[bool] = Field(None, description="是否激活") - resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID") - page: int = Field(1, ge=1, description="页码") - pagesize: int = Field(10, ge=1, le=100, description="每页数量") - - -class ApiKeyAuth(BaseModel): - """API Key 认证信息""" - api_key_id: uuid.UUID - workspace_id: uuid.UUID - type: str - scopes: List[str] - resource_id: Optional[uuid.UUID] - resource_type: Optional[str] diff --git a/app/schemas/app_schema.py b/app/schemas/app_schema.py deleted file mode 100644 index c387cee9..00000000 --- a/app/schemas/app_schema.py +++ /dev/null @@ -1,425 +0,0 @@ -import uuid -import datetime -from typing import Optional, Any, List, Dict, TYPE_CHECKING -from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator - - -# ---------- Input Schemas ---------- - -class KnowledgeBaseConfig(BaseModel): - """单个知识库配置""" - kb_id: str = Field(..., description="知识库ID") - top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量") - similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值") - strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense") - weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)") - vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重") - retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid") - - -class KnowledgeRetrievalConfig(BaseModel): - """知识库检索配置(支持多个知识库,每个有独立配置)""" - knowledge_bases: List[KnowledgeBaseConfig] = Field( - default_factory=list, - description="关联的知识库列表,每个知识库有独立配置" - ) - - # 多知识库融合策略 - merge_strategy: str = Field( - default="weighted", - description="多知识库结果融合策略: weighted | rrf | concat" - ) - reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID") - reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数") - - - -class ToolConfig(BaseModel): - """工具配置""" - enabled: bool = Field(default=False, description="是否启用该工具") - config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置") - - -class MemoryConfig(BaseModel): - """记忆配置""" - enabled: bool = Field(default=True, description="是否启用对话历史记忆") - memory_content: Optional[str] = Field(default=None, description="选择记忆的内容类型") - max_history: int = Field(default=10, ge=0, le=100, description="最大保留的历史对话轮数") - - -class ModelParameters(BaseModel): - """模型参数配置""" - temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="温度参数,控制输出的随机性") - max_tokens: int = Field(default=2000, ge=1, le=32000, description="最大生成token数") - top_p: float = Field(default=1.0, ge=0.0, le=1.0, description="核采样参数") - frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="频率惩罚") - presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="存在惩罚") - n: int = Field(default=1, ge=1, le=10, description="生成的回复数量") - stop: Optional[List[str]] = Field(default=None, description="停止序列") - - -class VariableDefinition(BaseModel): - """变量定义""" - name: str = Field(..., description="变量名称(标识符)") - display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)") - type: str = Field( - default="string", - description="变量类型: string(单行文本) | text(多行文本) | number(数字)" - ) - required: bool = Field(default=False, description="是否必填") - description: Optional[str] = Field(default=None, description="变量描述") - max_length: Optional[int] = Field(default=None, description="最大长度(用于文本类型)") - - -class AgentConfigCreate(BaseModel): - """Agent 行为配置""" - # 提示词配置 - system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则") - - # 模型配置 - default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID") - model_parameters: ModelParameters = Field( - default_factory=ModelParameters, - description="模型参数配置(temperature、max_tokens 等)" - ) - - # 知识库关联 - knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( - default=None, - description="知识库检索配置" - ) - - # 记忆配置 - memory: MemoryConfig = Field( - default_factory=lambda: MemoryConfig(enabled=True), - description="对话历史记忆配置" - ) - - # 变量配置 - variables: List[VariableDefinition] = Field( - default_factory=list, - description="Agent 可用的变量列表" - ) - - # 工具配置 - tools: Dict[str, ToolConfig] = Field( - default_factory=dict, - description="工具配置,key 为工具名称(web_search, code_interpreter, image_generation 等)" - ) - - -class AppCreate(BaseModel): - name: str - description: Optional[str] = None - icon: Optional[str] = None - icon_type: Optional[str] = None - type: str = Field(pattern=r"^(agent|workflow|multi_agent)$") - visibility: Optional[str] = None - status: Optional[str] = None - tags: Optional[List[str]] = Field(default_factory=list) - - # only for type=agent - agent_config: Optional[AgentConfigCreate] = None - - # only for type=multi_agent - multi_agent_config: Optional[Dict[str, Any]] = None - - -class AppUpdate(BaseModel): - name: Optional[str] = None - description: Optional[str] = None - icon: Optional[str] = None - icon_type: Optional[str] = None - visibility: Optional[str] = None - status: Optional[str] = None - tags: Optional[List[str]] = None - - -class AgentConfigUpdate(BaseModel): - """更新 Agent 行为配置""" - # 提示词配置 - system_prompt: Optional[str] = Field(default=None, description="系统提示词") - - # 模型配置 - default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID") - model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置") - - # 知识库关联 - knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field( - default=None, - description="知识库检索配置" - ) - - # 记忆配置 - memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置") - - # 变量配置 - variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") - - # 工具配置 - tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") - - -# ---------- Output Schemas ---------- - -class App(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - workspace_id: uuid.UUID - created_by: uuid.UUID - name: str - description: Optional[str] = None - icon: Optional[str] = None - icon_type: Optional[str] = None - type: str - visibility: str - status: str - tags: List[str] = [] - current_release_id: Optional[uuid.UUID] = None - is_active: bool - is_shared: bool = False # 是否是共享应用(从其他工作空间共享来的) - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -class AgentConfig(BaseModel): - """Agent 配置输出 Schema""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - app_id: uuid.UUID - - # 提示词 - system_prompt: Optional[str] = None - - # 模型配置 - default_model_config_id: Optional[uuid.UUID] = None - model_parameters: ModelParameters = Field(default_factory=ModelParameters) - - # 知识库检索 - knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None - - # 记忆配置 - memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True)) - - # 变量配置 - variables: List[VariableDefinition] = [] - - # 工具配置 - tools: Dict[str, ToolConfig] = {} - - is_active: bool - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_validator("model_parameters", mode="before") - @classmethod - def validate_model_parameters(cls, v): - """处理 None 值,返回默认的 ModelParameters""" - if v is None: - return ModelParameters() - return v - - @field_validator("memory", mode="before") - @classmethod - def validate_memory(cls, v): - """处理 None 值,返回默认的 MemoryConfig""" - if v is None: - return MemoryConfig(enabled=True) - return v - - @field_validator("variables", mode="before") - @classmethod - def validate_variables(cls, v): - """处理 None 值,返回空列表""" - if v is None: - return [] - return v - - @field_validator("tools", mode="before") - @classmethod - def validate_tools(cls, v): - """处理 None 值,返回空字典""" - if v is None: - return {} - return v - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -class PublishRequest(BaseModel): - """发布应用请求""" - version_name: str - release_notes: Optional[str] = Field(None, description="版本说明") - - -class AppRelease(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - app_id: uuid.UUID - version: int - release_notes: Optional[str] = None - version_name: str - description: Optional[str] = None - icon: Optional[str] = None - icon_type: Optional[str] = None - name: str - type: str - visibility: str - config: Dict[str, Any] = {} - default_model_config_id: Optional[uuid.UUID] = None - published_by: uuid.UUID - publisher_name: str - published_at: datetime.datetime - is_active: bool - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("published_at", when_used="json") - def _serialize_published_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -# ---------- App Share Schemas ---------- - -class AppShareCreate(BaseModel): - """应用分享请求""" - target_workspace_ids: List[uuid.UUID] = Field(..., description="目标工作空间ID列表") - - -class AppShare(BaseModel): - """应用分享输出""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - source_app_id: uuid.UUID - source_workspace_id: uuid.UUID - target_workspace_id: uuid.UUID - shared_by: uuid.UUID - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -# ---------- Draft Run Schemas ---------- - -class DraftRunRequest(BaseModel): - """试运行请求""" - message: str = Field(..., description="用户消息") - conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") - user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") - variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") - stream: bool = Field(default=False, description="是否流式返回") - - -class DraftRunResponse(BaseModel): - """试运行响应(非流式)""" - message: str = Field(..., description="AI 回复消息") - conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") - usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况") - elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)") - - -class DraftRunStreamChunk(BaseModel): - """试运行流式响应块""" - event: str = Field(..., description="事件类型: start | message | end | error") - data: Dict[str, Any] = Field(..., description="事件数据") - - -# ---------- Draft Run Compare Schemas ---------- - -class ModelCompareItem(BaseModel): - """单个对比模型配置""" - model_config_id: uuid.UUID = Field(..., description="模型配置ID") - model_parameters: Optional[Dict[str, Any]] = Field( - None, - description="覆盖模型参数,如 temperature, max_tokens 等" - ) - label: Optional[str] = Field( - None, - description="自定义显示标签,用于区分同一模型的不同配置" - ) - conversation_id: Optional[str] = Field( - None, - description="会话ID,用于为每个模型指定独立的会话历史" - ) - - -class DraftRunCompareRequest(BaseModel): - """多模型对比试运行请求""" - message: str = Field(..., description="用户消息") - conversation_id: Optional[str] = Field(None, description="会话ID") - user_id: Optional[str] = Field(None, description="用户ID") - variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") - - models: List[ModelCompareItem] = Field( - ..., - min_length=1, - max_length=5, - description="要对比的模型列表(1-5个)" - ) - - parallel: bool = Field(True, description="是否并行执行") - stream: bool = Field(False, description="是否流式返回") - timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)") - - -class ModelRunResult(BaseModel): - """单个模型运行结果""" - model_config_id: uuid.UUID - model_name: str - label: Optional[str] = None - - parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数") - - message: Optional[str] = None - usage: Optional[Dict[str, Any]] = None - elapsed_time: float - error: Optional[str] = None - - tokens_per_second: Optional[float] = None - cost_estimate: Optional[float] = None - conversation_id: Optional[str] = None - - -class DraftRunCompareResponse(BaseModel): - """多模型对比响应""" - results: List[ModelRunResult] - - total_elapsed_time: float - successful_count: int - failed_count: int - - fastest_model: Optional[str] = None - cheapest_model: Optional[str] = None diff --git a/app/schemas/chunk_schema.py b/app/schemas/chunk_schema.py deleted file mode 100644 index cda7ed94..00000000 --- a/app/schemas/chunk_schema.py +++ /dev/null @@ -1,26 +0,0 @@ -from pydantic import BaseModel, Field -import uuid -from enum import StrEnum - - -class RetrieveType(StrEnum): - """Retrieval type enumeration""" - PARTICIPLE = "participle" - SEMANTIC = "semantic" - HYBRID = "hybrid" - -class ChunkCreate(BaseModel): - content: str - - -class ChunkUpdate(BaseModel): - content: str | None = Field(None) - - -class ChunkRetrieve(BaseModel): - query: str - kb_ids: list[uuid.UUID] - similarity_threshold: float | None = Field(None) - vector_similarity_weight: float | None = Field(None) - top_k: int | None = Field(None) - retrieve_type: RetrieveType | None = Field(None) \ No newline at end of file diff --git a/app/schemas/conversation_schema.py b/app/schemas/conversation_schema.py deleted file mode 100644 index 63db6685..00000000 --- a/app/schemas/conversation_schema.py +++ /dev/null @@ -1,86 +0,0 @@ -"""会话和消息相关的 Schema""" -import uuid -import datetime -from typing import Optional, Dict, Any, List -from pydantic import BaseModel, Field, ConfigDict, field_serializer - - -# ---------- Input Schemas ---------- - -class ConversationCreate(BaseModel): - """创建会话请求""" - title: Optional[str] = Field(None, max_length=255, description="会话标题") - user_id: Optional[str] = Field(None, description="用户ID(外部系统)") - - -class MessageCreate(BaseModel): - """创建消息请求""" - content: str = Field(..., description="消息内容") - variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") - - -class ChatRequest(BaseModel): - """聊天请求(基于 share_token)""" - message: str = Field(..., description="用户消息") - conversation_id: Optional[uuid.UUID] = Field(None, description="会话ID(多轮对话)") - user_id: Optional[str] = Field(None, description="用户ID(外部系统)") - variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") - stream: bool = Field(default=False, description="是否流式返回") - web_search: bool = Field(default=False, description="是否启用网络搜索") - memory: bool = Field(default=True, description="是否启用记忆功能") - - -# ---------- Output Schemas ---------- - -class Message(BaseModel): - """消息输出""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - conversation_id: uuid.UUID - role: str - content: str - meta_data: Optional[Dict[str, Any]] = None - created_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -class Conversation(BaseModel): - """会话输出""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - app_id: uuid.UUID - workspace_id: uuid.UUID - user_id: Optional[str] = None - title: Optional[str] = None - summary: Optional[str] = None - is_draft: bool - message_count: int - is_active: bool - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -class ConversationWithMessages(Conversation): - """会话详情(包含消息列表)""" - messages: List[Message] = [] - - -class ChatResponse(BaseModel): - """聊天响应(非流式)""" - conversation_id: uuid.UUID - message: str - usage: Optional[Dict[str, Any]] = None - elapsed_time: Optional[float] = None diff --git a/app/schemas/document_schema.py b/app/schemas/document_schema.py deleted file mode 100644 index ae773b3c..00000000 --- a/app/schemas/document_schema.py +++ /dev/null @@ -1,63 +0,0 @@ -from pydantic import BaseModel, Field, field_serializer, ConfigDict -import datetime -import uuid - - -class DocumentBase(BaseModel): - kb_id: uuid.UUID - created_by: uuid.UUID | None = None - file_id: uuid.UUID - file_name: str - file_ext: str - file_size: int - file_meta: dict - parser_id: str - parser_config: dict - - -class DocumentCreate(DocumentBase): - pass - - -class DocumentUpdate(BaseModel): - file_id: uuid.UUID | None = Field(None) - file_name: str | None = Field(None) - file_ext: str | None = Field(None) - file_size: int | None = Field(None) - file_meta: dict | None = Field(None) - parser_id: str | None = Field(None) - parser_config: dict | None = Field(None) - chunk_num: int | None = Field(None) - progress: float | None = Field(None) - progress_msg: str | None = Field(None) - process_begin_at: datetime.datetime | None = Field(None) - process_duration: float | None = Field(None) - run: int | None = Field(None) - status: int | None = Field(None) - - -class Document(DocumentBase): - id: uuid.UUID - chunk_num: int - progress: float - progress_msg: str - process_begin_at: datetime.datetime - process_duration: float - run: int - status: int - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("process_begin_at", when_used="json") - def _serialize_process_begin_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None diff --git a/app/schemas/end_user_schema.py b/app/schemas/end_user_schema.py deleted file mode 100644 index 30dafddd..00000000 --- a/app/schemas/end_user_schema.py +++ /dev/null @@ -1,17 +0,0 @@ -import uuid -import datetime -from typing import Optional -from pydantic import BaseModel, Field -from pydantic import ConfigDict - -class EndUser(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID = Field(description="终端用户ID") - app_id: uuid.UUID = Field(description="应用ID") - # end_user_id: str = Field(description="终端用户ID") - other_id: Optional[str] = Field(description="第三方ID", default=None) - other_name: Optional[str] = Field(description="其他名称", default="") - other_address: Optional[str] = Field(description="其他地址", default="") - created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now) - updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now) diff --git a/app/schemas/file_schema.py b/app/schemas/file_schema.py deleted file mode 100644 index 00f1a148..00000000 --- a/app/schemas/file_schema.py +++ /dev/null @@ -1,39 +0,0 @@ -from pydantic import BaseModel, Field, field_serializer, ConfigDict -import datetime -import uuid - - -class FileBase(BaseModel): - kb_id: uuid.UUID - created_by: uuid.UUID | None = None - parent_id: uuid.UUID | None = None - file_name: str - file_ext: str - file_size: int - - -class FileCreate(FileBase): - pass - - -class CustomTextFileCreate(BaseModel): - title: str - content: str - - -class FileUpdate(BaseModel): - parent_id: uuid.UUID | None = Field(None) - file_name: str | None = Field(None) - file_ext: str | None = Field(None) - file_size: str | None = Field(None) - - -class File(FileBase): - id: uuid.UUID - created_at: datetime.datetime - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None diff --git a/app/schemas/generic_file_schema.py b/app/schemas/generic_file_schema.py deleted file mode 100644 index 507f8697..00000000 --- a/app/schemas/generic_file_schema.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Schemas for Generic File Upload System -""" -from pydantic import BaseModel, Field, ConfigDict -from typing import Optional, Dict, Any -import datetime -import uuid - -from app.core.upload_enums import UploadContext - - -class GenericFileBase(BaseModel): - """Base schema for generic file""" - file_name: str = Field(..., description="文件名") - context: UploadContext = Field(..., description="上传上下文") - is_public: bool = Field(False, description="是否公开") - file_metadata: Optional[Dict[str, Any]] = Field(default={}, description="文件元数据") - - -class GenericFileCreate(GenericFileBase): - """Schema for creating a generic file""" - tenant_id: uuid.UUID - created_by: uuid.UUID - file_ext: str - file_size: int - mime_type: Optional[str] = None - storage_path: str - - -class GenericFileResponse(BaseModel): - """Schema for generic file response""" - id: uuid.UUID = Field(..., description="文件ID") - file_name: str = Field(..., description="文件名") - file_ext: str = Field(..., description="文件扩展名") - file_size: int = Field(..., description="文件大小(字节)") - mime_type: Optional[str] = Field(None, description="MIME类型") - context: str = Field(..., description="上传上下文") - access_url: Optional[str] = Field(None, description="访问URL") - is_public: bool = Field(..., description="是否公开") - file_metadata: Dict[str, Any] = Field(default={}, description="文件元数据") - status: str = Field(..., description="文件状态") - model_config = ConfigDict(from_attributes=True) - - created_at: datetime.datetime = Field(..., description="创建时间") - updated_at: datetime.datetime = Field(..., description="更新时间") - - -class FileMetadataUpdate(BaseModel): - """Schema for updating file metadata""" - file_name: Optional[str] = Field(None, description="文件名") - file_metadata: Optional[Dict[str, Any]] = Field(None, description="文件元数据") - is_public: Optional[bool] = Field(None, description="是否公开") - - -class UploadResultSchema(BaseModel): - """Schema for upload result""" - success: bool = Field(..., description="是否成功") - file_id: Optional[uuid.UUID] = Field(None, description="文件ID") - file_name: str = Field(..., description="文件名") - error: Optional[str] = Field(None, description="错误信息") - file_info: Optional[GenericFileResponse] = Field(None, description="文件信息") - - -class BatchUploadResponse(BaseModel): - """Schema for batch upload response""" - total: int = Field(..., description="总文件数") - success_count: int = Field(..., description="成功数量") - failed_count: int = Field(..., description="失败数量") - results: list[UploadResultSchema] = Field(..., description="上传结果列表") diff --git a/app/schemas/item_schema.py b/app/schemas/item_schema.py deleted file mode 100644 index 474ac059..00000000 --- a/app/schemas/item_schema.py +++ /dev/null @@ -1,5 +0,0 @@ -from pydantic import BaseModel - -class Item(BaseModel): - name: str - price: float diff --git a/app/schemas/knowledge_schema.py b/app/schemas/knowledge_schema.py deleted file mode 100644 index 4e5ea7d2..00000000 --- a/app/schemas/knowledge_schema.py +++ /dev/null @@ -1,69 +0,0 @@ -from pydantic import BaseModel, Field, field_serializer, ConfigDict -import datetime -import uuid -from .user_schema import User -from .model_schema import ModelConfig -from typing import Optional -from app.models.knowledge_model import KnowledgeType, PermissionType - - -class KnowledgeBase(BaseModel): - workspace_id: uuid.UUID | None = None - created_by: uuid.UUID | None = None - parent_id: uuid.UUID | None = None - name: str - description: str | None = None - avatar: str | None = None - type: KnowledgeType | None = None - permission_id: PermissionType | None = None - embedding_id: uuid.UUID | None = None - reranker_id: uuid.UUID | None = None - llm_id: uuid.UUID | None = None - image2text_id: uuid.UUID | None = None - doc_num: int | None = None - chunk_num: int | None = None - parser_id: str | None = None - parser_config: dict | None = None - - -class KnowledgeCreate(KnowledgeBase): - pass - -class KnowledgeUpdate(BaseModel): - parent_id: uuid.UUID | None = Field(None) - name: str | None = Field(None) - description: str | None = Field(None) - avatar: str | None = Field(None) - type: KnowledgeType | None = Field(None) - permission_id: PermissionType | None = Field(None) - embedding_id: uuid.UUID | None = Field(None) - reranker_id: uuid.UUID | None = Field(None) - llm_id: uuid.UUID | None = Field(None) - image2text_id: uuid.UUID | None = Field(None) - doc_num: int | None = Field(None) - chunk_num: int | None = Field(None) - parser_id: str | None = Field(None) - parser_config: dict | None = Field(None) - status: int | None = Field(None) - - -class Knowledge(KnowledgeBase): - id: uuid.UUID - status: int - created_at: datetime.datetime - updated_at: datetime.datetime - created_user: User - embedding: Optional[ModelConfig] = None - reranker: Optional[ModelConfig] = None - llm: Optional[ModelConfig] = None - image2text: Optional[ModelConfig] = None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None diff --git a/app/schemas/knowledgeshare_schema.py b/app/schemas/knowledgeshare_schema.py deleted file mode 100644 index faa79235..00000000 --- a/app/schemas/knowledgeshare_schema.py +++ /dev/null @@ -1,37 +0,0 @@ -from pydantic import BaseModel, Field, field_serializer, ConfigDict -import datetime -import uuid -from .knowledge_schema import Knowledge -from .workspace_schema import Workspace -from .user_schema import User - - -class KnowledgeShareBase(BaseModel): - source_kb_id: uuid.UUID - source_workspace_id: uuid.UUID | None = None - target_kb_id: uuid.UUID | None = None - target_workspace_id: uuid.UUID - shared_by: uuid.UUID | None = None - - -class KnowledgeShareCreate(KnowledgeShareBase): - pass - - -class KnowledgeShare(KnowledgeShareBase): - id: uuid.UUID - created_at: datetime.datetime - updated_at: datetime.datetime - target_kb: Knowledge - target_workspace: Workspace - shared_user: User - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None diff --git a/app/schemas/memory_agent_schema.py b/app/schemas/memory_agent_schema.py deleted file mode 100644 index e7c17407..00000000 --- a/app/schemas/memory_agent_schema.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - - -class UserInput(BaseModel): - message: str - history: list[dict] - search_switch: str - group_id: str - config_id: Optional[str] = None - - -class Write_UserInput(BaseModel): - message: str - group_id: str - config_id: Optional[str] = None diff --git a/app/schemas/memory_increment_schema.py b/app/schemas/memory_increment_schema.py deleted file mode 100644 index 565286b5..00000000 --- a/app/schemas/memory_increment_schema.py +++ /dev/null @@ -1,18 +0,0 @@ -import uuid -import datetime -from typing import Optional -from pydantic import BaseModel, Field, field_serializer -from pydantic import ConfigDict - -class MemoryIncrement(BaseModel): - model_config = ConfigDict(from_attributes=True) - - workspace_id: uuid.UUID = Field(description="工作空间ID") - total_num: int = Field(description="增量总数") - created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now()) - updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now()) - - @field_serializer('created_at', 'updated_at') - def serialize_datetime(self, dt: datetime.datetime, _info) -> str: - """将日期时间序列化为年月日格式""" - return dt.strftime('%Y-%m-%d') diff --git a/app/schemas/memory_storage_schema.py b/app/schemas/memory_storage_schema.py deleted file mode 100644 index 2ff773f3..00000000 --- a/app/schemas/memory_storage_schema.py +++ /dev/null @@ -1,343 +0,0 @@ -""" -所有的内容是放错误地方了,应该放在models -""" - -from typing import Any, Optional, List, Dict, Literal -import time -import uuid -from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator - - -# ============================================================================ -# 原 UserInput 相关 Schema (保留原有功能) -# ============================================================================ -class UserInput(BaseModel): - message: str - history: list[dict] - search_switch: str - group_id: str - - -class Write_UserInput(BaseModel): - message: str - group_id: str - - -# ============================================================================ -# 从 json_schema.py 迁移的 Schema -# ============================================================================ -class BaseDataSchema(BaseModel): - """Base schema for the data""" - id: str = Field(..., description="The unique identifier for the data entry.") - statement: str = Field(..., description="The statement text.") - group_id: str = Field(..., description="The group identifier.") - chunk_id: str = Field(..., description="The chunk identifier.") - created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.") - expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") - valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.") - invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.") - entity_ids: List[str] = Field([], description="The list of entity identifiers.") - - -class ConflictResultSchema(BaseModel): - """Schema for the conflict result data in the reflexion_data.json file.""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data.") - conflict: bool = Field(..., description="Whether the memory is in conflict.") - conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") - - @model_validator(mode="before") - def _normalize_data(cls, v): - if isinstance(v, dict): - d = v.get("data") - if isinstance(d, dict): - v["data"] = [d] - return v - - -class ConflictSchema(BaseModel): - """Schema for the conflict data in the reflexion_data""" - data: List[BaseDataSchema] = Field(..., description="The conflict memory data.") - conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.") - - @model_validator(mode="before") - def _normalize_data(cls, v): - if isinstance(v, dict): - d = v.get("data") - if isinstance(d, dict): - v["data"] = [d] - return v - - -class ReflexionSchema(BaseModel): - """Schema for the reflexion data in the reflexion_data""" - reason: str = Field(..., description="The reason for the reflexion.") - solution: str = Field(..., description="The solution for the reflexion.") - - -class ResolvedSchema(BaseModel): - """Schema for the resolved memory data in the reflexion_data""" - original_memory_id: Optional[str] = Field(None, description="The original memory identifier.") - resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.") - - -class ReflexionResultSchema(BaseModel): - """Schema for the reflexion result data in the reflexion_data.json file.""" - # 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory),而非字典映射 - conflict: ConflictResultSchema = Field(..., description="The conflict result data.") - reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.") - resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.") - - @model_validator(mode="before") - def _normalize_resolved(cls, v): - if isinstance(v, dict): - conflict = v.get("conflict") - if isinstance(conflict, dict) and conflict.get("conflict") is False: - v["resolved"] = None - else: - resolved = v.get("resolved") - if isinstance(resolved, dict): - orig = resolved.get("original_memory_id") - mem = resolved.get("resolved_memory") - if orig is None and (mem is None or mem == {}): - v["resolved"] = None - return v - - -# ============================================================================ -# 从 messages.py 迁移的 Schema -# ============================================================================ - -# Composite key identifying a config row -class ConfigKey(BaseModel): # 配置参数键模型 - model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: int = Field("config_id", description="配置唯一标识(字符串)") - user_id: str = Field("user_id", description="用户标识(字符串)") - apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") - - -# Allowed chunking strategies (extendable later) -ChunkerStrategy = Literal[ # 分块策略枚举 - "RecursiveChunker", - "TokenChunker", - "SemanticChunker", - "NeuralChunker", - "HybridChunker", - "LLMChunker", - "SentenceChunker", - "LateChunker" -] - - -# 这是 Request body示例 -class ConfigParams(ConfigKey): # 创建配置参数模型 旧 - model_config = ConfigDict(populate_by_name=True, extra="forbid") - - # Boolean switches - enable_llm_dedup_blockwise: bool = Field(True, description="启用LLM决策去重") - enable_llm_disambiguation: bool = Field(True, description="启用LLM决策消歧") - deep_retrieval: bool = Field(True, description="深度检索开关(保留既有拼写)") - - # Thresholds in [0, 1] - t_type_strict: float = Field(0.8, ge=0.0, le=1.0, description="类型严格阈值") - t_name_strict: float = Field(0.8, ge=0.0, le=1.0, description="名称严格阈值") - t_overall: float = Field(0.8, ge=0.0, le=1.0, description="综合阈值") - state: bool = Field(False, description="配置使用状态(True/False)") - # Chunker strategy selection (must be one of the declared literals) - chunker_strategy: ChunkerStrategy = Field( - "RecursiveChunker", - description=( - "分块策略:RecursiveChunker/TokenChunker/SemanticChunker/NeuralChunker/" - "HybridChunker/LLMChunker/SentenceChunker/LateChunker" - ), - ) - - @field_validator("chunker_strategy", mode="before") - @classmethod - def map_chunker_aliases(cls, v: str): - # 允许常见别名并映射到合法枚举 - if isinstance(v, str): - m = v.strip().lower() - alias_map = { - "auto": "RecursiveChunker", - "by_sentence": "SentenceChunker", - "by_paragraph": "SemanticChunker", - "fixed_tokens": "TokenChunker", - "递归分块": "RecursiveChunker", - "token 分块": "TokenChunker", - "token分块": "TokenChunker", - "语义分块": "SemanticChunker", - "神经网络分块": "NeuralChunker", - "混合分块": "HybridChunker", - "llm 分块": "LLMChunker", - "llm分块": "LLMChunker", - "句子分块": "SentenceChunker", - "延迟分块": "LateChunker", - } - if m in alias_map: - return alias_map[m] - return v - - @field_validator("config_id", "user_id", "apply_id") - @classmethod - def non_empty_str(cls, v: str) -> str: - s = str(v).strip() if v is not None else "" - if not s: - raise ValueError("标识字段必须为非空字符串") - return s - - -class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,去除主键) - model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_name: str = Field("配置名称", description="配置名称(字符串)") - config_desc: str = Field("配置描述", description="配置描述(字符串)") - workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)") - - # 模型配置字段(可选,用于手动指定或自动填充) - llm_id: Optional[str] = Field(None, description="LLM模型配置ID") - embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") - rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") - - -class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) - model_config = ConfigDict(populate_by_name=True, extra="forbid") - # config_name: str = Field("配置名称", description="配置名称(字符串)") - config_id: int = Field("配置ID", description="配置ID(字符串)") - - -class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id: Optional[int] = None - config_name: str = Field("配置名称", description="配置名称(字符串)") - config_desc: str = Field("配置描述", description="配置描述(字符串)") - - -class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id: Optional[int] = None - llm_id: Optional[str] = Field(None, description="LLM模型配置ID") - embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") - rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") - enable_llm_dedup_blockwise: Optional[bool] = None - enable_llm_disambiguation: Optional[bool] = None - deep_retrieval: Optional[bool] = Field(None, validation_alias="deep_retrieval") - - t_type_strict: Optional[float] = Field(None, ge=0.0, le=1.0) - t_name_strict: Optional[float] = Field(None, ge=0.0, le=1.0) - t_overall: Optional[float] = Field(None, ge=0.0, le=1.0) - state: Optional[bool] = None - chunker_strategy: Optional[ChunkerStrategy] = None - # 句子提取 - statement_granularity: Optional[int] = Field(2, ge=1, le=3, description="陈述提取颗粒度,挡位 1/2/3;默认 2") - include_dialogue_context: Optional[bool] = None - max_context: Optional[int] = Field(1000, gt=100, description="对话语境中包含字符的最大数量(>100);默认 1000") - - # 剪枝配置:与 runtime.json 中 pruning 段对应 - pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝") - pruning_scene: Optional[Literal["education", "online_service", "outbound"]] = Field( - None, description="智能剪枝场景:education/online_service/outbound" - ) - pruning_threshold: Optional[float] = Field( - None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)" - ) - - # 反思配置 - enable_self_reflexion: Optional[bool] = Field(None, description="是否启用自我反思") - iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field( - "3", description="反思迭代周期,单位小时" - ) - reflexion_range: Optional[Literal["retrieval", "database"]] = Field( - "retrieval", description="反思范围:部分/全部" - ) - baseline: Optional[Literal["TIME", "FACT", "TIME-FACT"]] = Field( - "TIME", description="基线:时间/事实/时间和事实" - ) - - @field_validator("chunker_strategy", mode="before") - @classmethod - def map_chunker_aliases_update(cls, v: str): - if isinstance(v, str): - m = v.strip().lower() - alias_map = { - "auto": "RecursiveChunker", - "by_sentence": "SentenceChunker", - "by_paragraph": "SemanticChunker", - "fixed_tokens": "TokenChunker", - "递归分块": "RecursiveChunker", - "token 分块": "TokenChunker", - "token分块": "TokenChunker", - "语义分块": "SemanticChunker", - "神经网络分块": "NeuralChunker", - "混合分块": "HybridChunker", - "llm 分块": "LLMChunker", - "llm分块": "LLMChunker", - "句子分块": "SentenceChunker", - "延迟分块": "LateChunker", - } - if m in alias_map: - return alias_map[m] - return v - - -class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 - # 遗忘引擎配置参数更新模型 - config_id: Optional[int] = None - lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5") - lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5") - offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0") - - -class ConfigPilotRun(BaseModel): # 试运行触发请求模型 - config_id: int = Field(..., description="配置ID(唯一)") - dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") - model_config = ConfigDict(populate_by_name=True, extra="forbid") - - -class ConfigFilter(BaseModel): # 查询配置参数时使用的模型 - model_config = ConfigDict(populate_by_name=True, extra="forbid") - - config_id: Optional[int] = None - user_id: Optional[str] = None - apply_id: Optional[str] = None - - limit: int = Field(20, ge=1, le=200, description="返回数量上限") - offset: int = Field(0, ge=0, description="起始偏移") - - -class ApiResponse(BaseModel): # 通用API响应模型 - model_config = ConfigDict(populate_by_name=True, extra="forbid") - code: int = Field(..., description="0=成功,非0=各类业务异常") - msg: str = Field("", description="说明信息") - data: Optional[Any] = Field(None, description="返回数据载荷") - error: str = Field("", description="错误信息,失败时有值,成功为空字符串") - time: Optional[int] = Field(None, description="响应时间(毫秒,Unix 时间戳)") - - -def _now_ms() -> int: - return int(round(time.time() * 1000)) - - -def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse: - return ApiResponse(code=0, msg=msg, data=data, error="", time=time or _now_ms()) - - -def fail( - msg: str, - error_code: str = "ERROR", - data: Optional[Any] = None, - time: Optional[int] = None, - query_preview: Optional[str] = None, -) -> ApiResponse: - payload = data - if query_preview is not None: - if payload is None: - payload = {"query_preview": query_preview} - elif isinstance(payload, dict): - payload = {**payload, "query_preview": query_preview} - else: - payload = {"data": payload, "query_preview": query_preview} - - return ApiResponse( - code=1, - msg=msg, - data=payload, - error=error_code, - time=time or _now_ms(), - ) diff --git a/app/schemas/model_schema.py b/app/schemas/model_schema.py deleted file mode 100644 index 5b1fe6d9..00000000 --- a/app/schemas/model_schema.py +++ /dev/null @@ -1,162 +0,0 @@ -from pydantic import BaseModel, Field, field_serializer, ConfigDict -from typing import Optional, List, Dict, Any -import datetime -import uuid - -from app.models.models_model import ModelProvider, ModelType - - - -# ModelConfig Schemas -class ModelConfigBase(BaseModel): - """模型配置基础Schema""" - name: str = Field(..., description="模型显示名称", max_length=255) - type: ModelType = Field(..., description="模型类型") - description: Optional[str] = Field(None, description="模型描述") - config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数") - is_active: bool = Field(True, description="是否激活") - is_public: bool = Field(False, description="是否公开") - - -class ApiKeyCreateNested(BaseModel): - """用于在创建模型时内嵌创建API Key的Schema""" - model_name: str = Field(..., description="模型实际名称", max_length=255) - provider: ModelProvider = Field(..., description="API Key提供商") - api_key: str = Field(..., description="API密钥", max_length=500) - api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) - config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") - priority: str = Field("1", description="优先级", max_length=10) - - -class ModelConfigCreate(ModelConfigBase): - """创建模型配置Schema""" - api_keys: Optional[ApiKeyCreateNested] = Field(None, description="同时创建的API Key配置") - skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证") - - -class ModelConfigUpdate(BaseModel): - """更新模型配置Schema""" - name: Optional[str] = Field(None, description="模型显示名称", max_length=255) - type: Optional[ModelType] = Field(None, description="模型类型") - description: Optional[str] = Field(None, description="模型描述") - config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数") - is_active: Optional[bool] = Field(None, description="是否激活") - is_public: Optional[bool] = Field(None, description="是否公开") - - -class ModelConfig(ModelConfigBase): - """模型配置Schema""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - created_at: datetime.datetime - updated_at: datetime.datetime - api_keys: List["ModelApiKey"] = [] - - -# ModelApiKey Schemas -class ModelApiKeyBase(BaseModel): - """API Key基础Schema""" - model_name: str = Field(..., description="模型实际名称", max_length=255) - provider: ModelProvider = Field(..., description="API Key提供商") - api_key: str = Field(..., description="API密钥", max_length=500) - api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) - config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置") - is_active: bool = Field(True, description="是否激活") - priority: str = Field("1", description="优先级", max_length=10) - - -class ModelApiKeyCreate(ModelApiKeyBase): - """创建API Key Schema""" - model_config_id: uuid.UUID = Field(..., description="模型配置ID") - - -class ModelApiKeyUpdate(BaseModel): - """更新API Key Schema""" - model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255) - provider: Optional[ModelProvider] = Field(None, description="API Key提供商") - api_key: Optional[str] = Field(None, description="API密钥", max_length=500) - api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) - config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置") - is_active: Optional[bool] = Field(None, description="是否激活") - priority: Optional[str] = Field(None, description="优先级", max_length=10) - - -class ModelApiKey(ModelApiKeyBase): - """API Key Schema""" - id: uuid.UUID - model_config_id: uuid.UUID - usage_count: str - last_used_at: Optional[datetime.datetime] - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("last_used_at", when_used="json") - def _serialize_last_used_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -# 查询和响应Schemas -class ModelConfigQuery(BaseModel): - """模型配置查询Schema""" - type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") - provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") - is_active: Optional[bool] = Field(None, description="激活状态筛选") - is_public: Optional[bool] = Field(None, description="公开状态筛选") - search: Optional[str] = Field(None, description="搜索关键词", max_length=255) - page: int = Field(1, description="页码", ge=1) - pagesize: int = Field(10, description="每页数量", ge=1, le=100) - -class ModelMarketplace(BaseModel): - """模型广场响应Schema""" - llm_models: List[ModelConfig] = [] - embedding_models: List[ModelConfig] = [] - rerank_models: List[ModelConfig] = [] - total_count: int - active_count: int - - -# 统计信息Schema -class ModelStats(BaseModel): - """模型统计信息Schema""" - total_models: int - active_models: int - llm_count: int - embedding_count: int - rerank_count: int - provider_stats: Dict[str, int] - - -# 验证模型配置Schema -class ModelValidateRequest(BaseModel): - """验证模型配置请求""" - model_name: str = Field(..., description="模型实际名称") - provider: ModelProvider = Field(..., description="API Key提供商") - api_key: str = Field(..., description="API密钥") - api_base: Optional[str] = Field(None, description="API基础URL") - model_type: Optional[ModelType] = Field(ModelType.LLM, description="模型类型") - test_message: Optional[str] = Field("Hello", description="测试消息") - - -class ModelValidateResponse(BaseModel): - """验证模型配置响应""" - valid: bool = Field(..., description="是否有效") - message: str = Field(..., description="验证消息") - response: Optional[str] = Field(None, description="模型响应内容") - elapsed_time: Optional[float] = Field(None, description="响应时间(秒)") - error: Optional[str] = Field(None, description="错误信息") - usage: Optional[Dict[str, Any]] = Field(None, description="Token使用情况") - - -# 更新前向引用 -ModelConfig.model_rebuild() \ No newline at end of file diff --git a/app/schemas/multi_agent_schema.py b/app/schemas/multi_agent_schema.py deleted file mode 100644 index a1547167..00000000 --- a/app/schemas/multi_agent_schema.py +++ /dev/null @@ -1,167 +0,0 @@ -"""多 Agent 相关的 Schema 定义""" -import uuid -import datetime -from typing import Optional, List, Dict, Any, Union -from pydantic import BaseModel, Field, ConfigDict, field_serializer - - -# ==================== 子 Agent 配置 ==================== - -class SubAgentConfig(BaseModel): - """子 Agent 配置""" - agent_id: uuid.UUID = Field(..., description="Agent ID") - name: str = Field(..., description="Agent 名称") - role: Optional[str] = Field(None, description="角色描述") - priority: int = Field(default=1, ge=1, le=100, description="优先级(1-100)") - capabilities: List[str] = Field(default_factory=list, description="能力列表") - - -class RoutingRule(BaseModel): - """路由规则""" - condition: str = Field(..., description="条件表达式") - target_agent_id: uuid.UUID = Field(..., description="目标 Agent ID") - priority: int = Field(default=1, ge=1, le=100, description="优先级") - - -class ExecutionConfig(BaseModel): - """执行配置""" - max_iterations: int = Field(default=5, ge=1, le=20, description="最大迭代次数") - timeout: int = Field(default=60, ge=10, le=300, description="超时时间(秒)") - parallel_limit: int = Field(default=3, ge=1, le=10, description="并行限制") - retry_on_failure: bool = Field(default=True, description="失败时是否重试") - max_retries: int = Field(default=3, ge=0, le=10, description="最大重试次数") - - -# ==================== 多 Agent 配置 ==================== - -class MultiAgentConfigCreate(BaseModel): - """创建多 Agent 配置""" - master_agent_id: uuid.UUID = Field(..., description="主 Agent ID") - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") - orchestration_mode: str = Field( - ..., - pattern="^(sequential|parallel|conditional|loop)$", - description="编排模式:sequential|parallel|conditional|loop" - ) - sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表") - routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则") - execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") - aggregation_strategy: str = Field( - default="merge", - pattern="^(merge|vote|priority|custom)$", - description="结果整合策略:merge|vote|priority|custom" - ) - - -class MultiAgentConfigUpdate(BaseModel): - """更新多 Agent 配置""" - master_agent_id: Optional[uuid.UUID] = None - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") - orchestration_mode: Optional[str] = Field( - None, - pattern="^(sequential|parallel|conditional|loop)$" - ) - sub_agents: Optional[List[SubAgentConfig]] = None - routing_rules: Optional[List[RoutingRule]] = None - execution_config: Optional[ExecutionConfig] = None - aggregation_strategy: Optional[str] = Field( - None, - pattern="^(merge|vote|priority|custom)$" - ) - is_active: Optional[bool] = None - - -class MultiAgentConfigSchema(BaseModel): - """多 Agent 配置输出""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - app_id: uuid.UUID - master_agent_id: uuid.UUID - master_agent_name: Optional[str] - orchestration_mode: str - sub_agents: List[Dict[str, Any]] - routing_rules: Optional[List[Dict[str, Any]]] - execution_config: Dict[str, Any] - aggregation_strategy: str - is_active: bool - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -# ==================== 多 Agent 运行 ==================== - -class MultiAgentRunRequest(BaseModel): - """多 Agent 运行请求""" - message: str = Field(..., description="用户消息") - conversation_id: Optional[uuid.UUID] = Field(None, description="会话 ID") - user_id: Optional[str] = Field(None, description="用户 ID") - variables: Optional[Dict[str, Any]] = Field(None, description="变量参数") - use_llm_routing: bool = Field(default=True, description="是否启用 LLM 路由(默认启用)") - stream: bool = Field(default=False, description="是否流式返回") - web_search: bool = Field(default=False, description="是否启用网络搜索") - memory: bool = Field(default=True, description="是否启用记忆功能") - - -class SubAgentResult(BaseModel): - """子 Agent 执行结果""" - agent_id: str - agent_name: str - result: Optional[Dict[str, Any]] = None - error: Optional[str] = None - elapsed_time: Optional[float] = None - - -class MultiAgentRunResponse(BaseModel): - """多 Agent 运行响应""" - message: str = Field(..., description="最终结果") - conversation_id: Optional[uuid.UUID] = Field(None, description="会话 ID") - elapsed_time: float = Field(..., description="总耗时(秒)") - mode: str = Field(..., description="执行模式") - sub_results: Union[List[Dict[str, Any]], Dict[str, Any]] = Field(..., description="子 Agent 结果") - usage: Optional[Dict[str, Any]] = Field(None, description="资源使用情况") - - -# ==================== 智能路由测试 ==================== - -class RoutingTestRequest(BaseModel): - """路由测试请求""" - message: str = Field(..., description="测试消息") - conversation_id: Optional[uuid.UUID] = Field(None, description="会话 ID(可选)") - routing_model_id: Optional[uuid.UUID] = Field(None, description="路由模型 ID(用于 LLM 路由)") - use_llm: bool = Field(default=False, description="是否启用 LLM 路由") - keyword_threshold: Optional[float] = Field( - default=0.8, - ge=0.0, - le=1.0, - description="关键词置信度阈值(0-1)" - ) - force_new: bool = Field(default=False, description="是否强制重新路由") - - -class RoutingTestCase(BaseModel): - """路由测试用例""" - message: str = Field(..., description="测试消息") - expected_agent_id: Optional[uuid.UUID] = Field(None, description="期望的 Agent ID") - description: Optional[str] = Field(None, description="测试用例描述") - - -class BatchRoutingTestRequest(BaseModel): - """批量路由测试请求""" - test_cases: List[RoutingTestCase] = Field(..., description="测试用例列表") - routing_model_id: Optional[uuid.UUID] = Field(None, description="路由模型 ID") - use_llm: bool = Field(default=False, description="是否启用 LLM 路由") - keyword_threshold: Optional[float] = Field( - default=0.8, - ge=0.0, - le=1.0, - description="关键词置信度阈值" - ) diff --git a/app/schemas/prompt_schema.py b/app/schemas/prompt_schema.py deleted file mode 100644 index 409d162e..00000000 --- a/app/schemas/prompt_schema.py +++ /dev/null @@ -1,61 +0,0 @@ -from jinja2 import Environment, Template, meta -from typing import Any, Dict -from enum import Enum -from pydantic import BaseModel, Field -from abc import ABC -from typing import Union, List - - -class PromptMessageRole(str, Enum): - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - -class TextPromptMessageContent(BaseModel): - type: str = Field(default="text") - data: str -PromptMessageContentUnionTypes = TextPromptMessageContent -class PromptMessage(ABC, BaseModel): - role: PromptMessageRole - content: Union[str, List[PromptMessageContentUnionTypes], None] = None - name: Union[str, None] = None - - model_config = {"arbitrary_types_allowed": True} - - def is_empty(self) -> bool: - return not self.content - - def get_text_content(self) -> str: - if isinstance(self.content, str): - return self.content - elif isinstance(self.content, list): - return "".join([item.data for item in self.content if isinstance(item, TextPromptMessageContent)]) - return "" - - -def render_prompt_message(template_str: str, role: PromptMessageRole, params: Dict[str, Any]) -> PromptMessage: - """ - 通用函数:自动解析模板变量,渲染PromptMessage - - template_str: Jinja2模板字符串 - - role: PromptMessageRole - - params: 提供模板变量的字典 - """ - env = Environment() - parsed_content = env.parse(template_str) - variables = meta.find_undeclared_variables(parsed_content) - - # 检查缺失参数,如果缺失则给默认值 '' - for var in variables: - if var not in params: - params[var] = "" - - # 渲染模板 - jinja_template = Template(template_str) - rendered_text = jinja_template.render(**params) - - return PromptMessage( - role=role, - content=[TextPromptMessageContent(data=rendered_text)] - ) - - diff --git a/app/schemas/release_share_schema.py b/app/schemas/release_share_schema.py deleted file mode 100644 index 069b78a9..00000000 --- a/app/schemas/release_share_schema.py +++ /dev/null @@ -1,104 +0,0 @@ -import uuid -import datetime -from typing import Optional, List, Dict, Any -from pydantic import BaseModel, Field, ConfigDict, field_serializer - - -# ---------- Input Schemas ---------- - -class ReleaseShareCreate(BaseModel): - """创建/启用分享配置""" - is_enabled: bool = Field(default=True, description="是否启用公开分享") - require_password: bool = Field(default=False, description="是否需要密码访问") - password: Optional[str] = Field(None, min_length=4, max_length=50, description="访问密码(4-50字符)") - allow_embed: bool = Field(default=False, description="是否允许嵌入") - embed_domains: Optional[List[str]] = Field(default=None, description="允许嵌入的域名白名单,空表示不限制") - - -class ReleaseShareUpdate(BaseModel): - """更新分享配置""" - is_enabled: Optional[bool] = Field(None, description="是否启用公开分享") - require_password: Optional[bool] = Field(None, description="是否需要密码访问") - password: Optional[str] = Field(None, min_length=4, max_length=50, description="访问密码") - allow_embed: Optional[bool] = Field(None, description="是否允许嵌入") - embed_domains: Optional[List[str]] = Field(None, description="允许嵌入的域名白名单") - - -class PasswordVerifyRequest(BaseModel): - """密码验证请求""" - password: str = Field(..., description="访问密码") - - -class TokenRequest(BaseModel): - """获取访问 token 请求""" - user_id: Optional[str] = Field(None, description="用户 ID(可选,不提供则自动生成)") - password: Optional[str] = Field(None, description="访问密码(如果需要)") - - -# ---------- Output Schemas ---------- - -class ReleaseShare(BaseModel): - """分享配置输出""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - release_id: uuid.UUID - app_id: uuid.UUID - is_enabled: bool - share_token: str - share_url: str # 完整的公开访问 URL - require_password: bool - allow_embed: bool - embed_domains: List[str] = [] - view_count: int - last_accessed_at: Optional[datetime.datetime] = None - created_at: datetime.datetime - updated_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("updated_at", when_used="json") - def _serialize_updated_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("last_accessed_at", when_used="json") - def _serialize_last_accessed_at(self, dt: Optional[datetime.datetime]): - return int(dt.timestamp() * 1000) if dt else None - - -class SharedReleaseInfo(BaseModel): - """公开访问返回的应用信息""" - app_name: str - app_description: Optional[str] = None - app_icon: Optional[str] = None - app_type: str - version: int - release_notes: Optional[str] = None - published_at: int - - # 根据应用类型返回不同配置 - config: Dict[str, Any] = {} - - # 访问控制信息 - require_password: bool - is_password_verified: bool = False # 当前是否已验证密码 - - # 嵌入配置 - allow_embed: bool - - -class EmbedCode(BaseModel): - """嵌入代码""" - iframe_code: str = Field(..., description="iframe 嵌入代码") - preview_url: str = Field(..., description="预览 URL") - width: str = Field(default="100%", description="宽度") - height: str = Field(default="600px", description="高度") - - -class ShareStats(BaseModel): - """分享统计""" - view_count: int - last_accessed_at: Optional[int] = None - created_at: int diff --git a/app/schemas/response_schema.py b/app/schemas/response_schema.py deleted file mode 100644 index 91505d91..00000000 --- a/app/schemas/response_schema.py +++ /dev/null @@ -1,22 +0,0 @@ -from pydantic import BaseModel, Field -from typing import Any, Optional -import time - - -class PageMeta(BaseModel): - page: int = Field(..., description="当前页码,从1开始") - pagesize: int = Field(..., description="每页数量") - total: int = Field(..., description="总条数") - hasnext: bool = Field(..., description="是否有下一页") - -class PageData(BaseModel): - page: PageMeta = Field(..., description="分页元数据") - items: list = Field(..., description="分页数据列表") - - -class ApiResponse(BaseModel): - code: int = Field(0, description="业务状态码,0=成功,非0=各类业务异常") - msg: str = Field("OK", description="给人看的简短提示") - data: Optional[Any] = Field(None, description="具体数据") - error: str = Field("", description="失败时的字段级错误信息,成功时为空字符串") - time: int = Field(default_factory=lambda: int(time.time()), description="Unix时间戳(秒)") \ No newline at end of file diff --git a/app/schemas/retrieval_info_schema.py b/app/schemas/retrieval_info_schema.py deleted file mode 100644 index 42ab126f..00000000 --- a/app/schemas/retrieval_info_schema.py +++ /dev/null @@ -1,13 +0,0 @@ -import uuid -import datetime -from typing import Optional, Text -from pydantic import BaseModel, Field -from pydantic import ConfigDict - -class Host(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID = Field(description="宿主ID") - host_id: uuid.UUID = Field(description="其他ID") - retrieve_info: Optional[Text] = Field(description="检索信息") - created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now) diff --git a/app/schemas/tenant_schema.py b/app/schemas/tenant_schema.py deleted file mode 100644 index 6e8bd158..00000000 --- a/app/schemas/tenant_schema.py +++ /dev/null @@ -1,65 +0,0 @@ -from pydantic import BaseModel, Field, field_validator, ConfigDict -from typing import Optional, List -import datetime -import uuid -from app.core.exceptions import ValidationException -from app.core.error_codes import BizCode - - -class TenantBase(BaseModel): - """租户基础Schema""" - name: str = Field(..., description="租户名称", max_length=255) - description: Optional[str] = Field(None, description="租户描述", max_length=1000) - is_active: bool = Field(True, description="是否激活") - - @field_validator('name') - @classmethod - def validate_name(cls, v): - if not v or not v.strip(): - raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED) - return v.strip() - - -class TenantCreate(TenantBase): - """创建租户Schema""" - pass - - -class TenantUpdate(BaseModel): - """更新租户Schema""" - name: Optional[str] = Field(None, description="租户名称", max_length=255) - description: Optional[str] = Field(None, description="租户描述", max_length=1000) - is_active: Optional[bool] = Field(None, description="是否激活") - - @field_validator('name') - @classmethod - def validate_name(cls, v): - if v is not None and (not v or not v.strip()): - raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED) - return v.strip() if v else v - - -class Tenant(TenantBase): - """租户Schema""" - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - created_at: datetime.datetime - updated_at: datetime.datetime - - -class TenantQuery(BaseModel): - """租户查询Schema""" - is_active: Optional[bool] = Field(None, description="激活状态筛选") - search: Optional[str] = Field(None, description="搜索关键词", max_length=255) - page: int = Field(1, description="页码", ge=1) - size: int = Field(10, description="每页数量", ge=1, le=100) - - -class TenantList(BaseModel): - """租户列表响应Schema""" - items: List[Tenant] - total: int - page: int - size: int - pages: int \ No newline at end of file diff --git a/app/schemas/token_schema.py b/app/schemas/token_schema.py deleted file mode 100644 index 310e98a0..00000000 --- a/app/schemas/token_schema.py +++ /dev/null @@ -1,30 +0,0 @@ -from pydantic import BaseModel, EmailStr, field_serializer -from typing import Optional -import datetime - -class Token(BaseModel): - access_token: str - refresh_token: str - token_type: str - expires_at: datetime.datetime - refresh_expires_at: datetime.datetime - - @field_serializer("expires_at", when_used="json") - def _serialize_expires_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("refresh_expires_at", when_used="json") - def _serialize_refresh_expires_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - -class TokenData(BaseModel): - userId: Optional[str] = None - -class RefreshTokenRequest(BaseModel): - refresh_token: str - -class TokenRequest(BaseModel): - email: EmailStr - password: str - invite: Optional[str] = None - diff --git a/app/schemas/user_schema.py b/app/schemas/user_schema.py deleted file mode 100644 index 97006370..00000000 --- a/app/schemas/user_schema.py +++ /dev/null @@ -1,76 +0,0 @@ -from dataclasses import field -from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict -from typing import Optional -import datetime -import uuid - -from app.models import Workspace -from app.models.workspace_model import WorkspaceRole - - -class UserBase(BaseModel): - username: str - email: EmailStr - - -class UserCreate(UserBase): - password: str - - -class UserUpdate(BaseModel): - username: Optional[str] = None - email: Optional[EmailStr] = None - is_active: Optional[bool] = None - is_superuser: Optional[bool] = None - - -class ChangePasswordRequest(BaseModel): - """修改密码请求""" - old_password: str = Field(..., description="当前密码") - new_password: str = Field(..., min_length=6, description="新密码,至少6位") - - -class AdminChangePasswordRequest(BaseModel): - """管理员修改用户密码请求""" - user_id: uuid.UUID = Field(..., description="要修改密码的用户ID") - new_password: Optional[str] = Field(None, min_length=6, description="新密码,至少6位。如果不提供则自动生成随机密码") - - -class ChangePasswordResponse(BaseModel): - """修改密码响应""" - message: str - success: bool = True - generated_password: Optional[str] = Field(None, description="自动生成的密码(仅在管理员重置时返回)") - - -class User(UserBase): - id: uuid.UUID - is_active: bool - is_superuser: bool - created_at: int - last_login_at: Optional[int] = None - current_workspace_id: Optional[uuid.UUID] = None - current_workspace_name: Optional[str] = None - role: Optional[WorkspaceRole] = None - - # 将 datetime 转换为毫秒时间戳 - @validator("created_at", pre=True) - def _created_at_to_ms(cls, v): - if isinstance(v, datetime.datetime): - return int(v.timestamp() * 1000) - if isinstance(v, (int, float)): - return int(v) - return v - - model_config = ConfigDict(from_attributes=True) - - @field_validator("last_login_at", mode="before") - def _last_login_to_ms(cls, v): - if v is None: - return None - if isinstance(v, datetime.datetime): - return int(v.timestamp() * 1000) - if isinstance(v, (int, float)): - return int(v) - return v - diff --git a/app/schemas/workspace_schema.py b/app/schemas/workspace_schema.py deleted file mode 100644 index eb3e31e2..00000000 --- a/app/schemas/workspace_schema.py +++ /dev/null @@ -1,172 +0,0 @@ -import email -from pydantic import BaseModel, Field, EmailStr, field_serializer, computed_field, ConfigDict -import datetime -import uuid -from typing import Literal -from app.models.workspace_model import WorkspaceRole, InviteStatus - - -class WorkspaceBase(BaseModel): - name: str - description: str | None = None - icon: str | None = None - iconType: str | None = None - storage_type: str | None = None - llm: str | None = None - embedding: str | None = None - rerank: str | None = None - - -class WorkspaceCreate(WorkspaceBase): - pass - - - -class WorkspaceUpdate(BaseModel): - name: str | None = Field(None) - description: str | None = Field(None) - icon: str | None = Field(None) - iconType: str | None = Field(None) - storage_type: str | None = Field(None) - llm: str | None = Field(None) - embedding: str | None = Field(None) - rerank: str | None = Field(None) - - -class Workspace(WorkspaceBase): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - tenant_id: uuid.UUID - created_at: datetime.datetime - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -class WorkspaceResponse(WorkspaceBase): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - tenant_id: uuid.UUID - created_at: datetime.datetime - is_active: bool - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp()) if dt else None - - -class WorkspaceMemberBase(BaseModel): - user_id: uuid.UUID - role: WorkspaceRole - - -class WorkspaceMemberCreate(WorkspaceMemberBase): - pass - -class WorkspaceMemberUpdate(BaseModel): - id: uuid.UUID - role: WorkspaceRole - -class WorkspaceMember(WorkspaceMemberBase): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - workspace_id: uuid.UUID - email: str - - -# 简版嵌套模型用于成员详情的关系序列化 -class UserShort(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - username: str - email: EmailStr - - -class WorkspaceShort(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - name: str - - -class WorkspaceMemberDetail(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - role: WorkspaceRole - is_active: bool - user: UserShort - workspace: WorkspaceShort - - -# 成员管理表格视图项(扁平化字段,便于前端表格渲染) -class WorkspaceMemberItem(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - username: str - account: EmailStr - role: WorkspaceRole # 原始角色值:manager | member - last_login_at: datetime.datetime | None = None - - # 将最后登录时间序列化为毫秒时间戳,便于前端统一格式化 - @field_serializer("last_login_at", when_used="json") - def _serialize_last_login(self, dt: datetime.datetime | None): - return int(dt.timestamp() * 1000) if dt else None - - # # 动态计算角色中文标签 - # @computed_field - # def role_label(self) -> str: - # return "管理员" if self.role == WorkspaceRole.manager else "成员" - - -# Workspace Invite Schemas -class WorkspaceInviteCreate(BaseModel): - email: EmailStr = Field(..., description="被邀请者邮箱") - role: WorkspaceRole = Field(..., description="邀请角色:manager 或 member") - expires_in_days: int = Field(default=7, ge=1, le=30, description="邀请有效期天数,默认7天") - - -class WorkspaceInviteResponse(BaseModel): - id: uuid.UUID - workspace_id: uuid.UUID - email: str - role: WorkspaceRole - status: InviteStatus - expires_at: datetime.datetime - accepted_at: datetime.datetime | None - created_by_user_id: uuid.UUID - created_at: datetime.datetime - invite_token: str | None = Field(None, description="邀请令牌,仅在创建时返回") - - @field_serializer("expires_at", when_used="json") - def _serialize_expires_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - @field_serializer("created_at", when_used="json") - def _serialize_created_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("accepted_at", when_used="json") - def _serialize_accepted_at(self, dt: datetime.datetime): - return int(dt.timestamp() * 1000) if dt else None - - -class InviteValidateResponse(BaseModel): - workspace_name: str - workspace_id: uuid.UUID - email: str - role: WorkspaceRole - is_expired: bool - is_valid: bool - - -class InviteAcceptRequest(BaseModel): - token: str = Field(..., description="邀请令牌") diff --git a/app/services/__init__.py b/app/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/agent_config_converter.py b/app/services/agent_config_converter.py deleted file mode 100644 index 262c1c04..00000000 --- a/app/services/agent_config_converter.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -Agent 配置格式转换器 -用于将 Pydantic 模型转换为数据库存储格式 -""" -from typing import Dict, Any, Optional -from app.schemas.app_schema import ( - KnowledgeRetrievalConfig, - MemoryConfig, - VariableDefinition, - ToolConfig, - AgentConfigCreate, - AgentConfigUpdate, -) - - -class AgentConfigConverter: - """Agent 配置格式转换器""" - - @staticmethod - def to_storage_format(config: AgentConfigCreate | AgentConfigUpdate) -> Dict[str, Any]: - """ - 将配置对象转换为数据库存储格式 - - Args: - config: AgentConfigCreate 或 AgentConfigUpdate 对象 - - Returns: - 包含数据库字段的字典 - """ - result = {} - - # 1. 模型参数配置 - if hasattr(config, 'model_parameters') and config.model_parameters: - result["model_parameters"] = config.model_parameters.model_dump() - - # 2. 知识库检索配置 - if config.knowledge_retrieval: - result["knowledge_retrieval"] = config.knowledge_retrieval.model_dump() - - # 3. 记忆配置 - if hasattr(config, 'memory') and config.memory: - result["memory"] = config.memory.model_dump() - - # 4. 变量配置 - if hasattr(config, 'variables') and config.variables: - result["variables"] = [var.model_dump() for var in config.variables] - - # 5. 工具配置 - if hasattr(config, 'tools') and config.tools: - result["tools"] = { - name: tool.model_dump() - for name, tool in config.tools.items() - } - - return result - - @staticmethod - def from_storage_format( - model_parameters: Optional[Dict[str, Any]], - knowledge_retrieval: Optional[Dict[str, Any]], - memory: Optional[Dict[str, Any]], - variables: Optional[list], - tools: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: - """ - 将数据库存储格式转换为 Pydantic 对象 - - Args: - model_parameters: 模型参数配置 - knowledge_retrieval: 知识库检索配置 - memory: 记忆配置 - variables: 变量配置 - tools: 工具配置 - - Returns: - 包含 Pydantic 对象的字典 - """ - result = { - "model_parameters": None, - "knowledge_retrieval": None, - "memory": MemoryConfig(enabled=True), - "variables": [], - "tools": {}, - } - - # 1. 解析模型参数配置 - if model_parameters: - from app.schemas.app_schema import ModelParameters - result["model_parameters"] = ModelParameters(**model_parameters) - - # 2. 解析知识库检索配置 - if knowledge_retrieval: - result["knowledge_retrieval"] = KnowledgeRetrievalConfig(**knowledge_retrieval) - else: - # 提供默认的知识库配置(空列表) - result["knowledge_retrieval"] = KnowledgeRetrievalConfig( - knowledge_bases=[], - merge_strategy="weighted" - ) - - # 3. 解析记忆配置 - if memory: - result["memory"] = MemoryConfig(**memory) - - # 4. 解析变量配置 - if variables: - result["variables"] = [VariableDefinition(**var) for var in variables] - - # 5. 解析工具配置 - if tools: - result["tools"] = { - name: ToolConfig(**tool_data) - for name, tool_data in tools.items() - } - - return result diff --git a/app/services/agent_config_helper.py b/app/services/agent_config_helper.py deleted file mode 100644 index ae195913..00000000 --- a/app/services/agent_config_helper.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -Agent 配置辅助函数 -用于增强 AgentConfig 对象,添加解析后的字段 -""" -from app.models import AgentConfig -from app.services.agent_config_converter import AgentConfigConverter - - -def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: - """ - 增强 AgentConfig 对象,添加解析后的配置字段 - - Args: - agent_cfg: AgentConfig ORM 对象 - - Returns: - 增强后的 AgentConfig 对象(添加了解析字段) - """ - if not agent_cfg: - return agent_cfg - - # 解析数据库存储格式 - parsed = AgentConfigConverter.from_storage_format( - model_parameters=agent_cfg.model_parameters, - knowledge_retrieval=agent_cfg.knowledge_retrieval, - memory=agent_cfg.memory, - variables=agent_cfg.variables, - tools=agent_cfg.tools, - ) - - # 将解析后的字段添加到对象上(用于序列化) - agent_cfg.model_parameters = parsed["model_parameters"] - agent_cfg.knowledge_retrieval = parsed["knowledge_retrieval"] - agent_cfg.memory = parsed["memory"] - agent_cfg.variables = parsed["variables"] - agent_cfg.tools = parsed["tools"] - - return agent_cfg diff --git a/app/services/agent_invocation_service.py b/app/services/agent_invocation_service.py deleted file mode 100644 index e69de29b..00000000 diff --git a/app/services/agent_registry.py b/app/services/agent_registry.py deleted file mode 100644 index 999d018b..00000000 --- a/app/services/agent_registry.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Agent 注册表服务""" -import uuid -from typing import Optional, List, Dict, Any -from sqlalchemy.orm import Session -from sqlalchemy import select, or_, and_ - -from app.models import AgentConfig, App -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class AgentRegistry: - """Agent 注册表 - 管理所有可用的 Agent""" - - def __init__(self, db: Session): - self.db = db - self._cache: Dict[str, Dict[str, Any]] = {} - - def register_agent(self, agent: AgentConfig) -> None: - """注册 Agent 到系统 - - Args: - agent: Agent 配置对象 - """ - agent_info = self._to_agent_info(agent) - self._cache[str(agent.id)] = agent_info - - logger.info( - f"Agent 注册成功", - extra={ - "agent_id": str(agent.id), - "name": agent.app.name, - "domain": agent.agent_domain - } - ) - - def discover_agents( - self, - query: Optional[str] = None, - domain: Optional[str] = None, - capabilities: Optional[List[str]] = None, - workspace_id: Optional[uuid.UUID] = None - ) -> List[Dict[str, Any]]: - """发现可用的 Agent - - Args: - query: 搜索关键词 - domain: 专业领域 - capabilities: 所需能力列表 - workspace_id: 工作空间ID(权限过滤) - - Returns: - 匹配的 Agent 列表 - """ - # 构建查询 - stmt = select(AgentConfig).join(App).where( - AgentConfig.is_active == True, - App.is_active == True - ) - - # 工作空间过滤(同工作空间或公开) - if workspace_id: - stmt = stmt.where( - or_( - App.workspace_id == workspace_id, - App.visibility == "public" - ) - ) - - # 领域过滤 - if domain: - stmt = stmt.where(AgentConfig.agent_domain == domain) - - # 能力过滤 - if capabilities: - # PostgreSQL JSON 数组包含查询 - for cap in capabilities: - stmt = stmt.where( - AgentConfig.capabilities.contains([cap]) - ) - - # 关键词搜索 - if query: - stmt = stmt.where( - or_( - App.name.ilike(f"%{query}%"), - App.description.ilike(f"%{query}%") - ) - ) - - agents = self.db.scalars(stmt).all() - - logger.debug( - f"Agent 发现", - extra={ - "query": query, - "domain": domain, - "capabilities": capabilities, - "found_count": len(agents) - } - ) - - return [self._to_agent_info(agent) for agent in agents] - - def get_agent(self, agent_id: uuid.UUID) -> Optional[Dict[str, Any]]: - """获取 Agent 信息 - - Args: - agent_id: Agent ID - - Returns: - Agent 信息字典,如果不存在返回 None - """ - agent_id_str = str(agent_id) - - # 先查缓存 - if agent_id_str in self._cache: - return self._cache[agent_id_str] - - # 查数据库 - agent = self.db.get(AgentConfig, agent_id) - if agent and agent.is_active: - agent_info = self._to_agent_info(agent) - self._cache[agent_id_str] = agent_info - return agent_info - - return None - - def _to_agent_info(self, agent: AgentConfig) -> Dict[str, Any]: - """转换为 Agent 信息字典 - - Args: - agent: Agent 配置对象 - - Returns: - Agent 信息字典 - """ - return { - "id": str(agent.id), - "name": agent.app.name, - "description": agent.app.description, - "domain": agent.agent_domain, - "role": agent.agent_role, - "capabilities": agent.capabilities or [], - "tools": list(agent.tools.keys()) if agent.tools else [], - "knowledge_bases": self._extract_kb_ids(agent), - "system_prompt": self._truncate_prompt(agent.system_prompt), - "status": "active" if agent.is_active else "inactive", - "workspace_id": str(agent.app.workspace_id), - "visibility": agent.app.visibility - } - - def _extract_kb_ids(self, agent: AgentConfig) -> List[str]: - """提取知识库 ID 列表 - - Args: - agent: Agent 配置对象 - - Returns: - 知识库 ID 列表 - """ - if not agent.knowledge_retrieval: - return [] - - kb_config = agent.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - return [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - - def _truncate_prompt(self, prompt: Optional[str], max_length: int = 200) -> Optional[str]: - """截断提示词 - - Args: - prompt: 提示词 - max_length: 最大长度 - - Returns: - 截断后的提示词 - """ - if not prompt: - return None - - if len(prompt) <= max_length: - return prompt - - return prompt[:max_length] + "..." - - def clear_cache(self) -> None: - """清空缓存""" - self._cache.clear() - logger.debug("Agent 注册表缓存已清空") diff --git a/app/services/agent_server.py b/app/services/agent_server.py deleted file mode 100644 index 65d763fd..00000000 --- a/app/services/agent_server.py +++ /dev/null @@ -1,130 +0,0 @@ - - -from typing import Any, List - -from langchain_openai import ChatOpenAI -from langgraph.checkpoint.memory import InMemorySaver -from pydantic import BaseModel - -from langchain.agents import create_agent, AgentState -from langchain.agents.middleware import before_model -from langchain.tools import tool -from langchain_core.messages import RemoveMessage -from langgraph.graph.message import REMOVE_ALL_MESSAGES -from langgraph.runtime import Runtime - -from app.services.api_resquests_server import send_message, model, retrieval - - -class config(BaseModel): - template_str:str - params:dict - model_configs: List[dict] = [] - history_memory:bool - knowledge_base:bool - -class RemoryInput(BaseModel): - question: str - end_user_id: str - search_switch:str - -class ChatRequest(BaseModel): - end_user_id: str - message: str - search_switch:str - kb_ids: List[str] = [] - similarity_threshold:float - vector_similarity_weight:float - top_k:int - hybrid:bool - token:str - -class RetrievalInput(BaseModel): - message: str - kb_ids: List[str] = [] - similarity_threshold: float - vector_similarity_weight: float - top_k: int - hybrid: bool - token: str - -async def tool_Retrieval(req): - tool_result = retrieval_search.invoke({ - "message":req.message, "kb_ids":req.kb_ids, - "similarity_threshold":req.similarity_threshold, "vector_similarity_weight":req.vector_similarity_weight, - "top_k":req.top_k, "hybrid":req.hybrid, "token":req.token - }) - return tool_result -async def tool_memory(req): - tool_result = remory_sk.invoke({ - "question": req.message, - "end_user_id": req.end_user_id, - "search_switch": req.search_switch - }) - return tool_result - - -@before_model -# ========== 消息剪枝中间件 ========== -def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: - """保留前1条 + 最近3~4条消息""" - messages = state["messages"] - if len(messages) <= 10: - return None - first_msg = messages[0] - recent_messages = messages[-10:] if len(messages) % 2 == 0 else messages[-11:] - new_messages = [first_msg] + recent_messages - - return { - "messages": [ - RemoveMessage(id=REMOVE_ALL_MESSAGES), - *new_messages - ] - } - -##-----------历史记忆------------ -@ tool(args_schema=RemoryInput) -def remory_sk(question: str, end_user_id: str, search_switch: str): - """ - 条件调用工具: - - 仅当 question 是疑问句时调用 send_message - - 根据 end_user_id 和 search_switch 参数决定是否执行检索 - - Args: - question: 用户的提问内容 - end_user_id: 用户唯一标识符 - search_switch: 搜索开关,控制是否执行检索 - - Returns: - 检索结果或空字符串 - """ - # 移除关于 configurable 的描述,避免混淆 - if not end_user_id or end_user_id == '123': - print("警告: 无效的 user_id 参数") - return '' - - if search_switch in ['on', 'off'] or not search_switch: - print("警告: 无效的 search_switch 参数") - return '' - return send_message(end_user_id, question, '[]', search_switch) - -#-------------检索------------ - - -@ tool(args_schema=RetrievalInput) -def retrieval_search(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token): - '''检索''' - search=retrieval(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token) - return search -async def create_dynamic_agent(model_name: str,model_id:str,PROMPT:str,token:str): - """根据模型名动态创建代理""" - model_name, api_key, api_base=await model(model_id,token) - llm = ChatOpenAI(model=model_name, base_url=api_base, temperature=0.2,api_key=api_key) - memory = InMemorySaver() - return create_agent( - llm, - tools=[remory_sk,retrieval_search], - middleware=[trim_messages], - checkpointer=memory, - system_prompt=PROMPT - ) \ No newline at end of file diff --git a/app/services/agent_tools.py b/app/services/agent_tools.py deleted file mode 100644 index 96032b7d..00000000 --- a/app/services/agent_tools.py +++ /dev/null @@ -1,331 +0,0 @@ -"""Agent 发现和调用工具""" -import uuid -import time -import datetime -from typing import Optional, Dict, Any, List -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy.orm import Session - -from app.models import AgentConfig, ModelConfig, AgentInvocation -from app.services.agent_registry import AgentRegistry -from app.core.exceptions import BusinessException, ResourceNotFoundException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger -from app.repositories import workspace_repository, knowledge_repository - -logger = get_business_logger() - - -# ==================== Agent 发现工具 ==================== - -class AgentDiscoveryInput(BaseModel): - """Agent 发现工具输入参数""" - query: Optional[str] = Field(None, description="搜索关键词,如:'客服'、'技术支持'") - domain: Optional[str] = Field(None, description="专业领域,如:'customer_service'、'technical_support'") - capabilities: Optional[List[str]] = Field(None, description="所需能力列表,如:['退货处理', '订单查询']") - - -def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID): - """创建 Agent 发现工具 - - Args: - registry: Agent 注册表 - workspace_id: 当前工作空间 ID - - Returns: - Agent 发现工具 - """ - - @tool(args_schema=AgentDiscoveryInput) - def discover_agents( - query: Optional[str] = None, - domain: Optional[str] = None, - capabilities: Optional[List[str]] = None - ) -> str: - """发现系统中可用的 Agent。当需要找到能够处理特定任务的 Agent 时使用此工具。 - - Args: - query: 搜索关键词(如:"客服"、"技术支持") - domain: 专业领域(如:"customer_service"、"technical_support") - capabilities: 所需能力(如:["退货处理", "订单查询"]) - - Returns: - 可用 Agent 的列表和描述 - """ - try: - agents = registry.discover_agents( - query=query, - domain=domain, - capabilities=capabilities, - workspace_id=workspace_id - ) - - if not agents: - return "未找到匹配的 Agent" - - # 格式化输出 - result = f"找到 {len(agents)} 个可用的 Agent:\n\n" - for i, agent in enumerate(agents, 1): - result += f"{i}. {agent['name']}\n" - result += f" ID: {agent['id']}\n" - if agent['description']: - result += f" 描述: {agent['description']}\n" - if agent['domain']: - result += f" 领域: {agent['domain']}\n" - if agent['capabilities']: - result += f" 能力: {', '.join(agent['capabilities'])}\n" - if agent['tools']: - result += f" 工具: {', '.join(agent['tools'])}\n" - result += "\n" - - logger.info( - f"Agent 发现成功", - extra={ - "query": query, - "domain": domain, - "found_count": len(agents) - } - ) - - return result - - except Exception as e: - logger.error(f"Agent 发现失败", extra={"error": str(e)}) - return f"发现 Agent 失败: {str(e)}" - - return discover_agents - - -# ==================== Agent 调用工具 ==================== - -class AgentInvocationInput(BaseModel): - """Agent 调用工具输入参数""" - agent_id: str = Field(..., description="要调用的 Agent ID(通过 discover_agents 工具获取)") - message: str = Field(..., description="发送给 Agent 的消息或任务描述") - context: Optional[Dict[str, Any]] = Field(None, description="可选的上下文信息(如:用户信息、历史记录等)") - - -def create_agent_invocation_tool( - db: Session, - registry: AgentRegistry, - workspace_id: uuid.UUID, - current_agent_id: uuid.UUID, - conversation_id: Optional[uuid.UUID] = None, - parent_invocation_id: Optional[uuid.UUID] = None, - invocation_chain: Optional[List[uuid.UUID]] = None -): - """创建 Agent 调用工具 - - Args: - db: 数据库会话 - registry: Agent 注册表 - workspace_id: 当前工作空间 ID - current_agent_id: 当前 Agent ID - conversation_id: 会话 ID - parent_invocation_id: 父调用 ID - invocation_chain: 调用链(用于检测循环调用) - - Returns: - Agent 调用工具 - """ - # 1. 获取工作空间的 storage_type - storage_type = 'neo4j' # 默认值 - user_rag_memory_id = None - - try: - workspace = workspace_repository.get_workspace_by_id(db, workspace_id) - if workspace and workspace.storage_type: - storage_type = workspace.storage_type - logger.debug( - f"获取工作空间存储类型成功", - extra={ - "workspace_id": str(workspace_id), - "storage_type": storage_type - } - ) - except Exception as e: - logger.warning( - f"获取工作空间存储类型失败,使用默认值 neo4j", - extra={"workspace_id": str(workspace_id), "error": str(e)} - ) - - # 2. 如果 storage_type 是 rag,获取知识库 ID - if storage_type == 'rag': - try: - knowledge = knowledge_repository.get_knowledge_by_name( - db=db, - name="USER_RAG_MEMORY", - workspace_id=workspace_id - ) - if knowledge: - user_rag_memory_id = str(knowledge.id) - logger.debug( - f"获取 RAG 知识库成功", - extra={ - "workspace_id": str(workspace_id), - "knowledge_id": user_rag_memory_id - } - ) - else: - logger.warning( - f"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储", - extra={"workspace_id": str(workspace_id)} - ) - storage_type = 'neo4j' - except Exception as e: - logger.warning( - f"获取 RAG 知识库失败,将使用 neo4j 存储", - extra={"workspace_id": str(workspace_id), "error": str(e)} - ) - storage_type = 'neo4j' - - if invocation_chain is None: - invocation_chain = [] - - @tool(args_schema=AgentInvocationInput) - async def invoke_agent( - agent_id: str, - message: str, - context: Optional[Dict[str, Any]] = None - ) -> str: - """调用另一个 Agent 来处理任务。当当前 Agent 无法处理某个任务,需要其他专业 Agent 帮助时使用。 - - Args: - agent_id: 要调用的 Agent ID(通过 discover_agents 工具获取) - message: 发送给 Agent 的消息或任务描述 - context: 可选的上下文信息(如:用户信息、历史记录等) - - Returns: - 被调用 Agent 的响应结果 - """ - try: - # 1. 验证 Agent 存在 - agent_uuid = uuid.UUID(agent_id) - agent_info = registry.get_agent(agent_uuid) - if not agent_info: - return f"Agent {agent_id} 不存在" - - # 2. 验证权限(同工作空间或公开) - if agent_info["workspace_id"] != str(workspace_id) and agent_info["visibility"] != "public": - return f"无权访问 Agent {agent_info['name']}" - - # 3. 防止自己调用自己 - if agent_id == str(current_agent_id): - return "不能调用自己" - - # 4. 防止循环调用 - if agent_uuid in invocation_chain: - return f"检测到循环调用:{agent_info['name']} 已在调用链中" - - # 5. 检查调用深度 - max_depth = 5 - if len(invocation_chain) >= max_depth: - return f"调用深度超过限制(最大 {max_depth} 层)" - - # 6. 获取 Agent 配置 - agent_config = db.get(AgentConfig, agent_uuid) - if not agent_config: - return f"Agent 配置不存在" - - # 7. 获取模型配置 - model_config = db.get(ModelConfig, agent_config.default_model_config_id) - if not model_config: - return f"Agent 模型配置不存在" - - # 8. 创建调用记录 - invocation = AgentInvocation( - caller_agent_id=current_agent_id, - callee_agent_id=agent_uuid, - conversation_id=conversation_id, - parent_invocation_id=parent_invocation_id, - input_message=message, - context=context, - status="running", - started_at=datetime.datetime.now() - ) - db.add(invocation) - db.commit() - db.refresh(invocation) - - logger.info( - f"Agent 调用开始", - extra={ - "invocation_id": str(invocation.id), - "caller_agent_id": str(current_agent_id), - "callee_agent_id": agent_id, - "depth": len(invocation_chain) - } - ) - - start_time = time.time() - - try: - # 9. 调用 Agent - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) - - result = await draft_service.run( - agent_config=agent_config, - model_config=model_config, - message=message, - workspace_id=workspace_id, - variables=context or {}, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - elapsed_time = time.time() - start_time - - # 10. 更新调用记录 - invocation.status = "completed" - invocation.output_message = result["message"] - invocation.completed_at = datetime.datetime.now() - invocation.elapsed_time = elapsed_time - invocation.token_usage = result.get("usage", {}) - db.commit() - - logger.info( - f"Agent 调用成功", - extra={ - "invocation_id": str(invocation.id), - "caller_agent_id": str(current_agent_id), - "callee_agent_id": agent_id, - "elapsed_time": elapsed_time - } - ) - - return result["message"] - - except Exception as e: - # 更新调用记录为失败 - invocation.status = "failed" - invocation.error_message = str(e) - invocation.completed_at = datetime.datetime.now() - invocation.elapsed_time = time.time() - start_time - db.commit() - - logger.error( - f"Agent 调用失败", - extra={ - "invocation_id": str(invocation.id), - "caller_agent_id": str(current_agent_id), - "callee_agent_id": agent_id, - "error": str(e) - } - ) - - raise - - except Exception as e: - logger.error( - f"Agent 调用异常", - extra={ - "caller_agent_id": str(current_agent_id), - "callee_agent_id": agent_id, - "error": str(e) - } - ) - return f"调用 Agent 失败: {str(e)}" - - return invoke_agent diff --git a/app/services/api_key_service.py b/app/services/api_key_service.py deleted file mode 100644 index 6deac112..00000000 --- a/app/services/api_key_service.py +++ /dev/null @@ -1,173 +0,0 @@ -"""API Key Service""" -from sqlalchemy.orm import Session -from typing import Optional, Tuple, List -import uuid -import datetime -import math - -from app.models.api_key_model import ApiKey, ApiKeyType -from app.repositories.api_key_repository import ApiKeyRepository -from app.schemas import api_key_schema -from app.schemas.response_schema import PageData, PageMeta -from app.core.api_key_utils import generate_api_key -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class ApiKeyService: - """API Key 业务逻辑服务""" - - @staticmethod - def create_api_key( - db: Session, - *, - workspace_id: uuid.UUID, - user_id: uuid.UUID, - data: api_key_schema.ApiKeyCreate - ) -> Tuple[ApiKey, str]: - """创建 API Key - - Returns: - Tuple[ApiKey, str]: (API Key 对象, API Key 明文) - """ - # 生成 API Key - api_key, key_hash, key_prefix = generate_api_key(data.type) - - # 创建数据 - api_key_data = { - "id": uuid.uuid4(), - "name": data.name, - "description": data.description, - "key_prefix": key_prefix, - "key_hash": key_hash, - "type": data.type, - "scopes": data.scopes, - "workspace_id": workspace_id, - "resource_id": data.resource_id, - "resource_type": data.resource_type, - "rate_limit": data.rate_limit, - "quota_limit": data.quota_limit, - "expires_at": data.expires_at, - "created_by": user_id, - "created_at": datetime.datetime.now(), - "updated_at": datetime.datetime.now(), - } - - api_key_obj = ApiKeyRepository.create(db, api_key_data) - db.commit() - db.refresh(api_key_obj) - - logger.info(f"API Key 创建成功", extra={ - "api_key_id": str(api_key_obj.id), - "name": data.name, - "type": data.type - }) - - return api_key_obj, api_key - - @staticmethod - def get_api_key( - db: Session, - api_key_id: uuid.UUID, - workspace_id: uuid.UUID - ) -> ApiKey: - """获取 API Key""" - api_key = ApiKeyRepository.get_by_id(db, api_key_id) - if not api_key: - raise BusinessException("API Key 不存在", BizCode.NOT_FOUND) - - if api_key.workspace_id != workspace_id: - raise BusinessException("无权访问此 API Key", BizCode.FORBIDDEN) - - return api_key - - @staticmethod - def list_api_keys( - db: Session, - workspace_id: uuid.UUID, - query: api_key_schema.ApiKeyQuery - ) -> PageData: - """列出 API Keys""" - items, total = ApiKeyRepository.list_by_workspace(db, workspace_id, query) - pages = math.ceil(total / query.pagesize) if total > 0 else 0 - - return PageData( - page=PageMeta( - page=query.page, - pagesize=query.pagesize, - total=total, - hasnext=query.page < pages - ), - items=[api_key_schema.ApiKey.model_validate(item) for item in items] - ) - - @staticmethod - def update_api_key( - db: Session, - api_key_id: uuid.UUID, - workspace_id: uuid.UUID, - data: api_key_schema.ApiKeyUpdate - ) -> ApiKey: - """更新 API Key""" - api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - - update_data = data.model_dump(exclude_unset=True) - ApiKeyRepository.update(db, api_key_id, update_data) - db.commit() - db.refresh(api_key) - - logger.info(f"API Key 更新成功", extra={"api_key_id": str(api_key_id)}) - return api_key - - @staticmethod - def delete_api_key( - db: Session, - api_key_id: uuid.UUID, - workspace_id: uuid.UUID - ) -> bool: - """删除 API Key""" - api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - - ApiKeyRepository.delete(db, api_key_id) - db.commit() - - logger.info(f"API Key 删除成功", extra={"api_key_id": str(api_key_id)}) - return True - - @staticmethod - def regenerate_api_key( - db: Session, - api_key_id: uuid.UUID, - workspace_id: uuid.UUID - ) -> Tuple[ApiKey, str]: - """重新生成 API Key""" - api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - - # 生成新的 API Key - new_api_key, key_hash, key_prefix = generate_api_key(ApiKeyType(api_key.type)) - - # 更新 - ApiKeyRepository.update(db, api_key_id, { - "key_hash": key_hash, - "key_prefix": key_prefix - }) - db.commit() - db.refresh(api_key) - - logger.info(f"API Key 重新生成成功", extra={"api_key_id": str(api_key_id)}) - return api_key, new_api_key - - @staticmethod - def get_stats( - db: Session, - api_key_id: uuid.UUID, - workspace_id: uuid.UUID - ) -> api_key_schema.ApiKeyStats: - """获取使用统计""" - api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) - - stats_data = ApiKeyRepository.get_stats(db, api_key_id) - return api_key_schema.ApiKeyStats(**stats_data) diff --git a/app/services/app_service.py b/app/services/app_service.py deleted file mode 100644 index 621f0aa2..00000000 --- a/app/services/app_service.py +++ /dev/null @@ -1,1903 +0,0 @@ -""" -应用服务层 - -提供应用管理的业务逻辑,包括: -- 应用的创建、更新、查询 -- Agent 配置管理 -- 应用发布和版本管理 -- 应用回滚 -""" -import datetime -import uuid -from typing import Optional, List, Dict, Any, Tuple - -from sqlalchemy.orm import Session -from sqlalchemy import select, func, or_, and_ - -from app.models import App, AgentConfig, AppRelease, MultiAgentConfig -from app.schemas import app_schema -from app.core.exceptions import ( - ResourceNotFoundException, - ValidationException, - BusinessException, -) -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger -from app.services.agent_config_converter import AgentConfigConverter -from app.models.app_model import AppStatus, AppType - -# 获取业务日志器 -logger = get_business_logger() - - -class AppService: - """应用服务类 - - 负责应用相关的所有业务逻辑处理,遵循单一职责原则。 - """ - - def __init__(self, db: Session): - """初始化应用服务 - - Args: - db: 数据库会话 - """ - self.db = db - - # ==================== 私有辅助方法 ==================== - - def _validate_workspace_access(self, app: App, workspace_id: Optional[uuid.UUID]) -> None: - """验证工作空间访问权限(严格模式,用于修改操作) - - Args: - app: 应用对象 - workspace_id: 工作空间ID - - Raises: - BusinessException: 当应用不在指定工作空间时 - """ - if workspace_id is not None and app.workspace_id != workspace_id: - logger.warning( - f"工作空间访问被拒", - extra={"app_id": str(app.id), "workspace_id": str(workspace_id)} - ) - raise BusinessException("应用不在指定工作空间中", BizCode.WORKSPACE_NO_ACCESS) - - def _check_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> bool: - """检查应用是否可访问(包括共享应用) - - Args: - app: 应用对象 - workspace_id: 工作空间ID - - Returns: - bool: 是否可访问 - """ - from app.models import AppShare - - if workspace_id is None: - return True - - # 1. 检查是否是本工作空间的应用 - if app.workspace_id == workspace_id: - return True - - # 2. 检查是否是共享给本工作空间的应用 - stmt = select(AppShare).where( - AppShare.source_app_id == app.id, - AppShare.target_workspace_id == workspace_id - ) - share = self.db.scalars(stmt).first() - - return share is not None - - def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None: - """验证应用是否可访问(包括共享应用,用于只读操作) - - Args: - app: 应用对象 - workspace_id: 工作空间ID - - Raises: - BusinessException: 当应用不可访问时 - """ - if not self._check_app_accessible(app, workspace_id): - logger.warning( - f"应用访问被拒", - extra={"app_id": str(app.id), "workspace_id": str(workspace_id)} - ) - raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS) - - def _get_app_or_404(self, app_id: uuid.UUID) -> App: - """获取应用或抛出404异常 - - Args: - app_id: 应用ID - - Returns: - App: 应用对象 - - Raises: - ResourceNotFoundException: 当应用不存在时 - """ - app = self.db.get(App, app_id) - if not app: - logger.warning(f"应用不存在", extra={"app_id": str(app_id)}) - raise ResourceNotFoundException("应用", str(app_id)) - return app - - def _check_agent_config(self, app_id: uuid.UUID): - from app.models import AgentConfig, ModelConfig - from app.services.app_service import AppService - from app.models import AgentConfig, ModelConfig - from sqlalchemy import select - from app.core.exceptions import BusinessException - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - def _check_multi_agent_config(self, app_id: uuid.UUID): - """检查多智能体配置的完整性 - - 验证内容: - 1. 多智能体配置是否存在 - 2. 主 Agent 配置是否存在 - 3. 子 Agent 配置是否存在 - 4. 所有 Agent 的模型配置是否存在 - - Args: - app_id: 应用 ID - - Raises: - BusinessException: 配置不完整或不存在时抛出 - """ - from app.models import MultiAgentConfig, AgentConfig, ModelConfig - from app.services.multi_agent_service import MultiAgentService - - # 1. 检查多智能体配置是否存在 - service = MultiAgentService(self.db) - multi_agent_config = service.get_config(app_id) - - if not multi_agent_config: - raise BusinessException( - "多智能体配置不存在,无法运行", - BizCode.AGENT_CONFIG_MISSING - ) - - if not multi_agent_config.is_active: - raise BusinessException( - "多智能体配置未激活,无法运行", - BizCode.AGENT_CONFIG_MISSING - ) - - # 2. 检查主 Agent 配置 - if not multi_agent_config.master_agent_id: - raise BusinessException( - "未配置主 Agent,无法运行", - BizCode.AGENT_CONFIG_MISSING - ) - - master_agent_release = self.db.get(AppRelease, multi_agent_config.master_agent_id) - if not master_agent_release: - raise BusinessException( - f"主 Agent 配置不存在: {multi_agent_config.master_agent_id}", - BizCode.AGENT_CONFIG_MISSING - ) - - # 检查主 Agent 的模型配置 - if master_agent_release.default_model_config_id: - master_model = self.db.get(ModelConfig, master_agent_release.default_model_config_id) - if not master_model: - raise BusinessException( - f"主 Agent 的模型配置不存在: {master_agent_release.default_model_config_id}", - BizCode.MODEL_NOT_FOUND - ) - else: - raise BusinessException( - "主 Agent 未配置模型,无法运行", - BizCode.MODEL_NOT_FOUND - ) - - # 3. 检查子 Agent 配置 - if not multi_agent_config.sub_agents or len(multi_agent_config.sub_agents) == 0: - raise BusinessException( - "未配置子 Agent,无法运行", - BizCode.AGENT_CONFIG_MISSING - ) - - # 4. 验证每个子 Agent 及其模型配置 - for idx, sub_agent_data in enumerate(multi_agent_config.sub_agents): - agent_id = sub_agent_data.get('agent_id') - if not agent_id: - raise BusinessException( - f"子 Agent #{idx + 1} 缺少 agent_id", - BizCode.AGENT_CONFIG_MISSING - ) - - # 转换为 UUID - try: - from uuid import UUID - agent_uuid = UUID(agent_id) if isinstance(agent_id, str) else agent_id - except (ValueError, TypeError): - raise BusinessException( - f"子 Agent #{idx + 1} 的 agent_id 格式无效: {agent_id}", - BizCode.INVALID_PARAMETER - ) - - # 检查子 Agent 是否存在 - sub_agent_release = self.db.get(AppRelease, agent_uuid) - if not sub_agent_release: - raise BusinessException( - f"子 Agent 配置不存在: {agent_id} ({sub_agent_data.get('name', '未命名')})", - BizCode.AGENT_CONFIG_MISSING - ) - - # 检查子 Agent 的模型配置 - if sub_agent_release.default_model_config_id: - sub_model = self.db.get(ModelConfig, sub_agent_release.default_model_config_id) - if not sub_model: - raise BusinessException( - f"子 Agent '{sub_agent_data.get('name', '未命名')}' 的模型配置不存在: {sub_agent_release.default_model_config_id}", - BizCode.MODEL_NOT_FOUND - ) - else: - raise BusinessException( - f"子 Agent '{sub_agent_data.get('name', '未命名')}' 未配置模型,无法运行", - BizCode.MODEL_NOT_FOUND - ) - - logger.info( - f"多智能体配置检查通过", - extra={ - "app_id": str(app_id), - "master_agent_id": str(multi_agent_config.master_agent_id), - "sub_agent_count": len(multi_agent_config.sub_agents) - } - ) - - def _create_agent_config( - self, - app_id: uuid.UUID, - config_data: app_schema.AgentConfigCreate, - now: datetime.datetime - ) -> None: - """创建 Agent 配置(内部方法) - - Args: - app_id: 应用ID - config_data: Agent 配置数据 - now: 当前时间 - """ - storage_data = AgentConfigConverter.to_storage_format(config_data) - - agent_cfg = AgentConfig( - id=uuid.uuid4(), - app_id=app_id, - system_prompt=config_data.system_prompt, - default_model_config_id=config_data.default_model_config_id, - model_parameters=storage_data.get("model_parameters"), - knowledge_retrieval=storage_data.get("knowledge_retrieval"), - memory=storage_data.get("memory"), - variables=storage_data.get("variables", []), - tools=storage_data.get("tools", {}), - is_active=True, - created_at=now, - updated_at=now, - ) - self.db.add(agent_cfg) - logger.debug(f"Agent 配置已创建", extra={"app_id": str(app_id)}) - - def _create_multi_agent_config( - self, - app_id: uuid.UUID, - config_data: Dict[str, Any], - now: datetime.datetime - ) -> None: - """创建多 Agent 配置(内部方法) - - Args: - app_id: 应用ID - config_data: 多 Agent 配置数据(Dict) - now: 当前时间 - """ - # 将 Dict 转换为 MultiAgentConfigCreate - from app.schemas.multi_agent_schema import ( - MultiAgentConfigCreate, - SubAgentConfig, - RoutingRule, - ExecutionConfig - ) - - # 转换 sub_agents - sub_agents = [SubAgentConfig(**sa) for sa in config_data.get('sub_agents', [])] - - # 转换 routing_rules(如果有) - routing_rules = None - if config_data.get('routing_rules'): - routing_rules = [RoutingRule(**rr) for rr in config_data['routing_rules']] - - # 转换 execution_config - execution_config = ExecutionConfig(**config_data.get('execution_config', {})) - - # 创建 MultiAgentConfigCreate 对象 - config = MultiAgentConfigCreate( - master_agent_id=config_data['master_agent_id'], - orchestration_mode=config_data['orchestration_mode'], - sub_agents=sub_agents, - routing_rules=routing_rules, - execution_config=execution_config, - aggregation_strategy=config_data.get('aggregation_strategy', 'merge') - ) - - # 验证主 Agent 存在 - master_agent = self.db.get(AgentConfig, config.master_agent_id) - if not master_agent: - raise ResourceNotFoundException("主 Agent", str(config.master_agent_id)) - - # 验证子 Agent 存在 - for sub_agent in config.sub_agents: - agent = self.db.get(AgentConfig, sub_agent.agent_id) - if not agent: - raise ResourceNotFoundException("子 Agent", str(sub_agent.agent_id)) - - # 创建多 Agent 配置 - # 将 UUID 转换为字符串以便 JSON 序列化 - sub_agents_data = [] - for sub_agent in config.sub_agents: - sa_dict = sub_agent.model_dump() - sa_dict['agent_id'] = str(sa_dict['agent_id']) # UUID -> str - sub_agents_data.append(sa_dict) - - routing_rules_data = None - if config.routing_rules: - routing_rules_data = [] - for rule in config.routing_rules: - rule_dict = rule.model_dump() - rule_dict['target_agent_id'] = str(rule_dict['target_agent_id']) # UUID -> str - routing_rules_data.append(rule_dict) - - multi_agent_cfg = MultiAgentConfig( - id=uuid.uuid4(), - app_id=app_id, - master_agent_id=config.master_agent_id, - orchestration_mode=config.orchestration_mode, - sub_agents=sub_agents_data, - routing_rules=routing_rules_data, - execution_config=config.execution_config.model_dump(), - aggregation_strategy=config.aggregation_strategy, - is_active=True, - created_at=now, - updated_at=now, - ) - self.db.add(multi_agent_cfg) - logger.debug(f"多 Agent 配置已创建", extra={"app_id": str(app_id), "mode": config.orchestration_mode}) - - def _get_next_version(self, app_id: uuid.UUID) -> int: - """获取下一个版本号 - - Args: - app_id: 应用ID - - Returns: - int: 下一个版本号 - """ - stmt = select(func.max(AppRelease.version)).where(AppRelease.app_id == app_id) - max_ver = self.db.execute(stmt).scalar() - return 1 if max_ver is None else int(max_ver) + 1 - - def _convert_to_schema( - self, - app: App, - current_workspace_id: uuid.UUID - ) -> app_schema.App: - """将 App 模型转换为 Schema,并设置 is_shared 字段 - - Args: - app: App 模型实例 - current_workspace_id: 当前工作空间ID - - Returns: - app_schema.App: 应用 Schema - """ - app_dict = { - "id": app.id, - "workspace_id": app.workspace_id, - "created_by": app.created_by, - "name": app.name, - "description": app.description, - "icon": app.icon, - "icon_type": app.icon_type, - "type": app.type, - "visibility": app.visibility, - "status": app.status, - "tags": app.tags or [], - "current_release_id": app.current_release_id, - "is_active": app.is_active, - "is_shared": app.workspace_id != current_workspace_id, # 判断是否是共享应用 - "created_at": app.created_at, - "updated_at": app.updated_at - } - return app_schema.App(**app_dict) - - # ==================== 应用管理 ==================== - - def get_app( - self, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> App: - """获取应用详情 - - Args: - app_id: 应用ID - workspace_id: 工作空间ID(用于权限验证,支持共享应用) - - Returns: - App: 应用对象 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用不可访问时 - """ - app = self._get_app_or_404(app_id) - self._validate_app_accessible(app, workspace_id) - return app - - def create_app( - self, - *, - user_id: uuid.UUID, - workspace_id: uuid.UUID, - data: app_schema.AppCreate - ) -> App: - """创建应用 - - Args: - user_id: 创建者用户ID - workspace_id: 工作空间ID - data: 应用创建数据 - - Returns: - App: 创建的应用对象 - - Raises: - BusinessException: 当创建失败时 - """ - logger.info( - f"创建应用", - extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)} - ) - - try: - now = datetime.datetime.now() - - app = App( - id=uuid.uuid4(), - workspace_id=workspace_id, - created_by=user_id, - name=data.name, - description=data.description, - icon=data.icon, - icon_type=data.icon_type, - type=data.type, - visibility=data.visibility, - status=data.status, - tags=data.tags or [], - is_active=True, - created_at=now, - updated_at=now, - ) - self.db.add(app) - self.db.flush() # 获取 app.id - - # 如果是 agent 类型且提供了配置,创建 AgentConfig - if app.type == "agent" and data.agent_config: - self._create_agent_config(app.id, data.agent_config, now) - - # 如果是 multi_agent 类型且提供了配置,创建 MultiAgentConfig - if app.type == "multi_agent" and data.multi_agent_config: - self._create_multi_agent_config(app.id, data.multi_agent_config, now) - - self.db.commit() - self.db.refresh(app) - - logger.info(f"应用创建成功", extra={"app_id": str(app.id), "app_name": app.name}) - return app - - except Exception as e: - self.db.rollback() - logger.error(f"应用创建失败", extra={"app_name": data.name, "error": str(e)}) - raise BusinessException(f"应用创建失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) - - def update_app( - self, - *, - app_id: uuid.UUID, - data: app_schema.AppUpdate, - workspace_id: Optional[uuid.UUID] = None - ) -> App: - """更新应用基本信息 - - Args: - app_id: 应用ID - data: 更新数据 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - App: 更新后的应用对象 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用不在指定工作空间时 - """ - logger.info(f"更新应用", extra={"app_id": str(app_id)}) - - app = self._get_app_or_404(app_id) - self._validate_workspace_access(app, workspace_id) - - changed = False - for field in ["name", "description", "icon", "icon_type", "visibility", "status", "tags"]: - val = getattr(data, field, None) - if val is not None: - setattr(app, field, val) - changed = True - - if changed: - app.updated_at = datetime.datetime.now() - self.db.commit() - self.db.refresh(app) - logger.info(f"应用更新成功", extra={"app_id": str(app_id)}) - else: - logger.debug(f"应用无变更", extra={"app_id": str(app_id)}) - - return app - - def delete_app( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> None: - """删除应用 - - Args: - app_id: 应用ID - workspace_id: 工作空间ID(用于权限验证) - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用不在指定工作空间时 - """ - logger.info(f"删除应用", extra={"app_id": str(app_id)}) - - app = self._get_app_or_404(app_id) - self._validate_workspace_access(app, workspace_id) - - # 删除应用(级联删除相关数据) - self.db.delete(app) - self.db.commit() - - logger.info( - f"应用删除成功", - extra={ - "app_id": str(app_id), - "app_name": app.name, - "app_type": app.type - } - ) - - def copy_app( - self, - *, - app_id: uuid.UUID, - user_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None, - new_name: Optional[str] = None - ) -> App: - """复制应用(包括基础信息和配置) - - Args: - app_id: 源应用ID - user_id: 创建者用户ID - workspace_id: 目标工作空间ID(如果为None,则复制到源应用所在工作空间) - new_name: 新应用名称(如果为None,则使用"源应用名称 - 副本") - - Returns: - App: 复制后的新应用对象 - - Raises: - ResourceNotFoundException: 当源应用不存在时 - BusinessException: 当复制失败时 - """ - logger.info(f"复制应用", extra={"source_app_id": str(app_id)}) - - try: - # 获取源应用 - source_app = self._get_app_or_404(app_id) - self._validate_app_accessible(source_app, workspace_id) - - # 确定目标工作空间 - target_workspace_id = workspace_id or source_app.workspace_id - - # 确定新应用名称 - if not new_name: - new_name = f"{source_app.name} - 副本" - - now = datetime.datetime.now() - - # 创建新应用(复制基础信息) - new_app = App( - id=uuid.uuid4(), - workspace_id=target_workspace_id, - created_by=user_id, - name=new_name, - description=source_app.description, - icon=source_app.icon, - icon_type=source_app.icon_type, - type=source_app.type, - visibility=source_app.visibility, - status="draft", # 复制的应用默认为草稿状态 - tags=source_app.tags.copy() if source_app.tags else [], - is_active=True, - created_at=now, - updated_at=now, - ) - self.db.add(new_app) - self.db.flush() - - # 如果是 agent 类型,复制 AgentConfig - if source_app.type == "agent": - source_config = self.db.query(AgentConfig).filter( - AgentConfig.app_id == source_app.id - ).first() - - if source_config: - new_config = AgentConfig( - id=uuid.uuid4(), - app_id=new_app.id, - system_prompt=source_config.system_prompt, - default_model_config_id=source_config.default_model_config_id, - model_parameters=source_config.model_parameters.copy() if source_config.model_parameters else None, - knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None, - memory=source_config.memory.copy() if source_config.memory else None, - variables=source_config.variables.copy() if source_config.variables else [], - tools=source_config.tools.copy() if source_config.tools else {}, - is_active=True, - created_at=now, - updated_at=now, - ) - self.db.add(new_config) - - self.db.commit() - self.db.refresh(new_app) - - logger.info( - f"应用复制成功", - extra={ - "source_app_id": str(app_id), - "new_app_id": str(new_app.id), - "new_app_name": new_app.name - } - ) - - return new_app - - except Exception as e: - self.db.rollback() - logger.error( - f"应用复制失败", - extra={"source_app_id": str(app_id), "error": str(e)} - ) - raise BusinessException(f"应用复制失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) - - def list_apps( - self, - *, - workspace_id: uuid.UUID, - type: Optional[str] = None, - visibility: Optional[str] = None, - status: Optional[str] = None, - search: Optional[str] = None, - include_shared: bool = True, - page: int = 1, - pagesize: int = 10, - ) -> Tuple[List[App], int]: - """列出工作空间中的应用(分页) - - 包括: - 1. 本工作空间创建的应用 - 2. 其他工作空间分享给本工作空间的应用(如果 include_shared=True) - - Args: - workspace_id: 工作空间ID - type: 应用类型过滤 - visibility: 可见性过滤 - status: 状态过滤 - search: 搜索关键词 - include_shared: 是否包含分享的应用 - page: 页码(从1开始) - pagesize: 每页数量 - - Returns: - Tuple[List[App], int]: (应用列表, 总数) - """ - from app.models import AppShare - - logger.debug( - f"查询应用列表", - extra={ - "workspace_id": str(workspace_id), - "include_shared": include_shared, - "page": page, - "pagesize": pagesize - } - ) - - # 构建查询条件 - filters = [] - if type: - filters.append(App.type == type) - if visibility: - filters.append(App.visibility == visibility) - if status: - filters.append(App.status == status) - if search: - filters.append(func.lower(App.name).like(f"%{search.lower()}%")) - - # 基础查询:本工作空间的应用 - if include_shared: - # 查询本工作空间的应用 + 分享给本工作空间的应用 - # 使用 OR 条件:workspace_id = current OR app_id IN (shared apps) - - # 获取分享给本工作空间的应用ID列表 - shared_app_ids_stmt = ( - select(AppShare.source_app_id) - .where(AppShare.target_workspace_id == workspace_id) - ) - - # 构建主查询:本工作空间的应用 OR 分享的应用 - stmt = select(App).where( - or_( - App.workspace_id == workspace_id, - App.id.in_(shared_app_ids_stmt) - ) - ) - else: - # 只查询本工作空间的应用 - stmt = select(App).where(App.workspace_id == workspace_id) - - # 应用过滤条件 - if filters: - stmt = stmt.where(and_(*filters)) - - # 计算总数 - total_stmt = select(func.count()).select_from(stmt.subquery()) - total = self.db.execute(total_stmt).scalar() or 0 - - # 分页 - offset = (page - 1) * pagesize - stmt = stmt.order_by(App.created_at.desc()).offset(offset).limit(pagesize) - - items = list(self.db.scalars(stmt).all()) - - logger.debug( - f"应用列表查询完成", - extra={"total": total, "returned": len(items), "include_shared": include_shared} - ) - return items, int(total) - - # ==================== Agent 配置管理 ==================== - - def update_agent_config( - self, - *, - app_id: uuid.UUID, - data: app_schema.AgentConfigUpdate, - workspace_id: Optional[uuid.UUID] = None - ) -> AgentConfig: - """更新 Agent 配置 - - Args: - app_id: 应用ID - data: 配置更新数据 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - AgentConfig: 更新后的配置对象 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或不在指定工作空间时 - """ - logger.info(f"更新 Agent 配置", extra={"app_id": str(app_id)}) - - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED) - - self._validate_workspace_access(app, workspace_id) - - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active==True).order_by(AgentConfig.updated_at.desc()) - agent_cfg: Optional[AgentConfig] = self.db.scalars(stmt).first() - now = datetime.datetime.now() - - if not agent_cfg: - agent_cfg = AgentConfig( - id=uuid.uuid4(), - app_id=app_id, - is_active=True, - created_at=now, - updated_at=now, - ) - self.db.add(agent_cfg) - logger.debug(f"创建新的 Agent 配置", extra={"app_id": str(app_id)}) - - # 转换为存储格式 - storage_data = AgentConfigConverter.to_storage_format(data) - - # 更新字段 - # if data.system_prompt is not None: - agent_cfg.system_prompt = data.system_prompt - # if data.default_model_config_id is not None: - agent_cfg.default_model_config_id = data.default_model_config_id - # if data.model_parameters is not None: - agent_cfg.model_parameters = storage_data.get("model_parameters") - # if data.knowledge_retrieval is not None: - agent_cfg.knowledge_retrieval = storage_data.get("knowledge_retrieval") - # if data.memory is not None: - agent_cfg.memory = storage_data.get("memory") - # if data.variables is not None: - agent_cfg.variables = storage_data.get("variables", []) - # if data.tools is not None: - agent_cfg.tools = storage_data.get("tools", {}) - - agent_cfg.updated_at = now - - self.db.commit() - self.db.refresh(agent_cfg) - - logger.info(f"Agent 配置更新成功", extra={"app_id": str(app_id)}) - return agent_cfg - - def get_agent_config( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> AgentConfig: - """获取 Agent 配置 - - 如果配置不存在,返回默认配置模板(不保存到数据库) - - Args: - app_id: 应用ID - workspace_id: 工作空间ID(用于权限验证) - - Returns: - AgentConfig: Agent 配置对象(存在的配置或默认模板) - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或不可访问时 - """ - logger.debug(f"获取 Agent 配置", extra={"app_id": str(app_id)}) - - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc()) - config = self.db.scalars(stmt).first() - - if config: - return config - - # 返回默认配置模板(不保存到数据库) - logger.debug(f"配置不存在,返回默认模板", extra={"app_id": str(app_id)}) - return self._create_default_agent_config(app_id) - - def _create_default_agent_config(self, app_id: uuid.UUID) -> AgentConfig: - """创建默认的 Agent 配置模板(不保存到数据库) - - Args: - app_id: 应用ID - - Returns: - AgentConfig: 默认配置对象 - """ - now = datetime.datetime.now() - - # 创建一个临时的配置对象,不添加到数据库 - default_config = AgentConfig( - id=uuid.uuid4(), # 临时ID - app_id=app_id, - system_prompt="你是一个专业的AI助手,你的职责是帮助用户解决问题。", - default_model_config_id=None, - model_parameters={ - "temperature": 0.7, - "max_tokens": 2000, - "top_p": 1.0, - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "n": 1, - "stop": None - }, - knowledge_retrieval={ - "knowledge_bases": [], - "merge_strategy": "weighted" - }, - memory={ - "enabled": True, - "memory_content": None, - "max_history": 10 - }, - variables=[], - tools={}, - is_active=True, - created_at=now, - updated_at=now, - ) - - return default_config - - # ==================== 应用发布管理 ==================== - - def publish( - self, - *, - app_id: uuid.UUID, - publisher_id: uuid.UUID, - version_name: str, - workspace_id: Optional[uuid.UUID] = None, - release_notes: Optional[str] = None - ) -> AppRelease: - """发布应用(创建不可变快照) - - Args: - app_id: 应用ID - publisher_id: 发布者用户ID - workspace_id: 工作空间ID(用于权限验证) - release_notes: 版本说明 - - Returns: - AppRelease: 发布版本对象 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用缺少配置或不在指定工作空间时 - """ - logger.info(f"发布应用", extra={"app_id": str(app_id), "publisher_id": str(publisher_id)}) - - app = self._get_app_or_404(app_id) - # 检查应用归属 - self._validate_workspace_access(app, workspace_id) - - # 构建快照配置 - config: Dict[str, Any] = {} - default_model_config_id = None - - if app.type == AppType.AGENT: - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc()) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING) - - config = { - "system_prompt": agent_cfg.system_prompt, - "model_parameters": agent_cfg.model_parameters, - "knowledge_retrieval": agent_cfg.knowledge_retrieval, - "memory": agent_cfg.memory, - "variables": agent_cfg.variables or [], - "tools": agent_cfg.tools or {}, - } - # config = AgentConfigConverter.from_storage_format(agent_cfg) - default_model_config_id = agent_cfg.default_model_config_id - elif app.type == AppType.MULTI_AGENT: - # 1. 获取多智能体配置 - stmt = ( - select(MultiAgentConfig) - .where( - MultiAgentConfig.app_id == app_id, - MultiAgentConfig.is_active == True - ) - .order_by(MultiAgentConfig.updated_at.desc()) - ) - multi_agent_cfg = self.db.scalars(stmt).first() - if not multi_agent_cfg: - raise BusinessException("多 Agent 应用缺少有效配置,无法发布", BizCode.AGENT_CONFIG_MISSING) - - # 2. 检查配置完整性 - self._check_multi_agent_config(app_id) - - # 3. 获取主 Agent 的模型配置 ID - master_agent = self.db.get(AgentConfig, multi_agent_cfg.master_agent_id) - default_model_config_id = master_agent.default_model_config_id if master_agent else None - - # 4. 构建配置快照 - config = { - "master_agent_id": str(multi_agent_cfg.master_agent_id), - "orchestration_mode": multi_agent_cfg.orchestration_mode, - "sub_agents": multi_agent_cfg.sub_agents, - "routing_rules": multi_agent_cfg.routing_rules, - "execution_config": multi_agent_cfg.execution_config, - "aggregation_strategy": multi_agent_cfg.aggregation_strategy, - } - - logger.info( - f"多智能体应用发布配置准备完成", - extra={ - "app_id": str(app_id), - "master_agent_id": str(multi_agent_cfg.master_agent_id), - "sub_agent_count": len(multi_agent_cfg.sub_agents) if multi_agent_cfg.sub_agents else 0, - "orchestration_mode": multi_agent_cfg.orchestration_mode - } - ) - - now = datetime.datetime.now() - version = self._get_next_version(app_id) - - release = AppRelease( - id=uuid.uuid4(), - app_id=app_id, - version=version, - version_name = version_name, - release_notes=release_notes, - name=app.name, - description=app.description, - icon=app.icon, - icon_type=app.icon_type, - type=app.type, - visibility=app.visibility, - config=config, - default_model_config_id=default_model_config_id, - published_by=publisher_id, - published_at=now, - is_active=True, - created_at=now, - updated_at=now, - ) - self.db.add(release) - self.db.flush() # 先 flush,确保 release 已插入数据库 - - # 更新当前发布版本指针 - app.current_release_id = release.id - app.status = AppStatus.ACTIVE - app.updated_at = now - - self.db.commit() - self.db.refresh(release) - - logger.info( - f"应用发布成功", - extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)} - ) - return release - - def get_current_release( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> Optional[AppRelease]: - """获取当前发布版本 - - Args: - app_id: 应用ID - workspace_id: 工作空间ID(用于权限验证) - - Returns: - Optional[AppRelease]: 当前发布版本,如果未发布则返回 None - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用不可访问时 - """ - logger.debug(f"获取当前发布版本", extra={"app_id": str(app_id)}) - - app = self._get_app_or_404(app_id) - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - if not app.current_release_id: - return None - - return self.db.get(AppRelease, app.current_release_id) - - def list_releases( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> List[AppRelease]: - """列出应用的所有发布版本(倒序) - - Args: - app_id: 应用ID - workspace_id: 工作空间ID(用于权限验证) - - Returns: - List[AppRelease]: 发布版本列表 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用不可访问时 - """ - logger.debug(f"列出发布版本", extra={"app_id": str(app_id)}) - - app = self._get_app_or_404(app_id) - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - stmt = ( - select(AppRelease) - .where(AppRelease.app_id == app_id, AppRelease.is_active == True) - .order_by(AppRelease.version.desc()) - ) - return list(self.db.scalars(stmt).all()) - - def rollback( - self, - *, - app_id: uuid.UUID, - version: int, - workspace_id: Optional[uuid.UUID] = None - ) -> AppRelease: - """回滚到指定版本 - - Args: - app_id: 应用ID - version: 目标版本号 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - AppRelease: 回滚到的版本对象 - - Raises: - ResourceNotFoundException: 当应用或版本不存在时 - BusinessException: 当应用不在指定工作空间时 - """ - logger.info(f"回滚应用", extra={"app_id": str(app_id), "version": version}) - - app = self._get_app_or_404(app_id) - self._validate_app_accessible(app, workspace_id) - - stmt = select(AppRelease).where( - AppRelease.app_id == app_id, - AppRelease.version == version - ) - release = self.db.scalars(stmt).first() - - if not release: - logger.warning( - f"发布版本不存在", - extra={"app_id": str(app_id), "version": version} - ) - raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}") - - app.current_release_id = release.id - app.updated_at = datetime.datetime.now() - - self.db.commit() - self.db.refresh(release) - - logger.info( - f"应用回滚成功", - extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)} - ) - return release - - # ==================== 应用分享功能 ==================== - - def share_app( - self, - *, - app_id: uuid.UUID, - target_workspace_ids: List[uuid.UUID], - user_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> List["AppShare"]: - """分享应用到其他工作空间 - - Args: - app_id: 应用ID - target_workspace_ids: 目标工作空间ID列表 - user_id: 分享者用户ID - workspace_id: 当前工作空间ID(用于权限验证) - - Returns: - List[AppShare]: 创建的分享记录列表 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用不在指定工作空间或目标工作空间无效时 - """ - from app.models import AppShare, Workspace - - logger.info( - f"分享应用", - extra={ - "app_id": str(app_id), - "target_workspaces": [str(wid) for wid in target_workspace_ids], - "user_id": str(user_id) - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - self._validate_workspace_access(app, workspace_id) - - # 2. 验证目标工作空间 - for target_ws_id in target_workspace_ids: - target_ws = self.db.get(Workspace, target_ws_id) - if not target_ws: - raise ResourceNotFoundException("工作空间", str(target_ws_id)) - - # 不能分享给自己的工作空间 - if target_ws_id == app.workspace_id: - raise BusinessException( - "不能分享应用到自己的工作空间", - BizCode.INVALID_PARAMETER - ) - - # 3. 创建分享记录 - now = datetime.datetime.now() - shares = [] - - for target_ws_id in target_workspace_ids: - # 检查是否已经分享过 - stmt = select(AppShare).where( - AppShare.source_app_id == app_id, - AppShare.target_workspace_id == target_ws_id - ) - existing_share = self.db.scalars(stmt).first() - - if existing_share: - logger.debug( - f"应用已分享到该工作空间,跳过", - extra={"app_id": str(app_id), "target_workspace_id": str(target_ws_id)} - ) - shares.append(existing_share) - continue - - # 创建新的分享记录 - share = AppShare( - id=uuid.uuid4(), - source_app_id=app_id, - source_workspace_id=app.workspace_id, - target_workspace_id=target_ws_id, - shared_by=user_id, - created_at=now, - updated_at=now - ) - self.db.add(share) - shares.append(share) - - logger.debug( - f"创建分享记录", - extra={"app_id": str(app_id), "target_workspace_id": str(target_ws_id)} - ) - - self.db.commit() - - logger.info( - f"应用分享成功", - extra={ - "app_id": str(app_id), - "shared_count": len(shares), - "app_name": app.name - } - ) - - return shares - - def unshare_app( - self, - *, - app_id: uuid.UUID, - target_workspace_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> None: - """取消应用分享 - - Args: - app_id: 应用ID - target_workspace_id: 目标工作空间ID - workspace_id: 当前工作空间ID(用于权限验证) - - Raises: - ResourceNotFoundException: 当应用或分享记录不存在时 - BusinessException: 当应用不在指定工作空间时 - """ - from app.models import AppShare - - logger.info( - f"取消应用分享", - extra={ - "app_id": str(app_id), - "target_workspace_id": str(target_workspace_id) - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - self._validate_workspace_access(app, workspace_id) - - # 2. 查找分享记录 - stmt = select(AppShare).where( - AppShare.source_app_id == app_id, - AppShare.target_workspace_id == target_workspace_id - ) - share = self.db.scalars(stmt).first() - - if not share: - logger.warning( - f"分享记录不存在", - extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)} - ) - raise ResourceNotFoundException( - "分享记录", - f"app_id={app_id}, target_workspace_id={target_workspace_id}" - ) - - # 3. 删除分享记录 - self.db.delete(share) - self.db.commit() - - logger.info( - f"应用分享已取消", - extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)} - ) - - def list_app_shares( - self, - *, - app_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> List["AppShare"]: - """列出应用的所有分享记录 - - Args: - app_id: 应用ID - workspace_id: 当前工作空间ID(用于权限验证) - - Returns: - List[AppShare]: 分享记录列表 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用不在指定工作空间时 - """ - from app.models import AppShare - - logger.debug(f"列出应用分享记录", extra={"app_id": str(app_id)}) - - # 验证应用 - app = self._get_app_or_404(app_id) - self._validate_workspace_access(app, workspace_id) - - # 查询分享记录 - stmt = select(AppShare).where( - AppShare.source_app_id == app_id - ).order_by(AppShare.created_at.desc()) - - shares = list(self.db.scalars(stmt).all()) - - logger.debug( - f"应用分享记录查询完成", - extra={"app_id": str(app_id), "count": len(shares)} - ) - - return shares - - # ==================== 试运行功能 ==================== - - async def draft_run( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ) -> Dict[str, Any]: - """试运行 Agent(使用当前草稿配置) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - Dict: 包含 AI 回复和元数据的字典 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info(f"试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用试运行服务 - logger.debug( - f"准备调用试运行服务", - extra={ - "app_id": str(app_id), - "model": model_config.name, - "has_conversation_id": bool(conversation_id), - "has_variables": bool(variables) - } - ) - - draft_service = DraftRunService(self.db) - result = await draft_service.run( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ) - - logger.debug( - f"试运行服务返回结果", - extra={ - "result_type": str(type(result)), - "result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict", - "has_message": "message" in result if isinstance(result, dict) else False, - "has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False - } - ) - - logger.info( - f"试运行完成", - extra={ - "app_id": str(app_id), - "elapsed_time": result.get("elapsed_time"), - "model": model_config.name - } - ) - - return result - - async def draft_run_stream( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ): - """试运行 Agent(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Yields: - str: SSE 格式的事件数据 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info(f"流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用流式试运行服务 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ): - yield event - - # ==================== 多模型对比试运行 ==================== - - async def draft_run_compare( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ) -> Dict[str, Any]: - """多模型对比试运行 - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - parallel: 是否并行执行 - timeout: 超时时间(秒) - - Returns: - Dict: 对比结果 - """ - from app.services.draft_run_service import DraftRunService - from app.models import ModelConfig - - logger.info( - f"多模型对比试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models), - "parallel": parallel - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的对比方法 - draft_service = DraftRunService(self.db) - result = await draft_service.run_compare( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ) - - logger.info( - f"多模型对比完成", - extra={ - "app_id": str(app_id), - "successful": result["successful_count"], - "failed": result["failed_count"] - } - ) - - return result - - async def draft_run_compare_stream( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ): - """多模型对比试运行(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - timeout: 超时时间(秒) - - Yields: - str: SSE 格式的事件数据 - """ - from app.services.draft_run_service import DraftRunService - from app.models import ModelConfig - - logger.info( - f"多模型对比流式试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models) - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的流式对比方法 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ): - yield event - - logger.info( - f"多模型对比流式完成", - extra={"app_id": str(app_id)} - ) - - -# ==================== 向后兼容的函数接口 ==================== -# 保留函数接口以兼容现有代码,但内部使用服务类 - -def create_app(db: Session, *, user_id: uuid.UUID, workspace_id: uuid.UUID, data: app_schema.AppCreate) -> App: - """创建应用(向后兼容接口)""" - service = AppService(db) - return service.create_app(user_id=user_id, workspace_id=workspace_id, data=data) - - -def update_app(db: Session, *, app_id: uuid.UUID, data: app_schema.AppUpdate, workspace_id: uuid.UUID | None = None) -> App: - """更新应用(向后兼容接口)""" - service = AppService(db) - return service.update_app(app_id=app_id, data=data, workspace_id=workspace_id) - - -def delete_app(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> None: - """删除应用(向后兼容接口)""" - service = AppService(db) - return service.delete_app(app_id=app_id, workspace_id=workspace_id) - - -def update_agent_config(db: Session, *, app_id: uuid.UUID, data: app_schema.AgentConfigUpdate, workspace_id: uuid.UUID | None = None) -> AgentConfig: - """更新 Agent 配置(向后兼容接口)""" - service = AppService(db) - return service.update_agent_config(app_id=app_id, data=data, workspace_id=workspace_id) - - -def get_agent_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> AgentConfig: - """获取 Agent 配置(向后兼容接口) - - 如果配置不存在,返回默认配置模板 - """ - service = AppService(db) - return service.get_agent_config(app_id=app_id, workspace_id=workspace_id) - - -def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None,version_name:str, release_notes: Optional[str] = None) -> AppRelease: - """发布应用(向后兼容接口)""" - service = AppService(db) - return service.publish(app_id=app_id, publisher_id=publisher_id,version_name = version_name, workspace_id=workspace_id, release_notes=release_notes) - - -def get_current_release(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> Optional[AppRelease]: - """获取当前发布版本(向后兼容接口)""" - service = AppService(db) - return service.get_current_release(app_id=app_id, workspace_id=workspace_id) - - -def list_releases(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> List[AppRelease]: - """列出发布版本(向后兼容接口)""" - service = AppService(db) - return service.list_releases(app_id=app_id, workspace_id=workspace_id) - - -def rollback(db: Session, *, app_id: uuid.UUID, version: int, workspace_id: uuid.UUID | None = None) -> AppRelease: - """回滚应用(向后兼容接口)""" - service = AppService(db) - return service.rollback(app_id=app_id, version=version, workspace_id=workspace_id) - - -def list_apps( - db: Session, - *, - workspace_id: uuid.UUID, - type: Optional[str] = None, - visibility: Optional[str] = None, - status: Optional[str] = None, - search: Optional[str] = None, - include_shared: bool = True, - page: int = 1, - pagesize: int = 10, -) -> Tuple[List[App], int]: - """列出应用(向后兼容接口)""" - service = AppService(db) - return service.list_apps( - workspace_id=workspace_id, - type=type, - visibility=visibility, - status=status, - search=search, - include_shared=include_shared, - page=page, - pagesize=pagesize, - ) - - -# ==================== 向后兼容的函数接口 ==================== - -async def draft_run( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -) -> Dict[str, Any]: - """试运行 Agent(向后兼容接口)""" - service = AppService(db) - return await service.draft_run( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ) - - -async def draft_run_stream( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -): - """试运行 Agent 流式返回(向后兼容接口)""" - service = AppService(db) - async for event in service.draft_run_stream( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ): - yield event diff --git a/app/services/auth_service.py b/app/services/auth_service.py deleted file mode 100644 index 118b6bc5..00000000 --- a/app/services/auth_service.py +++ /dev/null @@ -1,262 +0,0 @@ -from sqlalchemy.orm import Session -from typing import Optional, Tuple, Union -import jwt -import time - -from app.models.user_model import User -from app.repositories import user_repository -from app.core.security import verify_password -from app.core.config import settings -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - -# Token 配置 -TOKEN_SECRET_KEY = settings.SECRET_KEY -TOKEN_ALGORITHM = "HS256" - -def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: - """ - Authenticates a user. - - :param db: The database session. - :param email: The email. - :param password: The password. - :return: The user object if authentication is successful, otherwise None. - """ - user = user_repository.get_user_by_email(db, email=email) - if not user: - return None # User not found - if not user.is_active: - return None # User is inactive - if not verify_password(password, user.hashed_password): - return None # Incorrect password - return user # Authentication successful - - -def authenticate_user_with_status(db: Session, email: str, password: str) -> Tuple[bool, Optional[User], str]: - """ - 认证用户并返回详细状态(用于需要区分不同失败原因的场景) - - :param db: 数据库会话 - :param email: 用户邮箱 - :param password: 用户密码 - :return: (认证成功, 用户对象, 状态消息) - 状态消息: "success", "user_not_found", "user_inactive", "password_incorrect" - """ - from app.core.logging_config import get_auth_logger - - logger = get_auth_logger() - - # 查找用户 - user = user_repository.get_user_by_email(db, email=email) - if not user: - logger.warning(f"用户不存在: {email}") - return (False, None, "user_not_found") - - # 检查用户状态 - if not user.is_active: - logger.warning(f"用户未激活: {email}") - return (False, user, "user_inactive") - - # 验证密码 - if not verify_password(password, user.hashed_password): - logger.warning(f"密码错误: {email}") - return (False, user, "password_incorrect") - - logger.info(f"用户认证成功: {email}") - return (True, user, "success") - - -def authenticate_user_or_raise(db: Session, email: str, password: str) -> User: - """ - 认证用户,失败时抛出异常(推荐使用) - - :param db: 数据库会话 - :param email: 用户邮箱 - :param password: 用户密码 - :return: 用户对象 - :raises BusinessException: 认证失败时抛出 - """ - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode - from app.core.logging_config import get_auth_logger - - logger = get_auth_logger() - - # 查找用户 - user = user_repository.get_user_by_email(db, email=email) - if not user: - logger.warning(f"用户不存在: {email}") - raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) - - # 检查用户状态 - if not user.is_active: - logger.warning(f"用户未激活: {email}") - raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND) - - # 验证密码 - if not verify_password(password, user.hashed_password): - logger.warning(f"密码错误: {email}") - raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR) - - logger.info(f"用户认证成功: {email}") - return user - - -def get_user_by_username(db: Session, username: str) -> Optional[User]: - """ - Get a user by username. - - :param db: The database session. - :param username: The username. - :return: The user object if found, otherwise None. - """ - return user_repository.get_user_by_username(db, username=username) - -def get_user_by_id(db: Session, user_id: str) -> Optional[User]: - """ - Get a user by user_id. - - :param db: The database session. - :param user_id: The user id (UUID string). - :return: The user object if found, otherwise None. - """ - return user_repository.get_user_by_id(db, user_id=user_id) - - -def register_user_with_invite( - db: Session, - email: str, - password: str, - invite_token: str, - workspace_id: str -) -> User: - """ - 使用邀请码注册新用户并加入工作空间 - - :param db: 数据库会话 - :param email: 用户邮箱 - :param password: 用户密码 - :param invite_token: 邀请令牌 - :param workspace_id: 工作空间ID - :return: 创建的用户对象 - """ - from app.schemas.user_schema import UserCreate - from app.schemas.workspace_schema import InviteAcceptRequest - from app.services import user_service, workspace_service - from app.core.logging_config import get_business_logger - - logger = get_business_logger() - logger.info(f"使用邀请码注册用户: {email}") - - try: - # 创建用户 - user_create = UserCreate( - email=email, - password=password, - username=email.split('@')[0] - ) - user = user_service.create_user(db=db, user=user_create) - logger.info(f"用户创建成功: {user.email} (ID: {user.id})") - - # 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit) - invite_accept = InviteAcceptRequest(token=invite_token) - workspace_service.accept_workspace_invite(db, invite_accept, user) - logger.info(f"用户接受邀请成功") - - # 重新查询用户对象以确保获取最新状态 - from app.repositories import user_repository - user = user_repository.get_user_by_id(db, str(user.id)) - - # 设置当前工作空间 - user.current_workspace_id = workspace_id - db.commit() - db.refresh(user) - - logger.info(f"用户注册并加入工作空间成功: {user.email}, workspace_id: {user.current_workspace_id}") - return user - - except Exception as e: - db.rollback() - logger.error(f"注册用户失败: {email} - {str(e)}") - raise - -def bind_workspace_with_invite( - db: Session, - user: User, - invite_token: str, - workspace_id: str -) -> User: - - from app.schemas.user_schema import UserCreate - from app.schemas.workspace_schema import InviteAcceptRequest - from app.services import user_service, workspace_service - from app.core.logging_config import get_business_logger - - logger = get_business_logger() - - try: - - # 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit) - invite_accept = InviteAcceptRequest(token=invite_token) - workspace_service.accept_workspace_invite(db, invite_accept, user) - logger.info(f"用户接受邀请成功") - - # 重新查询用户对象以确保获取最新状态 - from app.repositories import user_repository - user = user_repository.get_user_by_id(db, str(user.id)) - - # 设置当前工作空间 - user.current_workspace_id = workspace_id - db.commit() - db.refresh(user) - return user - - except Exception as e: - db.rollback() - logger.error(f"绑定工作空间失败: user={user.email} - {str(e)}") - raise - - -def create_access_token(user_id: str, share_token: str) -> str: - """创建访问 token - - Token 不设置过期时间,只要 share_token 有效,token 就有效 - - Args: - user_id: 用户 ID - share_token: 分享 token - - Returns: - JWT token - """ - payload = { - "user_id": user_id, - "share_token": share_token, - "iat": int(time.time()) # 签发时间 - } - - token = jwt.encode(payload, TOKEN_SECRET_KEY, algorithm=TOKEN_ALGORITHM) - return token - - -def decode_access_token(token: str) -> dict: - """解码访问 token - - Args: - token: JWT token - - Returns: - 包含 user_id 和 share_token 的字典 - - Raises: - BusinessException: token 无效 - """ - try: - payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM]) - return { - "user_id": payload["user_id"], - "share_token": payload["share_token"] - } - except jwt.InvalidTokenError: - raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN) \ No newline at end of file diff --git a/app/services/conversation_service.py b/app/services/conversation_service.py deleted file mode 100644 index 42144441..00000000 --- a/app/services/conversation_service.py +++ /dev/null @@ -1,229 +0,0 @@ -"""会话服务""" -import uuid -from typing import Optional, List, Tuple -from sqlalchemy.orm import Session -from sqlalchemy import select, desc - -from app.models import Conversation, Message -from app.core.exceptions import ResourceNotFoundException, BusinessException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class ConversationService: - """会话服务""" - - def __init__(self, db: Session): - self.db = db - - def create_conversation( - self, - app_id: uuid.UUID, - workspace_id: uuid.UUID, - user_id: Optional[str] = None, - title: Optional[str] = None, - is_draft: bool = False, - config_snapshot: Optional[dict] = None - ) -> Conversation: - """创建会话""" - conversation = Conversation( - app_id=app_id, - workspace_id=workspace_id, - user_id=user_id, - title=title or "新会话", - is_draft=is_draft, - config_snapshot=config_snapshot - ) - - self.db.add(conversation) - self.db.commit() - self.db.refresh(conversation) - - logger.info( - f"创建会话成功", - extra={ - "conversation_id": str(conversation.id), - "app_id": str(app_id), - "workspace_id": str(workspace_id), - "is_draft": is_draft - } - ) - - return conversation - - def get_conversation( - self, - conversation_id: uuid.UUID, - workspace_id: Optional[uuid.UUID] = None - ) -> Conversation: - """获取会话""" - stmt = select(Conversation).where(Conversation.id == conversation_id) - - if workspace_id: - stmt = stmt.where(Conversation.workspace_id == workspace_id) - - conversation = self.db.scalars(stmt).first() - - if not conversation: - raise ResourceNotFoundException("会话", str(conversation_id)) - - return conversation - - def list_conversations( - self, - app_id: uuid.UUID, - workspace_id: uuid.UUID, - user_id: Optional[str] = None, - is_draft: Optional[bool] = None, - page: int = 1, - pagesize: int = 20 - ) -> Tuple[List[Conversation], int]: - """列出会话""" - stmt = select(Conversation).where( - Conversation.app_id == app_id, - Conversation.workspace_id == workspace_id, - Conversation.is_active == True - ) - - if user_id: - stmt = stmt.where(Conversation.user_id == user_id) - - if is_draft is not None: - stmt = stmt.where(Conversation.is_draft == is_draft) - - # 总数 - count_stmt = stmt.with_only_columns(Conversation.id) - total = len(self.db.execute(count_stmt).all()) - - # 分页 - stmt = stmt.order_by(desc(Conversation.updated_at)) - stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) - - conversations = list(self.db.scalars(stmt).all()) - - return conversations, total - - def add_message( - self, - conversation_id: uuid.UUID, - role: str, - content: str, - meta_data: Optional[dict] = None - ) -> Message: - """添加消息""" - message = Message( - conversation_id=conversation_id, - role=role, - content=content, - meta_data=meta_data - ) - - self.db.add(message) - - # 更新会话的消息计数和更新时间 - conversation = self.get_conversation(conversation_id) - conversation.message_count += 1 - - # 如果是第一条用户消息,可以用它作为标题 - if conversation.message_count == 1 and role == "user": - conversation.title = content[:50] + ("..." if len(content) > 50 else "") - - self.db.commit() - self.db.refresh(message) - - return message - - def get_messages( - self, - conversation_id: uuid.UUID, - limit: Optional[int] = None - ) -> List[Message]: - """获取会话消息""" - stmt = select(Message).where( - Message.conversation_id == conversation_id - ).order_by(Message.created_at) - - if limit: - stmt = stmt.limit(limit) - - messages = list(self.db.scalars(stmt).all()) - - return messages - - def get_conversation_history( - self, - conversation_id: uuid.UUID, - max_history: Optional[int] = None - ) -> List[dict]: - """获取会话历史消息 - - Args: - conversation_id: 会话ID - max_history: 最大历史消息数量 - - Returns: - List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...] - """ - messages = self.get_messages(conversation_id, limit=max_history) - - # 转换为字典格式 - history = [ - { - "role": msg.role, - "content": msg.content - } - for msg in messages - ] - - return history - - def save_conversation_messages( - self, - conversation_id: uuid.UUID, - user_message: str, - assistant_message: str - ): - """保存会话消息(用户消息和助手回复)""" - # 添加用户消息 - self.add_message( - conversation_id=conversation_id, - role="user", - content=user_message - ) - - # 添加助手消息 - self.add_message( - conversation_id=conversation_id, - role="assistant", - content=assistant_message - ) - - logger.debug( - f"保存会话消息成功", - extra={ - "conversation_id": str(conversation_id), - "user_message_length": len(user_message), - "assistant_message_length": len(assistant_message) - } - ) - - def delete_conversation( - self, - conversation_id: uuid.UUID, - workspace_id: uuid.UUID - ): - """删除会话(软删除)""" - conversation = self.get_conversation(conversation_id, workspace_id) - conversation.is_active = False - - self.db.commit() - - logger.info( - f"删除会话成功", - extra={ - "conversation_id": str(conversation_id), - "workspace_id": str(workspace_id) - } - ) diff --git a/app/services/conversation_state_manager.py b/app/services/conversation_state_manager.py deleted file mode 100644 index b279696a..00000000 --- a/app/services/conversation_state_manager.py +++ /dev/null @@ -1,261 +0,0 @@ -"""会话状态管理器 - 解决多轮对话路由错乱""" -import json -from typing import Optional, Dict, Any, List -from datetime import datetime -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class ConversationStateManager: - """会话状态管理器 - - 用于管理多轮对话中的会话状态,包括: - - 当前使用的 Agent - - 路由历史 - - 主题追踪 - - Agent 切换统计 - """ - - def __init__(self, storage_backend: Optional[Any] = None): - """初始化状态管理器 - - Args: - storage_backend: 存储后端(Redis/内存等) - """ - self.storage = storage_backend or InMemoryStorage() - self.ttl = 3600 # 1小时过期 - - def get_state(self, conversation_id: str) -> Dict[str, Any]: - """获取会话状态 - - Args: - conversation_id: 会话 ID - - Returns: - 会话状态字典 - """ - state = self.storage.get(f"conv_state:{conversation_id}") - - if not state: - logger.info(f"创建新会话状态: {conversation_id}") - return self._create_new_state(conversation_id) - - return state - - def update_state( - self, - conversation_id: str, - agent_id: str, - message: str, - topic: Optional[str] = None, - confidence: float = 1.0 - ) -> Dict[str, Any]: - """更新会话状态 - - Args: - conversation_id: 会话 ID - agent_id: 当前 Agent ID - message: 用户消息 - topic: 消息主题 - confidence: 路由置信度 - - Returns: - 更新后的状态 - """ - state = self.get_state(conversation_id) - - # 检测 Agent 切换 - agent_changed = False - if state["current_agent_id"] and state["current_agent_id"] != agent_id: - agent_changed = True - state["switch_count"] += 1 - state["previous_agent_id"] = state["current_agent_id"] - state["same_agent_turns"] = 0 - - logger.info( - f"Agent 切换", - extra={ - "conversation_id": conversation_id, - "from": state["current_agent_id"], - "to": agent_id, - "switch_count": state["switch_count"] - } - ) - else: - state["same_agent_turns"] += 1 - - # 更新当前 Agent - state["current_agent_id"] = agent_id - state["last_message"] = message - state["last_topic"] = topic - state["updated_at"] = datetime.now().isoformat() - - # 添加到历史 - history_item = { - "message": message[:100], # 截断长消息 - "agent_id": agent_id, - "topic": topic, - "confidence": confidence, - "agent_changed": agent_changed, - "timestamp": datetime.now().isoformat() - } - state["routing_history"].append(history_item) - - # 保持最近 10 条历史 - if len(state["routing_history"]) > 10: - state["routing_history"] = state["routing_history"][-10:] - - # 保存状态 - self.storage.set( - f"conv_state:{conversation_id}", - state, - ttl=self.ttl - ) - - return state - - def clear_state(self, conversation_id: str) -> None: - """清除会话状态 - - Args: - conversation_id: 会话 ID - """ - self.storage.delete(f"conv_state:{conversation_id}") - logger.info(f"清除会话状态: {conversation_id}") - - def get_routing_history( - self, - conversation_id: str, - limit: int = 10 - ) -> List[Dict[str, Any]]: - """获取路由历史 - - Args: - conversation_id: 会话 ID - limit: 返回数量限制 - - Returns: - 路由历史列表 - """ - state = self.get_state(conversation_id) - history = state.get("routing_history", []) - return history[-limit:] if history else [] - - def get_statistics(self, conversation_id: str) -> Dict[str, Any]: - """获取会话统计信息 - - Args: - conversation_id: 会话 ID - - Returns: - 统计信息 - """ - state = self.get_state(conversation_id) - history = state.get("routing_history", []) - - # 统计各 Agent 使用次数 - agent_usage = {} - for item in history: - agent_id = item["agent_id"] - agent_usage[agent_id] = agent_usage.get(agent_id, 0) + 1 - - # 统计主题分布 - topic_distribution = {} - for item in history: - topic = item.get("topic", "未知") - topic_distribution[topic] = topic_distribution.get(topic, 0) + 1 - - return { - "conversation_id": conversation_id, - "total_turns": len(history), - "switch_count": state.get("switch_count", 0), - "current_agent_id": state.get("current_agent_id"), - "same_agent_turns": state.get("same_agent_turns", 0), - "agent_usage": agent_usage, - "topic_distribution": topic_distribution, - "created_at": state.get("created_at"), - "updated_at": state.get("updated_at") - } - - def _create_new_state(self, conversation_id: str) -> Dict[str, Any]: - """创建新的会话状态 - - Args: - conversation_id: 会话 ID - - Returns: - 新的状态字典 - """ - state = { - "conversation_id": conversation_id, - "current_agent_id": None, - "previous_agent_id": None, - "routing_history": [], - "last_message": None, - "last_topic": None, - "switch_count": 0, - "same_agent_turns": 0, - "created_at": datetime.now().isoformat(), - "updated_at": datetime.now().isoformat() - } - - # 保存初始状态 - self.storage.set( - f"conv_state:{conversation_id}", - state, - ttl=self.ttl - ) - - return state - - -class InMemoryStorage: - """内存存储后端(用于开发和测试)""" - - def __init__(self): - self._storage: Dict[str, Dict[str, Any]] = {} - - def get(self, key: str) -> Optional[Dict[str, Any]]: - """获取数据""" - return self._storage.get(key) - - def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None: - """设置数据""" - self._storage[key] = value - - def delete(self, key: str) -> None: - """删除数据""" - if key in self._storage: - del self._storage[key] - - def clear(self) -> None: - """清空所有数据""" - self._storage.clear() - - -class RedisStorage: - """Redis 存储后端(用于生产环境)""" - - def __init__(self, redis_client): - """初始化 Redis 存储 - - Args: - redis_client: Redis 客户端实例 - """ - self.redis = redis_client - - def get(self, key: str) -> Optional[Dict[str, Any]]: - """获取数据""" - data = self.redis.get(key) - if data: - return json.loads(data) - return None - - def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None: - """设置数据""" - self.redis.setex(key, ttl, json.dumps(value)) - - def delete(self, key: str) -> None: - """删除数据""" - self.redis.delete(key) diff --git a/app/services/document_service.py b/app/services/document_service.py deleted file mode 100644 index 0ecd8945..00000000 --- a/app/services/document_service.py +++ /dev/null @@ -1,85 +0,0 @@ -import uuid -from sqlalchemy.orm import Session -from app.models.user_model import User -from app.models.document_model import Document -from app.schemas.document_schema import DocumentCreate, DocumentUpdate -from app.repositories import document_repository -from app.core.logging_config import get_business_logger - -# Obtain a dedicated logger for business logic -business_logger = get_business_logger() - - -def get_documents_paginated( - db: Session, - current_user: User, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - business_logger.debug(f"Query document in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}") - - try: - total, items = document_repository.get_documents_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc - ) - business_logger.info(f"The document paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}") - return total, items - except Exception as e: - business_logger.error(f"Querying document pagination failed: username={current_user.username} - {str(e)}") - raise - - -def create_document( - db: Session, document: DocumentCreate, current_user: User -) -> Document: - business_logger.info(f"Create a document: {document.file_name}, creator: {current_user.username}") - - try: - document.created_by = current_user.id - db_document = document_repository.create_document( - db=db, document=document - ) - business_logger.info(f"The document has been successfully created: {document.file_name} (ID: {db_document.id}), creator: {current_user.username}") - return db_document - except Exception as e: - business_logger.error(f"Failed to create a document: {document.file_name} - {str(e)}") - raise - - -def get_document_by_id(db: Session, document_id: uuid.UUID, current_user: User) -> Document | None: - business_logger.debug(f"Query document based on ID: document_id={document_id}, username: {current_user.username}") - - try: - document = document_repository.get_document_by_id(db=db, document_id=document_id) - if document: - business_logger.info(f"document query successful: {document.file_name} (ID: {document_id})") - else: - business_logger.warning(f"document does not exist: document_id={document_id}") - return document - except Exception as e: - business_logger.error(f"Failed to query the document based on the ID: document_id={document_id} - {str(e)}") - raise - - -def reset_documents_progress_by_kb_id(db: Session, kb_id: uuid.UUID, current_user: User) -> int: - business_logger.debug(f"Reset the processing progress of all documents under the specified knowledge base: kb_id=={kb_id}, username: {current_user.username}") - return document_repository.reset_documents_progress_by_kb_id(db=db, kb_id=kb_id) - - -def delete_document_by_id(db: Session, document_id: uuid.UUID, current_user: User) -> None: - business_logger.info(f"Delete document: document_id={document_id}, operator: {current_user.username}") - - try: - document_repository.delete_document_by_id(db=db, document_id=document_id) - business_logger.info(f"document deleted successfully: document_id={document_id}, operator: {current_user.username}") - except Exception as e: - business_logger.error(f"Failed to delete document: document_id={document_id} - {str(e)}") - raise diff --git a/app/services/draft_run_service.py b/app/services/draft_run_service.py deleted file mode 100644 index d65a3612..00000000 --- a/app/services/draft_run_service.py +++ /dev/null @@ -1,1630 +0,0 @@ -""" -试运行服务 - -提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。 -""" -import time -import uuid -import json -import asyncio -import datetime -from typing import Dict, Any, Optional, List, AsyncGenerator -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy.orm import Session -from sqlalchemy import select - -from app.services.memory_konwledges_server import write_rag -from app.tasks import write_message_task -from app.models import AgentConfig, ModelConfig, ModelApiKey -from app.core.exceptions import BusinessException, ResourceNotFoundException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger -from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole -from app.services.memory_agent_service import MemoryAgentService -from app.services.model_parameter_merger import ModelParameterMerger -from app.core.rag.nlp.search import knowledge_retrieval -from app.services.langchain_tool_server import Search -from app.services.task_service import get_task_memory_write_result - -logger = get_business_logger() -class KnowledgeRetrievalInput(BaseModel): - """知识库检索工具输入参数""" - query: str = Field(description="需要检索的问题或关键词") - - -class WebSearchInput(BaseModel): - """网络搜索工具输入参数""" - query: str = Field(description="需要搜索的问题或关键词") - - -class LongTermMemoryInput(BaseModel): - """长期记忆工具输入参数""" - question: str = Field(description="需要查询的问题") - -def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,user_rag_memory_id: Optional[str] = None): - """创建长期记忆工具 - - Args: - memory_config: 记忆配置 - end_user_id: 用户ID - storage_type: 存储类型(可选) - - Returns: - 长期记忆工具 - """ - # search_switch = memory_config.get("search_switch", "2") - config_id= memory_config.get("memory_content",'17') - - logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}") - - @tool(args_schema=LongTermMemoryInput) - def long_term_memory(question: str) -> str: - """从长期记忆中检索历史对话信息。当需要回忆之前的对话内容、用户偏好或历史信息时使用此工具。 - - Args: - question: 需要查询的问题 - end_user_id: 用户唯一标识符 - search_switch: 搜索开关(on/off) - - Returns: - 检索到的历史记忆内容 - """ - logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}") - - try: - memory_content = asyncio.run( - MemoryAgentService().read_memory( - group_id=end_user_id, - message=question, - history=[], - search_switch="2", - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - ) - logger.info(f'用户ID:Agent:{end_user_id}') - logger.debug(f"调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) - - logger.info( - f"长期记忆检索成功", - extra={ - "end_user_id": end_user_id, - "content_length": len(str(memory_content)) - } - ) - - return f"检索到以下历史记忆:\n\n{memory_content}" - except Exception as e: - logger.error(f"长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) - return f"记忆检索失败: {str(e)}" - - return long_term_memory - - -def create_web_search_tool(web_search_config: Dict[str, Any]): - """创建网络搜索工具 - - Args: - web_search_config: 网络搜索配置 - - Returns: - 网络搜索工具 - """ - logger.info("创建网络搜索工具") - - @tool(args_schema=WebSearchInput) - def web_search_tool(query: str) -> str: - """从互联网搜索最新信息。当用户的问题需要实时信息、最新新闻或网络资料时,使用此工具进行搜索。 - - Args: - query: 需要搜索的问题或关键词 - - Returns: - 搜索到的相关网络信息 - """ - try: - logger.info(f"执行网络搜索: {query}") - - # 调用搜索服务 - search_result = Search(query) - logger.info( - "网络搜索成功", - extra={ - "query": query, - "result_length": len(search_result) - } - ) - - return f"搜索到以下网络信息:\n\n{search_result}" - - except Exception as e: - logger.error(f"网络搜索失败", extra={"error": str(e), "error_type": type(e).__name__}) - return f"搜索失败: {str(e)}" - - return web_search_tool - - -def create_knowledge_retrieval_tool(kb_config,kb_ids,user_id): - """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 - - Args: - query: 需要检索的问题或关键词 - - Returns: - 检索到的相关知识内容 - """ - logger.info(f"创建知识库检索工具,用户:{user_id}") - @tool(args_schema=KnowledgeRetrievalInput) - def knowledge_retrieval_tool(query: str) -> str: - """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 - - Args: - query: 需要检索的问题或关键词 - - Returns: - 检索到的相关知识内容 - """ - - - try: - - retrieve_chunks_result = knowledge_retrieval(query, kb_config) - if retrieve_chunks_result: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - context = '\n\n'.join(retrieval_knowledge) - logger.info( - f"知识库检索成功", - extra={ - "kb_ids": kb_ids, - "result_count": len(retrieval_knowledge), - "total_length": len(context) - } - ) - - return f"检索到以下相关信息:\n\n{context}" - else: - logger.warning("知识库检索未找到结果") - return "未找到相关信息" - except Exception as e: - logger.error(f"知识库检索失败", extra={"error": str(e), "error_type": type(e).__name__}) - return f"检索失败: {str(e)}" - - return knowledge_retrieval_tool - -class DraftRunService: - """试运行服务类""" - - def __init__(self, db: Session): - """初始化试运行服务 - - Args: - db: 数据库会话 - """ - self.db = db - - async def run( - self, - *, - agent_config: AgentConfig, - model_config: ModelConfig, - message: str, - workspace_id: uuid.UUID, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - web_search: bool = True, - memory: bool = True - ) -> Dict[str, Any]: - """执行试运行(使用 LangChain Agent) - - Args: - agent_config: Agent 配置 - model_config: 模型配置 - message: 用户消息 - workspace_id: 工作空间ID(必须,用于会话隔离) - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID - variables: 自定义变量参数值 - - Returns: - Dict: 包含 AI 回复和元数据的字典 - """ - - print('===========',storage_type) - - print(user_id) - if variables == None: variables = {} - from app.core.agent.langchain_agent import LangChainAgent - - start_time = time.time() - - try: - # 1. 获取 API Key 配置 - api_key_config = await self._get_api_key(model_config.id) - logger.debug( - f"API Key 配置获取成功", - extra={ - "model_name": api_key_config["model_name"], - "has_api_key": bool(api_key_config["api_key"]), - "has_api_base": bool(api_key_config.get("api_base")) - } - ) - - # 2. 合并模型参数 - effective_params = ModelParameterMerger.get_effective_parameters( - model_config=model_config, - agent_config=agent_config - ) - - - items_params=variables - system_prompt = render_prompt_message( - agent_config.system_prompt, # 修正拼写错误 - PromptMessageRole.USER, - items_params - ) - - # 3. 处理系统提示词(支持变量替换) - system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - print('系统提示词:',system_prompt) - - # 4. 准备工具列表 - tools = [] - - # 添加网络搜索工具 - if web_search: - if agent_config.tools: - web_search_config = agent_config.tools.get("web_search", {}) - web_search_enable = web_search_config.get("enabled", False) - - if web_search_enable: - logger.info("网络搜索已启用") - # 创建网络搜索工具 - search_tool = create_web_search_tool(web_search_config) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config,kb_ids,user_id) - tools.append(kb_tool) - - logger.debug( - f"已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) - - # 添加长期记忆工具 - if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - f"已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) - - # 4. 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_config["model_name"], - api_key=api_key_config["api_key"], - provider=api_key_config.get("provider", "openai"), - api_base=api_key_config.get("api_base"), - temperature=effective_params.get("temperature", 0.7), - max_tokens=effective_params.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - ) - - # 5. 处理会话ID(创建或验证) - conversation_id = await self._ensure_conversation( - conversation_id=conversation_id, - app_id=agent_config.app_id, - workspace_id=workspace_id, - user_id=user_id - ) - - # 6. 加载历史消息 - history = [] - if agent_config.memory and agent_config.memory.get("enabled"): - history = await self._load_conversation_history( - conversation_id=conversation_id, - max_history=agent_config.memory.get("max_history", 10) - ) - - # 6. 知识库检索 - context = None - - logger.debug( - f"准备调用 LangChain Agent", - extra={ - "model": api_key_config["model_name"], - "has_history": bool(history), - "has_context": bool(context) - } - ) - - memory_config_= agent_config.memory - config_id = memory_config_.get("memory_content") - - # 7. 调用 Agent - result = await agent.chat( - message=message, - history=history, - context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - elapsed_time = time.time() - start_time - - # 8. 保存会话消息 - if agent_config.memory and agent_config.memory.get("enabled"): - await self._save_conversation_message( - conversation_id=conversation_id, - user_message=message, - assistant_message=result["content"], - app_id=agent_config.app_id, - user_id=user_id - ) - - response = { - "message": result["content"], - "conversation_id": conversation_id, - "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }), - "elapsed_time": elapsed_time - } - - logger.info( - f"试运行完成", - extra={ - "model": model_config.name, - "elapsed_time": elapsed_time, - "message_length": len(result["content"]), - "total_tokens": result.get("usage", {}).get("total_tokens", 0) - } - ) - - return response - - except Exception as e: - logger.error(f"LangChain Agent 调用失败", extra={"error": str(e), "error_type": type(e).__name__}) - raise BusinessException(f"Agent 调用失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) - - async def run_stream( - self, - *, - agent_config: AgentConfig, - model_config: ModelConfig, - message: str, - workspace_id: uuid.UUID, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - web_search: bool = True, # 布尔类型默认值 - memory: bool = True # 布尔类型默认值 - - ) -> AsyncGenerator[str, None]: - """执行试运行(流式返回,使用 LangChain Agent) - - Args: - agent_config: Agent 配置 - model_config: 模型配置 - message: 用户消息 - workspace_id: 工作空间ID(必须,用于会话隔离) - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID - variables: 自定义变量参数值 - - Yields: - str: SSE 格式的事件数据 - """ - if variables==None:variables={} - - from app.core.agent.langchain_agent import LangChainAgent - - start_time = time.time() - - try: - # 1. 获取 API Key 配置 - api_key_config = await self._get_api_key(model_config.id) - - # 2. 合并模型参数 - effective_params = ModelParameterMerger.get_effective_parameters( - model_config=model_config, - agent_config=agent_config - ) - - items_params=variables - - system_prompt = render_prompt_message( - agent_config.system_prompt, # 修正拼写错误 - PromptMessageRole.USER, - items_params - ) - - # 3. 处理系统提示词(支持变量替换) - system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - - # 4. 准备工具列表 - tools = [] - - # 添加网络搜索工具 - if web_search: - if agent_config.tools: - web_search = agent_config.tools.get("web_search", {}) - web_search_enable = web_search.get("enable", False) - - if web_search_enable: - logger.info("网络搜索已启用(流式)") - # 创建网络搜索工具 - search_tool = create_web_search_tool(web_search) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具(流式)", - extra={ - "tool_count": len(tools) - } - ) - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config,kb_ids,user_id) - tools.append(kb_tool) - - logger.debug( - f"已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) - - # 添加长期记忆工具 - if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id,storage_type,user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - f"已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) - - # 4. 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_config["model_name"], - api_key=api_key_config["api_key"], - provider=api_key_config.get("provider", "openai"), - api_base=api_key_config.get("api_base"), - temperature=effective_params.get("temperature", 0.7), - max_tokens=effective_params.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - streaming=True - ) - - # 5. 处理会话ID(创建或验证) - conversation_id = await self._ensure_conversation( - conversation_id=conversation_id, - app_id=agent_config.app_id, - workspace_id=workspace_id, - user_id=user_id - ) - - # 6. 加载历史消息 - history = [] - if agent_config.memory and agent_config.memory.get("enabled"): - history = await self._load_conversation_history( - conversation_id=conversation_id, - max_history=agent_config.memory.get("max_history", 10) - ) - - # 7. 知识库检索 - context = None - - # 8. 发送开始事件 - yield self._format_sse_event("start", { - "conversation_id": conversation_id, - "timestamp": time.time() - }) - - memory_config_ = agent_config.memory - config_id = memory_config_.get("memory_content") - - # 9. 流式调用 Agent - full_content = "" - async for chunk in agent.chat_stream( - message=message, - history=history, - context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ): - full_content += chunk - # 发送消息块事件 - yield self._format_sse_event("message", { - "content": chunk - }) - - if storage_type == "rag": - await write_rag(user_id, full_content, user_rag_memory_id) - else: - write_id = write_message_task.delay(user_id, full_content, config_id, storage_type, user_rag_memory_id) - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'Agent:{user_id};{full_content}--{write_status}') - - elapsed_time = time.time() - start_time - - # 10. 保存会话消息 - if agent_config.memory and agent_config.memory.get("enabled"): - await self._save_conversation_message( - conversation_id=conversation_id, - user_message=message, - assistant_message=full_content, - app_id=agent_config.app_id, - user_id=user_id - ) - - # 11. 发送结束事件 - yield self._format_sse_event("end", { - "conversation_id": conversation_id, - "elapsed_time": elapsed_time, - "message_length": len(full_content) - }) - - logger.info( - f"流式试运行完成", - extra={ - "model": model_config.name, - "elapsed_time": elapsed_time, - "message_length": len(full_content) - } - ) - - except Exception as e: - logger.error(f"流式 Agent 调用失败", extra={"error": str(e)}) - # 发送错误事件 - yield self._format_sse_event("error", { - "error": str(e), - "timestamp": time.time() - }) - - def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str: - """格式化 SSE 事件 - - Args: - event: 事件类型 - data: 事件数据 - - Returns: - str: SSE 格式的字符串 - """ - return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - - async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]: - """获取模型的 API Key - - Args: - model_config_id: 模型配置ID - - Returns: - Dict: 包含 model_name, api_key, api_base 的字典 - - Raises: - BusinessException: 当没有可用的 API Key 时 - """ - stmt = ( - select(ModelApiKey) - .where( - ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active == True - ) - .order_by(ModelApiKey.priority.desc()) - .limit(1) - ) - - api_key = self.db.scalars(stmt).first() - - if not api_key: - raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - - return { - "model_name": api_key.model_name, - "provider": api_key.provider, - "api_key": api_key.api_key, - "api_base": api_key.api_base - } - - async def _ensure_conversation( - self, - conversation_id: Optional[str], - app_id: uuid.UUID, - workspace_id: uuid.UUID, - user_id: Optional[str] - ) -> str: - """确保会话存在(创建或验证) - - Args: - conversation_id: 会话ID(可选) - app_id: 应用ID - workspace_id: 工作空间ID(必须) - user_id: 用户ID - - Returns: - str: 会话ID - - Raises: - BusinessException: 当指定的会话不存在时 - """ - from app.services.conversation_service import ConversationService - from app.schemas.conversation_schema import ConversationCreate - from app.models import Conversation as ConversationModel - - conversation_service = ConversationService(self.db) - - # 如果没有提供会话ID,创建新会话 - if not conversation_id: - logger.info( - "创建新的草稿会话", - extra={"workspace_id": str(workspace_id)} - ) - - # 获取配置快照 - config_snapshot = await self._get_config_snapshot(app_id) - - # 创建新会话 - new_conv_id = str(uuid.uuid4()) - new_conversation = ConversationModel( - id=uuid.UUID(new_conv_id), - app_id=app_id, - workspace_id=workspace_id, - user_id=user_id, - is_draft=True, - title="草稿会话", - config_snapshot=config_snapshot - ) - self.db.add(new_conversation) - self.db.commit() - self.db.refresh(new_conversation) - - logger.info( - f"创建草稿会话成功", - extra={ - "conversation_id": new_conv_id, - "workspace_id": str(workspace_id) - } - ) - - return new_conv_id - - # 如果提供了会话ID,验证其存在性和工作空间归属 - try: - conv_uuid = uuid.UUID(conversation_id) - conversation = conversation_service.get_conversation(conv_uuid) - - # 验证会话属于当前工作空间 - if conversation.workspace_id != workspace_id: - logger.warning( - f"会话不属于当前工作空间", - extra={ - "conversation_id": conversation_id, - "conversation_workspace_id": str(conversation.workspace_id), - "current_workspace_id": str(workspace_id) - } - ) - raise BusinessException( - f"会话不属于当前工作空间", - BizCode.PERMISSION_DENIED - ) - - logger.debug( - f"使用现有会话", - extra={ - "conversation_id": conversation_id, - "workspace_id": str(workspace_id) - } - ) - return conversation_id - except BusinessException: - raise - except Exception as e: - logger.error( - f"会话不存在或无效", - extra={"conversation_id": conversation_id, "error": str(e)} - ) - raise BusinessException( - f"会话不存在: {conversation_id}", - BizCode.NOT_FOUND, - cause=e - ) - - async def _load_conversation_history( - self, - conversation_id: str, - max_history: int = 10 - ) -> List[Dict[str, str]]: - """加载会话历史消息 - - Args: - conversation_id: 会话ID - max_history: 最大历史消息数量 - - Returns: - List[Dict]: 历史消息列表 - """ - try: - from app.services.conversation_service import ConversationService - - conversation_service = ConversationService(self.db) - history = conversation_service.get_conversation_history( - conversation_id=uuid.UUID(conversation_id), - max_history=max_history - ) - - logger.debug( - f"加载会话历史", - extra={ - "conversation_id": conversation_id, - "max_history": max_history, - "loaded_count": len(history) - } - ) - - return history - - except Exception as e: - # 新会话没有历史记录是正常的 - logger.debug(f"加载会话历史失败(可能是新会话)", extra={"error": str(e)}) - return [] - - async def _save_conversation_message( - self, - conversation_id: str, - user_message: str, - assistant_message: str, - app_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None - ) -> None: - """保存会话消息(会话已通过 _ensure_conversation 确保存在) - - Args: - conversation_id: 会话ID - user_message: 用户消息 - assistant_message: AI 回复消息 - app_id: 应用ID(未使用,保留用于兼容性) - user_id: 用户ID(未使用,保留用于兼容性) - """ - try: - from app.services.conversation_service import ConversationService - - conversation_service = ConversationService(self.db) - conv_uuid = uuid.UUID(conversation_id) - - # 保存消息(会话已经存在) - # 保存用户消息 - conversation_service.add_message( - conversation_id=conv_uuid, - role="user", - content=user_message - ) - # 保存助手消息 - conversation_service.add_message( - conversation_id=conv_uuid, - role="assistant", - content=assistant_message - ) - - logger.debug( - f"保存会话消息", - extra={ - "conversation_id": conversation_id, - "user_message_length": len(user_message), - "assistant_message_length": len(assistant_message) - } - ) - - except Exception as e: - logger.warning(f"保存会话消息失败", extra={"error": str(e)}) - - async def _get_config_snapshot(self, app_id: uuid.UUID) -> Dict[str, Any]: - """获取当前配置快照 - - Args: - app_id: 应用ID - - Returns: - Dict: 配置快照 - """ - try: - from app.models import AgentConfig, ModelConfig - - # 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - return {} - - # 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - # 构建快照(确保所有值都可序列化) - def safe_serialize(value): - """安全序列化值""" - if value is None: - return None - if isinstance(value, (str, int, float, bool)): - return value - if isinstance(value, (dict, list)): - return value - # 对于 Pydantic 模型或其他对象,尝试转换为字典 - if hasattr(value, 'dict'): - return value.dict() - if hasattr(value, '__dict__'): - return value.__dict__ - return str(value) - - snapshot = { - "agent_config": { - "system_prompt": agent_cfg.system_prompt, - "model_parameters": safe_serialize(agent_cfg.model_parameters), - "knowledge_retrieval": safe_serialize(agent_cfg.knowledge_retrieval), - "memory": safe_serialize(agent_cfg.memory), - "variables": safe_serialize(agent_cfg.variables), - "tools": safe_serialize(agent_cfg.tools) - }, - "model_config": { - "model_name": model_config.name if model_config else None, - "provider": model_config.provider if model_config else None, - "type": model_config.type if model_config else None - } if model_config else None, - "snapshot_time": datetime.datetime.now().isoformat() - } - - return snapshot - - except Exception as e: - # 对于多 Agent 应用,没有直接的 AgentConfig 是正常的 - logger.debug(f"获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)}) - return {} - - def _replace_variables( - self, - text: str, - values: Dict[str, Any], - definitions: List[Dict[str, Any]] - ) -> str: - """替换文本中的变量 - - Args: - text: 原始文本 - values: 变量值 - definitions: 变量定义 - - Returns: - str: 替换后的文本 - """ - result = text - - # 创建变量定义映射 - var_defs = {var["name"]: var for var in definitions} - - for var_name, var_value in values.items(): - # 检查变量是否在定义中 - if var_name not in var_defs: - logger.warning(f"未定义的变量: {var_name}") - continue - - # 替换变量(支持多种格式) - placeholders = [ - f"{{{{{var_name}}}}}", # {{var_name}} - f"{{{var_name}}}", # {var_name} - f"${{{var_name}}}", # ${var_name} - ] - - for placeholder in placeholders: - if placeholder in result: - result = result.replace(placeholder, str(var_value)) - - return result - - # ==================== 多模型对比试运行 ==================== - - async def run_compare( - self, - *, - agent_config: AgentConfig, - models: List[Dict[str, Any]], - message: str, - workspace_id: uuid.UUID, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - parallel: bool = True, - timeout: int = 60, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - web_search: bool = True, - memory: bool = True, - ) -> Dict[str, Any]: - """多模型对比试运行 - - Args: - agent_config: Agent 配置 - models: 模型配置列表,每项包含 model_config, parameters, label, model_config_id - message: 用户消息 - workspace_id: 工作空间ID - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - parallel: 是否并行执行 - timeout: 超时时间(秒) - - Returns: - Dict: 对比结果 - """ - logger.info( - f"多模型对比试运行", - extra={ - "model_count": len(models), - "parallel": parallel - } - ) - - async def run_single_model(model_info): - """运行单个模型""" - try: - start_time = time.time() - - # 临时修改参数(不使用 deepcopy 避免 SQLAlchemy 会话问题) - original_params = agent_config.model_parameters - agent_config.model_parameters = model_info["parameters"] - - # 使用模型自己的 conversation_id,如果没有则使用全局的 - model_conversation_id = model_info.get("conversation_id") or conversation_id - try: - result = await asyncio.wait_for( - self.run( - agent_config=agent_config, - model_config=model_info["model_config"], - message=message, - workspace_id=workspace_id, - conversation_id=model_conversation_id, - user_id=user_id, - variables=variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - web_search=web_search, - memory=memory - ), - timeout=timeout - ) - finally: - # 恢复原始参数 - agent_config.model_parameters = original_params - - elapsed = time.time() - start_time - usage = result.get("usage", {}) - - return { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_info["label"], - "conversation_id":result['conversation_id'], - "parameters_used": model_info["parameters"], - "message": result.get("message"), - "usage": usage, - "elapsed_time": elapsed, - "tokens_per_second": ( - usage.get("completion_tokens", 0) / elapsed - if elapsed > 0 and usage.get("completion_tokens") else None - ), - "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), - "error": None - } - - except asyncio.TimeoutError: - logger.warning( - f"模型运行超时", - extra={ - "model_config_id": str(model_info["model_config_id"]), - "timeout": timeout - } - ) - return { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "conversation_id": conversation_id, - "label": model_info["label"], - "parameters_used": model_info["parameters"], - "elapsed_time": timeout, - "error": f"执行超时({timeout}秒)" - } - except Exception as e: - logger.error( - f"模型运行失败", - extra={ - "model_config_id": str(model_info["model_config_id"]), - "error": str(e) - } - ) - return { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_info["label"], - "conversation_id": conversation_id, - "parameters_used": model_info["parameters"], - "elapsed_time": 0, - "error": str(e) - } - - # 执行所有模型(并行或串行) - if parallel: - logger.debug(f"并行执行 {len(models)} 个模型") - results = await asyncio.gather( - *[run_single_model(m) for m in models], - return_exceptions=False - ) - else: - logger.debug(f"串行执行 {len(models)} 个模型") - results = [] - for model_info in models: - result = await run_single_model(model_info) - results.append(result) - - # 统计分析 - successful = [r for r in results if not r.get("error")] - failed = [r for r in results if r.get("error")] - - fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None - cheapest = min( - successful, - key=lambda x: x.get("cost_estimate") or float("inf") - ) if successful else None - - logger.info( - f"多模型对比完成", - extra={ - "successful": len(successful), - "failed": len(failed), - "total_time": sum(r.get("elapsed_time", 0) for r in results) - } - ) - - return { - "results": results, - "total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results), - "successful_count": len(successful), - "failed_count": len(failed), - "fastest_model": fastest["label"] if fastest else None, - "cheapest_model": cheapest["label"] if cheapest else None - } - - def _estimate_cost(self, usage: Dict[str, Any], model_config) -> Optional[float]: - """估算成本 - - Args: - usage: Token 使用情况 - model_config: 模型配置 - - Returns: - Optional[float]: 估算成本(美元) - """ - if not usage: - return None - - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - - # 简化成本估算:暂时返回 None - # TODO: 实现基于模型名称或配置的成本估算 - # 需要从 ModelApiKey 获取实际的模型名称,或者在 ModelConfig 中添加 model 字段 - return None - - def _with_parameters(self, agent_config: AgentConfig, parameters: Dict[str, Any]) -> AgentConfig: - """创建一个带有覆盖参数的 agent_config(浅拷贝,只修改 model_parameters) - - Args: - agent_config: 原始 Agent 配置 - parameters: 要覆盖的参数 - - Returns: - AgentConfig: 修改后的配置(注意:这是同一个对象,只是临时修改了 model_parameters) - """ - # 保存原始参数 - original_params = agent_config.model_parameters - # 设置新参数 - agent_config.model_parameters = parameters - return agent_config, original_params - - async def run_compare_stream( - self, - *, - agent_config: AgentConfig, - models: List[Dict[str, Any]], - message: str, - workspace_id: uuid.UUID, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - web_search: bool = True, - memory: bool = True, - parallel: bool = True, - timeout: int = 60 - ) -> AsyncGenerator[str, None]: - """多模型对比试运行(流式返回) - - 支持并行或串行执行,通过 model_index 区分不同模型的事件 - - Args: - agent_config: Agent 配置 - models: 模型配置列表 - message: 用户消息 - workspace_id: 工作空间ID - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - parallel: 是否并行执行 - timeout: 超时时间(秒) - - Yields: - str: SSE 格式的事件数据 - """ - logger.info( - f"多模型对比流式试运行", - extra={"model_count": len(models), "parallel": parallel} - ) - - # 确保有 conversation_id - if not conversation_id: - conversation_id = str(uuid.uuid4()) - - # 发送开始事件 - yield self._format_sse_event("compare_start", { - "conversation_id": conversation_id, - "model_count": len(models), - "parallel": parallel, - "timestamp": time.time() - }) - - results = [] - - if parallel: - # 并行执行所有模型 - import asyncio - - # 创建事件队列用于收集所有模型的事件 - event_queue = asyncio.Queue() - - async def run_single_model_stream(idx: int, model_info: Dict[str, Any]): - """运行单个模型并将事件放入队列""" - model_label = model_info["label"] - model_config_id = str(model_info["model_config_id"]) - # 使用模型自己的 conversation_id,如果没有则使用全局的 - model_conversation_id = model_info.get("conversation_id") or conversation_id - - try: - # 发送模型开始事件 - await event_queue.put(self._format_sse_event("model_start", { - "model_index": idx, - "model_config_id": model_config_id, - "model_name": model_info["model_config"].name, - "label": model_label, - "conversation_id": model_conversation_id, - "timestamp": time.time() - })) - - start_time = time.time() - full_content = "" - - # 临时修改参数(并行任务中安全) - original_params = agent_config.model_parameters - agent_config.model_parameters = model_info["parameters"] - - try: - # 流式调用单个模型 - async for event_str in self.run_stream( - agent_config=agent_config, - model_config=model_info["model_config"], - message=message, - workspace_id=workspace_id, - conversation_id=model_conversation_id, - user_id=user_id, - variables=variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - web_search=web_search, - memory=memory - ): - # 解析原始事件 - try: - lines = event_str.strip().split('\n') - event_type = None - event_data = None - - for line in lines: - if line.startswith('event: '): - event_type = line[7:].strip() - elif line.startswith('data: '): - event_data = json.loads(line[6:]) - - # 从 start 事件中获取 conversation_id - if event_type == "start" and event_data: - returned_conv_id = event_data.get("conversation_id") - if returned_conv_id: - model_conversation_id = returned_conv_id - - if event_type == "message" and event_data: - chunk = event_data.get("content", "") - full_content += chunk - - # 转发消息块事件(带模型标识和 conversation_id) - await event_queue.put(self._format_sse_event("model_message", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "conversation_id": model_conversation_id, - "content": chunk - })) - except Exception as e: - logger.warning(f"解析流式事件失败: {e}") - finally: - # 恢复原始参数 - agent_config.model_parameters = original_params - - elapsed = time.time() - start_time - - # 模型完成 - result = { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_label, - "parameters_used": model_info["parameters"], - "message": full_content, - "elapsed_time": elapsed, - "error": None - } - - # 发送模型完成事件 - await event_queue.put(self._format_sse_event("model_end", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "conversation_id": model_conversation_id, - "elapsed_time": elapsed, - "message_length": len(full_content), - "timestamp": time.time() - })) - - return result - - except asyncio.TimeoutError: - logger.warning(f"模型运行超时: {model_label}") - result = { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_label, - "elapsed_time": timeout, - "error": f"执行超时({timeout}秒)" - } - - await event_queue.put(self._format_sse_event("model_error", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "conversation_id": model_conversation_id, - "error": result["error"], - "timestamp": time.time() - })) - - return result - - except Exception as e: - logger.error(f"模型运行失败: {model_label}, error: {e}") - result = { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_label, - "elapsed_time": 0, - "error": str(e) - } - - await event_queue.put(self._format_sse_event("model_error", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "conversation_id": model_conversation_id, - "error": str(e), - "timestamp": time.time() - })) - - return result - - # 启动所有模型的并行任务 - tasks = [ - asyncio.create_task(run_single_model_stream(idx, model_info)) - for idx, model_info in enumerate(models) - ] - - # 持续从队列中取出事件并发送 - completed_count = 0 - while completed_count < len(models): - try: - # 等待事件或任务完成 - event = await asyncio.wait_for(event_queue.get(), timeout=0.1) - yield event - except asyncio.TimeoutError: - # 检查是否有任务完成 - for task in tasks: - if task.done() and task not in [t for t in tasks if hasattr(t, '_result_retrieved')]: - result = await task - results.append(result) - task._result_retrieved = True - completed_count += 1 - continue - - # 等待所有任务完成 - all_results = await asyncio.gather(*tasks, return_exceptions=False) - results = [r for r in all_results if r not in results] - results.extend([r for r in all_results if r not in results]) - - # 清空队列中剩余的事件 - while not event_queue.empty(): - try: - event = event_queue.get_nowait() - yield event - except asyncio.QueueEmpty: - break - - else: - # 串行执行每个模型 - for idx, model_info in enumerate(models): - model_label = model_info["label"] - model_config_id = str(model_info["model_config_id"]) - # 使用模型自己的 conversation_id,如果没有则使用全局的 - model_conversation_id = model_info.get("conversation_id") or conversation_id - - # 发送模型开始事件 - yield self._format_sse_event("model_start", { - "model_index": idx, - "model_config_id": model_config_id, - "model_name": model_info["model_config"].name, - "label": model_label, - "conversation_id": model_conversation_id, - "timestamp": time.time() - }) - - try: - start_time = time.time() - full_content = "" - - # 临时修改参数 - original_params = agent_config.model_parameters - agent_config.model_parameters = model_info["parameters"] - - try: - # 流式调用单个模型 - async for event_str in self.run_stream( - agent_config=agent_config, - model_config=model_info["model_config"], - message=message, - workspace_id=workspace_id, - conversation_id=model_conversation_id, - user_id=user_id, - variables=variables, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - web_search=web_search, - memory=memory - ): - # 解析原始事件 - try: - # SSE 格式: "event: xxx\ndata: {...}\n\n" - lines = event_str.strip().split('\n') - event_type = None - event_data = None - - for line in lines: - if line.startswith('event: '): - event_type = line[7:].strip() - elif line.startswith('data: '): - event_data = json.loads(line[6:]) - - if event_type == "message" and event_data: - # 累积内容 - chunk = event_data.get("content", "") - full_content += chunk - - # 转发消息块事件(带模型标识) - yield self._format_sse_event("model_message", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "content": chunk - }) - - except Exception as e: - logger.warning(f"解析流式事件失败: {e}") - finally: - # 恢复原始参数 - agent_config.model_parameters = original_params - - elapsed = time.time() - start_time - - # 模型完成 - result = { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_label, - "parameters_used": model_info["parameters"], - "message": full_content, - "elapsed_time": elapsed, - "error": None - } - results.append(result) - - # 发送模型完成事件 - yield self._format_sse_event("model_end", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "elapsed_time": elapsed, - "message_length": len(full_content), - "timestamp": time.time() - }) - - except asyncio.TimeoutError: - logger.warning(f"模型运行超时: {model_label}") - result = { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_label, - "elapsed_time": timeout, - "error": f"执行超时({timeout}秒)" - } - results.append(result) - - # 发送模型错误事件 - yield self._format_sse_event("model_error", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "error": result["error"], - "timestamp": time.time() - }) - - except Exception as e: - logger.error(f"模型运行失败: {model_label}, error: {e}") - result = { - "model_config_id": model_info["model_config_id"], - "model_name": model_info["model_config"].name, - "label": model_label, - "elapsed_time": 0, - "error": str(e) - } - results.append(result) - - # 发送模型错误事件 - yield self._format_sse_event("model_error", { - "model_index": idx, - "model_config_id": model_config_id, - "label": model_label, - "error": str(e), - "timestamp": time.time() - }) - - # 统计分析 - successful = [r for r in results if not r.get("error")] - failed = [r for r in results if r.get("error")] - - fastest = min(successful, key=lambda x: x["elapsed_time"]) if successful else None - - # 发送对比完成事件 - yield self._format_sse_event("compare_end", { - "conversation_id": conversation_id, - "total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results), - "successful_count": len(successful), - "failed_count": len(failed), - "fastest_model": fastest["label"] if fastest else None, - "timestamp": time.time() - }) - - logger.info( - f"多模型对比流式完成", - extra={ - "successful": len(successful), - "failed": len(failed) - } - ) - - -async def draft_run( - db: Session, - *, - agent_config: AgentConfig, - model_config: ModelConfig, - message: str, - user_id: Optional[str] = None, - kb_ids: Optional[List[str]] = None, - similarity_threshold: float = 0.7, - top_k: int = 3 -) -> Dict[str, Any]: - """试运行 Agent(便捷函数) - - Args: - db: 数据库会话 - agent_config: Agent 配置 - model_config: 模型配置 - message: 用户消息 - user_id: 用户ID - kb_ids: 知识库ID列表 - similarity_threshold: 相似度阈值 - top_k: 检索返回的文档数量 - - Returns: - Dict: 包含 AI 回复和元数据的字典 - """ - service = DraftRunService(db) - return await service.run( - agent_config=agent_config, - model_config=model_config, - message=message, - user_id=user_id, - kb_ids=kb_ids, - similarity_threshold=similarity_threshold, - top_k=top_k - ) - diff --git a/app/services/file_service.py b/app/services/file_service.py deleted file mode 100644 index a7f20e46..00000000 --- a/app/services/file_service.py +++ /dev/null @@ -1,87 +0,0 @@ -import uuid -from sqlalchemy.orm import Session -from app.models.user_model import User -from app.models.file_model import File -from app.schemas.file_schema import FileCreate, FileUpdate -from app.repositories import file_repository -from app.core.logging_config import get_business_logger - -# Obtain a dedicated logger for business logic -business_logger = get_business_logger() - - -def get_files_paginated( - db: Session, - current_user: User, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - business_logger.debug(f"Query file in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}") - - try: - total, items = file_repository.get_files_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc - ) - business_logger.info(f"The file paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}") - return total, items - except Exception as e: - business_logger.error(f"Querying file pagination failed: username={current_user.username} - {str(e)}") - raise - - -def create_file( - db: Session, file: FileCreate, current_user: User -) -> File: - business_logger.info(f"Create a file: {file.file_name}, creator: {current_user.username}") - - try: - file.created_by = current_user.id - if file.parent_id is None: - file.parent_id = file.kb_id - db_file = file_repository.create_file( - db=db, file=file - ) - business_logger.info(f"The file has been successfully created: {file.file_name} (ID: {db_file.id}), creator: {current_user.username}") - return db_file - except Exception as e: - business_logger.error(f"Failed to create a file: {file.file_name} - {str(e)}") - raise - - -def get_file_by_id(db: Session, file_id: uuid.UUID) -> File | None: - business_logger.debug(f"Query file based on ID: file_id={file_id}") - - try: - file = file_repository.get_file_by_id(db=db, file_id=file_id) - if file: - business_logger.info(f"file query successful: {file.file_name} (ID: {file_id})") - else: - business_logger.warning(f"file does not exist: file_id={file_id}") - return file - except Exception as e: - business_logger.error(f"Failed to query the file based on the ID: file_id={file_id} - {str(e)}") - raise - - -def get_files_by_parent_id(db: Session, parent_id: uuid.UUID | None, current_user: User) -> list | None: - business_logger.debug(f"Query file based on folder ID: parent_id={parent_id}, username: {current_user.username}") - return file_repository.get_files_by_parent_id(db=db, parent_id=parent_id) - - -def delete_file_by_id(db: Session, file_id: uuid.UUID, current_user: User) -> None: - business_logger.info(f"Delete file: file_id={file_id}, operator: {current_user.username}") - - try: - file_repository.delete_file_by_id(db=db, file_id=file_id) - business_logger.info(f"file_id deleted successfully: file_id={file_id}, operator: {current_user.username}") - except Exception as e: - business_logger.error(f"Failed to delete file: file_id={file_id} - {str(e)}") - raise diff --git a/app/services/knowledge_service.py b/app/services/knowledge_service.py deleted file mode 100644 index b9d97c29..00000000 --- a/app/services/knowledge_service.py +++ /dev/null @@ -1,126 +0,0 @@ -import uuid -from sqlalchemy.orm import Session -from app.models.user_model import User -from app.models.knowledge_model import Knowledge -from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate -from app.repositories import knowledge_repository -from app.core.logging_config import get_business_logger - -# Obtain a dedicated logger for business logic -business_logger = get_business_logger() - - -def get_knowledges_paginated( - db: Session, - current_user: User, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - business_logger.debug(f"Query knowledge base in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}") - - try: - total, items = knowledge_repository.get_knowledges_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc - ) - business_logger.info(f"The knowledge base paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}") - return total, items - except Exception as e: - business_logger.error(f"Querying knowledge base pagination failed: username={current_user.username} - {str(e)}") - raise - - -def get_chunded_knowledgeids( - db: Session, - current_user: User, - filters: list -) -> list: - business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}") - - try: - items = knowledge_repository.get_chunded_knowledgeids( - db=db, - filters=filters - ) - business_logger.info(f"Querying the vectorized knowledge base id list succeeded: username={current_user.username} count={len(items)}") - return items - except Exception as e: - business_logger.error(f"Querying the vectorized knowledge base id list failed: username={current_user.username} - {str(e)}") - raise - - -def create_knowledge( - db: Session, knowledge: KnowledgeCreate, current_user: User -) -> Knowledge: - business_logger.info(f"Create a knowledge base: {knowledge.name}, creator: {current_user.username}") - - try: - knowledge.created_by = current_user.id - if knowledge.workspace_id is None: - knowledge.workspace_id = current_user.current_workspace_id - if knowledge.parent_id is None: - knowledge.parent_id = knowledge.workspace_id - business_logger.debug(f"Start creating the knowledge base: {knowledge.name}") - db_knowledge = knowledge_repository.create_knowledge( - db=db, knowledge=knowledge - ) - business_logger.info(f"The knowledge base has been successfully created: {knowledge.name} (ID: {db_knowledge.id}), creator: {current_user.username}") - return db_knowledge - except Exception as e: - business_logger.error(f"Failed to create a knowledge base: {knowledge.name} - {str(e)}") - raise - - -def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID, current_user: User) -> Knowledge | None: - business_logger.debug(f"Query knowledge base based on ID: knowledge_id={knowledge_id}, username: {current_user.username}") - - try: - knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=knowledge_id) - if knowledge: - business_logger.info(f"knowledge base query successful: {knowledge.name} (ID: {knowledge_id})") - else: - business_logger.warning(f"knowledge base does not exist: knowledge_id={knowledge_id}") - return knowledge - except Exception as e: - business_logger.error(f"Failed to query the knowledge base based on the ID: knowledge_id={knowledge_id} - {str(e)}") - raise - - -def get_knowledge_by_name(db: Session, name: str, current_user: User) -> Knowledge | None: - business_logger.debug(f"Query knowledge base based on name: name={name}, username: {current_user.username}") - - try: - knowledge = knowledge_repository.get_knowledge_by_name(db=db, name=name, workspace_id=current_user.current_workspace_id) - if knowledge: - business_logger.info(f"knowledge base query successful: {name} (ID: {knowledge.id})") - else: - business_logger.warning(f"knowledge base does not exist: name={name}") - return knowledge - except Exception as e: - business_logger.error(f"Failed to query the knowledge base based on the name: name={name} - {str(e)}") - raise - - -def delete_knowledge_by_id(db: Session, knowledge_id: uuid.UUID, current_user: User) -> None: - business_logger.info(f"Delete knowledge base: knowledge_id={knowledge_id}, operator: {current_user.username}") - - try: - # First, query the knowledge base information for logging purposes - knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=knowledge_id) - if knowledge: - business_logger.debug(f"Execute knowledge base deletion: {knowledge.name} (ID: {knowledge_id})") - else: - business_logger.warning(f"The knowledge base to be deleted does not exist: knowledge_id={knowledge_id}") - - knowledge_repository.delete_knowledge_by_id(db=db, knowledge_id=knowledge_id) - business_logger.info(f"knowledge base record deleted successfully: knowledge_id={knowledge_id}, operator: {current_user.username}") - except Exception as e: - business_logger.error(f"Failed to delete knowledge base: knowledge_id={knowledge_id} - {str(e)}") - raise diff --git a/app/services/knowledgeshare_service.py b/app/services/knowledgeshare_service.py deleted file mode 100644 index b83ef0e5..00000000 --- a/app/services/knowledgeshare_service.py +++ /dev/null @@ -1,108 +0,0 @@ -import uuid -from sqlalchemy.orm import Session -from app.models.user_model import User -from app.models.knowledgeshare_model import KnowledgeShare -from app.schemas.knowledgeshare_schema import KnowledgeShareCreate -from app.repositories import knowledgeshare_repository -from app.core.logging_config import get_business_logger - -# Obtain a dedicated logger for business logic -business_logger = get_business_logger() - - -def get_knowledgeshares_paginated( - db: Session, - current_user: User, - filters: list, - page: int, - pagesize: int, - orderby: str = None, - desc: bool = False -) -> tuple[int, list]: - business_logger.debug(f"Query knowledge base sharing in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}") - - try: - total, items = knowledgeshare_repository.get_knowledgeshares_paginated( - db=db, - filters=filters, - page=page, - pagesize=pagesize, - orderby=orderby, - desc=desc - ) - business_logger.info(f"The knowledge base sharing paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}") - return total, items - except Exception as e: - business_logger.error(f"Querying knowledge base sharing pagination failed: username={current_user.username} - {str(e)}") - raise - - -def get_source_kb_ids_by_target_kb_id( - db: Session, - current_user: User, - filters: list -) -> list: - business_logger.debug(f"Query the original knowledge base id list by sharing the knowledge base: username={current_user.username}") - - try: - items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters - ) - business_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: username={current_user.username} count={len(items)}") - return items - except Exception as e: - business_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: username={current_user.username} - {str(e)}") - raise - - -def create_knowledgeshare( - db: Session, knowledgeshare: KnowledgeShareCreate, current_user: User -) -> KnowledgeShare: - business_logger.info(f"Create a knowledge base sharing: creator: {current_user.username}") - - try: - knowledgeshare.source_workspace_id = current_user.current_workspace_id - knowledgeshare.shared_by = current_user.id - business_logger.debug("Start creating a knowledge base sharing") - db_knowledgeshare = knowledgeshare_repository.create_knowledgeshare( - db=db, knowledgeshare=knowledgeshare - ) - business_logger.info(f"knowledge base sharing created successfully: (ID: {db_knowledgeshare.id}), creator: {current_user.username}") - return db_knowledgeshare - except Exception as e: - business_logger.error(f"Failed to create a knowledge base sharing - {str(e)}") - raise - - -def get_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID, current_user: User) -> KnowledgeShare | None: - business_logger.debug(f"Query knowledge base sharing based on ID: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}") - - try: - knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id) - if knowledgeshare: - business_logger.info(f"knowledge base sharing query successful: (ID: {knowledgeshare_id})") - else: - business_logger.warning(f"knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}") - return knowledgeshare - except Exception as e: - business_logger.error(f"Failed to query the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}") - raise - - -def delete_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID, current_user: User) -> None: - business_logger.info(f"Delete knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, operator: {current_user.username}") - - try: - # First, query the knowledge base sharing information for logging purposes - knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id) - if knowledgeshare: - business_logger.debug(f"Execute knowledge base sharing deletion: (ID: {knowledgeshare_id})") - else: - business_logger.warning(f"The knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}") - - knowledgeshare_repository.delete_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id) - business_logger.info(f"knowledge base sharing deleted successfully: knowledgeshare_id={knowledgeshare_id}, operator: {current_user.username}") - except Exception as e: - business_logger.error(f"Failed to delete knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}") - raise diff --git a/app/services/langchain_tool_server.py b/app/services/langchain_tool_server.py deleted file mode 100644 index f44e4cdc..00000000 --- a/app/services/langchain_tool_server.py +++ /dev/null @@ -1,51 +0,0 @@ -import requests -import json - -from dotenv import load_dotenv -import os - -# 加载.env文件 -load_dotenv() - -# 读取web_search环境变量 -web_search_value = os.getenv('web_search') -def Search(query): - url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" - api_key = web_search_value - payload = json.dumps({ - "messages": [ - { - "role": "user", - "content": query - } - ], #搜索输入 - "edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 - "search_source": "baidu_search_v2", #使用的搜索引擎版本 - "resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 - "search_filter": { - "range": { - "page_time": { - "gte": "now-1w/d", #时间查询参数,大于或等于 - "lt": "now/d", #时间查询参数,小于 - "gt": "", #时间查询参数,大于 - "lte": "" #时间查询参数,小于或等于 - } - } - }, - "block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表 - "search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year - "enable_full_content":True #是否输出网页完整原文 - }, ensure_ascii=False) - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {api_key}' - } - - response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json() - content=[] - for i in response['references']: - title=i['title'] - snippet=i['snippet'] - content.append(title+';'+snippet) - content='。'.join(content) - return content \ No newline at end of file diff --git a/app/services/llm_client.py b/app/services/llm_client.py deleted file mode 100644 index a7bc81b0..00000000 --- a/app/services/llm_client.py +++ /dev/null @@ -1,340 +0,0 @@ -"""LLM 客户端适配器 - 支持多种 LLM 提供商""" -import os -import json -from typing import Optional, Dict, Any -from abc import ABC, abstractmethod -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class BaseLLMClient(ABC): - """LLM 客户端基类""" - - @abstractmethod - async def chat(self, prompt: str, **kwargs) -> str: - """发送聊天请求 - - Args: - prompt: 提示词 - **kwargs: 其他参数 - - Returns: - LLM 响应文本 - """ - pass - - -class OpenAIClient(BaseLLMClient): - """OpenAI 客户端""" - - def __init__( - self, - api_key: Optional[str] = None, - model: str = "gpt-3.5-turbo", - base_url: Optional[str] = None - ): - """初始化 OpenAI 客户端 - - Args: - api_key: API 密钥 - model: 模型名称 - base_url: API 基础 URL(可选,用于兼容其他服务) - """ - self.api_key = api_key or os.getenv("OPENAI_API_KEY") - self.model = model - self.base_url = base_url - - if not self.api_key: - raise ValueError("OpenAI API key 未配置") - - try: - from openai import AsyncOpenAI - self.client = AsyncOpenAI( - api_key=self.api_key, - base_url=self.base_url - ) - except ImportError: - raise ImportError("请安装 openai 库: pip install openai") - - async def chat(self, prompt: str, **kwargs) -> str: - """发送聊天请求 - - Args: - prompt: 提示词 - **kwargs: 其他参数(temperature, max_tokens 等) - - Returns: - LLM 响应文本 - """ - try: - response = await self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - temperature=kwargs.get("temperature", 0.3), - max_tokens=kwargs.get("max_tokens", 500) - ) - - return response.choices[0].message.content - - except Exception as e: - logger.error(f"OpenAI API 调用失败: {str(e)}") - raise - - -class AzureOpenAIClient(BaseLLMClient): - """Azure OpenAI 客户端""" - - def __init__( - self, - api_key: Optional[str] = None, - endpoint: Optional[str] = None, - deployment_name: Optional[str] = None, - api_version: str = "2024-02-15-preview" - ): - """初始化 Azure OpenAI 客户端 - - Args: - api_key: API 密钥 - endpoint: Azure 端点 - deployment_name: 部署名称 - api_version: API 版本 - """ - self.api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY") - self.endpoint = endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") - self.deployment_name = deployment_name or os.getenv("AZURE_OPENAI_DEPLOYMENT") - self.api_version = api_version - - if not all([self.api_key, self.endpoint, self.deployment_name]): - raise ValueError("Azure OpenAI 配置不完整") - - try: - from openai import AsyncAzureOpenAI - self.client = AsyncAzureOpenAI( - api_key=self.api_key, - azure_endpoint=self.endpoint, - api_version=self.api_version - ) - except ImportError: - raise ImportError("请安装 openai 库: pip install openai") - - async def chat(self, prompt: str, **kwargs) -> str: - """发送聊天请求""" - try: - response = await self.client.chat.completions.create( - model=self.deployment_name, - messages=[{"role": "user", "content": prompt}], - temperature=kwargs.get("temperature", 0.3), - max_tokens=kwargs.get("max_tokens", 500) - ) - - return response.choices[0].message.content - - except Exception as e: - logger.error(f"Azure OpenAI API 调用失败: {str(e)}") - raise - - -class AnthropicClient(BaseLLMClient): - """Anthropic Claude 客户端""" - - def __init__( - self, - api_key: Optional[str] = None, - model: str = "claude-3-sonnet-20240229" - ): - """初始化 Anthropic 客户端 - - Args: - api_key: API 密钥 - model: 模型名称 - """ - self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY") - self.model = model - - if not self.api_key: - raise ValueError("Anthropic API key 未配置") - - try: - from anthropic import AsyncAnthropic - self.client = AsyncAnthropic(api_key=self.api_key) - except ImportError: - raise ImportError("请安装 anthropic 库: pip install anthropic") - - async def chat(self, prompt: str, **kwargs) -> str: - """发送聊天请求""" - try: - response = await self.client.messages.create( - model=self.model, - max_tokens=kwargs.get("max_tokens", 500), - temperature=kwargs.get("temperature", 0.3), - messages=[{"role": "user", "content": prompt}] - ) - - return response.content[0].text - - except Exception as e: - logger.error(f"Anthropic API 调用失败: {str(e)}") - raise - - -class LocalLLMClient(BaseLLMClient): - """本地 LLM 客户端(通过 HTTP API)""" - - def __init__( - self, - base_url: str = "http://localhost:8000", - model: str = "local-model" - ): - """初始化本地 LLM 客户端 - - Args: - base_url: API 基础 URL - model: 模型名称 - """ - self.base_url = base_url - self.model = model - - try: - import httpx - self.client = httpx.AsyncClient(timeout=30.0) - except ImportError: - raise ImportError("请安装 httpx 库: pip install httpx") - - async def chat(self, prompt: str, **kwargs) -> str: - """发送聊天请求""" - try: - response = await self.client.post( - f"{self.base_url}/v1/chat/completions", - json={ - "model": self.model, - "messages": [{"role": "user", "content": prompt}], - "temperature": kwargs.get("temperature", 0.3), - "max_tokens": kwargs.get("max_tokens", 500) - } - ) - - response.raise_for_status() - data = response.json() - - return data["choices"][0]["message"]["content"] - - except Exception as e: - logger.error(f"本地 LLM API 调用失败: {str(e)}") - raise - - -class MockLLMClient(BaseLLMClient): - """模拟 LLM 客户端(用于测试)""" - - def __init__(self): - """初始化模拟客户端""" - self.call_count = 0 - - async def chat(self, prompt: str, **kwargs) -> str: - """发送聊天请求(返回模拟结果)""" - self.call_count += 1 - - logger.info(f"模拟 LLM 调用 (第 {self.call_count} 次)") - - # 简单的规则匹配 - prompt_lower = prompt.lower() - - if "数学" in prompt_lower or "方程" in prompt_lower or "计算" in prompt_lower: - return json.dumps({ - "agent_id": "math-agent", - "confidence": 0.9, - "reason": "消息包含数学相关内容" - }, ensure_ascii=False) - - elif "化学" in prompt_lower or "反应" in prompt_lower or "元素" in prompt_lower: - return json.dumps({ - "agent_id": "chemistry-agent", - "confidence": 0.85, - "reason": "消息包含化学相关内容" - }, ensure_ascii=False) - - elif "物理" in prompt_lower or "力" in prompt_lower or "速度" in prompt_lower: - return json.dumps({ - "agent_id": "physics-agent", - "confidence": 0.88, - "reason": "消息包含物理相关内容" - }, ensure_ascii=False) - - elif "语文" in prompt_lower or "古诗" in prompt_lower or "作文" in prompt_lower: - return json.dumps({ - "agent_id": "chinese-agent", - "confidence": 0.87, - "reason": "消息包含语文相关内容" - }, ensure_ascii=False) - - elif "英语" in prompt_lower or "单词" in prompt_lower or "语法" in prompt_lower: - return json.dumps({ - "agent_id": "english-agent", - "confidence": 0.86, - "reason": "消息包含英语相关内容" - }, ensure_ascii=False) - - else: - return json.dumps({ - "agent_id": "math-agent", - "confidence": 0.5, - "reason": "无法明确判断,使用默认 Agent" - }, ensure_ascii=False) - - -class LLMClientFactory: - """LLM 客户端工厂""" - - @staticmethod - def create( - provider: str = "mock", - **kwargs - ) -> BaseLLMClient: - """创建 LLM 客户端 - - Args: - provider: 提供商名称 (openai, azure, anthropic, local, mock) - **kwargs: 客户端配置参数 - - Returns: - LLM 客户端实例 - """ - provider = provider.lower() - - if provider == "openai": - return OpenAIClient(**kwargs) - - elif provider == "azure": - return AzureOpenAIClient(**kwargs) - - elif provider == "anthropic": - return AnthropicClient(**kwargs) - - elif provider == "local": - return LocalLLMClient(**kwargs) - - elif provider == "mock": - return MockLLMClient() - - else: - raise ValueError(f"不支持的 LLM 提供商: {provider}") - - @staticmethod - def create_from_env() -> BaseLLMClient: - """从环境变量创建 LLM 客户端 - - 环境变量: - - LLM_PROVIDER: 提供商名称 - - OPENAI_API_KEY: OpenAI API 密钥 - - AZURE_OPENAI_API_KEY: Azure OpenAI API 密钥 - - ANTHROPIC_API_KEY: Anthropic API 密钥 - - Returns: - LLM 客户端实例 - """ - provider = os.getenv("LLM_PROVIDER", "mock") - - logger.info(f"从环境变量创建 LLM 客户端: {provider}") - - return LLMClientFactory.create(provider) diff --git a/app/services/llm_router.py b/app/services/llm_router.py deleted file mode 100644 index a6935862..00000000 --- a/app/services/llm_router.py +++ /dev/null @@ -1,685 +0,0 @@ -"""基于 LLM 的智能路由器 - 混合策略""" -import json -import re -import uuid -from typing import Dict, Any, List, Optional, Tuple -from sqlalchemy.orm import Session - -from app.services.conversation_state_manager import ConversationStateManager -from app.models import ModelConfig, AgentConfig -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class LLMRouter: - """基于 LLM 的智能路由器 - - 混合策略: - 1. 先用关键词快速筛选(置信度 > 0.8 直接返回) - 2. 对于模糊情况(置信度 0.3-0.8),调用 LLM 辅助 - 3. 对于完全不匹配(置信度 < 0.3),调用 LLM - 4. 缓存 LLM 结果,减少重复调用 - """ - - # 主题切换信号 - SWITCH_SIGNALS = [ - "换个话题", "另外", "还有", "对了", - "那这个呢", "再问一个", "顺便问下", - "我想问", "帮我", "请问", "换一个" - ] - - # 延续信号 - CONTINUATION_SIGNALS = [ - "继续", "还是", "也", "同样", "类似", - "这个", "那个", "它", "他", "她", "呢" - ] - - def __init__( - self, - db: Session, - state_manager: ConversationStateManager, - routing_rules: List[Dict[str, Any]], - sub_agents: Dict[str, Any], - routing_model_config: Optional[ModelConfig] = None, - use_llm: bool = True - ): - """初始化 LLM 路由器 - - Args: - db: 数据库会话 - state_manager: 会话状态管理器 - routing_rules: 路由规则列表 - sub_agents: 子 Agent 配置字典 - routing_model_config: 用于路由的模型配置(可选) - use_llm: 是否启用 LLM(默认 True) - """ - self.db = db - self.state_manager = state_manager - self.routing_rules = routing_rules - self.sub_agents = sub_agents - self.routing_model_config = routing_model_config - self.use_llm = use_llm and routing_model_config is not None - - # 配置参数 - self.min_confidence_for_switch = 0.7 - self.max_same_agent_turns = 10 - self.keyword_high_confidence_threshold = 0.8 # 关键词高置信度阈值 - self.keyword_low_confidence_threshold = 0.3 # 关键词低置信度阈值 - - # 缓存配置 - self.cache_enabled = True - self.cache_size = 1000 - - async def route( - self, - message: str, - conversation_id: Optional[str] = None, - force_new: bool = False - ) -> Dict[str, Any]: - """智能路由(混合策略) - - Args: - message: 用户消息 - conversation_id: 会话 ID - force_new: 是否强制重新路由 - - Returns: - 路由结果 - """ - logger.info( - f"开始 LLM 智能路由", - extra={ - "message_length": len(message), - "conversation_id": conversation_id, - "use_llm": self.use_llm - } - ) - - # 1. 获取会话状态 - state = None - if conversation_id and not force_new: - state = self.state_manager.get_state(conversation_id) - - # 2. 检测主题切换 - topic_changed = self._detect_topic_change(message, state) - - # 3. 提取当前主题 - topic = await self._extract_topic_with_llm(message) if self.use_llm else self._extract_topic(message) - - # 4. 选择路由策略 - if force_new: - agent_id, confidence, method = await self._route_with_hybrid(message) - strategy = "force_new" - reason = "用户强制重新路由" - - elif not state or not state.get("current_agent_id"): - agent_id, confidence, method = await self._route_with_hybrid(message) - strategy = "new_conversation" - reason = "新会话,首次路由" - - elif topic_changed: - agent_id, confidence, method = await self._route_with_hybrid(message) - strategy = "topic_changed" - reason = f"检测到主题切换: {state.get('last_topic')} -> {topic}" - - elif state.get("same_agent_turns", 0) >= self.max_same_agent_turns: - agent_id, confidence, method = await self._route_with_hybrid(message) - strategy = "max_turns_reached" - reason = f"同一 Agent 已使用 {state['same_agent_turns']} 轮" - - else: - current_agent_id = state["current_agent_id"] - should_continue, continue_confidence = self._should_continue_current_agent( - message, - current_agent_id - ) - - if should_continue: - agent_id = current_agent_id - confidence = continue_confidence - method = "keyword" - strategy = "continue_current" - reason = "消息在当前 Agent 能力范围内" - else: - new_agent_id, new_confidence, method = await self._route_with_hybrid(message) - - if new_confidence > continue_confidence + self.min_confidence_for_switch: - agent_id = new_agent_id - confidence = new_confidence - strategy = "switch_agent" - reason = f"新 Agent 置信度更高: {new_confidence:.2f} vs {continue_confidence:.2f}" - else: - agent_id = current_agent_id - confidence = continue_confidence - method = "keyword" - strategy = "keep_current" - reason = "置信度差距不足以切换 Agent" - - # 5. 更新会话状态 - if conversation_id: - self.state_manager.update_state( - conversation_id, - agent_id, - message, - topic, - confidence - ) - - result = { - "agent_id": agent_id, - "confidence": confidence, - "strategy": strategy, - "topic": topic, - "topic_changed": topic_changed, - "reason": reason, - "routing_method": method # "keyword", "llm", "hybrid" - } - - logger.info( - f"路由完成", - extra={ - "agent_id": agent_id, - "strategy": strategy, - "confidence": confidence, - "method": method - } - ) - - return result - - async def _route_with_hybrid(self, message: str) -> Tuple[str, float, str]: - """混合路由策略 - - Args: - message: 用户消息 - - Returns: - (agent_id, confidence, method) - """ - # 1. 先用关键词匹配 - keyword_agent_id, keyword_confidence = self._route_with_keywords(message) - - # 2. 判断是否需要 LLM - if not self.use_llm or not self.routing_model_config: - # 不使用 LLM,直接返回关键词结果 - return keyword_agent_id, keyword_confidence, "keyword" - - if keyword_confidence >= self.keyword_high_confidence_threshold: - # 关键词置信度很高,直接返回 - logger.info(f"关键词置信度高 ({keyword_confidence:.2f}),跳过 LLM") - return keyword_agent_id, keyword_confidence, "keyword" - - # 3. 使用 LLM 辅助决策 - logger.info(f"关键词置信度较低 ({keyword_confidence:.2f}),调用 LLM") - llm_agent_id, llm_confidence = await self._route_with_llm(message) - - # 4. 综合决策 - if llm_confidence > keyword_confidence: - # LLM 置信度更高 - final_confidence = llm_confidence * 0.7 + keyword_confidence * 0.3 - return llm_agent_id, final_confidence, "llm" - else: - # 关键词置信度更高或相当 - final_confidence = keyword_confidence * 0.7 + llm_confidence * 0.3 - return keyword_agent_id, final_confidence, "hybrid" - - def _route_with_keywords(self, message: str) -> Tuple[str, float]: - """基于关键词的路由 - - Args: - message: 用户消息 - - Returns: - (agent_id, confidence) - """ - best_agent_id = None - best_score = 0.0 - - for rule in self.routing_rules: - score = self._calculate_rule_score(message, rule) - - if score > best_score: - best_score = score - best_agent_id = rule.get("target_agent_id") - - if not best_agent_id or best_score < 0.3: - best_agent_id = self._get_default_agent_id() - best_score = 0.5 - - return best_agent_id, best_score - - async def _route_with_llm(self, message: str) -> Tuple[str, float]: - """基于 LLM 的路由 - - Args: - message: 用户消息 - - Returns: - (agent_id, confidence) - """ - # 检查缓存 - if self.cache_enabled: - cached_result = self._get_cached_llm_result(message) - if cached_result: - logger.info("使用缓存的 LLM 路由结果") - return cached_result - - # 构建 prompt - prompt = self._build_routing_prompt(message) - - try: - # 调用 LLM - response = await self._call_llm(prompt) - - # 解析结果 - agent_id, confidence = self._parse_llm_response(response) - - # 缓存结果 - if self.cache_enabled: - self._cache_llm_result(message, agent_id, confidence) - - return agent_id, confidence - - except Exception as e: - logger.error(f"LLM 路由失败: {str(e)}") - # 降级到关键词路由 - return self._route_with_keywords(message) - - def _build_routing_prompt(self, message: str) -> str: - """构建 LLM 路由 prompt - - Args: - message: 用户消息 - - Returns: - prompt 字符串 - """ - # 构建 Agent 描述 - agent_descriptions = [] - for agent_id, agent_data in self.sub_agents.items(): - # 获取 Agent 信息 - agent_info = agent_data.get("info", {}) - agent_config = agent_data.get("config") - - # 查找该 Agent 的路由规则 - rules = [r for r in self.routing_rules if r.get("target_agent_id") == agent_id] - - # 构建描述 - name = agent_info.get("name", "未命名 Agent") - role = agent_info.get("role", "") - capabilities = agent_info.get("capabilities", []) - - desc_parts = [f"- agent_id: {agent_id}", f" 名称: {name}"] - - if role: - desc_parts.append(f" 角色: {role}") - - # 从路由规则获取关键词 - if rules: - rule = rules[0] - keywords = rule.get("keywords", []) - if keywords: - desc_parts.append(f" 关键词: {', '.join(keywords[:5])}") - - # 从 Agent 信息获取能力 - if capabilities: - desc_parts.append(f" 擅长: {', '.join(capabilities[:5])}") - - agent_descriptions.append("\n".join(desc_parts)) - - agents_text = "\n\n".join(agent_descriptions) - - # 如果没有 Agent 描述,添加警告 - if not agents_text: - agents_text = "(警告:没有可用的 Agent 信息)" - - # 提取所有可用的 agent_id - available_agent_ids = list(self.sub_agents.keys()) - agent_ids_text = ", ".join(available_agent_ids) - - prompt = f"""你是一个智能路由助手,需要根据用户的消息,选择最合适的 Agent 来处理。 - -可用的 Agent: -{agents_text} - -用户消息:"{message}" - -**重要**:你必须从以下 agent_id 中选择一个:{agent_ids_text} - -请分析这条消息,选择最合适的 Agent。 - -要求: -1. 仔细理解消息的意图和主题 -2. 从上面列出的 agent_id 中选择最匹配的一个 -3. 给出置信度(0-1 之间的小数) -4. agent_id 必须是上面列出的其中一个,不能自己编造 - -请以 JSON 格式返回: -{{ - "agent_id": "从上面列表中选择的 agent_id", - "confidence": 0.95, - "reason": "选择理由" -}} -""" - return prompt - - async def _call_llm(self, prompt: str) -> str: - """调用 LLM API(使用系统的 RedBearLLM) - - Args: - prompt: 提示词 - - Returns: - LLM 响应 - """ - if not self.routing_model_config: - raise Exception("路由模型配置未设置") - - try: - # 使用系统的 RedBearLLM 来调用模型 - from app.core.models import RedBearLLM - from app.core.models.base import RedBearModelConfig - from app.models import ModelApiKey, ModelType - - # 获取 API Key 配置 - api_key_config = self.db.query(ModelApiKey).filter( - ModelApiKey.model_config_id == self.routing_model_config.id, - ModelApiKey.is_active == True - ).first() - - if not api_key_config: - raise Exception("路由模型没有可用的 API Key") - - # 打印供应商信息 - logger.info( - f"LLM 路由使用模型", - extra={ - "provider": api_key_config.provider, - "model_name": api_key_config.model_name, - "api_base": api_key_config.api_base, - "model_config_id": str(self.routing_model_config.id) - } - ) - - # 创建 RedBearModelConfig - model_config = RedBearModelConfig( - model_name=api_key_config.model_name, - provider=api_key_config.provider, - api_key=api_key_config.api_key, - base_url=api_key_config.api_base, - temperature=0.3, - max_tokens=500 - ) - - logger.debug(f"创建 LLM 实例 - Provider: {api_key_config.provider}, Model: {api_key_config.model_name}") - - # 创建 LLM 实例 - llm = RedBearLLM(model_config, type=ModelType.CHAT) - - # 调用模型 - response = await llm.ainvoke(prompt) - - # 提取响应内容 - if hasattr(response, 'content'): - return response.content - else: - return str(response) - - except Exception as e: - logger.error(f"LLM 路由调用失败: {str(e)}") - # 降级到关键词路由 - raise - - - - def _parse_llm_response(self, response: str) -> Tuple[str, float]: - """解析 LLM 响应 - - Args: - response: LLM 响应文本 - - Returns: - (agent_id, confidence) - """ - try: - # 提取 JSON - json_match = re.search(r'\{[^}]+\}', response) - if json_match: - result = json.loads(json_match.group()) - agent_id = result.get("agent_id") - confidence = float(result.get("confidence", 0.5)) - - # 验证 agent_id 是否有效 - if agent_id not in self.sub_agents: - logger.warning(f"LLM 返回的 agent_id 无效: {agent_id}") - agent_id = self._get_default_agent_id() - confidence = 0.5 - - return agent_id, confidence - else: - raise ValueError("无法从响应中提取 JSON") - - except Exception as e: - logger.error(f"解析 LLM 响应失败: {str(e)}") - return self._get_default_agent_id(), 0.5 - - def _get_cached_llm_result(self, message: str) -> Optional[Tuple[str, float]]: - """获取缓存的 LLM 结果 - - Args: - message: 用户消息 - - Returns: - 缓存的结果或 None - """ - # TODO: 实现真正的缓存机制(使用 Redis 或内存字典) - return None - - def _cache_llm_result(self, message: str, agent_id: str, confidence: float): - """缓存 LLM 结果 - - Args: - message: 用户消息 - agent_id: Agent ID - confidence: 置信度 - """ - # lru_cache 会自动处理缓存 - pass - - async def _extract_topic_with_llm(self, message: str) -> str: - """使用 LLM 提取主题 - - Args: - message: 用户消息 - - Returns: - 主题名称 - """ - if not self.routing_model_config: - return self._extract_topic(message) - - prompt = f"""请分析以下消息的主题,从这些选项中选择一个: -数学、物理、化学、语文、英语、历史、作业、学习规划、订单、退款、账户、支付、其他 - -消息:"{message}" - -只返回主题名称,不要其他内容。 -""" - - try: - response = await self._call_llm(prompt) - topic = response.strip() - - # 验证主题 - valid_topics = [ - "数学", "物理", "化学", "语文", "英语", "历史", - "作业", "学习规划", "订单", "退款", "账户", "支付", "其他" - ] - - if topic in valid_topics: - return topic - else: - return self._extract_topic(message) - - except Exception as e: - logger.error(f"LLM 提取主题失败: {str(e)}") - return self._extract_topic(message) - - # 以下方法与 SmartRouter 相同 - - def _detect_topic_change( - self, - message: str, - state: Optional[Dict[str, Any]] - ) -> bool: - """检测主题是否切换""" - if not state or not state.get("last_topic"): - return False - - for signal in self.SWITCH_SIGNALS: - if signal in message: - logger.info(f"检测到主题切换信号: {signal}") - return True - - current_topic = self._extract_topic(message) - last_topic = state.get("last_topic") - - if current_topic != last_topic and current_topic != "其他": - logger.info(f"主题变化: {last_topic} -> {current_topic}") - return True - - return False - - def _should_continue_current_agent( - self, - message: str, - current_agent_id: str - ) -> Tuple[bool, float]: - """判断是否应该继续使用当前 Agent""" - has_continuation_signal = any( - signal in message - for signal in self.CONTINUATION_SIGNALS - ) - - current_score = self._calculate_agent_score(message, current_agent_id) - - if has_continuation_signal and current_score > 0.3: - return True, min(current_score + 0.2, 1.0) - - if current_score > 0.6: - return True, current_score - - return False, current_score - - def _calculate_rule_score( - self, - message: str, - rule: Dict[str, Any] - ) -> float: - """计算规则匹配分数""" - score = 0.0 - message_lower = message.lower() - - keywords = rule.get("keywords", []) - if keywords: - matched_keywords = sum( - 1 for keyword in keywords - if keyword.lower() in message_lower - ) - keyword_score = matched_keywords / len(keywords) - score += keyword_score * 0.6 - - patterns = rule.get("patterns", []) - if patterns: - matched_patterns = sum( - 1 for pattern in patterns - if re.search(pattern, message, re.IGNORECASE) - ) - pattern_score = matched_patterns / len(patterns) - score += pattern_score * 0.3 - - exclude_keywords = rule.get("exclude_keywords", []) - if exclude_keywords: - has_exclude = any( - keyword.lower() in message_lower - for keyword in exclude_keywords - ) - if has_exclude: - score *= 0.5 - - min_keyword_count = rule.get("min_keyword_count", 0) - if keywords and min_keyword_count > 0: - matched_count = sum( - 1 for keyword in keywords - if keyword.lower() in message_lower - ) - if matched_count < min_keyword_count: - score *= 0.7 - - return min(score, 1.0) - - def _calculate_agent_score( - self, - message: str, - agent_id: str - ) -> float: - """计算 Agent 对消息的匹配分数""" - agent_rules = [ - rule for rule in self.routing_rules - if rule.get("target_agent_id") == agent_id - ] - - if not agent_rules: - return 0.0 - - max_score = max( - self._calculate_rule_score(message, rule) - for rule in agent_rules - ) - - return max_score - - def _extract_topic(self, message: str) -> str: - """提取消息主题(关键词方式)""" - topic_keywords = { - "数学": ["数学", "方程", "计算", "求解", "x", "y", "函数", "几何"], - "物理": ["物理", "力", "速度", "加速度", "能量", "功率", "电路"], - "化学": ["化学", "方程式", "反应", "元素", "分子", "原子", "化合物"], - "语文": ["语文", "古诗", "作文", "阅读", "文言文", "诗词"], - "英语": ["英语", "单词", "语法", "翻译", "时态", "句型"], - "历史": ["历史", "朝代", "事件", "人物", "战争", "革命"], - "作业": ["作业", "批改", "检查", "评分", "反馈"], - "学习规划": ["计划", "规划", "方法", "技巧", "时间", "安排"], - "订单": ["订单", "发货", "物流", "配送", "快递"], - "退款": ["退款", "退货", "售后", "换货", "维修"], - "账户": ["账户", "密码", "登录", "注册", "绑定"], - "支付": ["支付", "付款", "充值", "余额", "优惠券"] - } - - message_lower = message.lower() - - topic_scores = {} - for topic, keywords in topic_keywords.items(): - matched = sum( - 1 for keyword in keywords - if keyword in message_lower - ) - if matched > 0: - topic_scores[topic] = matched - - if topic_scores: - best_topic = max(topic_scores.items(), key=lambda x: x[1])[0] - return best_topic - - return "其他" - - def _get_default_agent_id(self) -> str: - """获取默认 Agent ID""" - if self.routing_rules: - return self.routing_rules[0].get("target_agent_id") - - if self.sub_agents: - return list(self.sub_agents.keys())[0] - - return "default-agent" diff --git a/app/services/memory_agent_service.py b/app/services/memory_agent_service.py deleted file mode 100644 index ab9b8195..00000000 --- a/app/services/memory_agent_service.py +++ /dev/null @@ -1,1035 +0,0 @@ -""" -Memory Agent Service - -Handles business logic for memory agent operations including read/write services, -health checks, and message type classification. -""" -import os -import re -import time -import json -import uuid -from threading import Lock -from typing import Dict, List, Optional, Any, AsyncGenerator -from app.services.memory_konwledges_server import write_rag -import redis -from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_mcp_adapters.tools import load_mcp_tools -from sqlalchemy.orm import Session -from sqlalchemy import func -from pydantic import BaseModel, Field - -from app.core.config import settings -from app.core.logging_config import get_logger -from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph -from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config -from app.core.memory.agent.utils.type_classifier import status_typle -from app.db import get_db -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.schemas.memory_storage_schema import ApiResponse, ok, fail -from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.data_config_repository import DataConfigRepository -from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename -from app.core.memory.utils.config.definitions import reload_configuration_from_database -from app.schemas.file_schema import CustomTextFileCreate -try: - from app.core.memory.utils.log.audit_logger import audit_logger -except ImportError: - audit_logger = None -logger = get_logger(__name__) - -# Initialize Neo4j connector for analytics functions -_neo4j_connector = Neo4jConnector() -db_gen = get_db() -db = next(db_gen) - -class MemoryAgentService: - """Service for memory agent operations""" - - def __init__(self): - self.user_locks: Dict[str, Lock] = {} - self.locks_lock = Lock() - - def writer_messages_deal(self,messages,start_time,group_id,config_id,message): - messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') - countext = re.findall(r'"status": "(.*?)",', messages)[0] - duration = time.time() - start_time - - if countext == 'success': - logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") - # 记录成功的操作 - if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, - duration=duration, details={"message_length": len(message)}) - return countext - else: - logger.warning(f"Write operation failed for group {group_id}") - - # 记录失败的操作 - if audit_logger: - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=f"写入失败: {messages[:100]}" - ) - - raise ValueError(f"写入失败: {messages}") - - def get_group_lock(self, group_id: str) -> Lock: - """Get lock for specific group to prevent concurrent processing""" - with self.locks_lock: - if group_id not in self.user_locks: - self.user_locks[group_id] = Lock() - return self.user_locks[group_id] - - def extract_tool_call_info(self, event: Dict) -> bool: - """Extract tool call information from event""" - last_message = event["messages"][-1] - - # Check if AI message contains tool calls - if hasattr(last_message, 'tool_calls') and last_message.tool_calls: - tool_calls = last_message.tool_calls - for i, tool_call in enumerate(tool_calls): - if isinstance(tool_call, dict): - tool_call_id = tool_call.get('id') - tool_name = tool_call.get('name') - tool_args = tool_call.get('args', {}) - else: - tool_call_id = getattr(tool_call, 'id', None) - tool_name = getattr(tool_call, 'name', None) - tool_args = getattr(tool_call, 'args', {}) - - logger.debug(f"Tool Call {i + 1}: ID={tool_call_id}, Name={tool_name}, Args={tool_args}") - return True - - # Check if tool message - elif hasattr(last_message, 'tool_call_id'): - tool_call_id = getattr(last_message, 'tool_call_id', None) - if hasattr(last_message, 'name') and hasattr(last_message, 'content'): - tool_name = getattr(last_message, 'name', None) - try: - content = json.loads(getattr(last_message, 'content', '{}')) - tool_args = content.get('args', {}) - logger.debug(f"Tool Call 1: ID={tool_call_id}, Name={tool_name}, Args={tool_args}") - except: - logger.debug(f"Tool Response ID: {tool_call_id}") - else: - logger.debug(f"Tool Response ID: {tool_call_id}") - return True - - return False - - async def get_health_status(self) -> Dict: - """ - Get latest health status from Redis cache - - Returns health status information written by Celery periodic task - """ - logger.info("Checking health status") - - client = redis.Redis( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, - password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None - ) - payload = client.hgetall("memsci:health:read_service") or {} - - if payload: - # decode bytes to str - decoded = {k.decode("utf-8"): v.decode("utf-8") for k, v in payload.items()} - status = decoded.get("status", "unknown") - else: - status = "unknown" - - logger.info(f"Health status: {status}") - return {"status": status} - - def get_log_content(self) -> str: - """ - Read and return agent service log file content - - Returns cleaned log content using the same cleaning logic as transmission mode - - Returns cleaned log content using the same cleaning logic as transmission mode - """ - logger.info("Reading log file") - - # Use project root directory for logs - # Get the project root (redbear-mem directory) - current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py - app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory - project_root = os.path.dirname(app_dir) # redbear-mem directory - log_path = os.path.join(project_root, "logs", "agent_service.log") - - summer = '' - - with open(log_path, "r", encoding="utf-8") as infile: - for line in infile: - # Use the same cleaning logic as LogStreamer for consistency - cleaned = LogStreamer.clean_log_line(line) - summer += cleaned - - if len(summer) < 10: - raise ValueError("NO LOGS") - - logger.info(f"Log content retrieved, size: {len(summer)} bytes") - return summer - - async def stream_log_content(self) -> AsyncGenerator[str, None]: - """ - Stream log content in real-time using Server-Sent Events (SSE) - - This method establishes a streaming connection and transmits log entries - as they are written to the log file. It uses the LogStreamer to watch - the file and yields SSE-formatted messages. - - Yields: - SSE-formatted strings with the following event types: - - log: Contains log content and timestamp - - keepalive: Periodic keepalive messages to maintain connection - - error: Error information if streaming fails - - done: Indicates streaming has completed - - Raises: - FileNotFoundError: If log file doesn't exist at stream start - Exception: For other unexpected errors during streaming - """ - logger.info("Starting log content streaming") - - # Get log file path - use project root directory - current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py - app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory - project_root = os.path.dirname(app_dir) # redbear-mem directory - log_path = os.path.join(project_root, "logs", "agent_service.log") - - # Check if file exists before starting stream - if not os.path.exists(log_path): - logger.error(f"Log file not found: {log_path}") - # Send error event in SSE format - yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n" - return - - streamer = None - try: - # Initialize LogStreamer with keepalive interval from settings (default 300 seconds) - keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300) - streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval) - - logger.info(f"LogStreamer initialized for {log_path}") - - # Stream log content using read_existing_and_stream to get all existing content first - async for message in streamer.read_existing_and_stream(): - event_type = message.get("event") - data = message.get("data") - - # Format as SSE message - # SSE format: "event: <type>\ndata: <json_data>\n\n" - sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - - logger.debug(f"Streaming event: {event_type}") - yield sse_message - - # If error or done event, stop streaming - if event_type in ["error", "done"]: - logger.info(f"Stream ended with event: {event_type}") - break - - except FileNotFoundError as e: - logger.error(f"Log file not found during streaming: {e}") - yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n" - - except Exception as e: - logger.error(f"Unexpected error during log streaming: {e}", exc_info=True) - yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n" - - finally: - # Resource cleanup - logger.info("Log streaming completed, cleaning up resources") - # LogStreamer uses context manager for file handling, so cleanup is automatic - - async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str: - """ - Process write operation with config_id - - Args: - group_id: Group identifier - message: Message to write - config_id: Configuration ID from database - - Returns: - Write operation result status - - Raises: - ValueError: If config loading fails or write operation fails - """ - if config_id==None: - config_id = os.getenv("config_id") - import time - start_time = time.time() - - # 如果 config_id 为 None,使用默认值 "17" - config_loaded = reload_configuration_from_database(config_id) - if not config_loaded: - error_msg = f"Failed to load configuration for config_id: {config_id}" - logger.error(error_msg) - - # 记录失败的操作 - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg ) - - raise ValueError(error_msg) - logger.info(f"Configuration loaded successfully for config_id: {config_id}") - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) - return result - else: - async with client.session("data_flow") as session: - logger.debug("Connected to MCP Server: data_flow") - tools = await load_mcp_tools(session) - - # Pass config_id to the graph workflow - async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph: - logger.debug("Write graph created successfully") - - config = {"configurable": {"thread_id": group_id}} - - async for event in graph.astream( - {"messages": message, "config_id": config_id}, - stream_mode="values", - config=config - ): - messages = event.get('messages') - return self.writer_messages_deal(messages,start_time,group_id,config_id,message) - - async def read_memory( - self, - group_id: str, - message: str, - history: List[Dict], - search_switch: str, - config_id: str, - storage_type: str, - user_rag_memory_id: str - ) -> Dict: - """ - Process read operation with config_id - - search_switch values: - - "0": Requires verification - - "1": No verification, direct split - - "2": Direct answer based on context - - Args: - group_id: Group identifier - message: User message - history: Conversation history - search_switch: Search mode switch - config_id: Configuration ID from database - - Returns: - Dict with 'answer' and 'intermediate_outputs' keys - - Raises: - ValueError: If config loading fails - """ - - import time - start_time = time.time() - - if config_id==None: - config_id = os.getenv("config_id") - - logger.info(f"Read operation for group {group_id} with config_id {config_id}") - - # 导入审计日志记录器 - try: - from app.core.memory.utils.log.audit_logger import audit_logger - except ImportError: - audit_logger = None - - # Get group lock to prevent concurrent processing - group_lock = self.get_group_lock(group_id) - - with group_lock: - # Step 1: Load configuration from database - from app.core.memory.utils.config.definitions import reload_configuration_from_database - - config_loaded = reload_configuration_from_database(config_id) - if not config_loaded: - error_msg = f"Failed to load configuration for config_id: {config_id}" - logger.error(error_msg) - - # 记录失败的操作 - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=error_msg - ) - - raise ValueError(error_msg) - - logger.info(f"Configuration loaded successfully for config_id: {config_id}") - - # Step 2: Prepare history - history.append({"role": "user", "content": message}) - logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") - - # Step 3: Initialize MCP client and execute read workflow - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - async with client.session('data_flow') as session: - logger.debug("Connected to MCP Server: data_flow") - tools = await load_mcp_tools(session) - outputs = [] - intermediate_outputs = [] - seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates - - # Pass config_id to the graph workflow - async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph: - start = time.time() - config = {"configurable": {"thread_id": group_id}} - - async for event in graph.astream( - {"messages": history, "config_id": config_id}, - stream_mode="values", - config=config - ): - messages = event.get('messages') - for msg in messages: - msg_content = msg.content - outputs.append({ - "role": msg.__class__.__name__.lower().replace("message", ""), - "content": msg_content - }) - - # Extract intermediate outputs - if hasattr(msg, 'content'): - try: - # Debug: log message type and content preview - msg_type = msg.__class__.__name__ - content_preview = str(msg_content)[:200] if msg_content else "empty" - logger.debug(f"Processing message type={msg_type}, content preview={content_preview}") - - # Try to parse content as JSON - if isinstance(msg_content, str): - try: - parsed = json.loads(msg_content) - if isinstance(parsed, dict): - # Debug: log what keys are in parsed - logger.debug(f"Parsed dict keys: {list(parsed.keys())}") - - # Check for single intermediate output - if '_intermediate' in parsed: - intermediate_data = parsed['_intermediate'] - output_key = self._create_intermediate_key(intermediate_data) - logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}") - - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in parsed: - logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items") - for intermediate_data in parsed['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) - logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}") - - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except (json.JSONDecodeError, ValueError): - pass - elif isinstance(msg_content, dict): - # Check for single intermediate output - if '_intermediate' in msg_content: - intermediate_data = msg_content['_intermediate'] - output_key = self._create_intermediate_key(intermediate_data) - - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in msg_content: - for intermediate_data in msg_content['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) - - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except Exception as e: - logger.debug(f"Failed to extract intermediate output: {e}") - - workflow_duration = time.time() - start - logger.info(f"Read graph workflow completed in {workflow_duration}s") - - # Extract final answer - final_answer = "" - for messages in outputs: - if messages['role'] == 'tool': - message = messages['content'] - try: - message = json.loads(message) if isinstance(message, str) else message - if isinstance(message, dict) and message.get('status') != '': - summary_result = message.get('summary_result') - if summary_result: - final_answer = summary_result - except (json.JSONDecodeError, ValueError): - pass - - # 记录成功的操作 - total_duration = time.time() - start_time - if audit_logger: - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=True, - duration=total_duration, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer) - } - ) - - return { - "answer": final_answer, - "intermediate_outputs": intermediate_outputs - } - - def _create_intermediate_key(self, output: Dict) -> str: - """ - Create a unique key for an intermediate output to detect duplicates. - - Args: - output: Intermediate output dictionary - - Returns: - Unique string key for this output - """ - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - # Use type + original query as key - return f"split:{output.get('original_query', '')}" - elif output_type == 'problem_extension': - # Use type + original query as key - return f"extension:{output.get('original_query', '')}" - elif output_type == 'search_result': - # Use type + query + index as key - return f"search:{output.get('query', '')}:{output.get('index', 0)}" - elif output_type == 'retrieval_summary': - # Use type + query as key - return f"summary:{output.get('query', '')}" - elif output_type == 'verification': - # Use type + query as key - return f"verification:{output.get('query', '')}" - elif output_type == 'input_summary': - # Use type + query as key - return f"input_summary:{output.get('query', '')}" - else: - # Fallback: use JSON representation - import json - return json.dumps(output, sort_keys=True) - - def _format_intermediate_output(self, output: Dict) -> Dict: - """Format intermediate output for frontend display.""" - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - return { - 'type': 'problem_split', - 'title': '问题拆分', - 'data': output.get('data', []), - 'original_query': output.get('original_query', '') - } - elif output_type == 'problem_extension': - return { - 'type': 'problem_extension', - 'title': '问题扩展', - 'data': output.get('data', {}), - 'original_query': output.get('original_query', '') - } - elif output_type == 'search_result': - return { - 'type': 'search_result', - 'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})', - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results', ''), - 'index': output.get('index', 0), - 'total': output.get('total', 0) - } - elif output_type == 'retrieval_summary': - return { - 'type': 'retrieval_summary', - 'title': '检索总结', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - elif output_type == 'verification': - return { - 'type': 'verification', - 'title': '数据验证', - 'result': output.get('result', 'unknown'), - 'reason': output.get('reason', ''), - 'query': output.get('query', ''), - 'verified_count': output.get('verified_count', 0) - } - elif output_type == 'input_summary': - return { - 'type': 'input_summary', - 'title': '快速答案', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - else: - return output - - async def classify_message_type(self, message: str) -> Dict: - """ - Determine the type of user message (read or write) - - Args: - message: User message to classify - - Returns: - Type classification result - """ - logger.info("Classifying message type") - - status = await status_typle(message) - logger.debug(f"Message type: {status}") - return status - - # ==================== 新增的三个接口方法 ==================== - - async def get_knowledge_type_stats( - self, - end_user_id: Optional[str] = None, - only_active: bool = True, - current_workspace_id: Optional[uuid.UUID] = None, - db: Session = None - ) -> Dict[str, Any]: - """ - 统计知识库类型分布,包含: - 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) - 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤) - 3. total: 所有类型的总和 - - 参数: - - end_user_id: 用户组ID(可选,未提供时 memory 统计为 0) - - only_active: 是否仅统计有效记录 - - current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0) - - db: 数据库会话 - - 返回格式: - { - "General": count, - "Web": count, - "Third-party": count, - "Folder": count, - "memory": chunk_count, - "total": sum_of_all - } - """ - result = {} - - # 1. 统计 PostgreSQL 中的知识库类型 - try: - if db is None: - from app.db import get_db - db_gen = get_db() - db = next(db_gen) - - # 初始化所有标准类型为 0 - for kb_type in KnowledgeType: - result[kb_type.value] = 0 - - # 如果提供了 workspace_id,则按 workspace_id 过滤 - if current_workspace_id: - # 构建查询条件 - query = db.query( - Knowledge.type, - func.count(Knowledge.id).label('count') - ).filter(Knowledge.workspace_id == current_workspace_id) - - # 检查 Knowledge 模型是否有 status 字段 - if only_active and hasattr(Knowledge, 'status'): - query = query.filter(Knowledge.status == 1) - - # 按类型分组 - type_counts = query.group_by(Knowledge.type).all() - - # 只填充标准类型的统计值,忽略其他类型 - valid_types = {kb_type.value for kb_type in KnowledgeType} - for type_name, count in type_counts: - if type_name in valid_types: - result[type_name] = count - - logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}") - else: - # 没有提供 workspace_id,所有知识库类型返回 0 - logger.info(f"未提供 workspace_id,知识库类型统计全部为 0") - - except Exception as e: - logger.error(f"知识库类型统计失败: {e}") - raise Exception(f"知识库类型统计失败: {e}") - - # 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数) - try: - if current_workspace_id: - # 获取当前空间下的所有宿主 - from app.repositories import app_repository, end_user_repository - from app.schemas.app_schema import App as AppSchema - from app.schemas.end_user_schema import EndUser as EndUserSchema - - # 查询应用并转换为 Pydantic 模型 - apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) - apps = [AppSchema.model_validate(h) for h in apps_orm] - app_ids = [app.id for app in apps] - - # 获取所有宿主 - end_users = [] - for app_id in app_ids: - end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id) - end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list]) - - # 统计所有宿主的 Chunk 总数 - total_chunks = 0 - for end_user in end_users: - end_user_id_str = str(end_user.id) - memory_query = """ - MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count - """ - neo4j_result = await _neo4j_connector.execute_query( - memory_query, - group_id=end_user_id_str, - ) - chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 - total_chunks += chunk_count - logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}") - - result["memory"] = total_chunks - logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}") - else: - # 没有 workspace_id 时,返回 0 - result["memory"] = 0 - logger.info(f"未提供 workspace_id,memory 统计为 0") - - except Exception as e: - logger.error(f"Neo4j memory统计失败: {e}", exc_info=True) - # 如果 Neo4j 查询失败,memory 设为 0 - result["memory"] = 0 - - # 3. 计算知识库类型总和(不包括 memory) - result["total"] = ( - result.get("General", 0) + - result.get("Web", 0) + - result.get("Third-party", 0) + - result.get("Folder", 0) - ) - - return result - - - async def get_hot_memory_tags_by_user( - self, - end_user_id: Optional[str] = None, - limit: int = 20 - ) -> List[Dict[str, Any]]: - """ - 获取指定用户的热门记忆标签 - - 参数: - - end_user_id: 用户ID(可选),对应Neo4j中的group_id字段 - - limit: 返回标签数量限制 - - 返回格式: - [ - {"name": "标签名", "frequency": 频次}, - ... - ] - """ - try: - # by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度) - tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) - payload = [{"name": t, "frequency": f} for t, f in tags] - return payload - except Exception as e: - logger.error(f"热门记忆标签查询失败: {e}") - raise Exception(f"热门记忆标签查询失败: {e}") - - - async def get_user_profile( - self, - end_user_id: Optional[str] = None, - current_user_id: Optional[str] = None - ) -> Dict[str, Any]: - """ - 获取用户详情,包含: - 1. 用户名字(直接使用 end_user_name) - 2. 用户标签(从摘要中用LLM总结3个标签) - 3. 热门记忆标签(从hot_memory_tags获取前4个) - - 参数: - - end_user_id: 用户ID(可选) - - current_user_id: 当前登录用户的ID(保留参数) - - 返回格式: - { - "name": "用户名", - "tags": ["产品设计师", "旅行爱好者", "摄影发烧友"], - "hot_tags": [ - {"name": "标签1", "frequency": 10}, - {"name": "标签2", "frequency": 8}, - ... - ] - } - """ - result = {} - - # 1. 根据 end_user_id 获取 end_user_name - try: - if end_user_id: - from app.repositories import end_user_repository - from app.schemas.end_user_schema import EndUser as EndUserSchema - - end_user_orm = end_user_repository.get_end_user_by_id(db, end_user_id) - if end_user_orm: - end_user = EndUserSchema.model_validate(end_user_orm) - end_user_name = end_user.other_name - else: - end_user_name = "默认用户" - else: - end_user_name = "默认用户" - except Exception as e: - logger.error(f"Failed to get end_user_name: {e}") - end_user_name = "默认用户" - - result["name"] = end_user_name - logger.debug(f"The end_user is: {end_user_name}") - - # 2. 使用LLM从语句和实体中提取标签 - try: - connector = Neo4jConnector() - - # 查询该用户的语句 - query = ( - "MATCH (s:Statement) " - "WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL " - "RETURN s.statement AS statement " - "ORDER BY s.created_at DESC LIMIT 100" - ) - rows = await connector.execute_query(query, group_id=end_user_id) - statements = [r.get("statement", "") for r in rows if r.get("statement")] - - # 查询该用户的热门实体 - entity_query = ( - "MATCH (e:ExtractedEntity) " - "WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL " - "RETURN e.name AS name, count(e) AS frequency " - "ORDER BY frequency DESC LIMIT 20" - ) - entity_rows = await connector.execute_query(entity_query, group_id=end_user_id) - entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows] - - await connector.close() - - if not statements: - result["tags"] = [] - else: - # 构建摘要文本 - summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}" - logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities") - - # 使用LLM提取标签 - llm_client = get_llm_client() - - # 定义标签提取的结构 - class UserTags(BaseModel): - tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") - - messages = [ - { - "role": "system", - "content": "你是一个信息提取助手。从用户的语句和实体中提取3个最能代表用户特征的标签。标签应该简洁(2-6个字),描述用户的职业、兴趣或特点。" - }, - { - "role": "user", - "content": f"请从以下用户信息中提取3个标签:\n\n{summary_text}" - } - ] - - user_tags = await llm_client.response_structured( - messages=messages, - response_model=UserTags - ) - - result["tags"] = user_tags.tags - logger.debug(f"Extracted tags: {user_tags.tags}") - - except Exception as e: - # 如果提取失败,使用默认值 - logger.error(f"Failed to extract user tags: {e}") - result["tags"] = [] - - try: - # 3. 获取热门记忆标签(前4个) - connector = Neo4jConnector() - names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria'] - hot_tag_query = ( - "MATCH (e:ExtractedEntity) " - "WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' " - "AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " - "RETURN e.name AS name, count(e) AS frequency " - "ORDER BY frequency DESC LIMIT 4" - ) - hot_tag_rows = await connector.execute_query( - hot_tag_query, - group_id=end_user_id, - names_to_exclude=names_to_exclude - ) - await connector.close() - - result["hot_tags"] = [{"name": r["name"], "frequency": r["frequency"]} for r in hot_tag_rows] - logger.debug(f"Hot tags found: {len(result['hot_tags'])} tags") - except Exception as e: - logger.error(f"Failed to get hot tags: {e}") - result["hot_tags"] = [] - - return result - - async def stream_log_content(self) -> AsyncGenerator[str, None]: - """ - Stream log content in real-time using Server-Sent Events (SSE) - - This method establishes a streaming connection and transmits log entries - as they are written to the log file. It uses the LogStreamer to watch - the file and yields SSE-formatted messages. - - Yields: - SSE-formatted strings with the following event types: - - log: Contains log content and timestamp - - keepalive: Periodic keepalive messages to maintain connection - - error: Error information if streaming fails - - done: Indicates streaming has completed - - Raises: - FileNotFoundError: If log file doesn't exist at stream start - Exception: For other unexpected errors during streaming - """ - logger.info("Starting log content streaming") - - # Get log file path - use project root directory - current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py - app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory - project_root = os.path.dirname(app_dir) # redbear-mem directory - log_path = os.path.join(project_root, "logs", "agent_service.log") - - # Check if file exists before starting stream - if not os.path.exists(log_path): - logger.error(f"Log file not found: {log_path}") - # Send error event in SSE format - yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n" - return - - streamer = None - try: - # Initialize LogStreamer with keepalive interval from settings (default 300 seconds) - keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300) - streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval) - - logger.info(f"LogStreamer initialized for {log_path}") - - # Stream log content using read_existing_and_stream to get all existing content first - async for message in streamer.read_existing_and_stream(): - event_type = message.get("event") - data = message.get("data") - - # Format as SSE message - # SSE format: "event: <type>\ndata: <json_data>\n\n" - sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - - logger.debug(f"Streaming event: {event_type}") - yield sse_message - - # If error or done event, stop streaming - if event_type in ["error", "done"]: - logger.info(f"Stream ended with event: {event_type}") - break - - except FileNotFoundError as e: - logger.error(f"Log file not found during streaming: {e}") - yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n" - - except Exception as e: - logger.error(f"Unexpected error during log streaming: {e}", exc_info=True) - yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n" - - finally: - # Resource cleanup - logger.info("Log streaming completed, cleaning up resources") - # LogStreamer uses context manager for file handling, so cleanup is automatic - -# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]: -# """ -# Parse and return API documentation - -# Args: -# file_path: Optional path to API docs file. If None, uses default path. - -# Returns: -# Dict containing parsed API documentation or error information -# """ -# try: -# target = file_path or get_default_docs_path() - -# if not os.path.isfile(target): -# return { -# "success": False, -# "msg": "API文档文件不存在", -# "error_code": "DOC_NOT_FOUND", -# "data": {"path": target} -# } - -# data = parse_api_docs(target) -# return { -# "success": True, -# "msg": "解析成功", -# "data": data -# } -# except Exception as e: -# logger.error(f"Failed to parse API docs: {e}") -# return { -# "success": False, -# "msg": "解析失败", -# "error_code": "DOC_PARSE_ERROR", -# "data": {"error": str(e)} -# } \ No newline at end of file diff --git a/app/services/memory_dashboard_service.py b/app/services/memory_dashboard_service.py deleted file mode 100644 index 31a6db3d..00000000 --- a/app/services/memory_dashboard_service.py +++ /dev/null @@ -1,595 +0,0 @@ -from sqlalchemy.orm import Session -from typing import List -import uuid -from fastapi import HTTPException - -from app.models.user_model import User -from app.models.app_model import App -from app.models.end_user_model import EndUser -from app.models.memory_increment_model import MemoryIncrement - -from app.repositories import ( - app_repository, - end_user_repository, - memory_increment_repository, - knowledge_repository -) -from app.schemas.end_user_schema import EndUser as EndUserSchema -from app.schemas.memory_increment_schema import MemoryIncrement as MemoryIncrementSchema -from app.schemas.app_schema import App as AppSchema -from app.core.logging_config import get_business_logger - - -# 获取业务逻辑专用日志器 -business_logger = get_business_logger() - - -def get_workspace_end_users( - db: Session, - workspace_id: uuid.UUID, - current_user: User -) -> List[EndUser]: - """获取工作空间的所有宿主""" - business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - # 查询应用(ORM)并转换为 Pydantic 模型 - apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - apps = [AppSchema.model_validate(h) for h in apps_orm] - app_ids = [app.id for app in apps] - end_users = [] - for app_id in app_ids: - end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id) - end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list]) - - business_logger.info(f"成功获取 {len(end_users)} 个宿主记录") - return end_users - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"获取工作空间宿主列表失败: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_workspace_memory_increment( - db: Session, - workspace_id: uuid.UUID, - limit: int, - current_user: User -) -> List[MemoryIncrementSchema]: - """获取工作空间的记忆增量""" - business_logger.info(f"获取工作空间记忆增量: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - # 查询记忆增量 - memory_increment_orm_list = memory_increment_repository.get_memory_increments_by_workspace_id(db, workspace_id, limit) - memory_increment = [MemoryIncrementSchema.model_validate(m) for m in memory_increment_orm_list] - - business_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录") - return memory_increment - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"获取工作空间记忆增量失败: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_workspace_api_increment( - db: Session, - workspace_id: uuid.UUID, - current_user: User -) -> int: - """获取工作空间的API调用增量""" - business_logger.info(f"获取工作空间API调用增量: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - # 查询API调用增量 - api_increment = 856 - - business_logger.info(f"成功获取 {api_increment} API调用增量") - return api_increment - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"获取工作空间API调用增量失败: workspace_id={workspace_id} - {str(e)}") - raise - - -def write_workspace_total_memory( - db: Session, - workspace_id: uuid.UUID, - current_user: User -) -> int: - """写入工作空间的记忆总量""" - business_logger.info(f"写入工作空间记忆总量: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - # 模拟记忆总量 - total_num = 1024 - - # 写入记忆总量 - memory_increment_repository.write_memory_increment(db, workspace_id, total_num) - - business_logger.info(f"成功写入记忆总量 {total_num}") - return total_num - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"写入工作空间记忆总量失败: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_workspace_memory_list( - db: Session, - workspace_id: uuid.UUID, - current_user: User, - limit: int = 7 -) -> dict: - """ - 获取工作空间的记忆列表(整合接口) - - 整合以下三个接口的数据: - 1. total_memory - 工作空间记忆总量 - 2. memory_increment - 工作空间记忆增量 - 3. hosts - 工作空间宿主列表 - """ - business_logger.info(f"获取工作空间记忆列表: workspace_id={workspace_id}, 操作者: {current_user.username}") - - result = {} - - try: - # 1. 获取记忆总量 - try: - total_memory = write_workspace_total_memory(db, workspace_id, current_user) - result["total_memory"] = total_memory - business_logger.info(f"成功获取记忆总量: {total_memory}") - except Exception as e: - business_logger.warning(f"获取记忆总量失败: {str(e)}") - result["total_memory"] = 0.0 - - # 2. 获取记忆增量 - try: - memory_increment = get_workspace_memory_increment(db, workspace_id, limit, current_user) - result["memory_increment"] = memory_increment - business_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录") - except Exception as e: - business_logger.warning(f"获取记忆增量失败: {str(e)}") - result["memory_increment"] = [] - - # 3. 获取宿主列表 - try: - hosts = get_workspace_end_users(db, workspace_id, current_user) - result["hosts"] = hosts - business_logger.info(f"成功获取 {len(hosts)} 个宿主记录") - except Exception as e: - business_logger.warning(f"获取宿主列表失败: {str(e)}") - result["hosts"] = [] - - business_logger.info(f"成功获取工作空间记忆列表") - return result - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"获取工作空间记忆列表失败: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_workspace_total_end_users( - db: Session, - workspace_id: uuid.UUID, - current_user: User -) -> dict: - """ - 获取用户列表的总用户数 - """ - business_logger.info(f"获取用户列表的总用户数: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - # 复用原有的 get_workspace_end_users 逻辑 - end_users = get_workspace_end_users(db, workspace_id, current_user) - - business_logger.info(f"成功获取 {len(end_users)} 个宿主记录") - return { - "total_num": len(end_users), - "online_num": len(end_users) - } - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"获取用户列表失败: workspace_id={workspace_id} - {str(e)}") - raise - - -async def get_workspace_total_memory_count( - db: Session, - workspace_id: uuid.UUID, - current_user: User, - end_user_id: str = None -) -> dict: - """ - 获取工作空间的记忆总量(通过聚合所有host的记忆数) - - 逻辑: - 1. 从 memory_list 获取所有 host_id - 2. 对每个 host_id 调用 search_all 获取 total - 3. 将所有 total 求和返回 - """ - business_logger.info(f"获取工作空间记忆总量: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - # 1. 获取所有 hosts - hosts = get_workspace_end_users(db, workspace_id, current_user) - business_logger.info(f"获取到 {len(hosts)} 个宿主") - - if not hosts: - business_logger.warning("未找到任何宿主,返回0") - return { - "total_memory_count": 0, - "host_count": 0, - "details": [] - } - - # 2. 对每个 host_id 调用 search_all 获取 total - from app.services import memory_storage_service - - total_count = 0 - details = [] - - # 如果提供了 end_user_id,只查询该用户 - if end_user_id: - search_result = await memory_storage_service.search_all(end_user_id=end_user_id) - return { - "total_memory_count": search_result.get("total", 0), - "host_count": 1, - "details": [{"end_user_id": end_user_id, "count": search_result.get("total", 0)}] - } - - for host in hosts: - try: - end_user_id_str = str(host.id) - - search_result = await memory_storage_service.search_all( - end_user_id=end_user_id_str - ) - - host_total = search_result.get("total", 0) - total_count += host_total - - details.append({ - "end_user_id": end_user_id_str, - "count": host_total - }) - - business_logger.debug(f"EndUser {end_user_id_str} 记忆数: {host_total}") - - except Exception as e: - business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}") - # 失败的 host 记为 0 - details.append({ - "end_user_id": str(host.id), - "count": 0 - }) - - result = { - "total_memory_count": total_count, - "host_count": len(hosts), - "details": details - } - - business_logger.info(f"成功获取工作空间记忆总量: {total_count} (来自 {len(hosts)} 个宿主)") - return result - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"获取工作空间记忆总量失败: workspace_id={workspace_id} - {str(e)}") - raise - - -# ======== RAG 相关服务 ======== -def get_rag_total_doc( - db: Session, - current_user: User -) -> int: - """ - 根据当前用户所在的workspace_id查询konwledges表所有doc_num的总和 - """ - workspace_id = current_user.current_workspace_id - business_logger.info(f"获取RAG总文档数: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - total_doc = knowledge_repository.get_total_doc_num_by_workspace(db, workspace_id) - business_logger.info(f"成功获取RAG总文档数: {total_doc}") - return total_doc - except Exception as e: - business_logger.error(f"获取RAG总文档数失败: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_rag_total_chunk( - db: Session, - current_user: User -) -> int: - """ - 根据当前用户所在的workspace_id查询konwledges表所有chunk_num的总和 - """ - workspace_id = current_user.current_workspace_id - business_logger.info(f"获取RAG总chunk数: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - total_chunk = knowledge_repository.get_total_chunk_num_by_workspace(db, workspace_id) - business_logger.info(f"成功获取RAG总chunk数: {total_chunk}") - return total_chunk - except Exception as e: - business_logger.error(f"获取RAG总chunk数失败: workspace_id={workspace_id} - {str(e)}") - raise - - -def get_rag_total_kb( - db: Session, - current_user: User -) -> int: - """ - 根据当前用户所在的workspace_id查询konwledges表所有不同id的数量 - """ - workspace_id = current_user.current_workspace_id - business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}") - - try: - total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id) - business_logger.info(f"成功获取RAG总知识库数: {total_kb}") - return total_kb - except Exception as e: - business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}") - raise - -def get_current_user_total_chunk( - end_user_id: str, - db: Session, - current_user: User -) -> int: - """ - 计算documents表中file_name=='end_user_id'+'.txt'的所有记录chunk_num的总和 - """ - business_logger.info(f"获取用户总chunk数: end_user_id={end_user_id}, 操作者: {current_user.username}") - - try: - from app.models.document_model import Document - from sqlalchemy import func - - # 构造文件名 - file_name = f"{end_user_id}.txt" - - # 查询并求和 - total_chunk = db.query(func.sum(Document.chunk_num)).filter( - Document.file_name == file_name - ).scalar() or 0 - - business_logger.info(f"成功获取用户总chunk数: {total_chunk} (file_name={file_name})") - return int(total_chunk) - - except Exception as e: - business_logger.error(f"获取用户总chunk数失败: end_user_id={end_user_id} - {str(e)}") - raise - -def get_rag_content( - end_user_id: str, - limit: int, - db: Session, - current_user: User -) -> dict: - """ - 先在documents表中查询file_name=='end_user_id'+'.txt'的id和kb_id, - 然后调用/chunks/{kb_id}/{document_id}/chunks接口的相关代码获取所有内容, - 接着对获取的内容进行提取,只要page_content的内容, - 最后返回数据 - """ - business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}") - - try: - from app.models.document_model import Document - from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory - - # 1. 构造文件名 - file_name = f"{end_user_id}.txt" - - # 2. 查询documents表获取id和kb_id - documents = db.query(Document).filter( - Document.file_name == file_name - ).all() - - if not documents: - business_logger.warning(f"未找到文件: {file_name}") - return { - "total": 0, - "contents": [] - } - - business_logger.info(f"找到 {len(documents)} 个文档记录") - - # 3. 获取所有chunks的page_content - all_contents = [] - total_chunks = 0 - - for document in documents: - try: - # 获取知识库信息 - kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id) - if not kb: - business_logger.warning(f"知识库不存在: kb_id={document.kb_id}") - continue - - # 初始化向量服务 - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=kb) - - # 获取该文档的所有chunks(分页获取) - page = 1 - pagesize = 100 # 每页100条 - - while True: - total, items = vector_service.search_by_segment( - document_id=str(document.id), - query=None, - pagesize=pagesize, - page=page, - asc=True - ) - - if not items: - break - - # 提取page_content - for item in items: - all_contents.append(item.page_content) - total_chunks += 1 - - # # 如果达到limit限制,直接返回 - # if limit > 0 and total_chunks >= limit: - # business_logger.info(f"已达到limit限制: {limit}") - # return { - # "total": total_chunks, - # "contents": all_contents[:limit] - # } - - # 检查是否还有下一页 - if page * pagesize >= total: - break - - page += 1 - - business_logger.info(f"文档 {document.id} 获取了 {len(items)} 个chunks") - - except Exception as e: - business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}") - continue - - # 4. 返回结果 - result = { - "total": total_chunks, - "contents": all_contents[:limit] if limit > 0 else all_contents - } - - business_logger.info(f"成功获取RAG内容: total={total_chunks}, 返回={len(result['contents'])} 条") - return result - - except Exception as e: - business_logger.error(f"获取RAG内容失败: end_user_id={end_user_id} - {str(e)}") - raise - - -async def get_chunk_summary_and_tags( - end_user_id: str, - limit: int, - max_tags: int, - db: Session, - current_user: User -) -> dict: - """ - 获取chunk的总结、标签和人物形象 - - Args: - end_user_id: 宿主ID - limit: 返回的chunk数量限制 - max_tags: 最大标签数量 - db: 数据库会话 - current_user: 当前用户 - - Returns: - 包含summary、tags和personas的字典 - """ - business_logger.info(f"获取chunk摘要、标签和人物形象: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}") - - try: - # 1. 获取chunk内容 - rag_content = get_rag_content(end_user_id, limit, db, current_user) - chunks = rag_content.get("contents", []) - - if not chunks: - business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}") - return { - "summary": "暂无内容", - "tags": [], - "personas": [] - } - - # 2. 导入RAG工具函数 - from app.core.rag_utils import generate_chunk_summary, extract_chunk_tags, extract_chunk_persona - - # 3. 并发生成摘要、提取标签和人物形象 - import asyncio - summary_task = generate_chunk_summary(chunks, max_chunks=limit) - tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit) - personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit) - - summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task) - - # 4. 格式化标签数据 - tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq] - - result = { - "summary": summary, - "tags": tags, - "personas": personas - } - - business_logger.info(f"成功获取chunk摘要、{len(tags)} 个标签和 {len(personas)} 个人物形象") - return result - - except Exception as e: - business_logger.error(f"获取chunk摘要、标签和人物形象失败: end_user_id={end_user_id} - {str(e)}") - raise - - -async def get_chunk_insight( - end_user_id: str, - limit: int, - db: Session, - current_user: User -) -> dict: - """ - 获取chunk的洞察分析 - - Args: - end_user_id: 宿主ID - limit: 返回的chunk数量限制 - db: 数据库会话 - current_user: 当前用户 - - Returns: - 包含insight的字典 - """ - business_logger.info(f"获取chunk洞察: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}") - - try: - # 1. 获取chunk内容 - rag_content = get_rag_content(end_user_id, limit, db, current_user) - chunks = rag_content.get("contents", []) - - if not chunks: - business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}") - return { - "insight": "暂无足够数据生成洞察报告" - } - - # 2. 导入RAG工具函数 - from app.core.rag_utils import generate_chunk_insight - - # 3. 生成洞察 - insight = await generate_chunk_insight(chunks, max_chunks=limit) - - result = { - "insight": insight - } - - business_logger.info(f"成功获取chunk洞察") - return result - - except Exception as e: - business_logger.error(f"获取chunk洞察失败: end_user_id={end_user_id} - {str(e)}") - raise \ No newline at end of file diff --git a/app/services/memory_konwledges_server.py b/app/services/memory_konwledges_server.py deleted file mode 100644 index dd9163a7..00000000 --- a/app/services/memory_konwledges_server.py +++ /dev/null @@ -1,582 +0,0 @@ -# 修改 memory_konwledges_server.py 文件 - -import asyncio -import os -import re -import uuid -from pathlib import Path -from typing import Optional - -from pydantic import BaseModel, Field - -from app.core.rag.models.chunk import DocumentChunk -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.response_utils import success -from app.db import get_db -from app.schemas import file_schema, document_schema -from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query -from app.models.document_model import Document -import uuid -from sqlalchemy.orm import Session -from fastapi import HTTPException, status - -from app.core.config import settings -from app.models.user_model import User -from app.schemas.file_schema import CustomTextFileCreate -from app.services import document_service, file_service, knowledge_service -from app.celery_app import celery_app -from app.core.logging_config import get_api_logger -from app.schemas.file_schema import CustomTextFileCreate -from app.db import get_db -# 创建一个简单的用户类用于测试 -api_logger = get_api_logger() - -class ChunkCreate(BaseModel): - content: str -class SimpleUser: - def __init__(self, user_id: str): - # 确保ID是UUID类型 - self.id = user_id - self.username = user_id - -'''解析''' -async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User): - """ - 解析指定文档 - - Args: - document_id: 文档ID - db: 数据库会话 - current_user: 当前用户 - - Returns: - dict: 包含任务ID的结果字典 - - Raises: - HTTPException: 当文档、文件或知识库不存在时抛出异常 - """ - - try: - # 1. 检查文档是否存在 - api_logger.debug(f"检查文档是否存在: {document_id}") - db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user) - - if not db_document: - api_logger.warning(f"文档不存在或无访问权限: document_id={document_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="文档不存在或无访问权限" - ) - - # 2. 检查文件是否存在 - api_logger.debug(f"检查文件是否存在: {db_document.file_id}") - db_file = file_service.get_file_by_id(db, file_id=db_document.file_id) - - if not db_file: - api_logger.warning(f"文件不存在或无访问权限: file_id={db_document.file_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="文件不存在或无访问权限" - ) - - # 3. 构建文件路径:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} - file_path = os.path.join( - settings.FILE_PATH, - str(db_file.kb_id), - str(db_file.parent_id), - f"{db_file.id}{db_file.file_ext}" - ) - - # 4. 检查文件是否存在于磁盘上 - if not os.path.exists(file_path): - api_logger.warning(f"文件未找到(可能已被删除): file_path={file_path}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="文件未找到(可能已被删除)" - ) - - # 5. 获取知识库信息 - api_logger.info(f"获取知识库详情: knowledge_id={db_document.kb_id}") - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, - current_user=current_user) - if not db_knowledge: - api_logger.warning(f"知识库不存在或访问被拒绝: knowledge_id={db_document.kb_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="知识库不存在或访问被拒绝" - ) - - # 6. 发送解析任务到Celery后台队列 - task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id]) - - result = { - "task_id": task.id - } - - api_logger.info(f"文档解析任务已接受: document_id={document_id}, task_id={task.id}") - return result - - except Exception as e: - api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}") - raise - -'''获取块ID''' -async def get_document_chunks( - kb_id: uuid.UUID, - document_id: uuid.UUID, - page: int = 1, - pagesize: int = 20, - keywords: Optional[str] = None, - db: Session = None, - current_user: User = None -): - """ - 分页查询文档块列表 - - Args: - kb_id: 知识库ID - document_id: 文档ID - page: 页码,默认为1 - pagesize: 每页大小,默认为20 - keywords: 用于匹配块内容的关键字 - db: 数据库会话 - current_user: 当前用户 - - Returns: - dict: 包含分页数据的响应结果 - - Raises: - HTTPException: 当知识库不存在或查询失败时抛出异常 - """ - api_logger.info( - f"分页查询文档块列表: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}") - - # 参数验证 - if page < 1 or pagesize < 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="分页参数必须大于0" - ) - - # 获取知识库信息 - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="知识库不存在或访问被拒绝" - ) - - # 执行分页查询 - try: - api_logger.debug(f"开始执行文档块查询") - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - total, items = vector_service.search_by_segment( - document_id=str(document_id), - query=keywords, - pagesize=pagesize, - page=page, - asc=True - ) - api_logger.info(f"文档块查询成功: total={total}, returned={len(items)} records") - except Exception as e: - api_logger.error(f"文档块查询失败: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"查询失败: {str(e)}" - ) - - # 构造响应结果 - result = { - "items": items, - "page": { - "page": page, - "pagesize": pagesize, - "total": total, - "has_next": True if page * pagesize < total else False - } - } - - return success(data=result, msg="文档块列表查询成功") - -'''查找文档ID''' -def find_document_id_by_kb_and_filename( - db: Session, - kb_id: str, - file_name: str -) -> str | None: - """ - 通过 kb_id 和 file_name 在 documents 表中查找对应的 ID - - Args: - db: 数据库会话 - kb_id: 知识库ID - file_name: 文件名 - - Returns: - str | None: 找到的 document ID,未找到返回 None - """ - try: - # 查询 documents 表 - document = db.query(Document).filter( - Document.kb_id == kb_id, - Document.file_name == file_name - ).first() - - if document: - print(f"找到文档: ID={document.id}, kb_id={kb_id}, file_name={file_name}") - return str(document.id) - else: - return None - - except Exception as e: - return None - -'''获取知识库ID''' -def find_documents_by_kb_id( - db: Session, - kb_id: str, - limit: int = 10 -) -> list[dict]: - """ - 通过 kb_id 查找所有相关文档 - - Args: - db: 数据库会话 - kb_id: 知识库ID - limit: 返回结果数量限制 - - Returns: - list[dict]: 文档列表,包含 id, file_name, created_at 等信息 - """ - try: - documents = db.query(Document).filter( - Document.kb_id == kb_id - ).limit(limit).all() - - result = [] - for doc in documents: - result.append({ - "id": str(doc.id), - "file_name": doc.file_name, - "file_ext": doc.file_ext, - "file_size": doc.file_size, - "created_at": doc.created_at.isoformat() if doc.created_at else None, - "status": getattr(doc, 'status', None) - }) - return result - - except Exception as e: - return [] - -''''上传文件''' -async def memory_konwledges_up( - kb_id: str, - parent_id: str, - create_data: file_schema.CustomTextFileCreate, - db: Session = Depends(get_db), - current_user: SimpleUser = None, # 修改为SimpleUser -): - # 如果没有提供current_user,则创建一个默认的 - if current_user is None: - current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60") - - content_bytes = create_data.content.encode('utf-8') - file_size = len(content_bytes) - print(f"file size: {file_size} byte") - - if file_size == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="The content is empty." - ) - - # If the file size exceeds 50MB (50 * 1024 * 1024 bytes) - if file_size > settings.MAX_FILE_SIZE: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit" - ) - - upload_file = file_schema.FileCreate( - kb_id=kb_id, - created_by=current_user.id, # 现在是UUID类型 - parent_id=parent_id, - file_name=f"{create_data.title}.txt", - file_ext=".txt", - file_size=file_size, - ) - db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user) - - # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} - # 使用 settings.FILE_PATH 确保与 parse_document_by_id 一致 - save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id)) - - # 确保目录存在 - Path(save_dir).mkdir(parents=True, exist_ok=True) - - save_path = os.path.join(save_dir, f"{db_file.id}.txt") - - # Save file - with open(save_path, "wb") as f: - f.write(content_bytes) - - # Verify whether the file has been saved successfully - if not os.path.exists(save_path): - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="File save failed" - ) - - # Create a document - create_document_data = document_schema.DocumentCreate( - kb_id=kb_id, - created_by=current_user.id, - file_id=db_file.id, - file_name=db_file.file_name, - file_ext=db_file.file_ext, - file_size=db_file.file_size, - file_meta={}, - parser_id="naive", - parser_config={ - "layout_recognize": "DeepDOC", - "chunk_token_num": 128, - "delimiter": "\n", - "auto_keywords": 0, - "auto_questions": 0, - "html4excel": "false" - } - ) - db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) - - return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") - -'''添加新块''' - - -async def create_document_chunk( - kb_id: uuid.UUID, - document_id: uuid.UUID, - create_data: ChunkCreate, - db: Session, - current_user: User -): - """ - 创建文档块 - - Args: - kb_id: 知识库ID - document_id: 文档ID - create_data: 创建数据 - db: 数据库会话 - current_user: 当前用户 - - Returns: - dict: 包含创建的文档块信息的成功响应 - - Raises: - HTTPException: 当知识库或文档不存在时抛出相应异常 - """ - api_logger.info( - f"创建文档块请求: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}") - - # 1. 获取知识库信息 - db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) - if not db_knowledge: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="知识库不存在或访问被拒绝" - ) - - # 2. 获取文档信息 - db_document = db.query(Document).filter(Document.id == document_id).first() - if not db_document: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="文档不存在或您无访问权限" - ) - - # 3. 初始化向量服务 - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - - # 4. 获取排序ID(处理索引不存在的情况) - sort_id = 0 - try: - total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False) - if items: - sort_id = items[0].metadata["sort_id"] - except Exception as e: - # 如果索引不存在,从 0 开始 - error_msg = str(e) - if "index_not_found_exception" in error_msg or "no such index" in error_msg: - api_logger.warning(f"索引不存在,将从 sort_id=0 开始: {error_msg}") - sort_id = 0 - else: - # 其他错误则抛出 - api_logger.error(f"查询文档块失败: {error_msg}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"查询文档块失败: {error_msg}" - ) - - sort_id = sort_id + 1 - - # 5. 创建文档块 - doc_id = uuid.uuid4().hex - metadata = { - "doc_id": doc_id, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(document_id), - "knowledge_id": str(kb_id), - "sort_id": sort_id, - "status": 1, - } - chunk = DocumentChunk(page_content=create_data.content, metadata=metadata) - - # 6. 存储向量化的文档块(这会自动创建索引如果不存在) - try: - vector_service.add_chunks([chunk]) - except Exception as e: - api_logger.error(f"添加文档块到向量库失败: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"添加文档块到向量库失败: {str(e)}" - ) - - # 7. 更新 chunk_num - db_document.chunk_num += 1 - db.commit() - - return success(data=chunk, msg="文档块创建成功") - -async def write_rag(group_id, message, user_rag_memory_id): - """ - 将消息写入 RAG 知识库 - - Args: - group_id: 组ID,用作文件标题 - message: 消息内容 - user_rag_memory_id: 知识库ID(必须是有效的UUID) - - Returns: - 写入结果 - - Raises: - HTTPException: 当参数无效或操作失败时 - """ - # 验证 user_rag_memory_id 是否为有效的 UUID - if not user_rag_memory_id: - api_logger.error("user_rag_memory_id 为空,无法执行 RAG 写入操作") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="知识库ID不能为空" - ) - - try: - # 尝试将字符串转换为 UUID 以验证格式 - kb_uuid = uuid.UUID(user_rag_memory_id) - except (ValueError, AttributeError) as e: - api_logger.error(f"user_rag_memory_id 不是有效的UUID: {user_rag_memory_id}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"知识库ID格式无效: {user_rag_memory_id}" - ) - - db_gen = get_db() - db = next(db_gen) - - try: - create_data = CustomTextFileCreate(title=group_id, content=message) - current_user = SimpleUser(user_rag_memory_id) - # 检查文档是否已存在 - document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt") - print('======',document) - api_logger.info(f"查找文档结果: document_id={document}") - if document is not None: - # 文档已存在,直接添加新块 - api_logger.info(f"文档已存在,添加新块: document_id={document}") - - create_chunks = ChunkCreate(content=message) - result = await create_document_chunk( - kb_id=kb_uuid, - document_id=uuid.UUID(document), - create_data=create_chunks, - db=db, - current_user=current_user - ) - return result - else: - # 文档不存在,创建新文档 - api_logger.info(f"文档不存在,创建新文档: group_id={group_id}") - result = await memory_konwledges_up( - kb_id=user_rag_memory_id, - parent_id=user_rag_memory_id, - create_data=create_data, - db=db, - current_user=current_user - ) - await parse_document_by_id(document, db=db, current_user=current_user) - return result - finally: - # 确保数据库会话被关闭 - db.close() -# 在异步环境中调用示例 - - -async def example_usage(): - - # 获取数据库会话 - db_gen = get_db() - db = next(db_gen) - - # 创建 CustomTextFileCreate 对象 - title = '2f6ff1eb-50c7-4765-8e89-e4566be19122' - create_data = CustomTextFileCreate( - title=title, - content="莫扎特在巴黎经历母亲去世后返回萨尔茨堡,他随后创作的交响曲主题是否与格鲁克在维也纳推动的“改革歌剧”理念存在共通之处?贝多芬早年曾师从海顿,而海顿又受雇于埃斯特哈齐家族——这种师承体系如何影响了当时欧洲宫廷音乐的传承结构?斯卡拉歌剧院选择萨列里的歌剧作为开幕演出,是否与当时米兰政治环境和奥地利宫廷影响有关?" - ) - - # 创建用户对象 - current_user = SimpleUser("6243c125-9420-402c-bbb5-d1977811ac96") - - # 上传文件 - result = await memory_konwledges_up( - kb_id="c71df60a-36a6-4759-a2ce-101e3087b401", - parent_id="c71df60a-36a6-4759-a2ce-101e3087b401", - create_data=create_data, - db=db, - current_user=current_user - ) - print(result) - #找到document_id - - # 使用刚创建的文档ID进行解析 - document = find_document_id_by_kb_and_filename(db=db, kb_id="c71df60a-36a6-4759-a2ce-101e3087b401", file_name=f"{title}.txt") - print('====',document) - res___=await parse_document_by_id(document, db=db, current_user=current_user) - print(res___) - - # result='e8cf9ace-d1a9-4af2-b0c4-3fc94f4f8042' - # document_id='d22e8173-50d0-4e10-a7de-aa638ef893bc' - # - # '''更新块''' - # - # new_content = "这是新的 chunk 内容,用来覆盖原来的内容" - # # 构造 ChunkUpdate 对象 - # update_data = ChunkCreate(content=new_content) - # updated_chunk = await create_document_chunk( - # kb_id= result, - # document_id=document_id, - # create_data= update_data, - # db=db, - # current_user=current_user - # ) - # print(updated_chunk) - return '','','' - - - -if __name__ == "__main__": - # asyncio.run(example_usage()) - asyncio.run(write_rag('1111','22222',"c71df60a-36a6-4759-a2ce-101e3087b401")) \ No newline at end of file diff --git a/app/services/memory_storage_service.py b/app/services/memory_storage_service.py deleted file mode 100644 index daf041e7..00000000 --- a/app/services/memory_storage_service.py +++ /dev/null @@ -1,568 +0,0 @@ -""" -Memory Storage Service - -Handles business logic for memory storage operations. -""" - -from typing import Dict, List, Optional, Any -import os -import json - -from dotenv import load_dotenv - -from app.core.logging_config import get_logger -from app.schemas.memory_storage_schema import ( - ConfigFilter, - ConfigPilotRun, - ConfigParamsCreate, - ConfigParamsDelete, - ConfigUpdate, - ConfigUpdateExtracted, - ConfigUpdateForget, - ConfigKey, -) -from app.repositories.data_config_repository import DataConfigRepository -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags -from app.core.memory.analytics.memory_insight import MemoryInsight -from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats -from app.core.memory.analytics.user_summary import generate_user_summary -from app.repositories.data_config_repository import DataConfigRepository - -logger = get_logger(__name__) - -# Load environment variables for Neo4j connector -load_dotenv() -_neo4j_connector = Neo4jConnector() - - -class MemoryStorageService: - """Service for memory storage operations""" - - def __init__(self): - logger.info("MemoryStorageService initialized") - - async def get_storage_info(self) -> dict: - """ - Example wrapper method - retrieves storage information - - Args: - - Returns: - Storage information dictionary - """ - logger.info(f"Getting storage info ") - - # Empty wrapper - implement your logic here - result = { - "status": "active", - "message": "This is an example wrapper" - } - - return result - -class DataConfigService: # 数据配置服务类(PostgreSQL) - """Service layer for config params CRUD. - - The DB connection is optional; when absent, methods return a failure - response containing an SQL preview to aid integration. - """ - - def __init__(self, db_conn: Optional[Any] = None) -> None: - self.db_conn = db_conn - - # --- Driver compatibility helpers --- - @staticmethod - def _is_pgsql_conn(conn: Any) -> bool: # 判断是否为 PostgreSQL 连接 - mod = type(conn).__module__ - return ("psycopg2" in mod) or ("psycopg" in mod) - - @staticmethod - def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式""" - from datetime import datetime - - for item in data_list: - for field in ['created_at', 'updated_at']: - if field in item and item[field] is not None: - value = item[field] - dt = None - - # 如果是 datetime 对象,直接使用 - if isinstance(value, datetime): - dt = value - # 如果是字符串,先解析 - elif isinstance(value, str): - try: - dt = datetime.fromisoformat(value.replace('Z', '+00:00')) - except Exception: - pass # 保持原值 - - # 转换为 YYYYMMDDHHmmss 格式 - if dt: - item[field] = dt.strftime('%Y%m%d%H%M%S') - - return data_list - - # --- Create --- - def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - # 如果workspace_id存在且模型字段未全部指定,则自动获取 - if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]): - configs = self._get_workspace_configs(params.workspace_id) - if configs is None: - raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}") - - # 只在未指定时填充(允许手动覆盖) - if not params.llm_id: - params.llm_id = configs.get('llm') - if not params.embedding_id: - params.embedding_id = configs.get('embedding') - if not params.rerank_id: - params.rerank_id = configs.get('rerank') - - query, qparams = DataConfigRepository.build_insert(params) - cur = self.db_conn.cursor() - # PostgreSQL 使用 psycopg2 的命名参数格式 - cur.execute(query, qparams) - self.db_conn.commit() - return {"affected": getattr(cur, "rowcount", None)} - - def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]: - """获取工作空间模型配置(内部方法,便于测试)""" - from app.db import SessionLocal - from app.repositories.workspace_repository import get_workspace_models_configs - - db_session = SessionLocal() - try: - return get_workspace_models_configs(db_session, workspace_id) - finally: - db_session.close() - - # --- Delete --- - def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置名称) - query, qparams = DataConfigRepository.build_delete(key) - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - cur = self.db_conn.cursor() - cur.execute(query, qparams) - affected = getattr(cur, "rowcount", None) - self.db_conn.commit() - # 如果没有任何行被删除,抛出异常 - if not affected: - raise ValueError("未找到配置") - return {"affected": affected} - - # --- Update --- - def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 - query, qparams = DataConfigRepository.build_update(update) - - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - cur = self.db_conn.cursor() - cur.execute(query, qparams) - affected = getattr(cur, "rowcount", None) - self.db_conn.commit() - if not affected: - raise ValueError("未找到配置") - return {"affected": affected} - - - - def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 - query, qparams = DataConfigRepository.build_update_extracted(update) - - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - cur = self.db_conn.cursor() - cur.execute(query, qparams) - affected = getattr(cur, "rowcount", None) - self.db_conn.commit() - if not affected: - raise ValueError("未找到配置") - return {"affected": affected} - - - # --- Forget config params --- - def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置 - query, qparams = DataConfigRepository.build_update_forget(update) - - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - cur = self.db_conn.cursor() - cur.execute(query, qparams) - affected = getattr(cur, "rowcount", None) - self.db_conn.commit() - if not affected: - raise ValueError("未找到配置") - return {"affected": affected} - - # --- Read --- - def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数 - query, qparams = DataConfigRepository.build_select_extracted(key) - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - cur = self.db_conn.cursor() - cur.execute(query, qparams) - row = cur.fetchone() - if not row: - raise ValueError("未找到配置") - # Map row to dict (DB-API cursor description available for many drivers) - columns = [desc[0] for desc in cur.description] - raw = {columns[i]: row[i] for i in range(len(columns))} - # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 - data_list = self._convert_timestamps_to_format([raw]) - return data_list[0] if data_list else raw - - def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数 - query, qparams = DataConfigRepository.build_select_forget(key) - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - cur = self.db_conn.cursor() - cur.execute(query, qparams) - row = cur.fetchone() - if not row: - raise ValueError("未找到配置") - # Map row to dict (DB-API cursor description available for many drivers) - columns = [desc[0] for desc in cur.description] - raw = {columns[i]: row[i] for i in range(len(columns))} - # 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 - data_list = self._convert_timestamps_to_format([raw]) - return data_list[0] if data_list else raw - - # --- Read All --- - def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 - query, qparams = DataConfigRepository.build_select_all(workspace_id) - if self.db_conn is None: - raise ConnectionError("数据库连接未配置") - - cur = self.db_conn.cursor() - cur.execute(query, qparams) - rows = cur.fetchall() - # 如果没有查询到任何配置,返回空列表(这是正常情况,不应抛出异常) - if not rows: - return [] - # Map rows to list of dicts - columns = [desc[0] for desc in cur.description] - data_list = [dict(zip(columns, row)) for row in rows] - # 将 UUID 转换为字符串,将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式 - for item in data_list: - if 'workspace_id' in item and item['workspace_id'] is not None: - item['workspace_id'] = str(item['workspace_id']) - return self._convert_timestamps_to_format(data_list) - - - async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]: - """ - 选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。 - 支持 dialogue_text 参数用于试运行模式。 - """ - project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json") - - payload_cid = str(getattr(payload, "config_id", "") or "").strip() - cid: Optional[str] = payload_cid if payload_cid else None - - if not cid and os.path.isfile(dbrun_path): - try: - with open(dbrun_path, "r", encoding="utf-8") as f: - dbrun = json.load(f) - if isinstance(dbrun, dict): - sel = dbrun.get("selections", {}) - if isinstance(sel, dict): - fallback_cid = str(sel.get("config_id") or "").strip() - cid = fallback_cid or None - except Exception: - cid = None - - if not cid: - raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行") - - # 验证 dialogue_text 必须提供 - dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else "" - logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}") - if not dialogue_text: - raise ValueError("试运行模式必须提供 dialogue_text 参数") - - # 应用内存覆写并刷新常量(在导入主管线前) - # 注意:仅在内存中覆写配置,不修改 runtime.json 文件 - from app.core.memory.utils.config.definitions import reload_configuration_from_database - - ok_override = reload_configuration_from_database(cid) - if not ok_override: - raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败") - - # 导入并 await 主管线(使用当前 ASGI 事件循环) - from app.core.memory.main import main as pipeline_main - from app.core.memory.utils.self_reflexion_utils import reflexion - - logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True") - await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True) - logger.info("[PILOT_RUN] pipeline_main completed") - - # 调用自我反思 - # data = [ - # { - # "data": { - # "id": "1", - # "statement": "张明现在在谷歌工作。", - # "group_id": "1", - # "chunk_id": "10", - # "created_at": "2023-01-01", - # "expired_at": "2023-01-02", - # "valid_at": "2023-01-01", - # "invalid_at": "2023-01-02", - # "entity_ids": [] - # }, - # "conflict": True, - # "conflict_memory": { - # "id": "1", - # "statement": "张明现在在清华大学当讲师。", - # "group_id": "1", - # "chunk_id": "1", - # "created_at": "2019-12-01T19:15:05.213210", - # "expired_at": None, - # "valid_at": None, - # "invalid_at": None, - # "entity_ids": [] - # } - # } - # ] - from app.core.memory.utils.config.get_example_data import get_example_data - data = get_example_data() - reflexion_result = await reflexion(data) - - # 读取输出,使用全局配置路径 - from app.core.config import settings - result_path = settings.get_memory_output_path("extracted_result.json") - if not os.path.isfile(result_path): - raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}") - - with open(result_path, "r", encoding="utf-8") as rf: - extracted_result = json.load(rf) - - extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None - return { - "config_id": cid, - "time_log": os.path.join(project_root, "time.log"), - "extracted_result": extracted_result, - } - - -# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) -------------------- -# Ensure env for connector (e.g., NEO4J_PASSWORD) -load_dotenv() -_neo4j_connector = Neo4jConnector() - - -async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_DIALOGUE, - group_id=end_user_id, - ) - data = {"search_for": "dialogue", "num": result[0]["num"]} - return data - - -async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]: - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_CHUNK, - group_id=end_user_id, - ) - data = {"search_for": "chunk", "num": result[0]["num"]} - return data - - -async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]: - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_STATEMENT, - group_id=end_user_id, - ) - data = {"search_for": "statement", "num": result[0]["num"]} - return data - - -async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ENTITY, - group_id=end_user_id, - ) - data = {"search_for": "entity", "num": result[0]["num"]} - return data - - -async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]: - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ALL, - group_id=end_user_id, - ) - - # 检查结果是否为空或长度不足 - if not result or len(result) < 4: - data = { - "total": 0, - "counts": { - "dialogue": 0, - "chunk": 0, - "statement": 0, - "entity": 0, - }, - } - return data - - data = { - "total": result[-1]["Count"], - "counts": { - "dialogue": result[0]["Count"], - "chunk": result[1]["Count"], - "statement": result[2]["Count"], - "entity": result[3]["Count"], - }, - } - return data - - -async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]: - """统一知识库类型分布接口。 - - 聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。 - """ - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ALL, - group_id=end_user_id, - ) - - # 检查结果是否为空或长度不足 - if not result or len(result) < 4: - data = { - "total": 0, - "distribution": [ - {"type": "dialogue", "count": 0}, - {"type": "chunk", "count": 0}, - {"type": "statement", "count": 0}, - {"type": "entity", "count": 0}, - ] - } - return data - - total = result[-1]["Count"] - distribution = [ - {"type": "dialogue", "count": result[0]["Count"]}, - {"type": "chunk", "count": result[1]["Count"]}, - {"type": "statement", "count": result[2]["Count"]}, - {"type": "entity", "count": result[3]["Count"]}, - ] - - data = {"total": total, "distribution": distribution} - return data - - -async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_DETIALS, - group_id=end_user_id, - ) - return result - - -async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_EDGES, - group_id=end_user_id, - ) - return result - - -async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]: - """搜索所有实体之间的关系网络(group 维度)。""" - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH, - group_id=end_user_id, - ) - # 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”) - for item in result: - source_fact = item["sourceNode"]["fact_summary"] - target_fact = item["targetNode"]["fact_summary"] - # 截取前三条“来源” - item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else [] - item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else [] - # 与现有返回风格保持一致,携带搜索类型、数量与详情 - data = { - "search_for": "entity_graph", - "num": len(result), - "detials": result, - } - return data - - -async def analytics_hot_memory_tags(end_user_id: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]: - """ - 获取热门记忆标签,按数量排序并返回前N个 - """ - # 获取更多标签供LLM筛选(获取limit*4个标签) - raw_limit = limit * 4 - tags = await get_hot_memory_tags(end_user_id, limit=raw_limit) - - # 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序) - sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True) - - # 只返回前limit个 - top_tags = sorted_tags[:limit] - - return [{"name": t, "frequency": f} for t, f in top_tags] - - -async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]: - insight = MemoryInsight(end_user_id) - report = await insight.generate_insight_report() - await insight.close() - data = {"report": report} - return data - - -async def analytics_recent_activity_stats() -> Dict[str, Any]: - stats, _msg = get_recent_activity_stats() - total = ( - stats.get("chunk_count", 0) - + stats.get("statements_count", 0) - + stats.get("triplet_entities_count", 0) - + stats.get("triplet_relations_count", 0) - + stats.get("temporal_count", 0) - ) - # 精简:仅提供“最新一次活动多久前” - latest_relative = None - try: - info = stats.get("log_path", "") - idx = info.rfind("最新:") - if idx != -1: - latest_path = info[idx + 3 :].strip() - if latest_path and os.path.exists(latest_path): - import time - diff = max(0.0, time.time() - os.path.getmtime(latest_path)) - m = int(diff // 60) - if m < 1: - latest_relative = "刚刚" - elif m < 60: - latest_relative = f"{m}分钟前" - else: - h = int(m // 60) - latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前" - except Exception: - pass - - data = {"total": total, "stats": stats, "latest_relative": latest_relative} - return data - - -async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]: - summary = await generate_user_summary(end_user_id) - data = {"summary": summary} - return data \ No newline at end of file diff --git a/app/services/model_parameter_merger.py b/app/services/model_parameter_merger.py deleted file mode 100644 index 25506f1c..00000000 --- a/app/services/model_parameter_merger.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -模型参数合并器 - -用于合并 ModelConfig 和 AgentConfig 中的模型参数, -AgentConfig 中的参数优先级更高,可以覆盖 ModelConfig 的默认参数。 -""" -from typing import Dict, Any, Optional -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class ModelParameterMerger: - """模型参数合并器""" - - @staticmethod - def merge_parameters( - model_config_params: Optional[Dict[str, Any]], - agent_config_params: Optional[Dict[str, Any]] - ) -> Dict[str, Any]: - """ - 合并模型配置参数和 Agent 配置参数 - - 优先级:agent_config_params > model_config_params > 默认值 - - Args: - model_config_params: ModelConfig.config 中的参数 - agent_config_params: AgentConfig.model_parameters 中的参数 - - Returns: - 合并后的参数字典 - - Example: - >>> model_params = {"temperature": 0.5, "max_tokens": 1000} - >>> agent_params = {"temperature": 0.8} - >>> merged = ModelParameterMerger.merge_parameters(model_params, agent_params) - >>> merged - {"temperature": 0.8, "max_tokens": 1000} - """ - # 默认参数 - default_params = { - "temperature": 0.7, - "max_tokens": 2000, - "top_p": 1.0, - "frequency_penalty": 0.0, - "presence_penalty": 0.0, - "n": 1, - "stop": None - } - - # 合并参数:默认值 -> 模型配置 -> Agent 配置 - merged = default_params.copy() - - # 应用模型配置参数 - if model_config_params: - for key in default_params.keys(): - if key in model_config_params: - merged[key] = model_config_params[key] - - # 应用 Agent 配置参数(优先级最高) - if agent_config_params: - for key in default_params.keys(): - if key in agent_config_params and agent_config_params[key] is not None: - merged[key] = agent_config_params[key] - - # 移除 None 值 - merged = {k: v for k, v in merged.items() if v is not None} - - logger.debug( - f"参数合并完成", - extra={ - "model_params": model_config_params, - "agent_params": agent_config_params, - "merged": merged - } - ) - - return merged - - @staticmethod - def get_effective_parameters( - model_config: Optional[Any], - agent_config: Optional[Any] - ) -> Dict[str, Any]: - """ - 获取有效的模型参数(从 ORM 对象中提取并合并) - - Args: - model_config: ModelConfig ORM 对象 - agent_config: AgentConfig ORM 对象 - - Returns: - 合并后的参数字典 - """ - # 提取模型配置参数 - model_params = None - if model_config and hasattr(model_config, 'config'): - model_params = model_config.config - - # 提取 Agent 配置参数 - agent_params = None - if agent_config and hasattr(agent_config, 'model_parameters'): - agent_params = agent_config.model_parameters - - return ModelParameterMerger.merge_parameters(model_params, agent_params) - - @staticmethod - def format_for_llm_call(parameters: Dict[str, Any]) -> Dict[str, Any]: - """ - 格式化参数用于 LLM API 调用 - - 不同的 LLM 提供商可能需要不同的参数格式, - 这个方法可以根据需要进行转换。 - - Args: - parameters: 合并后的参数字典 - - Returns: - 格式化后的参数字典 - """ - # 基本格式化(可以根据不同提供商扩展) - formatted = parameters.copy() - - # 确保参数在有效范围内 - if "temperature" in formatted: - formatted["temperature"] = max(0.0, min(2.0, formatted["temperature"])) - - if "max_tokens" in formatted: - formatted["max_tokens"] = max(1, min(32000, formatted["max_tokens"])) - - if "top_p" in formatted: - formatted["top_p"] = max(0.0, min(1.0, formatted["top_p"])) - - if "frequency_penalty" in formatted: - formatted["frequency_penalty"] = max(-2.0, min(2.0, formatted["frequency_penalty"])) - - if "presence_penalty" in formatted: - formatted["presence_penalty"] = max(-2.0, min(2.0, formatted["presence_penalty"])) - - if "n" in formatted: - formatted["n"] = max(1, min(10, formatted["n"])) - - return formatted - - -def merge_model_parameters( - model_config_params: Optional[Dict[str, Any]], - agent_config_params: Optional[Dict[str, Any]] -) -> Dict[str, Any]: - """ - 合并模型参数的便捷函数 - - Args: - model_config_params: ModelConfig.config 中的参数 - agent_config_params: AgentConfig.model_parameters 中的参数 - - Returns: - 合并后的参数字典 - """ - return ModelParameterMerger.merge_parameters(model_config_params, agent_config_params) diff --git a/app/services/model_service.py b/app/services/model_service.py deleted file mode 100644 index b6fb0560..00000000 --- a/app/services/model_service.py +++ /dev/null @@ -1,409 +0,0 @@ -from sqlalchemy.orm import Session -from typing import List, Optional, Dict, Any -import uuid -import math -import time -import asyncio - -from app.models.models_model import ModelConfig, ModelApiKey, ModelType -from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository -from app.schemas import model_schema -from app.schemas.model_schema import ( - ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, - ModelConfigQuery, ModelStats -) -from app.core.logging_config import get_business_logger -from app.schemas.response_schema import PageData, PageMeta -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - -logger = get_business_logger() - - -class ModelConfigService: - """模型配置服务""" - - @staticmethod - def get_model_by_id(db: Session, model_id: uuid.UUID) -> ModelConfig: - """根据ID获取模型配置""" - model = ModelConfigRepository.get_by_id(db, model_id) - if not model: - raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - return model - - @staticmethod - def get_model_list(db: Session, query: ModelConfigQuery) -> PageData: - """获取模型配置列表""" - models, total = ModelConfigRepository.get_list(db, query) - pages = math.ceil(total / query.pagesize) if total > 0 else 0 - - return PageData( - page=PageMeta( - page=query.page, - pagesize=query.pagesize, - total=total, - hasnext=query.page < pages - ), - items=[model_schema.ModelConfig.model_validate(model) for model in models] - ) - - @staticmethod - def get_model_by_name(db: Session, name: str) -> ModelConfig: - """根据名称获取模型配置""" - model = ModelConfigRepository.get_by_name(db, name) - if not model: - raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - return model - - @staticmethod - def search_models_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]: - """按名称模糊匹配获取模型配置列表""" - return ModelConfigRepository.search_by_name(db, name, limit) - - @staticmethod - async def validate_model_config( - db: Session, - *, - model_name: str, - provider: str, - api_key: str, - api_base: Optional[str] = None, - model_type: str = "llm", - test_message: str = "Hello" - ) -> Dict[str, Any]: - """验证模型配置是否有效 - - Args: - db: 数据库会话 - model_name: 模型名称 - provider: 提供商 - api_key: API密钥 - api_base: API基础URL - model_type: 模型类型 (llm/chat/embedding/rerank) - test_message: 测试消息 - - Returns: - Dict: 验证结果 - """ - from app.core.models import RedBearLLM, RedBearRerank - from app.core.models.base import RedBearModelConfig - from app.core.models.embedding import RedBearEmbeddings - import traceback - - try: - start_time = time.time() - - model_config = RedBearModelConfig( - model_name=model_name, - provider=provider, - api_key=api_key, - base_url=api_base, - temperature=0.7, - max_tokens=100 - ) - - # 根据模型类型选择不同的验证方式 - model_type_lower = model_type.lower() - - if model_type_lower in ["llm", "chat"]: - # LLM/Chat 模型验证 - 统一使用字符串输入 - llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT) - response = await llm.ainvoke(test_message) - elapsed_time = time.time() - start_time - - content = response.content if hasattr(response, 'content') else str(response) - usage = None - if hasattr(response, 'usage_metadata'): - usage = { - "input_tokens": getattr(response.usage_metadata, 'input_tokens', 0), - "output_tokens": getattr(response.usage_metadata, 'output_tokens', 0), - "total_tokens": getattr(response.usage_metadata, 'total_tokens', 0) - } - - return { - "valid": True, - "message": f"{model_type.upper()} 模型配置验证成功", - "response": content, - "elapsed_time": elapsed_time, - "usage": usage, - "error": None - } - - elif model_type_lower == "embedding": - # Embedding 模型验证(在线程中运行同步方法) - embedding = RedBearEmbeddings(model_config) - test_texts = [test_message, "测试文本"] - vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) - elapsed_time = time.time() - start_time - - return { - "valid": True, - "message": "Embedding 模型配置验证成功", - "response": f"成功生成 {len(vectors)} 个向量,维度: {len(vectors[0]) if vectors else 0}", - "elapsed_time": elapsed_time, - "usage": { - "input_tokens": len(test_message), - "vector_count": len(vectors), - "vector_dimension": len(vectors[0]) if vectors else 0 - }, - "error": None - } - - elif model_type_lower == "rerank": - # Rerank 模型验证(在线程中运行同步方法) - rerank = RedBearRerank(model_config) - query = test_message - documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"] - results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3) - elapsed_time = time.time() - start_time - - return { - "valid": True, - "message": "Rerank 模型配置验证成功", - "response": f"成功对 {len(documents)} 个文档进行重排序,返回 top {len(results) if results else 0} 结果", - "elapsed_time": elapsed_time, - "usage": { - "query_length": len(query), - "document_count": len(documents), - "result_count": len(results) if results else 0 - }, - "error": None - } - - else: - return { - "valid": False, - "message": "不支持的模型类型", - "response": None, - "elapsed_time": None, - "usage": None, - "error": f"不支持的模型类型: {model_type}" - } - - except Exception as e: - # 提取详细的错误信息 - error_message = str(e) - error_type = type(e).__name__ - print("=========error_message:",error_message.lower()) - # 特殊处理常见的错误类型 - if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): - # 区域/国家限制(适用于所有提供商) - error_message = "区域限制: 该模型在当前区域或国家/地区不可用,请检查提供商的服务区域限制" - elif "ValidationException" in error_type or "ValidationException" in error_message: - # 其他验证错误 - if "access denied" in error_message.lower(): - error_message = "访问被拒绝: 请检查 API 凭证和权限配置" - else: - error_message = f"验证失败: {error_message}" - elif "AuthenticationError" in error_type or "authentication" in error_message.lower(): - error_message = "认证失败: API Key 无效或已过期" - elif "RateLimitError" in error_type or "rate limit" in error_message.lower(): - error_message = "请求频率限制: 已超过 API 调用限制" - elif "InvalidRequestError" in error_type or "invalid request" in error_message.lower(): - error_message = f"无效请求: {error_message}" - elif "model_copy" in error_message: - error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)" - - # 记录详细错误日志 - logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}") - logger.error(f"错误详情: {error_message}") - logger.debug(f"完整堆栈: {traceback.format_exc()}") - - return { - "valid": False, - "message": f"{model_type.upper()} 模型配置验证失败", - "response": None, - "elapsed_time": None, - "usage": None, - "error": error_message, - "error_type": error_type - } - - @staticmethod - async def create_model(db: Session, model_data: ModelConfigCreate) -> ModelConfig: - """创建模型配置""" - # 检查名称是否已存在 - if ModelConfigRepository.get_by_name(db, model_data.name): - raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - - # 验证配置 - if not model_data.skip_validation and model_data.api_keys: - api_key_data = model_data.api_keys - validation_result = await ModelConfigService.validate_model_config( - db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, - api_key=api_key_data.api_key, - api_base=api_key_data.api_base, - model_type=model_data.type, # 传递模型类型 - test_message="Hello" - ) - if not validation_result["valid"]: - raise BusinessException( - f"模型配置验证失败: {validation_result['error']}", - BizCode.INVALID_PARAMETER - ) - - # 事务处理 - api_key_data = model_data.api_keys - model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"}) - - model = ModelConfigRepository.create(db, model_config_data) - db.flush() # 获取生成的 ID - - if api_key_data: - api_key_create_schema = ModelApiKeyCreate( - model_config_id=model.id, - **api_key_data.dict() - ) - ModelApiKeyRepository.create(db, api_key_create_schema) - - db.commit() - db.refresh(model) - return model - - @staticmethod - def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> ModelConfig: - """更新模型配置""" - existing_model = ModelConfigRepository.get_by_id(db, model_id) - if not existing_model: - raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - - if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name): - raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - - model = ModelConfigRepository.update(db, model_id, model_data) - db.commit() - db.refresh(model) - return model - - @staticmethod - def delete_model(db: Session, model_id: uuid.UUID) -> bool: - """删除模型配置""" - if not ModelConfigRepository.get_by_id(db, model_id): - raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - - success = ModelConfigRepository.delete(db, model_id) - db.commit() - return success - - @staticmethod - def get_model_stats(db: Session) -> ModelStats: - """获取模型统计信息""" - stats_data = ModelConfigRepository.get_stats(db) - return ModelStats( - total_models=stats_data["total_models"], - active_models=stats_data["active_models"], - llm_count=stats_data["llm_count"], - embedding_count=stats_data["embedding_count"], - rerank_count=stats_data["rerank_count"], - provider_stats=stats_data["provider_stats"] - ) - - -class ModelApiKeyService: - """模型API Key服务""" - - @staticmethod - def get_api_key_by_id(db: Session, api_key_id: uuid.UUID) -> ModelApiKey: - """根据ID获取API Key""" - api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) - if not api_key: - raise BusinessException("API Key不存在", BizCode.NOT_FOUND) - return api_key - - @staticmethod - def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]: - """根据模型配置ID获取API Key列表""" - if not ModelConfigRepository.get_by_id(db, model_config_id): - raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - - return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active) - - @staticmethod - async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey: - """创建API Key""" - model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id) - if not model_config: - raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - - validation_result = await ModelConfigService.validate_model_config( - db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, - api_key=api_key_data.api_key, - api_base=api_key_data.api_base, - model_type=model_config.type, # 传递模型类型 - test_message="Hello" - ) - print(validation_result) - if not validation_result["valid"]: - raise BusinessException( - f"模型配置验证失败: {validation_result['error']}", - BizCode.INVALID_PARAMETER - ) - - api_key = ModelApiKeyRepository.create(db, api_key_data) - db.commit() - db.refresh(api_key) - return api_key - - @staticmethod - async def update_api_key(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> ModelApiKey: - """更新API Key""" - existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) - if not existing_api_key: - raise BusinessException("API Key不存在", BizCode.NOT_FOUND) - - # 获取关联的模型配置以获取模型类型 - model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id) - if not model_config: - raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND) - - validation_result = await ModelConfigService.validate_model_config( - db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, - api_key=api_key_data.api_key, - api_base=api_key_data.api_base, - model_type=model_config.type, # 传递模型类型 - test_message="Hello" - ) - print(validation_result) - if not validation_result["valid"]: - raise BusinessException( - f"模型配置验证失败: {validation_result['error']}", - BizCode.INVALID_PARAMETER - ) - - api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data) - db.commit() - db.refresh(api_key) - return api_key - - @staticmethod - def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool: - """删除API Key""" - if not ModelApiKeyRepository.get_by_id(db, api_key_id): - raise BusinessException("API Key不存在", BizCode.NOT_FOUND) - - success = ModelApiKeyRepository.delete(db, api_key_id) - db.commit() - return success - - @staticmethod - def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]: - """获取可用的API Key(按优先级和负载均衡)""" - api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True) - if not api_keys: - return None - return min(api_keys, key=lambda x: int(x.usage_count or "0")) - - @staticmethod - def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool: - """记录API Key使用""" - success = ModelApiKeyRepository.update_usage(db, api_key_id) - if success: - db.commit() - return success diff --git a/app/services/multi_agent_config_converter.py b/app/services/multi_agent_config_converter.py deleted file mode 100644 index fd26e586..00000000 --- a/app/services/multi_agent_config_converter.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -多智能体配置格式转换器 -用于将 Pydantic 模型转换为数据库存储格式 -""" -from typing import Dict, Any, Optional, List -import uuid -from app.schemas.multi_agent_schema import ( - SubAgentConfig, - RoutingRule, - ExecutionConfig, - MultiAgentConfigCreate, - MultiAgentConfigUpdate, -) - - -class MultiAgentConfigConverter: - """多智能体配置格式转换器""" - - @staticmethod - def to_storage_format(config: MultiAgentConfigCreate | MultiAgentConfigUpdate) -> Dict[str, Any]: - """ - 将配置对象转换为数据库存储格式 - - Args: - config: MultiAgentConfigCreate 或 MultiAgentConfigUpdate 对象 - - Returns: - 包含数据库字段的字典 - """ - result = {} - - # 1. 子 Agent 配置 - if hasattr(config, 'sub_agents') and config.sub_agents: - result["sub_agents"] = [ - MultiAgentConfigConverter._convert_uuid_to_str(agent.model_dump()) - for agent in config.sub_agents - ] - - # 2. 路由规则配置 - if hasattr(config, 'routing_rules') and config.routing_rules: - result["routing_rules"] = [ - MultiAgentConfigConverter._convert_uuid_to_str(rule.model_dump()) - for rule in config.routing_rules - ] - - # 3. 执行配置 - if hasattr(config, 'execution_config') and config.execution_config: - result["execution_config"] = MultiAgentConfigConverter._convert_uuid_to_str( - config.execution_config.model_dump() - ) - - return result - - @staticmethod - def from_storage_format( - sub_agents: Optional[List[Dict[str, Any]]], - routing_rules: Optional[List[Dict[str, Any]]], - execution_config: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: - """ - 将数据库存储格式转换为 Pydantic 对象 - - Args: - sub_agents: 子 Agent 配置列表 - routing_rules: 路由规则配置列表 - execution_config: 执行配置 - - Returns: - 包含 Pydantic 对象的字典 - """ - result = { - "sub_agents": [], - "routing_rules": [], - "execution_config": None, - } - - # 1. 解析子 Agent 配置 - if sub_agents: - result["sub_agents"] = [ - SubAgentConfig(**agent_data) - for agent_data in sub_agents - ] - - # 2. 解析路由规则配置 - if routing_rules: - result["routing_rules"] = [ - RoutingRule(**rule_data) - for rule_data in routing_rules - ] - else: - # 提供默认的空路由规则 - result["routing_rules"] = [] - - # 3. 解析执行配置 - if execution_config: - result["execution_config"] = ExecutionConfig(**execution_config) - else: - # 提供默认的执行配置 - result["execution_config"] = ExecutionConfig( - max_iterations=10, - timeout=300, - enable_parallel=False, - error_handling="stop" - ) - - return result - - @staticmethod - def _convert_uuid_to_str(obj: Any) -> Any: - """ - 递归转换对象中的所有 UUID 为字符串 - - Args: - obj: 要转换的对象(dict, list, UUID 等) - - Returns: - 转换后的对象 - """ - if isinstance(obj, uuid.UUID): - return str(obj) - elif isinstance(obj, dict): - return {k: MultiAgentConfigConverter._convert_uuid_to_str(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [MultiAgentConfigConverter._convert_uuid_to_str(item) for item in obj] - else: - return obj - - @staticmethod - def enrich_with_published_configs( - sub_agents: List[Dict[str, Any]], - get_published_config_func - ) -> List[Dict[str, Any]]: - """ - 为子 Agent 配置添加发布的 config_id - - Args: - sub_agents: 子 Agent 配置列表 - get_published_config_func: 获取发布配置的函数 - - Returns: - 增强后的子 Agent 配置列表 - """ - enriched_agents = [] - - for agent in sub_agents: - agent_copy = agent.copy() - - # 获取该 Agent 当前发布的配置 - if 'agent_id' in agent: - try: - agent_id = uuid.UUID(agent['agent_id']) if isinstance(agent['agent_id'], str) else agent['agent_id'] - published_config = get_published_config_func(agent_id) - - if published_config: - agent_copy['published_config_id'] = str(published_config.get('id')) if isinstance(published_config, dict) else None - except Exception as e: - # 如果获取失败,记录但不中断 - from app.core.logging_config import get_business_logger - logger = get_business_logger() - logger.warning(f"获取 Agent {agent.get('agent_id')} 的发布配置失败: {e}") - - enriched_agents.append(agent_copy) - - return enriched_agents - - @staticmethod - def create_default_template(app_id: uuid.UUID) -> Dict[str, Any]: - """ - 创建默认的多智能体配置模板 - - Args: - app_id: 应用 ID - - Returns: - 默认配置模板 - """ - return { - "app_id": str(app_id), - "master_agent_id": None, - "orchestration_mode": "sequential", - "sub_agents": [], - "routing_rules": [], - "execution_config": { - "max_iterations": 10, - "timeout": 300, - "enable_parallel": False, - "error_handling": "stop" - }, - "aggregation_strategy": "last", - "is_active": False - } diff --git a/app/services/multi_agent_orchestrator.py b/app/services/multi_agent_orchestrator.py deleted file mode 100644 index b62ab690..00000000 --- a/app/services/multi_agent_orchestrator.py +++ /dev/null @@ -1,1116 +0,0 @@ -"""多 Agent 编排器""" -import uuid -import time -import asyncio -from typing import Dict, Any, List, Optional -from sqlalchemy.orm import Session - -from app.models import MultiAgentConfig, AgentConfig, ModelConfig -from app.services.agent_registry import AgentRegistry -from app.services.llm_router import LLMRouter -from app.services.conversation_state_manager import ConversationStateManager -from app.core.exceptions import BusinessException, ResourceNotFoundException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class MultiAgentOrchestrator: - """多 Agent 编排器 - 协调多个 Agent 协作完成任务""" - - def __init__(self, db: Session, config: MultiAgentConfig): - """初始化编排器 - - Args: - db: 数据库会话 - config: 多 Agent 配置 - """ - self.db = db - self.config = config - self.registry = AgentRegistry(db) - - # 加载主 Agent - self.master_agent = self._load_agent(config.master_agent_id) - - # 加载子 Agent - self.sub_agents = {} - for sub_agent_info in config.sub_agents: - agent_id = uuid.UUID(sub_agent_info["agent_id"]) - agent = self._load_agent(agent_id) - self.sub_agents[str(agent_id)] = { - "config": agent, - "info": sub_agent_info - } - - # 初始化 LLM 路由器(使用主 Agent 的模型) - self.llm_router = None - if self.master_agent and hasattr(self.master_agent, 'default_model_config_id'): - routing_model = self.db.get(ModelConfig, self.master_agent.default_model_config_id) - if routing_model: - state_manager = ConversationStateManager() - self.llm_router = LLMRouter( - db=db, - state_manager=state_manager, - routing_rules=config.routing_rules or [], - sub_agents=self.sub_agents, - routing_model_config=routing_model, - use_llm=True - ) - logger.info( - f"LLM 路由器已初始化(使用主 Agent 模型)", - extra={ - "routing_model": routing_model.name, - "routing_model_id": str(routing_model.id) - } - ) - - logger.info( - f"多 Agent 编排器初始化", - extra={ - "config_id": str(config.id), - "mode": config.orchestration_mode, - "sub_agent_count": len(self.sub_agents), - "has_llm_router": self.llm_router is not None - } - ) - - async def execute_stream( - self, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - use_llm_routing: bool = True, - web_search: bool = True, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ): - """执行多 Agent 任务(流式返回) - - Args: - message: 用户消息 - conversation_id: 会话 ID - user_id: 用户 ID - variables: 变量参数 - use_llm_routing: 是否使用 LLM 路由 - - Yields: - SSE 格式的事件流 - """ - import json - - start_time = time.time() - - logger.info( - f"开始执行多 Agent 任务(流式)", - extra={ - "mode": self.config.orchestration_mode, - "message_length": len(message) - } - ) - - try: - # 发送开始事件 - yield self._format_sse_event("start", { - "mode": self.config.orchestration_mode, - "timestamp": time.time() - }) - - # 1. 主 Agent 分析任务 - task_analysis = await self._analyze_task(message, variables) - task_analysis["use_llm_routing"] = use_llm_routing - - # 2. 根据模式执行(流式) - if self.config.orchestration_mode == "conditional": - async for event in self._execute_conditional_stream( - task_analysis, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ): - yield event - else: - # 其他模式暂时使用非流式执行,然后一次性返回 - if self.config.orchestration_mode == "sequential": - results = await self._execute_sequential( - task_analysis, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - elif self.config.orchestration_mode == "parallel": - results = await self._execute_parallel( - task_analysis, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - elif self.config.orchestration_mode == "loop": - results = await self._execute_loop( - task_analysis, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - else: - raise BusinessException( - f"不支持的编排模式: {self.config.orchestration_mode}", - BizCode.INVALID_PARAMETER - ) - - # 整合结果 - final_result = await self._aggregate_results(results) - - # 提取会话 ID - sub_conversation_id = None - if isinstance(results, dict): - sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") - elif isinstance(results, list) and results: - for item in results: - if "result" in item: - sub_conversation_id = item["result"].get("conversation_id") - if sub_conversation_id: - break - - # 发送消息事件 - yield self._format_sse_event("message", { - "content": final_result, - "conversation_id": sub_conversation_id - }) - - elapsed_time = time.time() - start_time - - # 发送结束事件 - yield self._format_sse_event("end", { - "elapsed_time": elapsed_time, - "timestamp": time.time() - }) - - logger.info( - f"多 Agent 任务完成(流式)", - extra={ - "mode": self.config.orchestration_mode, - "elapsed_time": elapsed_time - } - ) - - except Exception as e: - logger.error( - f"多 Agent 任务执行失败(流式)", - extra={"error": str(e), "mode": self.config.orchestration_mode} - ) - # 发送错误事件 - yield self._format_sse_event("error", { - "error": str(e), - "timestamp": time.time() - }) - - async def execute( - self, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - use_llm_routing: bool = True, - web_search: bool = False, - memory: bool = True - ) -> Dict[str, Any]: - """执行多 Agent 任务 - - Args: - message: 用户消息 - conversation_id: 会话 ID - user_id: 用户 ID - variables: 变量参数 - - Returns: - 执行结果 - """ - start_time = time.time() - - logger.info( - f"开始执行多 Agent 任务", - extra={ - "mode": self.config.orchestration_mode, - "message_length": len(message) - } - ) - - try: - # 1. 主 Agent 分析任务 - task_analysis = await self._analyze_task(message, variables) - task_analysis["use_llm_routing"] = use_llm_routing - - # 2. 根据模式执行 - if self.config.orchestration_mode == "sequential": - results = await self._execute_sequential( - task_analysis, - conversation_id, - user_id, - web_search, - memory - ) - elif self.config.orchestration_mode == "parallel": - results = await self._execute_parallel( - task_analysis, - conversation_id, - user_id, - web_search, - memory - ) - elif self.config.orchestration_mode == "conditional": - results = await self._execute_conditional( - task_analysis, - conversation_id, - user_id, - web_search, - memory - ) - elif self.config.orchestration_mode == "loop": - results = await self._execute_loop( - task_analysis, - conversation_id, - user_id, - web_search, - memory - ) - else: - raise BusinessException( - f"不支持的编排模式: {self.config.orchestration_mode}", - BizCode.INVALID_PARAMETER - ) - - # 3. 整合结果 - final_result = await self._aggregate_results(results) - - elapsed_time = time.time() - start_time - - # 4. 提取子 Agent 的 conversation_id(用于多轮对话) - sub_conversation_id = None - if isinstance(results, dict): - # conditional 或 loop 模式 - sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") - elif isinstance(results, list) and results: - # sequential 或 parallel 模式,使用第一个成功的结果 - for item in results: - if "result" in item: - sub_conversation_id = item["result"].get("conversation_id") - if sub_conversation_id: - break - - logger.info( - f"多 Agent 任务完成", - extra={ - "mode": self.config.orchestration_mode, - "elapsed_time": elapsed_time, - "sub_agent_count": len(results) if isinstance(results, list) else 1, - "sub_conversation_id": sub_conversation_id - } - ) - - return { - "message": final_result, - "conversation_id": sub_conversation_id, # 返回子 Agent 的会话 ID - "elapsed_time": elapsed_time, - "mode": self.config.orchestration_mode, - "sub_results": results - } - - except Exception as e: - logger.error( - f"多 Agent 任务执行失败", - extra={"error": str(e), "mode": self.config.orchestration_mode} - ) - raise - - async def _analyze_task( - self, - message: str, - variables: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - """主 Agent 分析任务 - - Args: - message: 用户消息 - variables: 变量参数 - - Returns: - 任务分析结果 - """ - # 简化版本:直接返回基本信息 - # 在实际应用中,可以让主 Agent 使用 LLM 分析任务 - return { - "message": message, - "variables": variables or {}, - "sub_agents": self.config.sub_agents, - "initial_context": variables or {} - } - - async def _execute_sequential( - self, - task_analysis: Dict[str, Any], - conversation_id: Optional[uuid.UUID], - user_id: Optional[str], - web_search: bool = False, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ) -> List[Dict[str, Any]]: - """顺序执行子 Agent - - Args: - task_analysis: 任务分析结果 - conversation_id: 会话 ID - user_id: 用户 ID - - Returns: - 执行结果列表 - """ - results = [] - context = task_analysis.get("initial_context", {}) - message = task_analysis.get("message", "") - - # 按优先级排序 - sub_agents = sorted( - task_analysis["sub_agents"], - key=lambda x: x.get("priority", 0) - ) - - for sub_agent_info in sub_agents: - agent_id = sub_agent_info["agent_id"] - agent_data = self.sub_agents.get(agent_id) - - if not agent_data: - logger.warning(f"子 Agent 不存在: {agent_id}") - continue - - logger.info( - f"执行子 Agent", - extra={ - "agent_id": agent_id, - "agent_name": sub_agent_info.get("name"), - "priority": sub_agent_info.get("priority") - } - ) - - # 执行子 Agent - result = await self._execute_sub_agent( - agent_data["config"], - message, - context, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - - results.append({ - "agent_id": agent_id, - "agent_name": sub_agent_info.get("name"), - "result": result, - "conversation_id": result.get("conversation_id") # 保存会话 ID - }) - - # 更新上下文(后续 Agent 可以使用前面的结果) - context[f"result_from_{sub_agent_info.get('name', agent_id)}"] = result.get("message") - - return results - - async def _execute_parallel( - self, - task_analysis: Dict[str, Any], - conversation_id: Optional[uuid.UUID], - user_id: Optional[str], - web_search: bool = False, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ) -> List[Dict[str, Any]]: - """并行执行子 Agent - - Args: - task_analysis: 任务分析结果 - conversation_id: 会话 ID - user_id: 用户 ID - - Returns: - 执行结果列表 - """ - context = task_analysis.get("initial_context", {}) - message = task_analysis.get("message", "") - - # 获取并发限制 - parallel_limit = self.config.execution_config.get("parallel_limit", 3) - - # 创建任务列表 - tasks = [] - for sub_agent_info in task_analysis["sub_agents"]: - agent_id = sub_agent_info["agent_id"] - agent_data = self.sub_agents.get(agent_id) - - if not agent_data: - continue - - task = self._execute_sub_agent( - agent_data["config"], - message, - context, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - tasks.append((agent_id, sub_agent_info.get("name"), task)) - - # 并行执行(带限制) - results = [] - for i in range(0, len(tasks), parallel_limit): - batch = tasks[i:i + parallel_limit] - batch_results = await asyncio.gather( - *[task for _, _, task in batch], - return_exceptions=True - ) - - for (agent_id, agent_name, _), result in zip(batch, batch_results): - if isinstance(result, Exception): - logger.error(f"子 Agent 执行失败: {agent_name}", extra={"error": str(result)}) - results.append({ - "agent_id": agent_id, - "agent_name": agent_name, - "error": str(result) - }) - else: - results.append({ - "agent_id": agent_id, - "agent_name": agent_name, - "result": result, - "conversation_id": result.get("conversation_id") # 保存会话 ID - }) - - return results - - async def _execute_conditional_stream( - self, - task_analysis: Dict[str, Any], - conversation_id: Optional[uuid.UUID], - user_id: Optional[str], - web_search: bool = False, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ): - """条件路由执行(流式) - - Args: - task_analysis: 任务分析结果 - conversation_id: 会话 ID - user_id: 用户 ID - - Yields: - SSE 格式的事件流 - """ - if not task_analysis["sub_agents"]: - raise BusinessException("没有可用的子 Agent", BizCode.AGENT_CONFIG_MISSING) - - message = task_analysis.get("message", "") - - # 使用路由规则选择 Agent - use_llm = task_analysis.get("use_llm_routing", True) - selected_agent_info = await self._route_by_rules( - message, - task_analysis["sub_agents"], - use_llm=use_llm, - conversation_id=str(conversation_id) if conversation_id else None - ) - - if not selected_agent_info: - selected_agent_info = task_analysis["sub_agents"][0] - logger.info("未匹配到路由规则,使用默认 Agent") - - agent_id = selected_agent_info["agent_id"] - agent_data = self.sub_agents.get(agent_id) - - if not agent_data: - raise BusinessException(f"子 Agent 不存在: {agent_id}", BizCode.AGENT_CONFIG_MISSING) - - logger.info( - f"条件路由选择 Agent(流式)", - extra={ - "agent_id": agent_id, - "agent_name": selected_agent_info.get("name"), - "message_preview": message[:50] - } - ) - - # 发送路由信息事件 - yield self._format_sse_event("agent_selected", { - "agent_id": agent_id, - "agent_name": selected_agent_info.get("name") - }) - - # 流式执行子 Agent - sub_conversation_id = None - async for event in self._execute_sub_agent_stream( - agent_data["config"], - message, - task_analysis.get("initial_context", {}), - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ): - # 解析事件以提取 conversation_id - if "data:" in event: - try: - import json - data_line = event.split("data: ", 1)[1].strip() - data = json.loads(data_line) - if "conversation_id" in data: - sub_conversation_id = data["conversation_id"] - except: - pass - - yield event - - # 如果有会话 ID,发送一个包含它的事件 - if sub_conversation_id: - yield self._format_sse_event("conversation", { - "conversation_id": sub_conversation_id - }) - - async def _execute_conditional( - self, - task_analysis: Dict[str, Any], - conversation_id: Optional[uuid.UUID], - user_id: Optional[str], - web_search: bool = False, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ) -> Dict[str, Any]: - """条件路由执行 - 根据路由规则选择合适的 Agent - - Args: - task_analysis: 任务分析结果 - conversation_id: 会话 ID - user_id: 用户 ID - - Returns: - 执行结果 - """ - if not task_analysis["sub_agents"]: - raise BusinessException("没有可用的子 Agent", BizCode.AGENT_CONFIG_MISSING) - - message = task_analysis.get("message", "") - - # 使用路由规则选择 Agent(默认启用 LLM) - use_llm = task_analysis.get("use_llm_routing", True) - selected_agent_info = await self._route_by_rules( - message, - task_analysis["sub_agents"], - use_llm=use_llm, - conversation_id=str(conversation_id) if conversation_id else None - ) - - if not selected_agent_info: - # 如果没有匹配的规则,使用第一个 Agent - selected_agent_info = task_analysis["sub_agents"][0] - logger.info("未匹配到路由规则,使用默认 Agent") - - agent_id = selected_agent_info["agent_id"] - agent_data = self.sub_agents.get(agent_id) - - if not agent_data: - raise BusinessException(f"子 Agent 不存在: {agent_id}", BizCode.AGENT_CONFIG_MISSING) - - logger.info( - f"条件路由选择 Agent", - extra={ - "agent_id": agent_id, - "agent_name": selected_agent_info.get("name"), - "message_preview": message[:50] - } - ) - - result = await self._execute_sub_agent( - agent_data["config"], - message, - task_analysis.get("initial_context", {}), - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - - # 确保返回子 Agent 的 conversation_id - return { - "agent_id": agent_id, - "agent_name": selected_agent_info.get("name"), - "result": result, - "conversation_id": result.get("conversation_id") # 传递子 Agent 的会话 ID - } - - async def _route_by_rules( - self, - message: str, - sub_agents: List[Dict[str, Any]], - use_llm: bool = True, - conversation_id: Optional[str] = None - ) -> Optional[Dict[str, Any]]: - """根据路由规则选择 Agent(支持 LLM 增强) - - Args: - message: 用户消息 - sub_agents: 子 Agent 列表 - use_llm: 是否使用 LLM 辅助路由 - conversation_id: 会话 ID(用于多轮对话状态管理) - - Returns: - 选中的 Agent 信息,如果没有匹配则返回 None - """ - # 如果配置了 LLM 路由器,优先使用 - if self.llm_router and use_llm: - try: - logger.info("使用 LLM 路由器进行智能路由") - routing_result = await self.llm_router.route( - message=message, - conversation_id=conversation_id, - force_new=False - ) - - selected_agent_id = routing_result["agent_id"] - confidence = routing_result["confidence"] - method = routing_result.get("routing_method", "unknown") - - logger.info( - f"LLM 路由完成", - extra={ - "agent_id": selected_agent_id, - "confidence": confidence, - "method": method, - "strategy": routing_result.get("strategy"), - "topic": routing_result.get("topic") - } - ) - - # 查找对应的 Agent - for agent in sub_agents: - if agent["agent_id"] == selected_agent_id: - return agent - - logger.warning(f"LLM 路由返回的 agent_id 不在子 Agent 列表中: {selected_agent_id}") - - except Exception as e: - logger.error(f"LLM 路由失败,降级到关键词路由: {str(e)}") - - # 降级到关键词路由 - if not self.config.routing_rules: - return None - - message_lower = message.lower() - best_match = None - best_score = 0 - - # 关键词匹配 - for rule in self.config.routing_rules: - target_agent_id = rule.get("target_agent_id") - condition = rule.get("condition", "") - priority = rule.get("priority", 1) - - # 解析条件表达式(简化版本:支持 contains_any) - score = self._evaluate_condition(condition, message_lower) - - # 考虑优先级 - weighted_score = score * priority - - if weighted_score > best_score: - # 找到对应的 Agent - for agent in sub_agents: - if agent["agent_id"] == target_agent_id: - best_match = agent - best_score = weighted_score - break - - if best_match: - logger.info( - f"关键词路由", - extra={ - "agent_name": best_match.get("name"), - "score": best_score - } - ) - - return best_match - - - def _evaluate_condition(self, condition: str, message: str) -> float: - """评估条件表达式 - - Args: - condition: 条件表达式,如 "contains_any(['数学', '物理'])" - message: 消息文本(已转小写) - - Returns: - 匹配分数 (0-1) - """ - import re - - # 解析 contains_any(['keyword1', 'keyword2', ...]) - match = re.search(r"contains_any\(\[(.*?)\]\)", condition) - if not match: - return 0 - - # 提取关键词列表 - keywords_str = match.group(1) - keywords = [k.strip().strip("'\"") for k in keywords_str.split(",")] - - # 计算匹配分数 - matched_count = 0 - for keyword in keywords: - if keyword.lower() in message: - matched_count += 1 - - if not keywords: - return 0 - - # 返回匹配比例 - return matched_count / len(keywords) - - async def _execute_loop( - self, - task_analysis: Dict[str, Any], - conversation_id: Optional[uuid.UUID], - user_id: Optional[str], - web_search: bool = False, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ) -> Dict[str, Any]: - """循环执行(迭代优化) - - Args: - task_analysis: 任务分析结果 - conversation_id: 会话 ID - user_id: 用户 ID - - Returns: - 执行结果 - """ - max_iterations = self.config.execution_config.get("max_iterations", 5) - - if not task_analysis["sub_agents"]: - raise BusinessException("没有可用的子 Agent", BizCode.AGENT_CONFIG_MISSING) - - agent_info = task_analysis["sub_agents"][0] - agent_id = agent_info["agent_id"] - agent_data = self.sub_agents.get(agent_id) - - if not agent_data: - raise BusinessException(f"子 Agent 不存在: {agent_id}", BizCode.AGENT_CONFIG_MISSING) - - context = task_analysis.get("initial_context", {}) - message = task_analysis.get("message", "") - - result = None - for i in range(max_iterations): - logger.info( - f"循环执行 Agent", - extra={ - "iteration": i + 1, - "max_iterations": max_iterations, - "agent_name": agent_info.get("name") - } - ) - - result = await self._execute_sub_agent( - agent_data["config"], - message, - context, - conversation_id, - user_id, - web_search, - memory, - storage_type, - user_rag_memory_id - ) - - # 简化版本:执行一次就返回 - # 在实际应用中,应该验证结果是否满足条件 - break - - return { - "agent_id": agent_id, - "agent_name": agent_info.get("name"), - "iterations": i + 1, - "result": result, - "conversation_id": result.get("conversation_id") if result else None # 保存会话 ID - } - - async def _execute_sub_agent_stream( - self, - agent_config: AgentConfig, - message: str, - context: Dict[str, Any], - conversation_id: Optional[uuid.UUID], - user_id: Optional[str], - web_search: bool = False, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ): - """执行单个子 Agent(流式) - - Args: - agent_config: Agent 配置 - message: 消息 - context: 上下文 - conversation_id: 会话 ID - user_id: 用户 ID - - Yields: - SSE 格式的事件流 - """ - from app.services.draft_run_service import DraftRunService - - # 获取模型配置 - model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) - if not model_config: - raise BusinessException( - f"Agent 模型配置不存在", - BizCode.AGENT_CONFIG_MISSING - ) - - # 流式执行 Agent - draft_service = DraftRunService(self.db) - async for event in draft_service.run_stream( - agent_config=agent_config, - model_config=model_config, - message=message, - workspace_id=agent_config.app.workspace_id, - conversation_id=str(conversation_id) if conversation_id else None, - user_id=user_id, - variables=context, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - web_search=web_search, - memory=memory - ): - yield event - - async def _execute_sub_agent( - self, - agent_config: AgentConfig, - message: str, - context: Dict[str, Any], - conversation_id: Optional[uuid.UUID], - user_id: Optional[str], - web_search: bool = False, - memory: bool = True, - storage_type: str = '', - user_rag_memory_id: str = '' - ) -> Dict[str, Any]: - """执行单个子 Agent - - Args: - agent_config: Agent 配置 - message: 消息 - context: 上下文 - conversation_id: 会话 ID - user_id: 用户 ID - - Returns: - 执行结果 - """ - from app.services.draft_run_service import DraftRunService - - # 获取模型配置 - model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) - if not model_config: - raise BusinessException( - f"Agent 模型配置不存在", - BizCode.AGENT_CONFIG_MISSING - ) - - # 执行 Agent - draft_service = DraftRunService(self.db) - result = await draft_service.run( - agent_config=agent_config, - model_config=model_config, - message=message, - workspace_id=agent_config.app.workspace_id, - conversation_id=str(conversation_id) if conversation_id else None, - user_id=user_id, - variables=context, - web_search=web_search, - memory=memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ) - - return result - - async def _aggregate_results( - self, - results: Any - ) -> str: - """整合子 Agent 的结果 - - Args: - results: 子 Agent 执行结果 - - Returns: - 整合后的结果 - """ - strategy = self.config.aggregation_strategy - - if strategy == "merge": - return self._merge_results(results) - elif strategy == "vote": - return self._vote_results(results) - elif strategy == "priority": - return self._priority_results(results) - else: - return self._merge_results(results) - - def _merge_results(self, results: Any) -> str: - """合并所有结果 - - Args: - results: 执行结果 - - Returns: - 合并后的结果 - """ - if isinstance(results, list): - # 顺序或并行执行的结果 - merged = [] - for item in results: - if "result" in item: - agent_name = item.get("agent_name", "Agent") - message = item["result"].get("message", "") - merged.append(f"【{agent_name}】\n{message}") - elif "error" in item: - agent_name = item.get("agent_name", "Agent") - merged.append(f"【{agent_name}】\n错误: {item['error']}") - - return "\n\n".join(merged) - elif isinstance(results, dict): - # 条件或循环执行的结果 - if "result" in results: - return results["result"].get("message", "") - return str(results) - - return str(results) - - def _vote_results(self, results: Any) -> str: - """投票选择最佳结果(简化版本) - - Args: - results: 执行结果 - - Returns: - 最佳结果 - """ - # 简化版本:返回第一个成功的结果 - if isinstance(results, list): - for item in results: - if "result" in item: - return item["result"].get("message", "") - - return self._merge_results(results) - - def _priority_results(self, results: Any) -> str: - """按优先级选择结果(简化版本) - - Args: - results: 执行结果 - - Returns: - 优先级最高的结果 - """ - # 简化版本:返回第一个结果 - if isinstance(results, list) and results: - if "result" in results[0]: - return results[0]["result"].get("message", "") - - return self._merge_results(results) - - def _format_sse_event(self, event: str, data: Dict[str, Any]) -> str: - """格式化 SSE 事件 - - Args: - event: 事件类型 - data: 事件数据 - - Returns: - SSE 格式的字符串 - """ - import json - return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - - def _load_agent(self, release_id: uuid.UUID): - """从发布版本加载 Agent 配置 - - Args: - release_id: 发布版本 ID - - Returns: - Agent 配置对象(包含发布版本的配置数据) - """ - from app.models import AppRelease, App - - # 获取发布版本 - release = self.db.get(AppRelease, release_id) - if not release: - raise ResourceNotFoundException("发布版本", str(release_id)) - - # 从发布版本的 config 中获取 Agent 配置 - config_data = release.config - if not config_data: - raise BusinessException(f"发布版本 {release_id} 缺少配置数据", BizCode.AGENT_CONFIG_MISSING) - - # 获取应用信息(用于 workspace_id) - app = self.db.get(App, release.app_id) - if not app: - raise ResourceNotFoundException("应用", str(release.app_id)) - - # 创建一个类似 AgentConfig 的对象,包含所有需要的属性 - class AgentConfigProxy: - """Agent 配置代理对象,模拟 AgentConfig 的接口""" - def __init__(self, release, app, config_data): - self.id = release.id - self.app_id = release.app_id - self.app = app - self.name = release.name - self.description = release.description - self.system_prompt = config_data.get("system_prompt") - self.model_parameters = config_data.get("model_parameters") - self.knowledge_retrieval = config_data.get("knowledge_retrieval") - self.memory = config_data.get("memory") - self.variables = config_data.get("variables", []) - self.tools = config_data.get("tools", {}) - self.default_model_config_id = release.default_model_config_id - - return AgentConfigProxy(release, app, config_data) diff --git a/app/services/multi_agent_service.py b/app/services/multi_agent_service.py deleted file mode 100644 index f6374dc5..00000000 --- a/app/services/multi_agent_service.py +++ /dev/null @@ -1,630 +0,0 @@ -"""多 Agent 配置管理服务""" -import uuid -from typing import Optional, List, Tuple, Any -from sqlalchemy.orm import Session -from sqlalchemy import select, desc - -from app.models import MultiAgentConfig, App, AgentConfig -from app.schemas.multi_agent_schema import ( - MultiAgentConfigCreate, - MultiAgentConfigUpdate, - MultiAgentRunRequest -) -from app.services.multi_agent_orchestrator import MultiAgentOrchestrator -from app.core.exceptions import ResourceNotFoundException, BusinessException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger -from app.models import AppRelease - -logger = get_business_logger() - - -def convert_uuids_to_str(obj: Any) -> Any: - """递归转换对象中的所有 UUID 为字符串 - - Args: - obj: 要转换的对象(dict, list, UUID 等) - - Returns: - 转换后的对象 - """ - if isinstance(obj, uuid.UUID): - return str(obj) - elif isinstance(obj, dict): - return {k: convert_uuids_to_str(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [convert_uuids_to_str(item) for item in obj] - else: - return obj - - -class MultiAgentService: - """多 Agent 配置管理服务""" - - def __init__(self, db: Session): - self.db = db - - def create_config( - self, - app_id: uuid.UUID, - data: MultiAgentConfigCreate, - created_by: uuid.UUID - ) -> MultiAgentConfig: - """创建多 Agent 配置 - - Args: - app_id: 应用 ID - data: 配置数据 - created_by: 创建者 ID - - Returns: - 多 Agent 配置 - """ - # 1. 验证应用存在 - app = self.db.get(App, app_id) - if not app: - raise ResourceNotFoundException("应用", str(app_id)) - - # 2. 检查是否已有有效配置 - existing = self.db.scalars( - select(MultiAgentConfig) - .where( - MultiAgentConfig.app_id == app_id, - MultiAgentConfig.is_active == True - ) - .order_by(MultiAgentConfig.updated_at.desc()) - ).first() - if existing: - raise BusinessException("应用已有多 Agent 配置", BizCode.DUPLICATE_RESOURCE) - - # 3. 验证主 Agent 存在 - master_agent = self.db.get(AgentConfig, data.master_agent_id) - if not master_agent: - raise ResourceNotFoundException("主 Agent", str(data.master_agent_id)) - - # 4. 验证子 Agent 存在 - for sub_agent in data.sub_agents: - agent = self.db.get(AgentConfig, sub_agent.agent_id) - if not agent: - raise ResourceNotFoundException("子 Agent", str(sub_agent.agent_id)) - - # 5. 创建配置(转换 UUID 为字符串以支持 JSON 序列化) - sub_agents_data = [convert_uuids_to_str(sub_agent.model_dump()) for sub_agent in data.sub_agents] - routing_rules_data = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None - - # 处理 execution_config(可能是 None、字典或 Pydantic 模型) - if data.execution_config is None: - execution_config_data = {} - elif isinstance(data.execution_config, dict): - execution_config_data = convert_uuids_to_str(data.execution_config) - else: - execution_config_data = convert_uuids_to_str(data.execution_config.model_dump()) - - config = MultiAgentConfig( - app_id=app_id, - master_agent_id=data.master_agent_id, - master_agent_name=data.master_agent_name, - orchestration_mode=data.orchestration_mode, - sub_agents=sub_agents_data, - routing_rules=routing_rules_data, - execution_config=execution_config_data, - aggregation_strategy=data.aggregation_strategy - ) - - self.db.add(config) - self.db.commit() - self.db.refresh(config) - - logger.info( - f"创建多 Agent 配置成功", - extra={ - "config_id": str(config.id), - "app_id": str(app_id), - "mode": data.orchestration_mode, - "sub_agent_count": len(data.sub_agents) - } - ) - - return config - - def get_config(self, app_id: uuid.UUID) -> Optional[MultiAgentConfig]: - """获取多 Agent 配置 - - Args: - app_id: 应用 ID - - Returns: - 多 Agent 配置,如果不存在返回 None - """ - return self.db.scalars( - select(MultiAgentConfig) - .where( - MultiAgentConfig.app_id == app_id, - MultiAgentConfig.is_active == True - ) - .order_by(MultiAgentConfig.updated_at.desc()) - ).first() - - def get_multi_agent_configs(self, app_id: uuid.UUID) -> Optional[dict]: - """通过 app_id 获取最新有效的多智能体配置,并将 agent_id 转换为 app_id - - Args: - app_id: 应用 ID - - Returns: - 转换后的配置字典,如果不存在返回 None - """ - config = self.get_config(app_id) - if not config: - return None - - # 转换 master_agent_id (release_id) 为 app_id - master_release = self.db.get(AppRelease, config.master_agent_id) - master_app_id = master_release.app_id if master_release else config.master_agent_id - - # 转换 sub_agents 中的 agent_id (release_id) 为 app_id - converted_sub_agents = [] - for sub_agent in config.sub_agents: - sub_agent_copy = sub_agent.copy() - release_id = sub_agent.get("agent_id") - if release_id: - try: - release_id_uuid = uuid.UUID(release_id) if isinstance(release_id, str) else release_id - sub_release = self.db.get(AppRelease, release_id_uuid) - if sub_release: - sub_agent_copy["agent_id"] = str(sub_release.app_id) - except Exception as e: - logger.warning(f"转换 sub_agent agent_id 失败: {release_id}, 错误: {str(e)}") - converted_sub_agents.append(sub_agent_copy) - - # 构建返回的配置字典 - return { - "id": config.id, - "app_id": config.app_id, - "master_agent_id": master_app_id, - "master_agent_name": config.master_agent_name, - "orchestration_mode": config.orchestration_mode, - "sub_agents": converted_sub_agents, - "routing_rules": config.routing_rules, - "execution_config": config.execution_config, - "aggregation_strategy": config.aggregation_strategy, - "is_active": config.is_active, - "created_at": config.created_at, - "updated_at": config.updated_at - } - - def get_published_config_by_agent_id(self, agent_id: uuid.UUID) -> Optional[dict]: - """通过 agent_id 获取当前发布版本的完整配置 - - Args: - agent_id: Agent 配置 ID - - Returns: - 当前发布版本的配置字典,如果没有发布版本则返回 None - """ - from app.models import AppRelease - - # 查询 Agent 配置 - agent_config = self.db.get(AgentConfig, agent_id) - if not agent_config: - logger.warning(f"Agent 配置不存在: {agent_id}") - return None - - # 获取关联的应用 - app = self.db.get(App, agent_config.app_id) - if not app or not app.current_release_id: - logger.warning(f"应用未发布或不存在: app_id={agent_config.app_id}") - return None - - # 获取当前发布版本 - release = self.db.get(AppRelease, app.current_release_id) - if not release: - logger.warning(f"发布版本不存在: release_id={app.current_release_id}") - return None - - # 从发布版本的 config 中获取完整配置 - # config 是一个 JSON 对象,包含了发布时的配置快照 - config_data = release.config - if config_data and isinstance(config_data, dict): - return config_data - - return None - - def get_published_by_agent_id(self, agent_id: uuid.UUID) -> Optional[AppRelease]: - """通过 agent_id 获取当前发布版本的完整配置 - - Args: - agent_id: Agent 配置 ID - - Returns: - 当前发布版本的配置字典,如果没有发布版本则返回 None - """ - - # 获取关联的应用 - app = self.db.get(App, agent_id) - if not app or not app.current_release_id: - logger.warning(f"应用未发布或不存在: app_id={agent_id}") - return None - - # 获取当前发布版本 - release = self.db.get(AppRelease, app.current_release_id) - if not release: - logger.warning(f"发布版本不存在: release_id={app.current_release_id}") - return None - return release - - def update_config( - self, - app_id: uuid.UUID, - data: MultiAgentConfigUpdate - ) -> MultiAgentConfig: - """更新多 Agent 配置 - - Args: - app_id: 应用 ID - data: 更新数据 - - Returns: - 更新后的配置 - """ - config = self.get_config(app_id) - if not config: - # 1. 验证应用存在 - app = self.db.get(App, app_id) - if not app: - raise ResourceNotFoundException("应用", str(app_id)) - - # 2. 验证主 Agent 存在并获取发布版本 ID - master_app_release = self.get_published_by_agent_id(data.master_agent_id) - if not master_app_release: - raise ResourceNotFoundException("主 Agent 未发布或不存在", str(data.master_agent_id)) - - # 使用发布版本 ID - data.master_agent_id = master_app_release.id - - # 3. 验证子 Agent 存在并获取发布版本 ID - for sub_agent in data.sub_agents: - agent_app_release = self.get_published_by_agent_id(sub_agent.agent_id) - if not agent_app_release: - raise ResourceNotFoundException("子 Agent 未发布或不存在", str(sub_agent.agent_id)) - - # 使用发布版本 ID - sub_agent.agent_id = agent_app_release.id - - - # 5. 创建配置(转换 UUID 为字符串以支持 JSON 序列化) - sub_agents_data = [convert_uuids_to_str(sub_agent.model_dump()) for sub_agent in data.sub_agents] - # routing_rules_data = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None - - # 处理 execution_config(可能是 None、字典或 Pydantic 模型) - if data.execution_config is None: - execution_config_data = {} - elif isinstance(data.execution_config, dict): - execution_config_data = convert_uuids_to_str(data.execution_config) - else: - execution_config_data = convert_uuids_to_str(data.execution_config.model_dump()) - - config = MultiAgentConfig( - app_id=app_id, - master_agent_id=data.master_agent_id, - master_agent_name=data.master_agent_name, - orchestration_mode=data.orchestration_mode, - sub_agents=sub_agents_data, - # routing_rules=routing_rules_data, - execution_config=execution_config_data, - aggregation_strategy=data.aggregation_strategy - ) - - self.db.add(config) - self.db.commit() - self.db.refresh(config) - - logger.info( - f"创建多 Agent 配置成功", - extra={ - "config_id": str(config.id), - "app_id": str(app_id), - "mode": data.orchestration_mode, - "sub_agent_count": len(data.sub_agents) - } - ) - return config - # raise ResourceNotFoundException("多 Agent 配置", str(app_id)) - - # 更新字段 - if data.master_agent_id is not None: - # 验证主 Agent 存在 - # 3. 验证主 Agent 存在并获取发布配置 - master_app_release = self.get_published_by_agent_id(data.master_agent_id) - if not master_app_release: - raise ResourceNotFoundException("主 Agent 未发布或", str(data.master_agent_id)) - - config.master_agent_id = master_app_release.id - - if data.master_agent_name is not None: - config.master_agent_name = data.master_agent_name - - if data.orchestration_mode is not None: - config.orchestration_mode = data.orchestration_mode - - if data.sub_agents is not None: - # 验证子 Agent 存在,并获取其发布的 config_id - updated_sub_agents = [] - for sub_agent in data.sub_agents: - agent_app_release = self.get_published_by_agent_id(sub_agent.agent_id) - if not agent_app_release: - raise ResourceNotFoundException("子 Agent 未发布或", str(sub_agent.agent_id)) - sub_agent.agent_id = agent_app_release.id - sub_agent_dict = convert_uuids_to_str(sub_agent.model_dump()) - updated_sub_agents.append(sub_agent_dict) - - config.sub_agents = updated_sub_agents - - # if data.routing_rules is not None: - # config.routing_rules = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None - - if data.execution_config is None: - execution_config_data = {} - elif isinstance(data.execution_config, dict): - execution_config_data = convert_uuids_to_str(data.execution_config) - else: - execution_config_data = convert_uuids_to_str(data.execution_config.model_dump()) - - if data.aggregation_strategy is not None: - config.aggregation_strategy = data.aggregation_strategy - - if data.is_active is not None: - config.is_active = data.is_active - - self.db.commit() - self.db.refresh(config) - - logger.info( - f"更新多 Agent 配置成功", - extra={ - "config_id": str(config.id), - "app_id": str(app_id) - } - ) - - return config - - def delete_config(self, app_id: uuid.UUID) -> None: - """删除多 Agent 配置 - - Args: - app_id: 应用 ID - """ - config = self.get_config(app_id) - if not config: - raise ResourceNotFoundException("多 Agent 配置", str(app_id)) - - self.db.delete(config) - self.db.commit() - - logger.info( - f"删除多 Agent 配置成功", - extra={ - "config_id": str(config.id), - "app_id": str(app_id) - } - ) - - async def run( - self, - app_id: uuid.UUID, - request: MultiAgentRunRequest - ) -> dict: - """运行多 Agent 任务 - - Args: - app_id: 应用 ID - request: 运行请求 - - Returns: - 执行结果 - """ - # 1. 获取配置 - config = self.get_config(app_id) - if not config: - raise ResourceNotFoundException("多 Agent 配置", str(app_id)) - - if not config.is_active: - raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED) - - # 2. 创建编排器 - orchestrator = MultiAgentOrchestrator(self.db, config) - - # 3. 执行任务 - result = await orchestrator.execute( - message=request.message, - conversation_id=request.conversation_id, - user_id=request.user_id, - variables=request.variables, - use_llm_routing=getattr(request, 'use_llm_routing', True), # 默认启用 LLM 路由 - web_search=getattr(request, 'web_search', False), # 网络搜索参数 - memory=getattr(request, 'memory', True) # 记忆功能参数 - ) - - return result - - async def run_stream( - self, - app_id: uuid.UUID, - request: MultiAgentRunRequest, - storage_type :str, - user_rag_memory_id :str - ): - """运行多 Agent 任务(流式返回) - - Args: - app_id: 应用 ID - request: 运行请求 - - Yields: - SSE 格式的事件流 - """ - # 1. 获取配置 - config = self.get_config(app_id) - if not config: - raise ResourceNotFoundException("多 Agent 配置", str(app_id)) - - if not config.is_active: - raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED) - - # 2. 创建编排器 - orchestrator = MultiAgentOrchestrator(self.db, config) - - # 3. 流式执行任务 - async for event in orchestrator.execute_stream( - message=request.message, - conversation_id=request.conversation_id, - user_id=request.user_id, - variables=request.variables, - use_llm_routing=getattr(request, 'use_llm_routing', True), - web_search=getattr(request, 'web_search', False), # 网络搜索参数 - memory=getattr(request, 'memory', True) , # 记忆功能参数 - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ): - yield event - - def add_sub_agent( - self, - app_id: uuid.UUID, - agent_id: uuid.UUID, - name: str, - role: Optional[str] = None, - priority: int = 1, - capabilities: Optional[List[str]] = None - ) -> MultiAgentConfig: - """添加子 Agent - - Args: - app_id: 应用 ID - agent_id: Agent ID - name: Agent 名称 - role: 角色描述 - priority: 优先级 - capabilities: 能力列表 - - Returns: - 更新后的配置 - """ - config = self.get_config(app_id) - if not config: - raise ResourceNotFoundException("多 Agent 配置", str(app_id)) - - # 验证 Agent 存在 - agent = self.db.get(AgentConfig, agent_id) - if not agent: - raise ResourceNotFoundException("Agent", str(agent_id)) - - # 检查是否已存在 - for sub_agent in config.sub_agents: - if sub_agent["agent_id"] == str(agent_id): - raise BusinessException("Agent 已存在于配置中", BizCode.DUPLICATE_RESOURCE) - - # 添加子 Agent - new_sub_agent = { - "agent_id": str(agent_id), - "name": name, - "role": role, - "priority": priority, - "capabilities": capabilities or [] - } - - config.sub_agents.append(new_sub_agent) - - # 标记为已修改 - self.db.add(config) - self.db.commit() - self.db.refresh(config) - - logger.info( - f"添加子 Agent 成功", - extra={ - "config_id": str(config.id), - "agent_id": str(agent_id), - "agent_name": name - } - ) - - return config - - def remove_sub_agent( - self, - app_id: uuid.UUID, - agent_id: uuid.UUID - ) -> MultiAgentConfig: - """移除子 Agent - - Args: - app_id: 应用 ID - agent_id: Agent ID - - Returns: - 更新后的配置 - """ - config = self.get_config(app_id) - if not config: - raise ResourceNotFoundException("多 Agent 配置", str(app_id)) - - # 查找并移除 - original_count = len(config.sub_agents) - config.sub_agents = [ - sub_agent for sub_agent in config.sub_agents - if sub_agent["agent_id"] != str(agent_id) - ] - - if len(config.sub_agents) == original_count: - raise ResourceNotFoundException("子 Agent", str(agent_id)) - - # 标记为已修改 - self.db.add(config) - self.db.commit() - self.db.refresh(config) - - logger.info( - f"移除子 Agent 成功", - extra={ - "config_id": str(config.id), - "agent_id": str(agent_id) - } - ) - - return config - - def list_configs( - self, - workspace_id: uuid.UUID, - page: int = 1, - pagesize: int = 20 - ) -> Tuple[List[MultiAgentConfig], int]: - """列出多 Agent 配置 - - Args: - workspace_id: 工作空间 ID - page: 页码 - pagesize: 每页数量 - - Returns: - 配置列表和总数 - """ - # 构建查询 - stmt = ( - select(MultiAgentConfig) - .join(App) - .where(App.workspace_id == workspace_id) - .order_by(desc(MultiAgentConfig.created_at)) - ) - - # 总数 - count_stmt = stmt.with_only_columns(MultiAgentConfig.id) - total = len(self.db.execute(count_stmt).all()) - - # 分页 - stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) - configs = list(self.db.scalars(stmt).all()) - - return configs, total diff --git a/app/services/release_share_service.py b/app/services/release_share_service.py deleted file mode 100644 index 7278aea0..00000000 --- a/app/services/release_share_service.py +++ /dev/null @@ -1,444 +0,0 @@ -import uuid -from typing import Optional, Dict, Any -from sqlalchemy.orm import Session -from sqlalchemy import select - -from app.models import ReleaseShare, AppRelease, App, AgentConfig -from app.repositories.release_share_repository import ReleaseShareRepository -from app.core.share_utils import ( - generate_share_token, - hash_password, - verify_password, - build_share_url, - generate_embed_code -) -from app.core.exceptions import ResourceNotFoundException, BusinessException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger -from app.schemas import release_share_schema - -logger = get_business_logger() - - -class ReleaseShareService: - """发布版本分享服务""" - - def __init__(self, db: Session): - self.db = db - self.repo = ReleaseShareRepository(db) - - def create_or_update_share( - self, - release_id: uuid.UUID, - user_id: uuid.UUID, - workspace_id: uuid.UUID, - data: release_share_schema.ReleaseShareCreate, - base_url: Optional[str] = None - ) -> ReleaseShare: - """创建或更新分享配置 - - Args: - release_id: 发布版本 ID - user_id: 用户 ID - workspace_id: 工作空间 ID - data: 分享配置数据 - base_url: 基础 URL(用于生成完整的分享链接) - - Returns: - 分享配置 - """ - # 验证发布版本存在且属于该工作空间 - release = self._get_release_or_404(release_id) - self._validate_release_access(release, workspace_id) - - # 检查是否已存在分享配置 - existing_share = self.repo.get_by_release_id(release_id) - - if existing_share: - # 更新现有配置 - return self._update_share_internal(existing_share, data) - else: - # 创建新配置 - return self._create_share_internal(release, user_id, data) - - def _create_share_internal( - self, - release: AppRelease, - user_id: uuid.UUID, - data: release_share_schema.ReleaseShareCreate - ) -> ReleaseShare: - """内部方法:创建分享配置""" - # 生成唯一的 share_token - share_token = self._generate_unique_token() - - # 处理密码 - password_hash = None - if data.require_password and data.password: - password_hash = hash_password(data.password) - - # 创建分享配置 - share = ReleaseShare( - release_id=release.id, - app_id=release.app_id, - is_enabled=data.is_enabled, - share_token=share_token, - require_password=data.require_password, - password_hash=password_hash, - allow_embed=data.allow_embed, - embed_domains=data.embed_domains or [], - created_by=user_id - ) - - share = self.repo.create(share) - - logger.info( - f"创建分享配置", - extra={ - "share_id": str(share.id), - "release_id": str(release.id), - "app_id": str(release.app_id), - "share_token": share_token - } - ) - - return share - - def _update_share_internal( - self, - share: ReleaseShare, - data: release_share_schema.ReleaseShareUpdate - ) -> ReleaseShare: - """内部方法:更新分享配置""" - if data.is_enabled is not None: - share.is_enabled = data.is_enabled - - if data.require_password is not None: - share.require_password = data.require_password - - if data.password is not None: - if data.password: - share.password_hash = hash_password(data.password) - else: - share.password_hash = None - - if data.allow_embed is not None: - share.allow_embed = data.allow_embed - - if data.embed_domains is not None: - share.embed_domains = data.embed_domains or [] - - share = self.repo.update(share) - - logger.info( - f"更新分享配置", - extra={ - "share_id": str(share.id), - "release_id": str(share.release_id) - } - ) - - return share - - def update_share( - self, - release_id: uuid.UUID, - workspace_id: uuid.UUID, - data: release_share_schema.ReleaseShareUpdate - ) -> ReleaseShare: - """更新分享配置 - - Args: - release_id: 发布版本 ID - workspace_id: 工作空间 ID - data: 更新数据 - - Returns: - 更新后的分享配置 - """ - # 验证发布版本 - release = self._get_release_or_404(release_id) - self._validate_release_access(release, workspace_id) - - # 获取分享配置 - share = self.repo.get_by_release_id(release_id) - if not share: - raise ResourceNotFoundException("分享配置", str(release_id)) - - return self._update_share_internal(share, data) - - def get_share( - self, - release_id: uuid.UUID, - workspace_id: uuid.UUID, - base_url: Optional[str] = None - ) -> Optional[release_share_schema.ReleaseShare]: - """获取分享配置 - - Args: - release_id: 发布版本 ID - workspace_id: 工作空间 ID - base_url: 基础 URL - - Returns: - 分享配置 Schema - """ - # 验证发布版本 - release = self._get_release_or_404(release_id) - self._validate_release_access(release, workspace_id) - - share = self.repo.get_by_release_id(release_id) - if not share: - return None - - return self._convert_to_schema(share, base_url) - - def delete_share( - self, - release_id: uuid.UUID, - workspace_id: uuid.UUID - ) -> None: - """删除(禁用)分享配置 - - Args: - release_id: 发布版本 ID - workspace_id: 工作空间 ID - """ - # 验证发布版本 - release = self._get_release_or_404(release_id) - self._validate_release_access(release, workspace_id) - - share = self.repo.get_by_release_id(release_id) - if not share: - raise ResourceNotFoundException("分享配置", str(release_id)) - - self.repo.delete(share) - - logger.info( - f"删除分享配置", - extra={ - "share_id": str(share.id), - "release_id": str(release_id) - } - ) - - def regenerate_token( - self, - release_id: uuid.UUID, - workspace_id: uuid.UUID - ) -> ReleaseShare: - """重新生成分享 token(旧链接失效) - - Args: - release_id: 发布版本 ID - workspace_id: 工作空间 ID - - Returns: - 更新后的分享配置 - """ - # 验证发布版本 - release = self._get_release_or_404(release_id) - self._validate_release_access(release, workspace_id) - - share = self.repo.get_by_release_id(release_id) - if not share: - raise ResourceNotFoundException("分享配置", str(release_id)) - - # 生成新 token - old_token = share.share_token - share.share_token = self._generate_unique_token() - share = self.repo.update(share) - - logger.info( - f"重新生成分享 token", - extra={ - "share_id": str(share.id), - "old_token": old_token, - "new_token": share.share_token - } - ) - - return share - - def get_shared_release_info( - self, - share_token: str, - password: Optional[str] = None - ) -> release_share_schema.SharedReleaseInfo: - """获取公开分享的发布版本信息 - - Args: - share_token: 分享 token - password: 访问密码(如果需要) - - Returns: - 分享的发布版本信息 - """ - # 获取分享配置 - share = self.repo.get_by_share_token(share_token) - if not share: - raise ResourceNotFoundException("分享链接", share_token) - - # 检查是否启用 - if not share.is_enabled: - raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED) - - # 验证密码 - is_password_verified = False - if share.require_password: - if not password: - # 需要密码但未提供,返回基本信息 - release = self.db.get(AppRelease, share.release_id) - return release_share_schema.SharedReleaseInfo( - app_name=release.name, - app_description=release.description, - app_icon=release.icon, - app_type=release.type, - version=release.version, - release_notes=release.release_notes, - published_at=int(release.published_at.timestamp() * 1000), - config={}, - require_password=True, - is_password_verified=False, - allow_embed=share.allow_embed - ) - - # 验证密码 - if not share.password_hash or not verify_password(password, share.password_hash): - raise BusinessException("密码错误", BizCode.INVALID_PASSWORD) - - is_password_verified = True - - # 获取发布版本详细信息 - release = self.db.get(AppRelease, share.release_id) - if not release: - raise ResourceNotFoundException("发布版本", str(share.release_id)) - - # 异步更新访问统计(不阻塞响应) - try: - self.repo.increment_view_count(share.id) - except Exception as e: - logger.warning(f"更新访问统计失败: {str(e)}") - - # 返回完整信息 - return release_share_schema.SharedReleaseInfo( - app_name=release.name, - app_description=release.description, - app_icon=release.icon, - app_type=release.type, - version=release.version, - release_notes=release.release_notes, - published_at=int(release.published_at.timestamp() * 1000), - config=release.config or {}, - require_password=share.require_password, - is_password_verified=is_password_verified, - allow_embed=share.allow_embed - ) - - def verify_password( - self, - share_token: str, - password: str - ) -> bool: - """验证分享密码 - - Args: - share_token: 分享 token - password: 密码 - - Returns: - 是否验证成功 - """ - share = self.repo.get_by_share_token(share_token) - if not share: - raise ResourceNotFoundException("分享链接", share_token) - - if not share.is_enabled: - raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED) - - if not share.require_password: - return True - - if not share.password_hash: - return False - - return verify_password(password, share.password_hash) - - def get_embed_code( - self, - share_token: str, - width: str = "100%", - height: str = "600px", - base_url: Optional[str] = None - ) -> release_share_schema.EmbedCode: - """获取嵌入代码 - - Args: - share_token: 分享 token - width: 宽度 - height: 高度 - base_url: 基础 URL - - Returns: - 嵌入代码 - """ - share = self.repo.get_by_share_token(share_token) - if not share: - raise ResourceNotFoundException("分享链接", share_token) - - if not share.is_enabled: - raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED) - - if not share.allow_embed: - raise BusinessException("该分享不允许嵌入", BizCode.EMBED_NOT_ALLOWED) - - embed_data = generate_embed_code(share_token, width, height, base_url) - return release_share_schema.EmbedCode(**embed_data) - - def _generate_unique_token(self, max_attempts: int = 10) -> str: - """生成唯一的分享 token""" - for _ in range(max_attempts): - token = generate_share_token() - if not self.repo.token_exists(token): - return token - - raise BusinessException("生成唯一 token 失败,请重试", BizCode.INTERNAL_ERROR) - - def _get_release_or_404(self, release_id: uuid.UUID) -> AppRelease: - """获取发布版本或抛出 404""" - release = self.db.get(AppRelease, release_id) - if not release: - raise ResourceNotFoundException("发布版本", str(release_id)) - return release - - def _validate_release_access(self, release: AppRelease, workspace_id: uuid.UUID) -> None: - """验证发布版本访问权限""" - app = self.db.get(App, release.app_id) - if not app: - raise ResourceNotFoundException("应用", str(release.app_id)) - - if app.workspace_id != workspace_id: - raise BusinessException("无权访问该发布版本", BizCode.PERMISSION_DENIED) - - def _convert_to_schema( - self, - share: ReleaseShare, - base_url: Optional[str] = None - ) -> release_share_schema.ReleaseShare: - """转换为 Schema""" - share_url = build_share_url(share.share_token, base_url) - - return release_share_schema.ReleaseShare( - id=share.id, - release_id=share.release_id, - app_id=share.app_id, - is_enabled=share.is_enabled, - share_token=share.share_token, - share_url=share_url, - require_password=share.require_password, - allow_embed=share.allow_embed, - embed_domains=share.embed_domains or [], - view_count=share.view_count, - last_accessed_at=share.last_accessed_at, - created_at=share.created_at, - updated_at=share.updated_at - ) diff --git a/app/services/session_service.py b/app/services/session_service.py deleted file mode 100644 index 938590df..00000000 --- a/app/services/session_service.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import Optional -import json -from datetime import datetime, timedelta, timezone - -from app.aioRedis import aio_redis_set, aio_redis_get, aio_redis_delete -from app.core.config import settings - - -class SessionService: - """用户会话管理服务""" - - @staticmethod - def _get_user_session_key(username: str) -> str: - """获取用户会话的Redis键""" - return f"user_session:{username}" - - @staticmethod - def _get_token_blacklist_key(token_id: str) -> str: - """获取token黑名单的Redis键""" - return f"token_blacklist:{token_id}" - - @staticmethod - async def set_user_active_session(username: str, token_id: str, expires_at: datetime) -> None: - """设置用户的活跃会话 - - Args: - username: 用户名 - token_id: token的唯一标识 - expires_at: token过期时间 - """ - if not settings.ENABLE_SINGLE_SESSION: - return - - session_key = SessionService._get_user_session_key(username) - session_data = { - "token_id": token_id, - "created_at": datetime.now(timezone.utc).isoformat(), - "expires_at": expires_at.isoformat() - } - - # 计算过期时间(秒) - expire_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds()) - if expire_seconds > 0: - await aio_redis_set(session_key, session_data, expire=expire_seconds) - - @staticmethod - async def get_user_active_session(username: str) -> Optional[dict]: - """获取用户的活跃会话 - - Args: - username: 用户名 - - Returns: - 会话数据字典或None - """ - if not settings.ENABLE_SINGLE_SESSION: - return None - - session_key = SessionService._get_user_session_key(username) - session_data = await aio_redis_get(session_key) - - if session_data: - try: - return json.loads(session_data) if isinstance(session_data, str) else session_data - except json.JSONDecodeError: - return None - return None - - @staticmethod - async def invalidate_old_session(username: str, new_token_id: str) -> None: - """使旧会话失效 - - Args: - username: 用户名 - new_token_id: 新token的ID - """ - if not settings.ENABLE_SINGLE_SESSION: - return - - # 获取当前活跃会话 - current_session = await SessionService.get_user_active_session(username) - - if current_session and current_session.get("token_id") != new_token_id: - # 将旧token加入黑名单 - old_token_id = current_session.get("token_id") - if old_token_id: - await SessionService.blacklist_token(old_token_id) - - @staticmethod - async def blacklist_token(token_id: str, expire_seconds: int = None) -> None: - """将token加入黑名单 - - Args: - token_id: token的唯一标识 - expire_seconds: 黑名单过期时间(秒),默认为refresh token的过期时间 - """ - if expire_seconds is None: - # 默认使用refresh token的过期时间 - expire_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 - - blacklist_key = SessionService._get_token_blacklist_key(token_id) - await aio_redis_set(blacklist_key, "blacklisted", expire=expire_seconds) - - @staticmethod - async def is_token_blacklisted(token_id: str) -> bool: - """检查token是否在黑名单中 - - Args: - token_id: token的唯一标识 - - Returns: - True如果token在黑名单中,否则False - """ - if not settings.ENABLE_SINGLE_SESSION: - return False - - blacklist_key = SessionService._get_token_blacklist_key(token_id) - result = await aio_redis_get(blacklist_key) - return result is not None - - @staticmethod - async def clear_user_session(username: str) -> None: - """清除用户会话 - - Args: - username: 用户名 - """ - session_key = SessionService._get_user_session_key(username) - await aio_redis_delete(session_key) - - @staticmethod - async def invalidate_all_user_tokens(user_id: str) -> None: - """使用户的所有 tokens 失效(用于密码重置等场景) - - 通过在 Redis 中设置一个用户级别的失效标记来实现。 - 所有在此时间点之前签发的 tokens 都将被视为无效。 - - Args: - user_id: 用户ID - """ - invalidation_key = f"user_token_invalidation:{user_id}" - current_time = datetime.now(timezone.utc).isoformat() - - # 设置失效时间戳,过期时间为 refresh token 的最大有效期 - expire_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 - await aio_redis_set(invalidation_key, current_time, expire=expire_seconds) - - @staticmethod - async def get_user_token_invalidation_time(user_id: str) -> Optional[str]: - """获取用户 token 失效时间 - - Args: - user_id: 用户ID - - Returns: - 失效时间的 ISO 格式字符串,如果没有失效记录则返回 None - """ - invalidation_key = f"user_token_invalidation:{user_id}" - result = await aio_redis_get(invalidation_key) - return result if result else None \ No newline at end of file diff --git a/app/services/shared_chat_service.py b/app/services/shared_chat_service.py deleted file mode 100644 index fb32b648..00000000 --- a/app/services/shared_chat_service.py +++ /dev/null @@ -1,759 +0,0 @@ -"""基于分享链接的聊天服务""" -import uuid -import time -import asyncio -from typing import Optional, Dict, Any, AsyncGenerator -from sqlalchemy.orm import Session - -from app.models import ReleaseShare, AppRelease, Conversation -from app.services.conversation_service import ConversationService -from app.services.draft_run_service import create_web_search_tool -from app.services.release_share_service import ReleaseShareService -from app.core.exceptions import BusinessException, ResourceNotFoundException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger -from app.services.multi_agent_service import MultiAgentService -from app.models import MultiAgentConfig -from app.repositories import knowledge_repository -import json -logger = get_business_logger() - - -class SharedChatService: - """基于分享链接的聊天服务""" - - def __init__(self, db: Session): - self.db = db - self.conversation_service = ConversationService(db) - self.share_service = ReleaseShareService(db) - - def _get_release_by_share_token( - self, - share_token: str, - password: Optional[str] = None - ) -> tuple[ReleaseShare, AppRelease]: - """通过 share_token 获取发布版本""" - # 获取分享配置 - share = self.share_service.repo.get_by_share_token(share_token) - if not share: - raise ResourceNotFoundException("分享链接", share_token) - - # 验证分享是否启用 - if not share.is_enabled: - raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED) - - # 验证密码 - if share.require_password: - if not password: - raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED) - - if not self.share_service.verify_password(share_token, password): - raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD) - - # 获取发布版本 - release = self.db.get(AppRelease, share.release_id) - if not release: - raise ResourceNotFoundException("发布版本", str(share.release_id)) - - # 更新访问统计 - try: - self.share_service.repo.increment_view_count(share.id) - except Exception as e: - logger.warning(f"更新访问统计失败: {str(e)}") - - return share, release - - def create_or_get_conversation( - self, - share_token: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - password: Optional[str] = None - ) -> Conversation: - """创建或获取会话""" - share, release = self._get_release_by_share_token(share_token, password) - - # 如果提供了 conversation_id,尝试获取现有会话 - if conversation_id: - try: - conversation = self.conversation_service.get_conversation( - conversation_id=conversation_id, - workspace_id=release.app.workspace_id - ) - - # 验证会话是否属于该应用 - if conversation.app_id != release.app_id: - raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) - - return conversation - except ResourceNotFoundException: - logger.warning( - f"会话不存在,将创建新会话", - extra={"conversation_id": str(conversation_id)} - ) - - # 创建新会话(使用发布版本的配置) - conversation = self.conversation_service.create_conversation( - app_id=release.app_id, - workspace_id=release.app.workspace_id, - user_id=user_id, - is_draft=False, # 分享链接使用发布版本 - config_snapshot=release.config - ) - - logger.info( - f"为分享链接创建新会话", - extra={ - "conversation_id": str(conversation.id), - "share_token": share_token, - "release_id": str(release.id) - } - ) - - return conversation - - async def chat( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True - ) -> Dict[str, Any]: - """聊天(非流式)""" - from app.core.agent.langchain_agent import LangChainAgent - from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool - from app.services.model_parameter_merger import ModelParameterMerger - from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey - - start_time = time.time() - - if variables is None: - variables = {} - - # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - - # 获取 Agent 配置 - config = release.config or {} - - - # 获取模型配置ID - model_config_id = release.default_model_config_id - if not model_config_id: - raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) - - # 获取模型配置 - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_config_id)) - - # 获取 API Key - stmt = ( - select(ModelApiKey) - .where( - ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active == True - ) - .order_by(ModelApiKey.priority.desc()) - .limit(1) - ) - api_key_obj = self.db.scalars(stmt).first() - if not api_key_obj: - raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - - # 获取或创建会话 - conversation = self.create_or_get_conversation( - share_token=share_token, - conversation_id=conversation_id, - user_id=user_id, - password=password - ) - - # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "你是一个专业的AI助手") - if variables: - system_prompt_rendered = render_prompt_message( - system_prompt, - PromptMessageRole.USER, - variables - ) - system_prompt = system_prompt_rendered.get_text_content() or system_prompt - - # 准备工具列表 - tools = [] - - # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 - - if memory==True: - memory_config = config.get("memory", {}) - if memory_config.get("enabled") and user_id: - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) - - web_tools=config.get("tools") - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled",False) - if web_search==True: - if web_search_enable==True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 获取模型参数 - model_parameters = config.get("model_parameters", {}) - - # 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_obj.model_name, - api_key=api_key_obj.api_key, - provider=api_key_obj.provider, - api_base=api_key_obj.api_base, - temperature=model_parameters.get("temperature", 0.7), - max_tokens=model_parameters.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - ) - - # 加载历史消息 - history = [] - memory_config={"enabled":True,'max_history':10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation.id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] - - # 调用 Agent - result = await agent.chat( - message=message, - history=history, - context=None, - end_user_id=user_id - ) - - # 保存消息 - self.conversation_service.save_conversation_messages( - conversation_id=conversation.id, - user_message=message, - assistant_message=result["content"] - ) - # self.conversation_service.add_message( - # conversation_id=conversation.id, - # role="user", - # content=message - # ) - - # self.conversation_service.add_message( - # conversation_id=conversation.id, - # role="assistant", - # content=result["content"], - # meta_data={ - # "model": api_key_obj.model_name, - # "usage": result.get("usage", {}) - # } - # ) - - elapsed_time = time.time() - start_time - - return { - "conversation_id": conversation.id, - "message": result["content"], - "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }), - "elapsed_time": elapsed_time - } - - async def chat_stream( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True - ) -> AsyncGenerator[str, None]: - """聊天(流式)""" - from app.core.agent.langchain_agent import LangChainAgent - from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool - from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey - import json - - start_time = time.time() - - if variables is None: - variables = {} - memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10} - - try: - # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - - # 获取 Agent 配置 - config = release.config or {} - agent_config_data = config.get("agent_config", {}) - - # 获取模型配置ID - model_config_id = release.default_model_config_id - if not model_config_id: - raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) - - # 获取模型配置 - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_config_id)) - - # 获取 API Key - stmt = ( - select(ModelApiKey) - .where( - ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active == True - ) - .order_by(ModelApiKey.priority.desc()) - .limit(1) - ) - api_key_obj = self.db.scalars(stmt).first() - if not api_key_obj: - raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - - # 获取或创建会话 - conversation = self.create_or_get_conversation( - share_token=share_token, - conversation_id=conversation_id, - user_id=user_id, - password=password - ) - - # 处理系统提示词(支持变量替换) - system_prompt = config.get("system_prompt", "你是一个专业的AI助手") - if variables: - system_prompt_rendered = render_prompt_message( - system_prompt, - PromptMessageRole.USER, - variables - ) - system_prompt = system_prompt_rendered.get_text_content() or system_prompt - - # 准备工具列表 - tools = [] - - # 添加知识库检索工具 - knowledge_retrieval = config.get("knowledge_retrieval") - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 - if memory: - memory_config = config.get("memory", {}) - if memory_config.get("enabled") and user_id: - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) - - web_tools = config.get("tools") - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 获取模型参数 - model_parameters = config.get("model_parameters", {}) - - # 创建 LangChain Agent - agent = LangChainAgent( - model_name=api_key_obj.model_name, - api_key=api_key_obj.api_key, - provider=api_key_obj.provider, - api_base=api_key_obj.api_base, - temperature=model_parameters.get("temperature", 0.7), - max_tokens=model_parameters.get("max_tokens", 2000), - system_prompt=system_prompt, - tools=tools, - streaming=True - ) - - # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation.id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] - - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" - - # 流式调用 Agent - full_content = "" - async for chunk in agent.chat_stream( - message=message, - history=history, - context=None, - end_user_id=user_id - ): - full_content += chunk - # 发送消息块事件 - yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" - - elapsed_time = time.time() - start_time - - # 保存消息 - self.conversation_service.add_message( - conversation_id=conversation.id, - role="user", - content=message - ) - - self.conversation_service.add_message( - conversation_id=conversation.id, - role="assistant", - content=full_content, - meta_data={ - "model": api_key_obj.model_name, - "usage": {} - } - ) - - # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} - yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" - - logger.info( - f"流式聊天完成", - extra={ - "conversation_id": str(conversation.id), - "elapsed_time": elapsed_time, - "message_length": len(full_content) - } - ) - - except (GeneratorExit, asyncio.CancelledError): - # 生成器被关闭或任务被取消,正常退出 - logger.debug("流式聊天被中断") - raise - except Exception as e: - logger.error(f"流式聊天失败: {str(e)}", exc_info=True) - # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" - - def get_conversation_messages( - self, - share_token: str, - conversation_id: uuid.UUID, - password: Optional[str] = None - ) -> Conversation: - """获取会话消息""" - share, release = self._get_release_by_share_token(share_token, password) - - # 获取会话 - conversation = self.conversation_service.get_conversation( - conversation_id=conversation_id, - workspace_id=release.app.workspace_id - ) - - # 验证会话是否属于该应用 - if conversation.app_id != release.app_id: - raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) - - return conversation - - def list_conversations( - self, - share_token: str, - user_id: Optional[str] = None, - password: Optional[str] = None, - page: int = 1, - pagesize: int = 20 - ) -> tuple[list[Conversation], int]: - """列出会话""" - share, release = self._get_release_by_share_token(share_token, password) - - conversations, total = self.conversation_service.list_conversations( - app_id=release.app_id, - workspace_id=release.app.workspace_id, - user_id=user_id, - is_draft=False, # 只显示发布版本的会话 - page=page, - pagesize=pagesize - ) - - return conversations, total - - async def multi_agent_chat( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True - ) -> Dict[str, Any]: - """多 Agent 聊天(非流式)""" - from app.services.multi_agent_service import MultiAgentService - from app.models import MultiAgentConfig - - start_time = time.time() - - if variables is None: - variables = {} - - # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - - # 获取或创建会话 - conversation = self.create_or_get_conversation( - share_token=share_token, - conversation_id=conversation_id, - user_id=user_id, - password=password - ) - - # 获取多 Agent 配置 - multi_agent_config = self.db.query(MultiAgentConfig).filter( - MultiAgentConfig.app_id == release.app_id, - MultiAgentConfig.is_active == True - ).first() - - if not multi_agent_config: - raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 构建多 Agent 运行请求 - from app.schemas.multi_agent_schema import MultiAgentRunRequest - - multi_agent_request = MultiAgentRunRequest( - message=message, - conversation_id=conversation.id, - user_id=user_id, - variables=variables, - use_llm_routing=True, - web_search=web_search, - memory=memory - ) - - # 使用多 Agent 服务执行 - multi_agent_service = MultiAgentService(self.db) - result = await multi_agent_service.run( - app_id=release.app_id, - request=multi_agent_request - ) - - elapsed_time = time.time() - start_time - - # 保存消息 - self.conversation_service.add_message( - conversation_id=conversation.id, - role="user", - content=message - ) - - self.conversation_service.add_message( - conversation_id=conversation.id, - role="assistant", - content=result.get("message", ""), - meta_data={ - "mode": result.get("mode"), - "elapsed_time": result.get("elapsed_time"), - "sub_results": result.get("sub_results") - } - ) - - return { - "conversation_id": conversation.id, - "message": result.get("message", ""), - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }, - "elapsed_time": elapsed_time - } - - async def multi_agent_chat_stream( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True - ) -> AsyncGenerator[str, None]: - """多 Agent 聊天(流式)""" - - start_time = time.time() - - if variables is None: - variables = {} - - try: - # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - - # 获取或创建会话 - conversation = self.create_or_get_conversation( - share_token=share_token, - conversation_id=conversation_id, - user_id=user_id, - password=password - ) - - # 获取多 Agent 配置 - multi_agent_config = self.db.query(MultiAgentConfig).filter( - MultiAgentConfig.app_id == release.app_id, - MultiAgentConfig.is_active == True - ).first() - - if not multi_agent_config: - raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 获取 storage_type 和 user_rag_memory_id - workspace_id = release.app.workspace_id - storage_type = 'neo4j' # 默认值 - user_rag_memory_id = '' - - try: - # 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享) - from app.models import Workspace - workspace = self.db.get(Workspace, workspace_id) - if workspace and workspace.storage_type: - storage_type = workspace.storage_type - - # 获取 USER_RAG_MERORY 知识库 ID - knowledge = knowledge_repository.get_knowledge_by_name( - db=self.db, - name="USER_RAG_MERORY", - workspace_id=workspace_id - ) - if knowledge: - user_rag_memory_id = str(knowledge.id) - except Exception as e: - logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}") - - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" - - # 构建多 Agent 运行请求 - from app.schemas.multi_agent_schema import MultiAgentRunRequest - - multi_agent_request = MultiAgentRunRequest( - message=message, - conversation_id=conversation.id, - user_id=user_id, - variables=variables, - use_llm_routing=True, - web_search=web_search, - memory=memory - ) - - # 使用多 Agent 服务流式执行 - multi_agent_service = MultiAgentService(self.db) - full_content = "" - - async for event in multi_agent_service.run_stream( - app_id=release.app_id, - request=multi_agent_request, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id - ): - # 直接转发事件 - yield event - - # 尝试提取内容(用于保存) - if "data:" in event: - try: - data_line = event.split("data: ", 1)[1].strip() - data = json.loads(data_line) - if "content" in data: - full_content += data["content"] - except: - pass - - elapsed_time = time.time() - start_time - - # 保存消息 - self.conversation_service.add_message( - conversation_id=conversation.id, - role="user", - content=message - ) - - self.conversation_service.add_message( - conversation_id=conversation.id, - role="assistant", - content=full_content, - meta_data={ - "elapsed_time": elapsed_time - } - ) - - logger.info( - f"多 Agent 流式聊天完成", - extra={ - "conversation_id": str(conversation.id), - "elapsed_time": elapsed_time, - "message_length": len(full_content) - } - ) - - except (GeneratorExit, asyncio.CancelledError): - # 生成器被关闭或任务被取消,正常退出 - logger.debug("多 Agent 流式聊天被中断") - raise - except Exception as e: - logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True) - # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" diff --git a/app/services/smart_router.py b/app/services/smart_router.py deleted file mode 100644 index 5889440b..00000000 --- a/app/services/smart_router.py +++ /dev/null @@ -1,426 +0,0 @@ -"""智能路由器 - 解决多轮对话路由错乱""" -import re -from typing import Dict, Any, List, Optional, Tuple -from app.services.conversation_state_manager import ConversationStateManager -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class SmartRouter: - """智能路由器 - - 核心功能: - 1. 检测主题切换 - 2. 判断是否应该继续使用当前 Agent - 3. 智能选择最合适的 Agent - 4. 支持强制重新路由 - """ - - # 主题切换信号 - SWITCH_SIGNALS = [ - "换个话题", "另外", "还有", "对了", - "那这个呢", "再问一个", "顺便问下", - "我想问", "帮我", "请问", "换一个" - ] - - # 延续信号 - CONTINUATION_SIGNALS = [ - "继续", "还是", "也", "同样", "类似", - "这个", "那个", "它", "他", "她", "呢" - ] - - def __init__( - self, - state_manager: ConversationStateManager, - routing_rules: List[Dict[str, Any]], - sub_agents: Dict[str, Any] - ): - """初始化智能路由器 - - Args: - state_manager: 会话状态管理器 - routing_rules: 路由规则列表 - sub_agents: 子 Agent 配置字典 - """ - self.state_manager = state_manager - self.routing_rules = routing_rules - self.sub_agents = sub_agents - - # 配置参数 - self.min_confidence_for_switch = 0.7 # 切换 Agent 的最小置信度 - self.max_same_agent_turns = 10 # 同一 Agent 最大连续轮数 - - async def route( - self, - message: str, - conversation_id: Optional[str] = None, - force_new: bool = False - ) -> Dict[str, Any]: - """智能路由 - - Args: - message: 用户消息 - conversation_id: 会话 ID - force_new: 是否强制重新路由(忽略历史) - - Returns: - 路由结果 { - "agent_id": str, - "confidence": float, - "strategy": str, - "topic": str, - "topic_changed": bool, - "reason": str - } - """ - logger.info( - f"开始智能路由", - extra={ - "message_length": len(message), - "conversation_id": conversation_id, - "force_new": force_new - } - ) - - # 1. 获取会话状态 - state = None - if conversation_id and not force_new: - state = self.state_manager.get_state(conversation_id) - - # 2. 检测主题切换 - topic_changed = self._detect_topic_change(message, state) - - # 3. 提取当前主题 - topic = self._extract_topic(message) - - # 4. 选择路由策略 - if force_new: - # 强制重新路由 - agent_id, confidence = self._route_from_scratch(message) - strategy = "force_new" - reason = "用户强制重新路由" - - elif not state or not state.get("current_agent_id"): - # 新会话,从头路由 - agent_id, confidence = self._route_from_scratch(message) - strategy = "new_conversation" - reason = "新会话,首次路由" - - elif topic_changed: - # 主题切换,重新路由 - agent_id, confidence = self._route_from_scratch(message) - strategy = "topic_changed" - reason = f"检测到主题切换: {state.get('last_topic')} -> {topic}" - - elif state.get("same_agent_turns", 0) >= self.max_same_agent_turns: - # 同一 Agent 使用太久,强制重新评估 - agent_id, confidence = self._route_from_scratch(message) - strategy = "max_turns_reached" - reason = f"同一 Agent 已使用 {state['same_agent_turns']} 轮" - - else: - # 检查是否应该继续使用当前 Agent - current_agent_id = state["current_agent_id"] - should_continue, continue_confidence = self._should_continue_current_agent( - message, - current_agent_id - ) - - if should_continue: - # 继续使用当前 Agent - agent_id = current_agent_id - confidence = continue_confidence - strategy = "continue_current" - reason = "消息在当前 Agent 能力范围内" - else: - # 重新路由 - new_agent_id, new_confidence = self._route_from_scratch(message) - - # 只有新 Agent 的置信度明显更高时才切换 - if new_confidence > continue_confidence + self.min_confidence_for_switch: - agent_id = new_agent_id - confidence = new_confidence - strategy = "switch_agent" - reason = f"新 Agent 置信度更高: {new_confidence:.2f} vs {continue_confidence:.2f}" - else: - # 置信度差距不大,继续使用当前 Agent - agent_id = current_agent_id - confidence = continue_confidence - strategy = "keep_current" - reason = "置信度差距不足以切换 Agent" - - # 5. 更新会话状态 - if conversation_id: - self.state_manager.update_state( - conversation_id, - agent_id, - message, - topic, - confidence - ) - - result = { - "agent_id": agent_id, - "confidence": confidence, - "strategy": strategy, - "topic": topic, - "topic_changed": topic_changed, - "reason": reason - } - - logger.info( - f"路由完成", - extra={ - "agent_id": agent_id, - "strategy": strategy, - "confidence": confidence, - "topic": topic - } - ) - - return result - - def _detect_topic_change( - self, - message: str, - state: Optional[Dict[str, Any]] - ) -> bool: - """检测主题是否切换 - - Args: - message: 用户消息 - state: 会话状态 - - Returns: - 是否切换主题 - """ - if not state or not state.get("last_topic"): - return False - - # 检查明确的切换信号 - for signal in self.SWITCH_SIGNALS: - if signal in message: - logger.info(f"检测到主题切换信号: {signal}") - return True - - # 比较主题 - current_topic = self._extract_topic(message) - last_topic = state.get("last_topic") - - if current_topic != last_topic and current_topic != "其他": - logger.info(f"主题变化: {last_topic} -> {current_topic}") - return True - - return False - - def _should_continue_current_agent( - self, - message: str, - current_agent_id: str - ) -> Tuple[bool, float]: - """判断是否应该继续使用当前 Agent - - Args: - message: 用户消息 - current_agent_id: 当前 Agent ID - - Returns: - (是否继续, 置信度) - """ - # 检查延续信号 - has_continuation_signal = any( - signal in message - for signal in self.CONTINUATION_SIGNALS - ) - - # 计算当前 Agent 对消息的匹配度 - current_score = self._calculate_agent_score(message, current_agent_id) - - # 如果有延续信号且匹配度不太低,继续使用 - if has_continuation_signal and current_score > 0.3: - return True, min(current_score + 0.2, 1.0) - - # 如果匹配度高,继续使用 - if current_score > 0.6: - return True, current_score - - return False, current_score - - def _route_from_scratch(self, message: str) -> Tuple[str, float]: - """从头开始路由(不考虑历史) - - Args: - message: 用户消息 - - Returns: - (Agent ID, 置信度) - """ - best_agent_id = None - best_score = 0.0 - - # 遍历所有路由规则 - for rule in self.routing_rules: - score = self._calculate_rule_score(message, rule) - - if score > best_score: - best_score = score - best_agent_id = rule.get("target_agent_id") - - # 如果没有匹配的规则,使用默认 Agent - if not best_agent_id or best_score < 0.3: - best_agent_id = self._get_default_agent_id() - best_score = 0.5 - logger.warning(f"未找到匹配规则,使用默认 Agent: {best_agent_id}") - - return best_agent_id, best_score - - def _calculate_rule_score( - self, - message: str, - rule: Dict[str, Any] - ) -> float: - """计算规则匹配分数 - - Args: - message: 用户消息 - rule: 路由规则 - - Returns: - 匹配分数 (0-1) - """ - score = 0.0 - message_lower = message.lower() - - # 1. 关键词匹配 (权重 0.6) - keywords = rule.get("keywords", []) - if keywords: - matched_keywords = sum( - 1 for keyword in keywords - if keyword.lower() in message_lower - ) - keyword_score = matched_keywords / len(keywords) - score += keyword_score * 0.6 - - # 2. 正则匹配 (权重 0.3) - patterns = rule.get("patterns", []) - if patterns: - matched_patterns = sum( - 1 for pattern in patterns - if re.search(pattern, message, re.IGNORECASE) - ) - pattern_score = matched_patterns / len(patterns) - score += pattern_score * 0.3 - - # 3. 排除关键词 (负分) - exclude_keywords = rule.get("exclude_keywords", []) - if exclude_keywords: - has_exclude = any( - keyword.lower() in message_lower - for keyword in exclude_keywords - ) - if has_exclude: - score *= 0.5 # 减半 - - # 4. 最小关键词数量要求 - min_keyword_count = rule.get("min_keyword_count", 0) - if keywords and min_keyword_count > 0: - matched_count = sum( - 1 for keyword in keywords - if keyword.lower() in message_lower - ) - if matched_count < min_keyword_count: - score *= 0.7 # 惩罚 - - return min(score, 1.0) - - def _calculate_agent_score( - self, - message: str, - agent_id: str - ) -> float: - """计算 Agent 对消息的匹配分数 - - Args: - message: 用户消息 - agent_id: Agent ID - - Returns: - 匹配分数 (0-1) - """ - # 找到该 Agent 对应的所有规则 - agent_rules = [ - rule for rule in self.routing_rules - if rule.get("target_agent_id") == agent_id - ] - - if not agent_rules: - return 0.0 - - # 返回最高分数 - max_score = max( - self._calculate_rule_score(message, rule) - for rule in agent_rules - ) - - return max_score - - def _extract_topic(self, message: str) -> str: - """提取消息主题 - - Args: - message: 用户消息 - - Returns: - 主题名称 - """ - # 主题关键词映射 - topic_keywords = { - "数学": ["数学", "方程", "计算", "求解", "x", "y", "函数", "几何"], - "物理": ["物理", "力", "速度", "加速度", "能量", "功率", "电路"], - "化学": ["化学", "方程式", "反应", "元素", "分子", "原子", "化合物"], - "语文": ["语文", "古诗", "作文", "阅读", "文言文", "诗词"], - "英语": ["英语", "单词", "语法", "翻译", "时态", "句型"], - "历史": ["历史", "朝代", "事件", "人物", "战争", "革命"], - "作业": ["作业", "批改", "检查", "评分", "反馈"], - "学习规划": ["计划", "规划", "方法", "技巧", "时间", "安排"], - "订单": ["订单", "发货", "物流", "配送", "快递"], - "退款": ["退款", "退货", "售后", "换货", "维修"], - "账户": ["账户", "密码", "登录", "注册", "绑定"], - "支付": ["支付", "付款", "充值", "余额", "优惠券"] - } - - message_lower = message.lower() - - # 统计每个主题的匹配度 - topic_scores = {} - for topic, keywords in topic_keywords.items(): - matched = sum( - 1 for keyword in keywords - if keyword in message_lower - ) - if matched > 0: - topic_scores[topic] = matched - - # 返回匹配度最高的主题 - if topic_scores: - best_topic = max(topic_scores.items(), key=lambda x: x[1])[0] - return best_topic - - return "其他" - - def _get_default_agent_id(self) -> str: - """获取默认 Agent ID - - Returns: - 默认 Agent ID - """ - # 优先使用第一个路由规则的 Agent - if self.routing_rules: - return self.routing_rules[0].get("target_agent_id") - - # 否则使用第一个子 Agent - if self.sub_agents: - return list(self.sub_agents.keys())[0] - - return "default-agent" diff --git a/app/services/task_service.py b/app/services/task_service.py deleted file mode 100644 index 6350001b..00000000 --- a/app/services/task_service.py +++ /dev/null @@ -1,52 +0,0 @@ -from app.celery_app import celery_app - -def create_processing_task(item_data: dict) -> str: - """ - Sends a task to the Celery queue to process an item. - - :param item_data: The dictionary representation of the item. - :return: The ID of the created task. - """ - task = celery_app.send_task("tasks.process_item", args=[item_data]) - return task.id - -def get_task_result(task_id: str) -> dict: - """ - Checks the status and result of a Celery task. - - :param task_id: The ID of the task to check. - :return: A dictionary with the task's status and result (if ready). - """ - result = celery_app.AsyncResult(task_id) - - if result.ready(): - return {"status": result.status, "result": result.get()} - - return {"status": result.status} -def get_task_memory_read_result(task_id: str) -> dict: - """ - Checks the status and result of a memory read task. - - :param task_id: The ID of the task to check. - :return: A dictionary with the task's status and result (if ready). - """ - result = celery_app.AsyncResult(task_id) - - if result.ready(): - return {"status": result.status, "result": result.get()} - - return {"status": result.status} - -def get_task_memory_write_result(task_id: str) -> dict: - """ - Checks the status and result of a memory write task. - - :param task_id: The ID of the task to check. - :return: A dictionary with the task's status and result (if ready). - """ - result = celery_app.AsyncResult(task_id) - - if result.ready(): - return {"status": result.status, "result": result.get()} - - return {"status": result.status} diff --git a/app/services/tenant_service.py b/app/services/tenant_service.py deleted file mode 100644 index 2edb46df..00000000 --- a/app/services/tenant_service.py +++ /dev/null @@ -1,220 +0,0 @@ -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid - -from app.core.logging_config import get_business_logger -from app.repositories.tenant_repository import TenantRepository -from app.repositories.user_repository import UserRepository -from app.repositories.workspace_repository import WorkspaceRepository -from app.schemas.tenant_schema import ( - TenantCreate, TenantUpdate, Tenant, TenantQuery, TenantList -) -from app.schemas.user_schema import User -from app.schemas.workspace_schema import WorkspaceCreate -from app.models.tenant_model import Tenants -from app.models.user_model import User as UserModel -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - -# 获取业务逻辑专用日志器 -business_logger = get_business_logger() - -class TenantService: - """租户业务逻辑层""" - - def __init__(self, db: Session): - self.db = db - self.tenant_repo = TenantRepository(db) - self.user_repo = UserRepository(db) - self.workspace_repo = WorkspaceRepository(db) - - def create_tenant(self, tenant_data: TenantCreate) -> Tenants: - """创建租户""" - # 检查租户名称是否已存在 - existing_tenant = self.tenant_repo.get_tenant_by_name(tenant_data.name) - if existing_tenant: - raise BusinessException(f"租户名称 '{tenant_data.name}' 已存在", code=BizCode.DUPLICATE_NAME) - - try: - tenant = self.tenant_repo.create_tenant(tenant_data) - business_logger.info(f"创建租户成功: {tenant.name} (ID: {tenant.id})") - return tenant - except Exception as e: - business_logger.error(f"创建租户失败: {str(e)}") - raise BusinessException(f"创建租户失败: {str(e)}", code=BizCode.DB_ERROR) - - def create_tenant_and_assign_user(self, tenant_data: TenantCreate, user_id: uuid.UUID) -> Tenants: - """创建租户并分配用户""" - try: - # 创建租户 - tenant = self.create_tenant(tenant_data) - - # 将用户分配给租户 - success = self.user_repo.assign_user_to_tenant(user_id, tenant.id) - if not success: - raise BusinessException("分配用户到租户失败", code=BizCode.STATE_CONFLICT) - - business_logger.info(f"创建租户并分配用户成功: {tenant.name}") - return tenant - - except Exception as e: - business_logger.error(f"创建租户和分配用户失败: {str(e)}") - self.db.rollback() - raise BusinessException(f"创建租户失败: {str(e)}", code=BizCode.DB_ERROR) - - def get_tenant(self, tenant_id: uuid.UUID) -> Optional[Tenants]: - """获取租户""" - return self.tenant_repo.get_tenant_by_id(tenant_id) - - def get_tenant_by_name(self, name: str) -> Optional[Tenants]: - """根据名称获取租户""" - return self.tenant_repo.get_tenant_by_name(name) - - def get_tenants(self, query: TenantQuery) -> TenantList: - """获取租户列表""" - skip = (query.page - 1) * query.size - - tenants = self.tenant_repo.get_tenants( - skip=skip, - limit=query.size, - is_active=query.is_active, - search=query.search - ) - - total = self.tenant_repo.count_tenants( - is_active=query.is_active, - search=query.search - ) - - pages = (total + query.size - 1) // query.size - - return TenantList( - items=[Tenant.model_validate(tenant) for tenant in tenants], - total=total, - page=query.page, - size=query.size, - pages=pages - ) - - def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]: - """更新租户""" - # 如果更新名称,检查是否重复 - if tenant_data.name: - existing_tenant = self.tenant_repo.get_tenant_by_name(tenant_data.name) - if existing_tenant and existing_tenant.id != tenant_id: - raise BusinessException(f"租户名称 '{tenant_data.name}' 已存在", code=BizCode.DUPLICATE_NAME) - - try: - tenant = self.tenant_repo.update_tenant(tenant_id, tenant_data) - if tenant: - business_logger.info(f"更新租户成功: {tenant.name} (ID: {tenant.id})") - return tenant - except Exception as e: - business_logger.error(f"更新租户失败: {str(e)}") - raise BusinessException(f"更新租户失败: {str(e)}", code=BizCode.DB_ERROR) - - def delete_tenant(self, tenant_id: uuid.UUID) -> bool: - """删除租户""" - try: - # 检查租户是否存在 - tenant = self.tenant_repo.get_tenant_by_id(tenant_id) - if not tenant: - return False - - # 检查是否有关联的用户 - users = self.tenant_repo.get_tenant_users(tenant_id) - if users: - raise BusinessException("无法删除租户,存在关联的用户", code=BizCode.STATE_CONFLICT) - - # 检查是否有关联的工作空间 - workspaces = self.workspace_repo.get_workspaces_by_tenant(tenant_id) - if workspaces: - raise BusinessException("无法删除租户,存在关联的工作空间", code=BizCode.STATE_CONFLICT) - - success = self.tenant_repo.delete_tenant(tenant_id) - if success: - business_logger.info(f"删除租户成功: {tenant.name} (ID: {tenant.id})") - return success - - except Exception as e: - business_logger.error(f"删除租户失败: {str(e)}") - raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR) - - # 租户用户管理 - def get_tenant_users( - self, - tenant_id: uuid.UUID, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> List[UserModel]: - """获取租户下的用户列表""" - return self.user_repo.get_users_by_tenant( - tenant_id=tenant_id, - skip=skip, - limit=limit, - is_active=is_active, - search=search - ) - - def count_tenant_users( - self, - tenant_id: uuid.UUID, - is_active: Optional[bool] = None, - search: Optional[str] = None - ) -> int: - """统计租户下的用户数量""" - return self.user_repo.count_users_by_tenant( - tenant_id=tenant_id, - is_active=is_active, - search=search - ) - - def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: - """将用户分配给租户""" - # 检查租户是否存在 - tenant = self.tenant_repo.get_tenant_by_id(tenant_id) - if not tenant: - raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND) - - try: - success = self.user_repo.assign_user_to_tenant(user_id, tenant_id) - if success: - business_logger.info(f"分配用户到租户成功: 用户ID {user_id}, 租户ID {tenant_id}") - return success - except Exception as e: - business_logger.error(f"分配用户到租户失败: {str(e)}") - raise BusinessException(f"分配用户到租户失败: {str(e)}", code=BizCode.DB_ERROR) - - def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]: - """获取用户所属的租户""" - return self.tenant_repo.get_user_tenant(user_id) - - def remove_user_from_tenant(self, user_id: uuid.UUID) -> bool: - """将用户从租户中移除(设置tenant_id为None)""" - try: - user = self.user_repo.get_user_by_id(user_id) - if not user: - return False - - success = self.user_repo.assign_user_to_tenant(user_id, None) - if success: - business_logger.info(f"移除用户租户关联成功: 用户ID {user_id}") - return success - except Exception as e: - business_logger.error(f"移除用户租户关联失败: {str(e)}") - raise BusinessException(f"移除用户租户关联失败: {str(e)}", code=BizCode.DB_ERROR) - - def get_users_without_tenant( - self, - skip: int = 0, - limit: int = 100, - is_active: Optional[bool] = None - ) -> List[UserModel]: - """获取没有租户的用户列表""" - return self.user_repo.get_users_without_tenant( - skip=skip, - limit=limit, - is_active=is_active - ) \ No newline at end of file diff --git a/app/services/upload_service.py b/app/services/upload_service.py deleted file mode 100644 index a149ebf2..00000000 --- a/app/services/upload_service.py +++ /dev/null @@ -1,617 +0,0 @@ -""" -Upload Service for Generic File Upload System -Handles file upload, storage, access, deletion, and metadata updates. -""" -import os -import uuid -import shutil -from pathlib import Path -from typing import Dict, Any, List, Optional -from datetime import datetime -from sqlalchemy.orm import Session -from fastapi import UploadFile - -from app.models.user_model import User -from app.models.generic_file_model import GenericFile -from app.repositories.generic_file_repository import GenericFileRepository -from app.core.upload_enums import UploadContext -from app.core.storage_strategy import StrategyFactory -from app.core.validators.file_validator import FileValidator -from app.core.exceptions import BusinessException, PermissionDeniedException -from app.core.error_codes import BizCode -from app.core.config import settings -from app.core.logging_config import get_logger -from app.core.uow import IUnitOfWork -from app.core.compensation import CompensationHandler - -# Get logger -logger = get_logger(__name__) - - -class FileNotFoundError(BusinessException): - """Exception raised when file is not found.""" - def __init__(self, file_id: uuid.UUID): - super().__init__( - f"文件 {file_id} 不存在", - code=BizCode.NOT_FOUND - ) - - -class FileAccessDeniedError(BusinessException): - """Exception raised when file access is denied.""" - def __init__(self, file_id: uuid.UUID): - super().__init__( - f"无权访问文件 {file_id}", - code=BizCode.FORBIDDEN - ) - - -class FileStorageError(BusinessException): - """Exception raised when file storage fails.""" - def __init__(self, reason: str): - super().__init__( - f"文件存储失败: {reason}", - code=BizCode.INTERNAL_ERROR - ) - - -class FileReferencedError(BusinessException): - """Exception raised when trying to delete a referenced file.""" - def __init__(self, file_id: uuid.UUID, reference_count: int): - super().__init__( - f"文件 {file_id} 被 {reference_count} 个资源引用,无法删除", - code=BizCode.BAD_REQUEST - ) - - -class UploadResult: - """Result of a file upload operation.""" - def __init__(self, success: bool, file_id: Optional[uuid.UUID] = None, - file_name: str = "", error: Optional[str] = None): - self.success = success - self.file_id = file_id - self.file_name = file_name - self.error = error - - -class UploadService: - """ - Service for handling file uploads and management. - Coordinates validation, storage, and database operations. - Uses Unit of Work pattern for transaction management. - """ - - def __init__(self, uow: IUnitOfWork = None): - self.validator = FileValidator() - self.uow = uow - - def upload_file( - self, - file: UploadFile, - context: UploadContext, - metadata: Optional[Dict[str, Any]], - current_user: User, - db: Session = None - ) -> GenericFile: - """ - Upload a single file using Unit of Work pattern with compensation transactions. - - Args: - file: The uploaded file - context: Upload context (avatar, app_icon, etc.) - metadata: Additional metadata for the file - current_user: The user uploading the file - db: Optional database session (for backward compatibility) - - Returns: - GenericFile: The created file record - - Raises: - FileSizeExceededError: If file size exceeds limit - FileTypeNotAllowedError: If file type is not allowed - EmptyFileError: If file is empty - FileStorageError: If file storage fails - """ - logger.info(f"Starting file upload: filename={file.filename}, context={context}, user={current_user.id}") - - if metadata is None: - metadata = {} - - # Get storage strategy for this context - strategy = StrategyFactory.get_strategy(context) - upload_policy = strategy.get_upload_policy() - - # Validate file against upload policy - logger.debug(f"Validating file: {file.filename}") - self.validator.validate_and_raise(file, upload_policy) - - # Generate file ID - file_id = uuid.uuid4() - - # Extract file information - filename = file.filename or "unknown" - file_extension = "" - if "." in filename: - file_extension = "." + filename.rsplit(".", 1)[1].lower() - - # Get file size - file.file.seek(0, 2) - file_size = file.file.tell() - file.file.seek(0) - - # Get storage path - storage_path = strategy.get_storage_path( - tenant_id=current_user.tenant_id, - file_id=file_id, - file_extension=file_extension, - metadata=metadata - ) - - logger.debug(f"Storage path: {storage_path}") - - # Use Unit of Work pattern with compensation handler - compensation = CompensationHandler() - - try: - # Use provided UoW or create a new one for backward compatibility - if self.uow: - uow = self.uow - should_manage_context = False - else: - # Backward compatibility: use provided db session - if db: - # Create a temporary UoW wrapper for the existing session - from app.core.uow import SqlAlchemyUnitOfWork - uow = SqlAlchemyUnitOfWork(lambda: db) - uow._session = db - uow.files = GenericFileRepository(db) - should_manage_context = False - else: - raise FileStorageError("Either uow or db session must be provided") - - # 1. Save physical file - self._save_physical_file(file, storage_path) - - # Register compensation: delete physical file if database operation fails - compensation.register(lambda: self._delete_physical_file(storage_path)) - - # 2. Generate access URL - access_url = None - if context in [UploadContext.AVATAR, UploadContext.APP_ICON]: - access_url = f"{settings.FILE_ACCESS_URL_PREFIX}/{file_id}" - - # 3. Create file data - file_data = { - "id": file_id, - "tenant_id": current_user.tenant_id, - "created_by": current_user.id, - "file_name": filename, - "file_ext": file_extension, - "file_size": file_size, - "mime_type": file.content_type, - "context": context.value, - "storage_path": str(storage_path), - "file_metadata": metadata, - "status": "active", - "is_public": metadata.get("is_public", False), - "access_url": access_url, - "reference_count": 0, - } - - # 4. Create database record - db_file = uow.files.create_file(file_data) - - # 5. Commit transaction (only if we're managing the session) - if should_manage_context: - uow.commit() - elif db: - db.commit() - - # Success - clear compensation operations - compensation.clear() - - logger.info(f"File upload completed successfully: {filename} (ID: {file_id})") - return db_file - - except Exception as e: - # Execute compensation operations - compensation.execute() - - # Rollback if we're managing the session - if db: - db.rollback() - - logger.error(f"File upload failed: {str(e)}") - raise FileStorageError(f"文件上传失败: {str(e)}") - - def _save_physical_file(self, file: UploadFile, storage_path: Path): - """ - Save physical file to filesystem. - - Args: - file: The uploaded file - storage_path: Path where file should be saved - - Raises: - FileStorageError: If file save fails - """ - try: - # Create directory if it doesn't exist - storage_path.parent.mkdir(parents=True, exist_ok=True) - - # Save file - with open(storage_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - - logger.info(f"File saved to filesystem: {storage_path}") - - except Exception as e: - logger.error(f"Failed to save file to filesystem: {str(e)}") - raise FileStorageError(f"无法保存文件到磁盘: {str(e)}") - - def _delete_physical_file(self, storage_path: Path): - """ - Delete physical file (compensation operation). - - Args: - storage_path: Path of file to delete - """ - try: - if os.path.exists(storage_path): - os.remove(storage_path) - logger.info(f"补偿操作:删除文件 {storage_path}") - except Exception as e: - logger.error(f"删除文件失败: {e}") - - def _restore_file_from_backup(self, backup_path: Path, original_path: Path): - """ - Restore file from backup (compensation operation). - - Args: - backup_path: Path of backup file - original_path: Path where file should be restored - """ - try: - if backup_path.exists(): - shutil.copy2(backup_path, original_path) - logger.info(f"补偿操作:从备份恢复文件 {original_path}") - # Clean up backup after restoration - os.remove(backup_path) - logger.debug(f"补偿操作:删除备份文件 {backup_path}") - except Exception as e: - logger.error(f"恢复文件失败: {e}") - - def upload_files_batch( - self, - files: List[UploadFile], - context: UploadContext, - metadata: Optional[Dict[str, Any]], - current_user: User, - db: Session = None - ) -> List[UploadResult]: - """ - Upload multiple files in batch. - Individual file failures do not affect other files. - - Args: - files: List of uploaded files - context: Upload context (avatar, app_icon, etc.) - metadata: Additional metadata for the files - current_user: The user uploading the files - db: Optional database session (for backward compatibility) - - Returns: - List[UploadResult]: List of upload results for each file - - Raises: - BusinessException: If batch size exceeds limit - """ - logger.info(f"Starting batch upload: {len(files)} files, context={context}, user={current_user.id}") - - # Validate batch size - MAX_BATCH_SIZE = 20 - if len(files) > MAX_BATCH_SIZE: - raise BusinessException( - f"批量上传文件数量不能超过 {MAX_BATCH_SIZE} 个", - code=BizCode.BAD_REQUEST, - context={ - "file_count": len(files), - "max_batch_size": MAX_BATCH_SIZE, - "user_id": str(current_user.id), - "tenant_id": str(current_user.tenant_id), - "context": context - } - ) - - results = [] - - for file in files: - try: - # Upload each file independently - db_file = self.upload_file(file, context, metadata, current_user, db) - - results.append(UploadResult( - success=True, - file_id=db_file.id, - file_name=file.filename or "unknown", - error=None - )) - - logger.info(f"Batch upload success: {file.filename}") - - except Exception as e: - # Log error but continue with other files - logger.error(f"Batch upload failed for {file.filename}: {str(e)}") - - results.append(UploadResult( - success=False, - file_id=None, - file_name=file.filename or "unknown", - error=str(e) - )) - - logger.info(f"Batch upload completed: {sum(1 for r in results if r.success)}/{len(files)} successful") - return results - - def get_file( - self, - file_id: uuid.UUID, - current_user: User, - db: Session = None - ) -> GenericFile: - """ - Get a file by ID with permission validation. - - Args: - file_id: UUID of the file - current_user: The user requesting the file - db: Optional database session (for backward compatibility) - - Returns: - GenericFile: The file record - - Raises: - FileNotFoundError: If file doesn't exist - FileAccessDeniedError: If user doesn't have permission - """ - logger.debug(f"Getting file: file_id={file_id}, user={current_user.id}") - - # Use UoW or provided db session - if self.uow: - with self.uow: - file = self.uow.files.get_file_by_id(file_id) - elif db: - repository = GenericFileRepository(db) - file = repository.get_file_by_id(file_id) - else: - raise FileStorageError("Either uow or db session must be provided") - - if not file: - logger.warning(f"File not found: {file_id}") - raise FileNotFoundError(file_id) - - # Check permissions using permission service - from app.core.permissions import permission_service, Subject, Resource, Action - - subject = Subject.from_user(current_user) - resource = Resource.from_file(file) - - try: - permission_service.require_permission( - subject, - Action.READ, - resource, - error_message=f"无权访问文件 {file_id}" - ) - except PermissionDeniedException: - logger.warning(f"Access denied: file_id={file_id}, user={current_user.id}") - raise FileAccessDeniedError(file_id) - - logger.debug(f"File access granted: {file.file_name}") - return file - - def delete_file( - self, - file_id: uuid.UUID, - current_user: User, - db: Session = None - ) -> None: - """ - Delete a file (both physical file and database record) using UoW pattern with compensation. - - This method uses compensation transactions to ensure data consistency: - 1. Delete physical file first - 2. Register compensation to restore file if DB deletion fails - 3. Delete database record - 4. Commit transaction - 5. Clear compensation on success - - Args: - file_id: UUID of the file to delete - current_user: The user requesting deletion - db: Optional database session (for backward compatibility) - - Raises: - FileNotFoundError: If file doesn't exist - FileAccessDeniedError: If user doesn't have permission - FileReferencedError: If file is still referenced - FileStorageError: If deletion fails - """ - logger.info(f"Deleting file: file_id={file_id}, user={current_user.id}") - - # Get file and check permissions - if self.uow: - with self.uow: - file = self.uow.files.get_file_by_id(file_id) - elif db: - repository = GenericFileRepository(db) - file = repository.get_file_by_id(file_id) - else: - raise FileStorageError("Either uow or db session must be provided") - - if not file: - logger.warning(f"File not found for deletion: {file_id}") - raise FileNotFoundError(file_id) - - # Check permissions using permission service - from app.core.permissions import permission_service, Subject, Resource, Action - - subject = Subject.from_user(current_user) - resource = Resource.from_file(file) - - try: - permission_service.require_permission( - subject, - Action.DELETE, - resource, - error_message=f"无权删除文件 {file_id}" - ) - except PermissionDeniedException: - logger.warning(f"Delete access denied: file_id={file_id}, user={current_user.id}") - raise FileAccessDeniedError(file_id) - - # Check reference count - if file.reference_count > 0: - logger.warning(f"Cannot delete referenced file: file_id={file_id}, references={file.reference_count}") - raise FileReferencedError(file_id, file.reference_count) - - # Store storage path and file content for potential restoration - storage_path = Path(file.storage_path) - backup_path = None - - # Use compensation handler for atomic deletion - compensation = CompensationHandler() - - try: - # 1. Backup and delete physical file first - if storage_path.exists(): - # Create backup in temp location - backup_path = storage_path.parent / f".backup_{file_id}{storage_path.suffix}" - shutil.copy2(storage_path, backup_path) - logger.debug(f"Created backup: {backup_path}") - - # Delete original file - os.remove(storage_path) - logger.info(f"Physical file deleted: {storage_path}") - - # Register compensation: restore file from backup if DB deletion fails - compensation.register(lambda: self._restore_file_from_backup(backup_path, storage_path)) - else: - logger.warning(f"Physical file not found: {storage_path}") - - # 2. Delete database record (soft delete) - if self.uow: - with self.uow: - self.uow.files.delete_file(file_id) - self.uow.commit() - elif db: - repository = GenericFileRepository(db) - repository.delete_file(file_id) - db.commit() - - logger.info(f"File record deleted successfully: {file.file_name} (ID: {file_id})") - - # 3. Success - clear compensations and remove backup - compensation.clear() - if backup_path and backup_path.exists(): - os.remove(backup_path) - logger.debug(f"Removed backup: {backup_path}") - - except Exception as e: - # Execute compensation to restore file - compensation.execute() - - # Rollback database if using db session - if db: - db.rollback() - - logger.error(f"Failed to delete file: {str(e)}") - raise FileStorageError(f"无法删除文件: {str(e)}") - - def update_file_metadata( - self, - file_id: uuid.UUID, - update_data: Dict[str, Any], - current_user: User, - db: Session = None - ) -> GenericFile: - """ - Update file metadata using UoW pattern. - - Args: - file_id: UUID of the file to update - update_data: Dictionary containing fields to update - current_user: The user requesting the update - db: Optional database session (for backward compatibility) - - Returns: - GenericFile: The updated file record - - Raises: - FileNotFoundError: If file doesn't exist - FileAccessDeniedError: If user doesn't have permission - """ - logger.info(f"Updating file metadata: file_id={file_id}, user={current_user.id}") - - # Get file and check permissions - if self.uow: - with self.uow: - file = self.uow.files.get_file_by_id(file_id) - elif db: - repository = GenericFileRepository(db) - file = repository.get_file_by_id(file_id) - else: - raise FileStorageError("Either uow or db session must be provided") - - if not file: - logger.warning(f"File not found for update: {file_id}") - raise FileNotFoundError(file_id) - - # Check permissions using permission service - from app.core.permissions import permission_service, Subject, Resource, Action - - subject = Subject.from_user(current_user) - resource = Resource.from_file(file) - - try: - permission_service.require_permission( - subject, - Action.UPDATE, - resource, - error_message=f"无权更新文件 {file_id}" - ) - except PermissionDeniedException: - logger.warning(f"Update access denied: file_id={file_id}, user={current_user.id}") - raise FileAccessDeniedError(file_id) - - # Filter allowed fields for update - # Users can only update: file_name, file_metadata, is_public - allowed_fields = ["file_name", "file_metadata", "is_public"] - filtered_update_data = { - key: value for key, value in update_data.items() - if key in allowed_fields - } - - if not filtered_update_data: - logger.warning(f"No valid fields to update for file: {file_id}") - return file - - # Update file metadata - try: - if self.uow: - with self.uow: - updated_file = self.uow.files.update_file(file_id, filtered_update_data) - self.uow.commit() - elif db: - repository = GenericFileRepository(db) - updated_file = repository.update_file(file_id, filtered_update_data) - db.commit() - - logger.info(f"File metadata updated successfully: {file.file_name} (ID: {file_id})") - return updated_file - - except Exception as e: - if db: - db.rollback() - logger.error(f"Failed to update file metadata: {str(e)}") - raise FileStorageError(f"无法更新文件元数据: {str(e)}") diff --git a/app/services/user_service.py b/app/services/user_service.py deleted file mode 100644 index d9b6ea9d..00000000 --- a/app/services/user_service.py +++ /dev/null @@ -1,570 +0,0 @@ -import datetime -import secrets -import string -from sqlalchemy.orm import Session -import uuid - -from app.models.user_model import User -from app.repositories import user_repository -from app.schemas.user_schema import UserCreate -from app.schemas.tenant_schema import TenantCreate -from app.services.tenant_service import TenantService -from app.services.session_service import SessionService -from app.core.security import get_password_hash, verify_password -from app.core.config import settings -from app.core.logging_config import get_business_logger -from app.core.exceptions import BusinessException, PermissionDeniedException -from app.core.error_codes import BizCode -# from app.services import workspace_service -# from app.schemas.workspace_schema import WorkspaceCreate - -# 获取业务逻辑专用日志器 -business_logger = get_business_logger() - - -def create_initial_superuser(db: Session): - business_logger.info("检查并创建初始超级用户") - - superuser = user_repository.get_superuser(db) - if superuser: - business_logger.info("超级用户已存在,跳过创建") - return - - user_in = UserCreate( - username=settings.FIRST_SUPERUSER_USERNAME, - email=settings.FIRST_SUPERUSER_EMAIL, - password=settings.FIRST_SUPERUSER_PASSWORD, - ) - - try: - business_logger.debug("开始创建初始租户") - # Create a default tenant for the superuser - default_tenant = TenantCreate( - name=f"{user_in.username}'s Tenant", - description=f"Default tenant for {user_in.username}", - ) - # Create tenant service and create tenant with user assignment - tenant_service = TenantService(db) - tenant = tenant_service.create_tenant(default_tenant) - db.flush() - business_logger.debug("开始创建初始超级用户") - - hashed_password = get_password_hash(user_in.password) - superuser = user_repository.create_user( - db=db, user=user_in, hashed_password=hashed_password, is_superuser=True, - tenant_id=tenant.id - ) - db.commit() - db.refresh(superuser) - business_logger.info(f"初始超级用户创建成功: {superuser.username} (ID: {superuser.id})") - return superuser - except Exception as e: - business_logger.error(f"初始超级用户创建失败: {str(e)}") - db.rollback() - raise BusinessException( - f"初始超级用户创建失败: {str(e)}", - code=BizCode.DB_ERROR, - context={"username": username, "email": email}, - cause=e - ) - - -def create_user(db: Session, user: UserCreate) -> User: - business_logger.info(f"创建用户: {user.username}, email: {user.email}") - - try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={"username": user.username, "email": user.email} - ) - - # 检查邮箱是否已注册 - business_logger.debug(f"检查邮箱是否已注册: {user.email}") - db_user_by_email = user_repository.get_user_by_email(db, email=user.email) - if db_user_by_email: - business_logger.warning(f"邮箱已注册: {user.email}") - raise BusinessException( - "邮箱已注册", - code=BizCode.DUPLICATE_NAME, - context={"email": user.email, "username": user.username} - ) - - # 创建普通用户,需要有默认租户 - business_logger.debug(f"开始创建用户: {user.username}") - hashed_password = get_password_hash(user.password) - - # 获取默认租户(第一个活跃租户) - from app.repositories.tenant_repository import TenantRepository - tenant_repo = TenantRepository(db) - tenants = tenant_repo.get_tenants(skip=0, limit=1, is_active=True) - - if not tenants: - business_logger.error("系统中没有可用的租户") - raise BusinessException( - "系统配置错误:没有可用的租户", - code=BizCode.TENANT_NOT_FOUND, - context={"username": user.username, "email": user.email} - ) - - default_tenant = tenants[0] - - new_user = user_repository.create_user( - db=db, user=user, hashed_password=hashed_password, - tenant_id=default_tenant.id, is_superuser=False - ) - - db.commit() - db.refresh(new_user) - business_logger.info(f"用户创建成功: {new_user.username} (ID: {new_user.id})") - return new_user - except Exception as e: - business_logger.error(f"用户创建失败: {user.username} - {str(e)}") - db.rollback() - raise BusinessException( - f"用户创建失败: {user.username} - {str(e)}", - code=BizCode.DB_ERROR, - context={"username": user.username, "email": user.email}, - cause=e - ) - - -def create_superuser(db: Session, user: UserCreate, current_user: User) -> User: - business_logger.info(f"创建超级管理员: {user.username}, email: {user.email}") - - # 检查当前用户是否为超级管理员 - from app.core.permissions import permission_service, Subject - - subject = Subject.from_user(current_user) - try: - permission_service.check_superuser( - subject, - error_message="只有超级管理员才能创建超级管理员用户" - ) - except PermissionDeniedException as e: - business_logger.warning(f"非超级管理员尝试创建超级管理员用户: {user.username}") - raise BusinessException( - str(e), - code=BizCode.FORBIDDEN, - context={ - "current_user_id": str(current_user.id), - "current_user_username": current_user.username, - "target_username": user.username - } - ) - - try: - # 检查用户名是否已存在 - business_logger.debug(f"检查用户名是否已存在: {user.username}") - db_user_by_username = user_repository.get_user_by_username(db, username=user.username) - if db_user_by_username: - business_logger.warning(f"用户名已存在: {user.username}") - raise BusinessException( - "用户名已存在", - code=BizCode.DUPLICATE_NAME, - context={ - "username": user.username, - "email": user.email, - "created_by": str(current_user.id) - } - ) - - # 检查邮箱是否已注册 - business_logger.debug(f"检查邮箱是否已注册: {user.email}") - db_user_by_email = user_repository.get_user_by_email(db, email=user.email) - if db_user_by_email: - business_logger.warning(f"邮箱已注册: {user.email}") - raise BusinessException( - "邮箱已注册", - code=BizCode.DUPLICATE_NAME, - context={ - "email": user.email, - "username": user.username, - "created_by": str(current_user.id) - } - ) - - # 创建超级管理员用户并加入当前用户的租户 - business_logger.debug(f"开始创建超级管理员: {user.username}") - hashed_password = get_password_hash(user.password) - - new_user = user_repository.create_user( - db=db, user=user, hashed_password=hashed_password, - tenant_id=current_user.tenant_id, is_superuser=True - ) - - db.commit() - db.refresh(new_user) - business_logger.info(f"超级管理员创建成功: {new_user.username} (ID: {new_user.id}), 已加入租户: {current_user.tenant_id}") - return new_user - except Exception as e: - business_logger.error(f"超级管理员创建失败: {user.username} - {str(e)}") - db.rollback() - raise BusinessException( - f"超级管理员创建失败: {user.username} - {str(e)}", - code=BizCode.DB_ERROR, - context={ - "username": user.username, - "email": user.email, - "created_by": str(current_user.id), - "tenant_id": str(current_user.tenant_id) - }, - cause=e - ) - - -def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user: User) -> User: - business_logger.info(f"停用用户: user_id={user_id_to_deactivate}, 操作者: {current_user.username}") - - try: - # 查找用户 - business_logger.debug(f"查找待停用用户: {user_id_to_deactivate}") - db_user = user_repository.get_user_by_id(db, user_id=user_id_to_deactivate) - if not db_user: - business_logger.warning(f"用户不存在: {user_id_to_deactivate}") - raise BusinessException( - "用户不存在", - code=BizCode.USER_NOT_FOUND, - context={"user_id": str(user_id_to_deactivate)} - ) - - # 权限检查 using permission service - from app.core.permissions import permission_service, Subject, Resource, Action - - subject = Subject.from_user(current_user) - resource = Resource.from_user(db_user) - - try: - permission_service.require_permission( - subject, - Action.DEACTIVATE, - resource, - error_message="没有权限停用该用户" - ) - except PermissionDeniedException as e: - business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试停用用户 {user_id_to_deactivate}") - raise BusinessException( - str(e), - code=BizCode.FORBIDDEN, - context={ - "current_user_id": str(current_user.id), - "current_user_username": current_user.username, - "target_user_id": str(user_id_to_deactivate) - } - ) - # 检查用户类型,如果是超级管理员,判断一下不是唯一的一个 - if db_user.is_superuser: - is_only_superuser = user_repository.check_superuser_only(db) - if is_only_superuser: - business_logger.warning(f"停用超级管理员用户: {db_user.username} (ID: {user_id_to_deactivate})") - raise BusinessException( - "不能停用唯一的超级管理员用户", - code=BizCode.FORBIDDEN, - context={ - "user_id": str(user_id_to_deactivate), - "username": db_user.username - } - ) - - # 停用用户 - business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})") - db_user.is_active = False - db.add(db_user) - db.commit() - db.refresh(db_user) - business_logger.info(f"用户停用成功: {db_user.username} (ID: {user_id_to_deactivate})") - return db_user - except Exception as e: - business_logger.error(f"用户停用失败: user_id={user_id_to_deactivate} - {str(e)}") - db.rollback() - if isinstance(e, BusinessException): - raise e - raise BusinessException(f"{str(e)}", code=BizCode.DB_ERROR) - -def activate_user(db: Session, user_id_to_activate: uuid.UUID, current_user: User) -> User: - business_logger.info(f"激活用户: user_id={user_id_to_activate}, 操作者: {current_user.username}") - - try: - # 查找用户 - business_logger.debug(f"查找待激活用户: {user_id_to_activate}") - db_user = user_repository.get_user_by_id(db, user_id=user_id_to_activate) - if not db_user: - business_logger.warning(f"用户不存在: {user_id_to_activate}") - raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) - - # 权限检查 using permission service - from app.core.permissions import permission_service, Subject, Resource, Action - - subject = Subject.from_user(current_user) - resource = Resource.from_user(db_user) - - try: - permission_service.require_permission( - subject, - Action.ACTIVATE, - resource, - error_message="没有权限激活该用户" - ) - except PermissionDeniedException as e: - business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试激活用户 {user_id_to_activate}") - raise BusinessException(str(e), code=BizCode.FORBIDDEN) - - # 激活用户 - business_logger.debug(f"执行用户激活: {db_user.username} (ID: {user_id_to_activate})") - db_user.is_active = True - db.add(db_user) - db.commit() - db.refresh(db_user) - business_logger.info(f"用户激活成功: {db_user.username} (ID: {user_id_to_activate})") - return db_user - except Exception as e: - business_logger.error(f"用户激活失败: user_id={user_id_to_activate} - {str(e)}") - db.rollback() - raise BusinessException(f"用户激活失败: user_id={user_id_to_activate} - {str(e)}", code=BizCode.DB_ERROR) - - -def get_user(db: Session, user_id: uuid.UUID, current_user: User) -> User: - business_logger.info(f"获取用户信息: user_id={user_id}, 操作者: {current_user.username}") - - try: - # 查找用户 - business_logger.debug(f"查找用户: {user_id}") - db_user = user_repository.get_user_by_id(db, user_id=user_id) - if not db_user: - business_logger.warning(f"用户不存在: {user_id}") - raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) - - # 权限检查 using permission service - from app.core.permissions import permission_service, Subject, Resource, Action - - subject = Subject.from_user(current_user) - resource = Resource.from_user(db_user) - - try: - permission_service.require_permission( - subject, - Action.READ, - resource, - error_message="没有权限获取该用户信息" - ) - except PermissionDeniedException as e: - business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试获取用户 {user_id} 信息") - raise BusinessException(str(e), code=BizCode.FORBIDDEN) - - # 返回用户信息 - business_logger.debug(f"返回用户信息: {db_user.username} (ID: {user_id})") - return db_user - except Exception as e: - business_logger.error(f"获取用户信息失败: user_id={user_id} - {str(e)}") - raise BusinessException(f"获取用户信息失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR) - - -def get_tenant_superusers(db: Session, current_user: User, include_inactive: bool = True) -> list[User]: - """获取当前租户下的超管账号列表""" - business_logger.info(f"获取租户超管列表: tenant_id={current_user.tenant_id}, 请求者: {current_user.username}, include_inactive={include_inactive}") - - try: - # 检查当前用户是否有权限查看(只有超管才能查看超管列表) - from app.core.permissions import permission_service, Subject - - subject = Subject.from_user(current_user) - try: - permission_service.check_superuser( - subject, - error_message="只有超级管理员才能查看超管列表" - ) - except PermissionDeniedException as e: - business_logger.warning(f"非超级管理员尝试查看超管列表: {current_user.username}") - raise BusinessException(str(e), code=BizCode.FORBIDDEN) - - # 检查用户是否有租户 - if not current_user.tenant_id: - business_logger.warning(f"用户没有租户信息: {current_user.username}") - raise BusinessException("用户没有租户信息", code=BizCode.TENANT_NOT_FOUND) - - # 获取租户下的超管列表 - business_logger.debug(f"查询租户超管: tenant_id={current_user.tenant_id}, include_inactive={include_inactive}") - is_active_filter = None if include_inactive else True - superusers = user_repository.get_superusers_by_tenant( - db=db, - tenant_id=current_user.tenant_id, - is_active=is_active_filter - ) - - business_logger.info(f"租户超管查询成功: tenant_id={current_user.tenant_id}, count={len(superusers)}") - return superusers - - except Exception as e: - business_logger.error(f"获取租户超管列表失败: tenant_id={current_user.tenant_id} - {str(e)}") - raise BusinessException(f"获取租户超管列表失败: tenant_id={current_user.tenant_id} - {str(e)}", code=BizCode.DB_ERROR) - - -def update_last_login_time(db: Session, user_id: uuid.UUID) -> User: - """更新用户的最后登录时间""" - business_logger.info(f"更新用户最后登录时间: user_id={user_id}") - - try: - # 获取用户 - db_user = user_repository.get_user_by_id(db=db, user_id=user_id) - if not db_user: - business_logger.warning(f"用户不存在: {user_id}") - raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) - - # 更新最后登录时间 - db_user.last_login_at = datetime.datetime.now() - db.commit() - db.refresh(db_user) - - business_logger.info(f"用户最后登录时间更新成功: {db_user.username} (ID: {user_id})") - return db_user - - except HTTPException: - raise - except Exception as e: - business_logger.error(f"更新用户最后登录时间失败: user_id={user_id} - {str(e)}") - db.rollback() - raise - - -async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User: - """普通用户修改自己的密码""" - business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}") - - # 检查权限:只能修改自己的密码 - if current_user.id != user_id: - business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}") - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You can only change your own password" - ) - - try: - # 获取用户 - db_user = user_repository.get_user_by_id(db=db, user_id=user_id) - if not db_user: - business_logger.warning(f"用户不存在: {user_id}") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="User not found" - ) - - # 验证旧密码 - if not verify_password(old_password, db_user.hashed_password): - business_logger.warning(f"用户旧密码验证失败: {user_id}") - raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED) - - # 更新密码 - db_user.hashed_password = get_password_hash(new_password) - db.commit() - db.refresh(db_user) - - # 使所有旧 tokens 失效 - await SessionService.invalidate_all_user_tokens(str(user_id)) - - business_logger.info(f"用户密码修改成功: {db_user.username} (ID: {user_id})") - return db_user - - except Exception as e: - business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}") - db.rollback() - raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR) - - -async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]: - """ - 超级管理员修改指定用户的密码 - - Args: - db: 数据库会话 - target_user_id: 目标用户ID - new_password: 新密码,如果为None则自动生成随机密码 - current_user: 当前用户(超级管理员) - - Returns: - tuple[User, str]: (更新后的用户对象, 实际使用的密码) - """ - business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}") - - # 检查权限:只有超级管理员可以修改他人密码 - from app.core.permissions import permission_service, Subject - - subject = Subject.from_user(current_user) - try: - permission_service.check_superuser( - subject, - error_message="只有超级管理员可以修改他人密码" - ) - except PermissionDeniedException as e: - business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}") - raise BusinessException(str(e), code=BizCode.FORBIDDEN) - - try: - # 获取目标用户 - target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id) - if not target_user: - business_logger.warning(f"目标用户不存在: {target_user_id}") - raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND) - - # 检查租户权限:超管只能修改同租户用户的密码 - if current_user.tenant_id != target_user.tenant_id: - business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}") - raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN) - - # 如果没有提供新密码,则生成随机密码 - actual_password = new_password if new_password else generate_random_password() - - # 更新密码 - target_user.hashed_password = get_password_hash(actual_password) - db.commit() - db.refresh(target_user) - - # 使所有旧 tokens 失效 - await SessionService.invalidate_all_user_tokens(str(target_user_id)) - - password_type = "指定密码" if new_password else "随机生成密码" - business_logger.info(f"管理员修改用户密码成功: admin={current_user.username}, target={target_user.username} (ID: {target_user_id}), 类型={password_type}") - return target_user, actual_password - - except Exception as e: - business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}") - db.rollback() - raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR) - - -def generate_random_password(length: int = 12) -> str: - """ - 生成随机密码 - - Args: - length: 密码长度,默认12位 - - Returns: - str: 生成的随机密码 - """ - # 确保密码包含大小写字母、数字和特殊字符 - lowercase = string.ascii_lowercase - uppercase = string.ascii_uppercase - digits = string.digits - special_chars = "!@#$%^&*" - - # 确保至少包含每种字符类型 - password = [ - secrets.choice(lowercase), - secrets.choice(uppercase), - secrets.choice(digits), - secrets.choice(special_chars) - ] - - # 填充剩余长度 - all_chars = lowercase + uppercase + digits + special_chars - for _ in range(length - 4): - password.append(secrets.choice(all_chars)) - - # 打乱顺序 - secrets.SystemRandom().shuffle(password) - - return ''.join(password) diff --git a/app/services/workspace_service.py b/app/services/workspace_service.py deleted file mode 100644 index dfe18435..00000000 --- a/app/services/workspace_service.py +++ /dev/null @@ -1,776 +0,0 @@ -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid -import secrets -import hashlib -import datetime -from fastapi import HTTPException, status -from app.core.error_codes import BizCode -from app.core.exceptions import BusinessException, PermissionDeniedException -from app.models.tenant_model import Tenants -from app.models.user_model import User -from app.models.app_model import App -from app.models.end_user_model import EndUser -from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember -from app.schemas.workspace_schema import ( - WorkspaceCreate, - WorkspaceUpdate, - WorkspaceInviteCreate, - WorkspaceInviteResponse, - InviteValidateResponse, - InviteAcceptRequest, - WorkspaceMemberUpdate -) -from app.repositories import workspace_repository -from app.repositories.workspace_invite_repository import WorkspaceInviteRepository -from app.core.logging_config import get_business_logger -from app.core.config import settings -from app.services import user_service -from os import getenv -# 获取业务逻辑专用日志器 -business_logger = get_business_logger() -import os # -from dotenv import load_dotenv -load_dotenv() -def switch_workspace( - db: Session, - workspace_id: uuid.UUID, - user: User, -): - """切换工作空间""" - business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}") - - # 检查用户是否为成员或超级管理员 - _check_workspace_member_permission(db, workspace_id, user) - - # 更新当前用户的工作空间上下文 - try: - user.current_workspace_id = workspace_id - db.commit() - business_logger.info(f"用户 {user.username} 成功切换工作空间为 {workspace_id}") - return - except Exception as e: - db.rollback() - business_logger.error(f"切换工作空间失败 - 工作空间: {workspace_id}, 错误: {str(e)}") - raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR) - - -def delete_workspace_member( - db: Session, - workspace_id: uuid.UUID, - member_id: uuid.UUID, - user: User, - ): - """删除工作空间成员""" - business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") - _check_workspace_admin_permission(db, workspace_id, user) - workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id) - if not workspace_member: - raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - - if workspace_member.workspace_id != workspace_id: - raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - - try: - workspace_member.is_active = False - workspace_member.user.current_workspace_id = None - db.commit() - business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}") - except Exception as e: - db.rollback() - business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}") - raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR) - - -def get_user_workspaces(db: Session, user: User) -> List[Workspace]: - """获取当前用户参与的所有工作空间""" - business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})") - workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id) - business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}") - return workspaces - - -def _create_workspace_only( - db: Session, workspace: WorkspaceCreate, owner: User -) -> Workspace: - business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}") - - try: - # Create the workspace without adding any members - business_logger.debug(f"创建工作空间: {workspace.name}") - db_workspace = workspace_repository.create_workspace( - db=db, workspace=workspace, tenant_id=owner.tenant_id - ) - business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {owner.username}") - return db_workspace - except Exception as e: - business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}") - raise - -def create_workspace( - db: Session, workspace: WorkspaceCreate, user: User -) -> Workspace: - business_logger.info( - f"创建工作空间: {workspace.name}, 创建者: {user.username}, " - f"storage_type: {workspace.storage_type}" - ) - llm=workspace.llm - embedding=workspace.embedding - rerank=workspace.rerank - try: - # Create the workspace without adding any members - business_logger.debug(f"创建工作空间: {workspace.name}") - db_workspace = workspace_repository.create_workspace( - db=db, workspace=workspace, tenant_id=user.tenant_id - ) - business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}") - db.commit() - db.refresh(db_workspace) - - # 如果 storage_type 是 "rag",自动创建知识库 - if workspace.storage_type == "rag": - business_logger.info( - f"检测到 storage_type 为 'rag',开始为工作空间 " - f"{db_workspace.id} 创建知识库" - ) - try: - import os - from app.schemas.knowledge_schema import KnowledgeCreate - from app.models.knowledge_model import KnowledgeType, PermissionType - from app.repositories import knowledge_repository - - # 创建知识库数据 - knowledge_data = KnowledgeCreate( - workspace_id=db_workspace.id, - created_by=user.id, - parent_id=db_workspace.id, - name="USER_RAG_MERORY", - description=f"工作空间 {workspace.name} 的默认知识库", - avatar='', - type=KnowledgeType.General, - permission_id=PermissionType.Private, - embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding, - reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank, - llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm, - image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm, - parser_config={ - "layout_recognize": "DeepDOC", - "chunk_token_num": 256, - "delimiter": "\n", - "auto_keywords": 0, - "auto_questions": 0, - "html4excel": False - } - ) - - # 直接使用 repository 创建知识库,避免 service 层的额外逻辑 - db_knowledge = knowledge_repository.create_knowledge( - db=db, - knowledge=knowledge_data - ) - db.commit() - business_logger.info( - f"为工作空间 {db_workspace.id} 自动创建知识库成功: " - f"{db_knowledge.name} (ID: {db_knowledge.id})" - ) - except Exception as kb_error: - business_logger.error( - f"为工作空间 {db_workspace.id} 创建知识库失败: {str(kb_error)}" - ) - db.rollback() - raise BusinessException( - f"工作空间创建成功,但知识库创建失败: {str(kb_error)}", - BizCode.INTERNAL_ERROR - ) - - return db_workspace - - except Exception as e: - business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}") - db.rollback() - raise - - -def update_workspace( - db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User -) -> Workspace: - business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}") - - db_workspace = _check_workspace_admin_permission(db,workspace_id,user) - try: - # 更新工作空间 - business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})") - update_data = workspace_in.model_dump(exclude_unset=True) - for field, value in update_data.items(): - setattr(db_workspace, field, value) - - db.add(db_workspace) - db.commit() - db.refresh(db_workspace) - business_logger.info(f"工作空间更新成功: {db_workspace.name} (ID: {workspace_id})") - return db_workspace - except Exception as e: - business_logger.error(f"工作空间更新失败: workspace_id={workspace_id} - {str(e)}") - db.rollback() - raise - - -def get_workspace_members( - db: Session, workspace_id: uuid.UUID, user: User -) -> List[WorkspaceMember]: - """获取某工作空间的成员列表(关系序列化由模型关系支持)""" - business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}") - - # 查找工作空间 - business_logger.debug(f"查找工作空间: {workspace_id}") - workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) - if not workspace: - business_logger.warning(f"工作空间不存在: {workspace_id}") - raise BusinessException( - message="Workspace not found", - code=BizCode.WORKSPACE_NOT_FOUND - ) - - # 权限检查:工作空间成员或超级管理员可以查看成员列表 - from app.core.permissions import permission_service, Subject, Resource, Action - member = workspace_repository.get_member_in_workspace( - db=db, user_id=user.id, workspace_id=workspace_id - ) - workspace_memberships = {workspace_id} if member else set() - - subject = Subject.from_user(user, workspace_memberships=workspace_memberships) - resource = Resource.from_workspace(workspace) - - try: - permission_service.require_permission( - subject, - Action.READ, - resource, - error_message=f"用户 {user.username} 没有查看工作空间 {workspace_id} 成员列表的权限" - ) - except PermissionDeniedException as e: - business_logger.warning( - f"权限不足: 用户 {user.username} 尝试获取工作空间 {workspace_id} 成员列表" - ) - raise BusinessException(str(e), BizCode.WORKSPACE_ACCESS_DENIED) - - # 查询成员并预加载 user/workspace 关系 - members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id) - business_logger.info(f"工作空间成员数量: {len(members)} - workspace_id={workspace_id}") - return members - - - -# ==================== 邀请相关服务方法 ==================== - -def _generate_invite_token() -> tuple[str, str]: - """生成邀请令牌和其哈希值 - - Returns: - tuple: (原始令牌, 令牌哈希) - """ - # 生成32字节的随机令牌 - token = secrets.token_urlsafe(32) - # 生成令牌的SHA256哈希 - token_hash = hashlib.sha256(token.encode()).hexdigest() - return token, token_hash - - -def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, user: User) -> Workspace | None: - """检查用户是否为工作空间成员或超级管理员(使用统一权限服务)""" - # 获取工作空间信息 - db_workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) - if not db_workspace: - raise BusinessException( - message="Workspace not found", - code=BizCode.WORKSPACE_NOT_FOUND - ) - - # 使用统一权限服务检查访问权限 - from app.core.permissions import permission_service, Subject, Resource, Action - - # 获取用户的工作空间成员关系 - member = workspace_repository.get_member_in_workspace( - db=db, user_id=user.id, workspace_id=workspace_id - ) - - # 任何成员都有访问权限 - workspace_memberships = {workspace_id} if member else set() - - subject = Subject.from_user(user, workspace_memberships=workspace_memberships) - resource = Resource.from_workspace(db_workspace) - - try: - permission_service.require_permission( - subject, - Action.READ, - resource, - error_message=f"用户 {user.username} 不是工作空间 {workspace_id} 的成员" - ) - business_logger.debug(f"用户 {user.username} 是工作空间 {workspace_id} 的成员或超级管理员") - except PermissionDeniedException as e: - business_logger.warning(f"权限不足: 用户 {user.username} 尝试访问工作空间 {workspace_id}") - raise BusinessException(str(e), BizCode.WORKSPACE_NO_ACCESS) - return db_workspace - - -def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user: User) -> Workspace | None: - """检查用户是否有工作空间管理员权限(使用统一权限服务)""" - # 获取工作空间信息 - db_workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) - if not db_workspace: - raise BusinessException( - message="Workspace not found", - code=BizCode.WORKSPACE_NOT_FOUND - ) - - # 使用统一权限服务检查管理权限 - from app.core.permissions import permission_service, Subject, Resource, Action - - # 获取用户的工作空间成员关系 - member = workspace_repository.get_member_in_workspace( - db=db, user_id=user.id, workspace_id=workspace_id - ) - - # 只有 manager 才有管理权限 - workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set() - - subject = Subject.from_user(user, workspace_memberships=workspace_memberships) - resource = Resource.from_workspace(db_workspace) - - try: - permission_service.require_permission( - subject, - Action.MANAGE, - resource, - error_message=f"用户 {user.username} 没有管理工作空间 {workspace_id} 的权限" - ) - business_logger.debug(f"用户 {user.username} 有权限管理工作空间 {workspace_id}") - except PermissionDeniedException as e: - business_logger.warning(f"权限不足: 用户 {user.username} 尝试管理工作空间 {workspace_id}") - raise BusinessException(str(e), BizCode.WORKSPACE_ACCESS_DENIED) - return db_workspace - - -def create_workspace_invite( - db: Session, - workspace_id: uuid.UUID, - invite_data: WorkspaceInviteCreate, - user: User -) -> WorkspaceInviteResponse: - """创建工作空间邀请""" - business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}") - - try: - # 检查权限 - _check_workspace_admin_permission(db, workspace_id, user) - if settings.ENABLE_SINGLE_WORKSPACE: - # 检查被邀请用户是否已经在工作空间中 - from app.repositories import user_repository - invited_user = user_repository.get_user_by_email(db, invite_data.email) - - if invited_user: - # 用户存在,检查是否已经是工作空间成员 - existing_member = workspace_repository.get_member_in_workspace( - db=db, - user_id=invited_user.id, - workspace_id=workspace_id - ) - if existing_member: - business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员") - raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS) - - # 检查是否已有待处理的邀请 - invite_repo = WorkspaceInviteRepository(db) - existing_invite = invite_repo.get_pending_invite_by_email_and_workspace( - email=invite_data.email, - workspace_id=workspace_id - ) - - invite_token = None - if existing_invite: - business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请") - # 生成新的邀请链接(重新生成令牌) - token, token_hash = _generate_invite_token() - existing_invite.token_hash = token_hash - existing_invite.updated_at = datetime.datetime.now() - db.commit() - db.refresh(existing_invite) - invite_token = token - else: - # 生成邀请令牌 - token, token_hash = _generate_invite_token() - # 创建邀请 - db_invite = invite_repo.create_invite( - workspace_id=workspace_id, - invite_data=invite_data, - token_hash=token_hash, - created_by_user_id=user.id - ) - db.commit() - db.refresh(db_invite) - invite_token = token - - invite_obj = existing_invite or db_invite - business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}") - - # 构造响应 - response = WorkspaceInviteResponse.model_validate(invite_obj) - response.invite_token = invite_token - return response - - - except Exception as e: - db.rollback() - business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}") - raise - - -def get_workspace_invites( - db: Session, - workspace_id: uuid.UUID, - user: User, - status: Optional[InviteStatus] = None, - limit: int = 50, - offset: int = 0 -) -> List[WorkspaceInviteResponse]: - """获取工作空间邀请列表""" - business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}") - - # 检查工作空间是否存在 - workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) - if not workspace: - raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND) - - # 检查权限 - _check_workspace_admin_permission(db, workspace_id, user) - - # 获取邀请列表 - invite_repo = WorkspaceInviteRepository(db) - invites = invite_repo.get_workspace_invites( - workspace_id=workspace_id, - status=status, - limit=limit, - offset=offset - ) - - return [WorkspaceInviteResponse.model_validate(invite) for invite in invites] - - -def validate_invite_token(db: Session, token: str) -> InviteValidateResponse: - """验证邀请令牌""" - business_logger.info(f"验证邀请令牌") - - # 生成令牌哈希 - token_hash = hashlib.sha256(token.encode()).hexdigest() - - # 查找邀请 - invite_repo = WorkspaceInviteRepository(db) - invite = invite_repo.get_invite_by_token_hash(token_hash) - - if not invite: - business_logger.warning(f"邀请令牌无效") - raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND) - - # 检查邀请状态和过期时间 - now = datetime.datetime.now() - is_expired = invite.expires_at < now or invite.status != InviteStatus.pending - is_valid = not is_expired - - # 获取工作空间信息 - workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id) - - business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}") - - return InviteValidateResponse( - workspace_name=workspace.name, - workspace_id=invite.workspace_id, - email=invite.email, - role=WorkspaceRole(invite.role), - is_expired=is_expired, - is_valid=is_valid - ) - - -def accept_workspace_invite( - db: Session, - accept_request: InviteAcceptRequest, - user: User -) -> dict: - """接受工作空间邀请""" - business_logger.info(f"接受工作空间邀请: 用户 {user.username}") - - try: - from app.core.config import settings - - # 生成令牌哈希 - token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest() - - # 查找邀请 - invite_repo = WorkspaceInviteRepository(db) - invite = invite_repo.get_invite_by_token_hash(token_hash) - - if not invite: - business_logger.warning(f"邀请令牌无效") - raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND) - - # 检查邀请状态 - if invite.status != InviteStatus.pending: - business_logger.warning(f"邀请已被处理: status={invite.status}") - raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID) - - # 检查过期时间 - now = datetime.datetime.now() - if invite.expires_at < now: - business_logger.warning(f"邀请已过期") - # 标记为过期 - invite_repo.update_invite_status(invite.id, InviteStatus.expired) - raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED) - - # 检查邮箱是否匹配 - if invite.email != user.email: - business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}") - raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN) - - # 如果启用单工作空间模式,检查用户是否已有工作空间 - if settings.ENABLE_SINGLE_WORKSPACE: - user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id) - if user_workspaces: - business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}") - raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN) - - # 检查用户是否已经是工作空间成员 - existing_member = workspace_repository.get_member_in_workspace( - db=db, - user_id=user.id, - workspace_id=invite.workspace_id - ) - - if existing_member: - business_logger.info(f"用户已是工作空间成员,更新邀请状态") - invite_repo.update_invite_status( - invite.id, - InviteStatus.accepted, - accepted_at=now - ) - db.commit() - workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id) - return { - "message": "You are already a member of this workspace", - "workspace": workspace - } - - # 将角色映射到工作空间角色(现在直接使用相同的角色) - workspace_role = invite.role - - # 添加用户到工作空间 - workspace_repository.add_member_to_workspace( - db=db, - user_id=user.id, - workspace_id=invite.workspace_id, - role=workspace_role - ) - - # 标记邀请为已接受 - invite_repo.update_invite_status( - invite.id, - InviteStatus.accepted, - accepted_at=now - ) - - db.commit() - - # 获取工作空间信息 - workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id) - - business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}") - - return { - "message": "Successfully joined the workspace", - "workspace": workspace, - "role": workspace_role - } - - except Exception as e: - db.rollback() - business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}") - raise - - -def revoke_workspace_invite( - db: Session, - workspace_id: uuid.UUID, - invite_id: uuid.UUID, - user: User -) -> dict: - """撤销工作空间邀请""" - business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}") - - try: - # 检查权限 - _check_workspace_admin_permission(db, workspace_id, user) - - # 撤销邀请 - invite_repo = WorkspaceInviteRepository(db) - invite = invite_repo.revoke_invite(invite_id) - - if not invite: - business_logger.warning(f"邀请不存在: invite_id={invite_id}") - raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND) - - if invite.workspace_id != workspace_id: - business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}") - raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST) - - db.commit() - business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}") - return {"message": "邀请撤销成功"} - - except Exception as e: - db.rollback() - business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}") - raise - - -def update_workspace_member_roles( - db: Session, - workspace_id: uuid.UUID, - updates: List[WorkspaceMemberUpdate], - user: User, -) -> List[WorkspaceMember]: - """更新工作空间成员角色""" - business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}") - - # 检查管理员权限 - _check_workspace_admin_permission(db, workspace_id, user) - - # 获取所有当前成员 - all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id) - member_map = {m.id: m for m in all_members} - - # 验证和业务规则检查 - update_ids = set() - for upd in updates: - # 检查成员是否存在 - if upd.id not in member_map: - raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - - member = member_map[upd.id] - - # 检查成员是否属于该工作空间 - if member.workspace_id != workspace_id: - raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - - # 不能修改自己的角色 - if member.user_id == user.id: - raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST) - - update_ids.add(upd.id) - - # 检查是否至少保留一个 manager - current_managers = [m for m in all_members if m.role == WorkspaceRole.manager] - managers_after_update = [ - m for m in all_members - if m.id not in update_ids and m.role == WorkspaceRole.manager - ] - - # 添加更新后会成为 manager 的成员 - for upd in updates: - if upd.role == WorkspaceRole.manager: - managers_after_update.append(member_map[upd.id]) - - if len(managers_after_update) == 0: - raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST) - - # 执行更新 - try: - for upd in updates: - workspace_repository.update_member_role_by_id( - db=db, - id=upd.id, - role=upd.role, - ) - business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}") - - db.commit() - - # 重新获取更新后的成员列表 - updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id) - business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}") - - return updated_members - - except Exception as e: - db.rollback() - business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}") - raise BusinessException(f"更新成员角色失败: {str(e)}", BizCode.INTERNAL_ERROR) - - -def get_workspace_storage_type( - db: Session, - workspace_id: uuid.UUID, - user: User, -) -> Optional[str]: - """获取工作空间的存储类型 - - Args: - db: 数据库会话 - workspace_id: 工作空间ID - user: 当前用户 - - Returns: - storage_type: 存储类型字符串,如果未设置则返回 None - """ - business_logger.info(f"用户 {user.username} 请求获取工作空间 {workspace_id} 的存储类型") - - # 检查用户是否有权限访问该工作空间 - _check_workspace_member_permission(db, workspace_id, user) - - # 查询工作空间 - workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) - if not workspace: - business_logger.error(f"工作空间不存在: workspace_id={workspace_id}") - raise BusinessException( - code=BizCode.WORKSPACE_NOT_FOUND, - message="工作空间不存在" - ) - - business_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {workspace.storage_type}") - return workspace.storage_type - - -def get_workspace_models_configs( - db: Session, - workspace_id: uuid.UUID, - user: User, -) -> Optional[dict]: - """获取工作空间的模型配置(llm, embedding, rerank) - - Args: - db: 数据库会话 - workspace_id: 工作空间ID - user: 当前用户 - - Returns: - dict: 包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None - """ - business_logger.info(f"用户 {user.username} 请求获取工作空间 {workspace_id} 的模型配置") - - # 检查用户是否有权限访问该工作空间 - _check_workspace_member_permission(db, workspace_id, user) - - # 查询工作空间模型配置 - configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id) - - if configs is None: - business_logger.error(f"工作空间不存在: workspace_id={workspace_id}") - raise BusinessException( - code=BizCode.WORKSPACE_NOT_FOUND, - message="工作空间不存在" - ) - - business_logger.info( - f"成功获取工作空间 {workspace_id} 的模型配置: " - f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}" - ) - return configs \ No newline at end of file diff --git a/app/tasks.py b/app/tasks.py deleted file mode 100644 index 42642726..00000000 --- a/app/tasks.py +++ /dev/null @@ -1,451 +0,0 @@ -import os -import asyncio -from typing import Any, Dict, List, Optional -import requests -from datetime import datetime, timezone -import time -import uuid -from math import ceil -import redis -import json - -from app.db import get_db -from app.models.document_model import Document -from app.models.knowledge_model import Knowledge -from app.core.rag.llm.cv_model import QWenCV -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.rag.models.chunk import DocumentChunk -from app.services.memory_agent_service import MemoryAgentService -from app.core.config import settings - -# Import a unified Celery instance -from app.celery_app import celery_app - - -@celery_app.task(name="tasks.process_item") -def process_item(item: dict): - """ - A simulated long-running task that processes an item. - In a real-world scenario, this could be anything: - - Sending an email - - Generating a report - - Performing a complex calculation - - Calling a third-party API - """ - print(f"Processing item: {item['name']}") - # Simulate work for 5 seconds - time.sleep(5) - result = f"Item '{item['name']}' processed successfully at a price of ${item['price']}." - print(result) - return result - - -@celery_app.task(name="app.core.rag.tasks.parse_document") -def parse_document(file_path: str, document_id: uuid.UUID): - """ - Document parsing, vectorization, and storage - """ - db = next(get_db()) # Manually call the generator - db_document = None - db_knowledge = None - progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" - try: - db_document = db.query(Document).filter(Document.id == document_id).first() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() - # 1. Document parsing & segmentation - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" - start_time = time.time() - db_document.progress = 0.0 - db_document.progress_msg = progress_msg - db_document.process_begin_at = datetime.now(tz=timezone.utc) - db_document.process_duration = 0.0 - db_document.run = 1 - db.commit() - db.refresh(db_document) - - def progress_callback(prog=None, msg=None): - nonlocal progress_msg # Declare the use of an external progress_msg variable - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" - # Prepare to configure vision_model information - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - from app.core.rag.app.naive import chunk - res = chunk(filename=file_path, - from_page=0, - to_page=100000, - callback=progress_callback, - vision_model=vision_model, - parser_config=db_document.parser_config, - is_root=False) - - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" - db_document.progress = 0.8 - db_document.progress_msg = progress_msg - db.commit() - db.refresh(db_document) - - # 2. Document vectorization and storage - total_chunks = len(res) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" - batch_size = 100 - total_batches = ceil(total_chunks / batch_size) - progress_per_batch = 0.2 / total_batches # Progress of each batch - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2.1 Delete document vector index - vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) - # 2.2 Vectorize and import batch documents - for batch_start in range(0, total_chunks, batch_size): - batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds - batch = res[batch_start: batch_end] # Retrieve the current batch - chunks = [] - - # Process the current batch - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch # Calculate global index - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) - - # Bulk segmented vector import - vector_service.add_chunks(chunks) - - # Update progress - db_document.progress += progress_per_batch - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" - db_document.progress_msg = progress_msg - db_document.process_duration = time.time() - start_time - db_document.run = 0 - db.commit() - db.refresh(db_document) - - # Vectorization and data entry completed - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" - db_document.chunk_num = total_chunks - db_document.progress = 1.0 - db_document.process_duration = time.time() - start_time - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" - db_document.progress_msg = progress_msg - db_document.run = 0 - db.commit() - result = f"parse document '{db_document.file_name}' processed successfully." - return result - except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" - db_document.run = 0 - db.commit() - result = f"parse document '{db_document.file_name}' failed." - return result - finally: - db.close() - - -@celery_app.task(name="app.core.memory.agent.read_message", bind=True) -def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: - - """Celery task to process a read message via MemoryAgentService. - - Args: - group_id: Group ID for the memory agent - message: User message to process - history: Conversation history - search_switch: Search switch parameter - config_id: Optional configuration ID - - Returns: - Dict containing the result and metadata - - Raises: - Exception on failure - """ - start_time = time.time() - - async def _run() -> str: - service = MemoryAgentService() - return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id) - - try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) - elapsed_time = time.time() - start_time - - return { - "status": "SUCCESS", - "result": result, - "group_id": group_id, - "config_id": config_id, - "elapsed_time": elapsed_time, - "task_id": self.request.id - } - except Exception as e: - elapsed_time = time.time() - start_time - return { - "status": "FAILURE", - "error": str(e), - "group_id": group_id, - "config_id": config_id, - "elapsed_time": elapsed_time, - "task_id": self.request.id - } - - -@celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: - """Celery task to process a write message via MemoryAgentService. - - Args: - group_id: Group ID for the memory agent - message: Message to write - config_id: Optional configuration ID - - Returns: - Dict containing the result and metadata - - Raises: - Exception on failure - """ - start_time = time.time() - - async def _run() -> str: - service = MemoryAgentService() - return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id) - - try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) - elapsed_time = time.time() - start_time - - return { - "status": "SUCCESS", - "result": result, - "group_id": group_id, - "config_id": config_id, - "elapsed_time": elapsed_time, - "task_id": self.request.id - } - except Exception as e: - elapsed_time = time.time() - start_time - return { - "status": "FAILURE", - "error": str(e), - "group_id": group_id, - "config_id": config_id, - "elapsed_time": elapsed_time, - "task_id": self.request.id - } - - -def reflection_engine() -> None: - """Empty function placeholder for timed background reflection. - - Intentionally left blank; replace with real reflection logic later. - """ - from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - import asyncio - - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) - - -@celery_app.task(name="app.core.memory.agent.reflection.timer") -def reflection_timer_task() -> None: - """Periodic Celery task that invokes reflection_engine. - - Raises an exception on failure. - """ - reflection_engine() - - -@celery_app.task(name="app.core.memory.agent.health.check_read_service") -def check_read_service_task() -> Dict[str, str]: - """Call read_service and write latest status to Redis. - - Returns status data dict that gets written to Redis. - """ - client = redis.Redis( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, - password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None - ) - try: - api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" - payload = { - "user_id": "健康检查", - "apply_id": "健康检查", - "group_id": "健康检查", - "message": "你好", - "history": [], - "search_switch": "2", - } - resp = requests.post(api_url, json=payload, timeout=15) - ok = resp.status_code == 200 - status = "Success" if ok else "Fail" - msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" - error = "" if ok else resp.text - code = 0 if ok else 500 - except Exception as e: - status = "Fail" - msg = "接口请求失败" - error = str(e) - code = 500 - - data = { - "status": status, - "msg": msg, - "error": error, - "code": str(code), - "time": str(int(time.time())), - } - - client.hset("memsci:health:read_service", mapping=data) - client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) - - return data - - -@celery_app.task(name="app.controllers.memory_storage_controller.search_all") -def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: - """定时任务:查询工作空间下所有宿主的记忆总量并写入数据库 - - Args: - workspace_id: 工作空间ID - - Returns: - 包含任务执行结果的字典 - """ - start_time = time.time() - - async def _run() -> Dict[str, Any]: - from app.services.memory_storage_service import search_all - from app.repositories.memory_increment_repository import write_memory_increment - from app.models.end_user_model import EndUser - from app.models.app_model import App - - db = next(get_db()) - try: - workspace_uuid = uuid.UUID(workspace_id) - - # 1. 查询当前workspace下的所有app - apps = db.query(App).filter(App.workspace_id == workspace_uuid).all() - - if not apps: - # 如果没有app,总量为0 - memory_increment = write_memory_increment( - db=db, - workspace_id=workspace_uuid, - total_num=0 - ) - return { - "status": "SUCCESS", - "workspace_id": workspace_id, - "total_num": 0, - "end_user_count": 0, - "memory_increment_id": str(memory_increment.id), - "created_at": memory_increment.created_at.isoformat(), - } - - # 2. 查询所有app下的end_user_id(去重) - app_ids = [app.id for app in apps] - end_users = db.query(EndUser.id).filter( - EndUser.app_id.in_(app_ids) - ).distinct().all() - - # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 - total_num = 0 - end_user_details = [] - - for (end_user_id,) in end_users: - try: - # 调用 search_all 接口查询该宿主的总量 - result = await search_all(str(end_user_id)) - user_total = result.get("total", 0) - total_num += user_total - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": user_total - }) - except Exception as e: - # 记录单个用户查询失败,但继续处理其他用户 - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": 0, - "error": str(e) - }) - - # 4. 写入数据库 - memory_increment = write_memory_increment( - db=db, - workspace_id=workspace_uuid, - total_num=total_num - ) - - return { - "status": "SUCCESS", - "workspace_id": workspace_id, - "total_num": total_num, - "end_user_count": len(end_users), - "end_user_details": end_user_details, - "memory_increment_id": str(memory_increment.id), - "created_at": memory_increment.created_at.isoformat(), - } - finally: - db.close() - - try: - result = asyncio.run(_run()) - elapsed_time = time.time() - start_time - result["elapsed_time"] = elapsed_time - return result - except Exception as e: - elapsed_time = time.time() - start_time - return { - "status": "FAILURE", - "error": str(e), - "workspace_id": workspace_id, - "elapsed_time": elapsed_time, - } \ No newline at end of file diff --git a/app/utils/volc_asr.py b/app/utils/volc_asr.py deleted file mode 100644 index f059df4f..00000000 --- a/app/utils/volc_asr.py +++ /dev/null @@ -1,112 +0,0 @@ -import requests -import json -import uuid -import os -import time -from datetime import datetime -from app.core.config import settings - - -# 火山的ASR -class VolcASR: - def __init__(self): - self.app_key = settings.VOLC_APP_KEY # 需要替换为实际的APP KEY - self.access_key = settings.VOLC_ACCESS_KEY # 需要替换为实际的Access Key - self.submit_url = settings.VOLC_SUBMIT_URL - self.query_url = settings.VOLC_QUERY_URL - - def get_headers(self): - request_id = str(uuid.uuid4()) - return { - "X-Api-App-Key": self.app_key, - "X-Api-Access-Key": self.access_key, - "X-Api-Resource-Id": "volc.bigasr.auc", - "X-Api-Request-Id": request_id, - "X-Api-Sequence": "-1", - "Content-Type": "application/json" - } - - def submit_task(self, audio_url): - headers = self.get_headers() - data = { - "audio": { - "url": audio_url - } - } - - response = requests.post(self.submit_url, headers=headers, json=data) - return response.headers.get("X-Api-Request-Id"), response.headers.get("X-Api-Status-Code") - - def query_result(self, task_id): - headers = self.get_headers() - headers["X-Api-Request-Id"] = task_id - - while True: - response = requests.post(self.query_url, headers=headers, json={}) - status_code = response.headers.get("X-Api-Status-Code") - - if status_code == "20000000": # 成功 - return response.json() - elif status_code in ["20000001", "20000002"]: # 处理中或在队列中 - time.sleep(1) - continue - elif status_code in ["20000003"]: # 静音音频 - raise Exception(f"静音音频: {status_code}") - elif status_code in ["45000001"]: # 请求参数无效 - raise Exception(f"请求参数无效: {status_code}") - elif status_code in ["45000002"]: # 空音频 - raise Exception(f"空音频: {status_code}") - elif status_code in ["45000151"]: # 音频格式不正确 - raise Exception(f"音频格式不正确: {status_code}") - elif status_code in ["55000031"]: # 服务器繁忙 - raise Exception(f"服务器繁忙: {status_code}") - else: - raise Exception(f"服务内部处理错误: {status_code}") - - -def main(): - # 音频URL - audio_url = "https://fosun-lcp-clickpaas.oss-cn-shanghai.aliyuncs.com/fosun-dify-files-images/test.mp3" - - # 输出目录 - output_dir = "/Users/sbtjfdn/Downloads" - os.makedirs(output_dir, exist_ok=True) - - # 初始化ASR客户端 - asr_client = VolcASR() - - try: - # 提交任务 - print("提交语音识别任务...") - task_id, status_code = asr_client.submit_task(audio_url) - if not task_id: - raise Exception("提交任务失败,未获取到任务ID") - - print(f"任务ID: {task_id}") - - # 查询结果 - print("等待识别结果...") - result = asr_client.query_result(task_id) - - # 保存结果 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_file = os.path.join(output_dir, f"result_{timestamp}.json") - - with open(output_file, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - - print(f"识别结果已保存到: {output_file}") - - # 如果有识别文本,单独保存文本文件 - if "result" in result and "text" in result["result"]: - text_file = os.path.join(output_dir, f"text_{timestamp}.txt") - with open(text_file, "w", encoding="utf-8") as f: - f.write(result["result"]["text"]) - print(f"识别文本已保存到: {text_file}") - - except Exception as e: - print(f"发生错误: {str(e)}") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 74c69353..00000000 --- a/docker-compose.yml +++ /dev/null @@ -1,22 +0,0 @@ -version: '3.8' - -services: - api: - image: redbear-mem:latest - container_name: api - ports: - - "8000:8000" - env_file: - - .env - volumes: - - ./files:/files - command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-level debug - - worker: - image: redbear-mem:latest - container_name: worker - env_file: - - .env - volumes: - - ./files:/files - command: celery -A app.celery_worker.celery_app worker --loglevel=info \ No newline at end of file diff --git a/env.example b/env.example deleted file mode 100644 index f368d35d..00000000 --- a/env.example +++ /dev/null @@ -1,87 +0,0 @@ - -# Neo4j Configuration (记忆系统数据库) -NEO4J_URI= -NEO4J_USERNAME= -NEO4J_PASSWORD= - - -# Postgres Database configuration -DB_HOST= -DB_PORT= -DB_USER= -DB_PASSWORD= -DB_NAME= - -# Database Migration Configuration -# Set to true to automatically upgrade database schema on startup -DB_AUTO_UPGRADE=false - - - -# Redis configuration -REDIS_HOST= -REDIS_PORT= -REDIS_DB= -REDIS_PASSWORD=password - -#celery -BROKER_URL= -RESULT_BACKEND= -CELERY_BROKER= -CELERY_BACKEND= - -# ElasticSearch configuration -ELASTICSEARCH_HOST= -ELASTICSEARCH_PORT= -ELASTICSEARCH_USERNAME= -ELASTICSEARCH_PASSWORD= -ELASTICSEARCH_VERIFY_CERTS= -ELASTICSEARCH_CA_CERTS= -ELASTICSEARCH_REQUEST_TIMEOUT= -ELASTICSEARCH_RETRY_ON_TIMEOUT= -ELASTICSEARCH_MAX_RETRIES= - -# xinference configuration -XINFERENCE_URL= - -# LangSmith configuration -LANGCHAIN_TRACING_V2= -LANGCHAIN_TRACING= -LANGCHAIN_API_KEY= -LANGCHAIN_ENDPOINT= - -# This key is used for signing JWT tokens. -# It should be a long, random string and kept secret. -# Generate a new one with: openssl rand -hex 32 -SECRET_KEY=your-secret-key-here-generate-with-openssl-rand-hex-32 - -# JWT Token expiration settings -ACCESS_TOKEN_EXPIRE_MINUTES=30 -REFRESH_TOKEN_EXPIRE_DAYS=7 - -# Single Sign-On configuration -ENABLE_SINGLE_SESSION= - -# File Upload -MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024 -FILE_PATH=/files - -# VOLC ASR -VOLC_APP_KEY= -VOLC_ACCESS_KEY= -VOLC_SUBMIT_URL= -VOLC_QUERY_URL= - -# Server Configuration -SERVER_IP=127.0.0.1 - - -web_search= -KB_embedding_id= -KB_reranker_id= -KB_llm_id= -KB_image2text_id= - -config_id= -reranker_id= - diff --git a/main.py b/main.py deleted file mode 100644 index 3c84c1a8..00000000 --- a/main.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from redbear-mem!") - - -if __name__ == "__main__": - main() diff --git a/migrations/README b/migrations/README deleted file mode 100644 index 98e4f9c4..00000000 --- a/migrations/README +++ /dev/null @@ -1 +0,0 @@ -Generic single-database configuration. \ No newline at end of file diff --git a/migrations/env.py b/migrations/env.py deleted file mode 100644 index 95d74019..00000000 --- a/migrations/env.py +++ /dev/null @@ -1,141 +0,0 @@ -import os -import sys -import importlib -import inspect -from logging.config import fileConfig - -from dotenv import load_dotenv - -load_dotenv() # Moved to top - -from sqlalchemy import engine_from_config -from sqlalchemy import pool - -from alembic import context -# import app.models # <--- REMOVED THIS LINE - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Set the database URL from the environment variables -DB_USER = os.getenv("DB_USER") -DB_PASSWORD = os.getenv("DB_PASSWORD") -DB_HOST = os.getenv("DB_HOST") -DB_PORT = os.getenv("DB_PORT") -DB_NAME = os.getenv("DB_NAME") -DB_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" -config.set_main_option("sqlalchemy.url", DB_URL) - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -from app.db import Base # <--- Keep this import for Base - -target_metadata = Base.metadata -# target_metadata = None - -# <--- NEW FUNCTION START --- -def import_all_models_from_package(package_name: str): - """Dynamically imports all Python modules within a given package - to ensure SQLAlchemy models are registered with Base.metadata.""" - - # Add the project root to sys.path if not already there - # This is crucial for relative imports like 'app.db' to work - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) - if project_root not in sys.path: - sys.path.insert(0, project_root) - - try: - package = importlib.import_module(package_name) - package_dir = os.path.dirname(package.__file__) - except ImportError: - print(f"Warning: Could not import package {package_name}. Skipping model discovery.") - return - - for root, _, files in os.walk(package_dir): - for file in files: - if file.endswith(".py") and file != "__init__.py": - module_path = os.path.join(root, file) - # Calculate relative path from package_dir to module_path - rel_path = os.path.relpath(module_path, package_dir) - # Convert file path to module name (e.g., 'user_model.py' -> 'user_model') - module_name = os.path.splitext(rel_path)[0].replace(os.sep, '.') - full_module_name = f"{package_name}.{module_name}" - - try: - module = importlib.import_module(full_module_name) - # Optional: inspect module to ensure models inheriting from Base are loaded - # This step is mostly for verification; importing the module is usually enough - # for Base.metadata to pick up the models if they are defined correctly. - for name, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, Base) and obj != Base: - # Model found and registered with Base.metadata - pass - except Exception as e: - print(f"Warning: Could not import module {full_module_name}: {e}") - -# Call the function to import all models -import_all_models_from_package('app.models') # <--- NEW CALL -# <--- NEW FUNCTION END --- - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online() -> None: - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako deleted file mode 100644 index fbc4b07d..00000000 --- a/migrations/script.py.mako +++ /dev/null @@ -1,26 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - ${downgrades if downgrades else "pass"} diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 4bb55bf5..00000000 --- a/pyproject.toml +++ /dev/null @@ -1,137 +0,0 @@ -[project] -name = "redbear-mem" -version = "0.1.0" -description = "Add your description here" -readme = "README.md" -requires-python = ">=3.12,<3.13" -dependencies = [ - "alembic==1.17.0", - "amqp==5.3.1", - "annotated-types==0.7.0", - "anyio==4.11.0", - "async-timeout==5.0.1", - "bcrypt==5.0.0", - "billiard==4.2.2", - "celery==5.5.3", - "cffi==2.0.0", - "click==8.3.0", - "click-didyoumean==0.3.1", - "click-plugins==1.1.1.2", - "click-repl==0.3.0", - "cryptography==46.0.3", - "ecdsa==0.19.1", - "email-validator>=2.3.0", - "exceptiongroup==1.3.0", - "fastapi==0.119.0", - "greenlet==3.2.4", - "h11==0.16.0", - "httptools==0.7.1", - "idna==3.11", - "kombu==5.5.4", - "mako==1.3.10", - "markupsafe==3.0.3", - "packaging==25.0", - "passlib==1.7.4", - "prompt-toolkit==3.0.52", - "psycopg2-binary==2.9.11", - "pyasn1==0.6.1", - "pycparser==2.23", - "pydantic==2.12.2", - "pydantic-core==2.41.4", - "python-dateutil==2.9.0.post0", - "python-dotenv==1.1.1", - "python-jose==3.5.0", - "python-multipart>=0.0.20", - "pyyaml==6.0.3", - "redis==6.4.0", - "rsa==4.9.1", - "six==1.17.0", - "sniffio==1.3.1", - "sqlalchemy==2.0.44", - "starlette==0.48.0", - "tomli==2.3.0", - "typing-extensions==4.15.0", - "typing-inspection==0.4.2", - "tzdata==2025.2", - "uvicorn==0.37.0", - "uvloop==0.22.1; sys_platform != 'win32'", - "vine==5.1.0", - "watchfiles==1.1.1", - "wcwidth==0.2.14", - "websockets==15.0.1", - "requests==2.32.5", - "elasticsearch==8.17.0", - "xinference-client==1.11.0", - "langchain-ollama", - "chardet==5.2.0", - "tiktoken==0.12.0", - "markdown==3.8", - "langchain>=1.0.3", - "langchain-openai>=1.0.2", - "langchain-community>=0.3.31", - "dashscope>=1.25.0", - "neo4j>=6.0.3", - "chonkie>=1.1.2", - "pandas>=2.3.3", - "jinja2>=3.1.6", - "mcp>=1.21.1", - "concurrent-log-handler>=0.9.28", - "langchain-mcp-adapters>=0.1.13", - "pytest>=9.0.1", - "matplotlib>=3.10.7", - "langfuse>=3.10.0", - "beartype==0.22.5", - "pdfplumber==0.11.7", - "olefile==0.47", - "cachetools==6.2.1", - "ruamel.yaml==0.18.10", - "strenum==0.4.15", - "aspose-slides==24.12.0", - "opencv-python==4.10.0.84", - "numpy>=1.26.0,<2.0.0", - "huggingface-hub==0.25.2", - "torch==2.2.2", - "onnxruntime==1.20.1", - "shapely==2.1.2", - "pyclipper==1.3.0.post6", - "trio==0.32.0", - "pillow==12.0.0", - "roman-numbers==1.0.2", - "word2number==1.1", - "cn2an==0.5.23", - "scikit-learn==1.7.2", - "datrie==0.8.3", - "hanziconv==0.3.2", - "nltk==3.9.2", - "python-pptx==1.0.2", - "xgboost==3.0.0", - "pypdf==6.1.3", - "beautifulsoup4==4.14.2", - "pandas==2.3.3", - "openpyxl==3.1.5", - "python-docx==1.2.0", - "demjson3==3.0.6", - "xpinyin==0.7.7", - "json-repair==0.53.0", - "jinja2==3.1.6", - "xxhash==3.6.0", - "tika==3.1.0", - "PyPDF2==3.0.1", - "mammoth==1.11.0", - "markdownify==1.2.0", - "flask==3.1.2", - "html5lib==1.1", - "jieba>=0.42.1", - "fastmcp>=2.13.1", - "pytest-asyncio>=1.3.0", - "uvicorn>=0.34.0", - "celery>=5.5.2", -] - -[tool.pytest.ini_options] -testpaths = ["tests"] -python_files = ["test_*.py"] -python_classes = ["Test*"] -python_functions = ["test_*"] -# 使用 anyio 作为异步测试后端 -anyio_backends = ["asyncio"]