Initial commit
This commit is contained in:
39
.gitignore
vendored
Normal file
39
.gitignore
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
logs/
|
||||||
|
res/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
docs/
|
||||||
|
examples/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
.kiro
|
||||||
|
.vscode/settings.json
|
||||||
|
.idea
|
||||||
|
|
||||||
|
# Temporary outputs
|
||||||
|
app/core/memory/agent/.DS_Store
|
||||||
|
app/core/memory/src/utils/.DS_Store
|
||||||
|
time.log
|
||||||
|
celerybeat-schedule.db
|
||||||
|
search_results.json
|
||||||
|
*.txt
|
||||||
|
*.json
|
||||||
|
migrations/versions
|
||||||
|
tmp
|
||||||
|
files
|
||||||
|
|
||||||
|
# Exclude dep files
|
||||||
|
huggingface.co/
|
||||||
|
nltk_data/
|
||||||
|
tika-server*.jar*
|
||||||
|
cl100k_base.tiktoken
|
||||||
|
libssl*.deb
|
||||||
97
Dockerfile
Normal file
97
Dockerfile
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
USER root
|
||||||
|
SHELL ["/bin/bash", "-c"]
|
||||||
|
|
||||||
|
ARG NEED_MIRROR=1
|
||||||
|
|
||||||
|
WORKDIR /code
|
||||||
|
|
||||||
|
# 1. Download dependencies through download_deps.py: python download_deps.py --china-mirrors
|
||||||
|
# 2. Copy models
|
||||||
|
COPY huggingface.co/InfiniFlow/deepdoc/ /code/res/deepdoc/
|
||||||
|
COPY huggingface.co/InfiniFlow/text_concat_xgb_v1.0/ /code/res/text_concat_xgb_v1.0/
|
||||||
|
COPY huggingface.co/InfiniFlow/huqie/huqie.txt.trie /code/res/
|
||||||
|
|
||||||
|
# https://github.com/chrismattmann/tika-python
|
||||||
|
# 3. This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
|
||||||
|
COPY nltk_data/ /root/nltk_data/
|
||||||
|
COPY tika-server-standard-3.1.0.jar /tmp/tika-server.jar
|
||||||
|
COPY tika-server-standard-3.1.0.jar.md5 /tmp/tika-server.jar.md5
|
||||||
|
COPY cl100k_base.tiktoken /code/res/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
|
||||||
|
|
||||||
|
ENV TIKA_SERVER_JAR="file:///tmp/tika-server.jar"
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# 4. Setup apt
|
||||||
|
# Python package and implicit dependencies:
|
||||||
|
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
|
||||||
|
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||||
|
# python-pptx: default-jdk tika-server-standard-3.0.0.jar
|
||||||
|
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
|
||||||
|
RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||||
|
apt install -y libicu-dev && \
|
||||||
|
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
|
rm -f /etc/apt/sources.list.d/debian.sources && \
|
||||||
|
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \
|
||||||
|
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
|
||||||
|
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-backports main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
|
||||||
|
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list; \
|
||||||
|
fi; \
|
||||||
|
rm -f /etc/apt/apt.conf.d/docker-clean && \
|
||||||
|
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
|
||||||
|
chmod 1777 /tmp && \
|
||||||
|
apt update && \
|
||||||
|
apt --no-install-recommends install -y ca-certificates && \
|
||||||
|
apt update && \
|
||||||
|
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
|
||||||
|
apt install -y pkg-config libgdiplus && \
|
||||||
|
apt install -y default-jdk && \
|
||||||
|
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||||
|
apt install -y libjemalloc-dev && \
|
||||||
|
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||||
|
apt install -y ghostscript
|
||||||
|
|
||||||
|
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
|
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||||
|
pip3 config set global.trusted-host pypi.tuna.tsinghua.edu.cn; \
|
||||||
|
mkdir -p /etc/uv && \
|
||||||
|
echo "[[index]]" > /etc/uv/uv.toml && \
|
||||||
|
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
|
||||||
|
echo "default = true" >> /etc/uv/uv.toml; \
|
||||||
|
fi; \
|
||||||
|
pipx install uv
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||||
|
ENV PATH=/root/.local/bin:$PATH
|
||||||
|
|
||||||
|
# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13
|
||||||
|
# 5. aspose-slides on linux/arm64 is unavailable
|
||||||
|
COPY libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb /tmp/
|
||||||
|
RUN if [ "$(uname -m)" = "x86_64" ]; then \
|
||||||
|
dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \
|
||||||
|
elif [ "$(uname -m)" = "aarch64" ]; then \
|
||||||
|
dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_arm64.deb; \
|
||||||
|
fi && \
|
||||||
|
rm -f /tmp/libssl1.1_*.deb
|
||||||
|
|
||||||
|
|
||||||
|
# 6. install dependencies from uv.lock file
|
||||||
|
COPY ./pyproject.toml /code/pyproject.toml
|
||||||
|
COPY ./uv.lock /code/uv.lock
|
||||||
|
COPY ./app /code/app
|
||||||
|
|
||||||
|
# https://github.com/astral-sh/uv/issues/10462
|
||||||
|
# uv records index url into uv.lock but doesn't failover among multiple indexes
|
||||||
|
RUN --mount=type=cache,id=mem_uv,target=/root/.cache/uv,sharing=locked \
|
||||||
|
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
|
sed -i 's|pypi.org|pypi.tuna.tsinghua.edu.cn|g' uv.lock; \
|
||||||
|
else \
|
||||||
|
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
|
||||||
|
fi; \
|
||||||
|
uv lock && \
|
||||||
|
uv sync --locked --no-dev
|
||||||
|
|
||||||
|
ENV PATH=/code/.venv/bin:$PATH
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
sources, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but not
|
||||||
|
limited to compiled object code, generated documentation, and
|
||||||
|
conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [2025] [SuanmoSuanyangTechnology]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
137
README.md
Normal file
137
README.md
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# MemoryBear
|
||||||
|
|
||||||
|
## 项目简介
|
||||||
|
|
||||||
|
MemoryBear 是一个面向智能体的记忆系统与知识管理服务。它支持从对话与文档中萃取结构化知识、生成嵌入向量、构建图谱,提供关键词与语义的混合搜索,并内置遗忘机制与自我反思流程,以维持长期记忆的有效性与可用性。
|
||||||
|
|
||||||
|
## 核心特性
|
||||||
|
|
||||||
|
- 知识萃取:陈述句、三元组、时间信息与摘要生成
|
||||||
|
- 图谱存储:对接 Neo4j 管理实体与关系
|
||||||
|
- 混合搜索:关键词检索 + 语义向量检索
|
||||||
|
- 遗忘机制:按记忆强度与时效做逐步衰减
|
||||||
|
- 自我反思:定期回顾并优化已有记忆
|
||||||
|
- FastAPI 服务:统一暴露管理端与服务端 API
|
||||||
|
|
||||||
|
## 架构总览
|
||||||
|
|
||||||
|
- 萃取引擎(Extraction Engine):预处理、去重、结构化提取
|
||||||
|
- 遗忘引擎(Forgetting Engine):记忆强度模型与衰减策略
|
||||||
|
- 自我反思引擎(Reflection Engine):评价与重写记忆
|
||||||
|
- 检索服务:关键词、语义与混合检索
|
||||||
|
- Agent 与 MCP:提供多工具协作的智能体能力
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 环境要求
|
||||||
|
|
||||||
|
- Python 3.12
|
||||||
|
- PostgreSQL 13+
|
||||||
|
- Neo4j 4.4+
|
||||||
|
- Redis 6.0+
|
||||||
|
|
||||||
|
### 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
||||||
|
|
||||||
|
# 方式一:基于 pyproject 安装
|
||||||
|
pip install .
|
||||||
|
|
||||||
|
# 方式二:使用 requirements.txt
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置环境变量
|
||||||
|
|
||||||
|
创建 `.env` 文件(示例):
|
||||||
|
|
||||||
|
```env
|
||||||
|
# Postgres
|
||||||
|
DB_HOST=127.0.0.1
|
||||||
|
DB_PORT=5432
|
||||||
|
DB_USER=postgres
|
||||||
|
DB_PASSWORD=your-password
|
||||||
|
DB_NAME=redbear-mem
|
||||||
|
DB_AUTO_UPGRADE=false
|
||||||
|
|
||||||
|
# Neo4j
|
||||||
|
NEO4J_URI=bolt://localhost:7687
|
||||||
|
NEO4J_USERNAME=neo4j
|
||||||
|
NEO4J_PASSWORD=your-password
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
REDIS_HOST=127.0.0.1
|
||||||
|
REDIS_PORT=6379
|
||||||
|
REDIS_DB=1
|
||||||
|
|
||||||
|
# LLM / API Keys(按需)
|
||||||
|
OPENAI_API_KEY=your-openai-key
|
||||||
|
DASHSCOPE_API_KEY=your-dashscope-key
|
||||||
|
|
||||||
|
# 其他
|
||||||
|
WEB_URL=http://localhost:3000
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
```
|
||||||
|
|
||||||
|
### 初始化与启动
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 如需自动迁移数据库:设置 DB_AUTO_UPGRADE=true 或手动执行
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
# 启动开发服务
|
||||||
|
uvicorn app.main:app --reload --port 8000
|
||||||
|
|
||||||
|
# 打开交互文档
|
||||||
|
# http://localhost:8000/docs
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
app/
|
||||||
|
├── main.py # FastAPI 入口
|
||||||
|
├── controllers/ # 控制器与路由
|
||||||
|
├── core/ # 核心:配置、异常、日志等
|
||||||
|
│ └── memory/ # 记忆模块
|
||||||
|
│ ├── storage_services/ # 萃取/遗忘/反思/检索
|
||||||
|
│ ├── agent/ # Agent + MCP 服务
|
||||||
|
│ ├── utils/ # 工具与提示词
|
||||||
|
│ └── models/ # 领域模型
|
||||||
|
└── rag/ # RAG 能力与文档解析
|
||||||
|
|
||||||
|
logs/ # 日志与输出
|
||||||
|
LICENSE # 许可协议(Apache-2.0)
|
||||||
|
README.md # 项目说明
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 与路由
|
||||||
|
|
||||||
|
- 管理端:`/api`(JWT 认证)
|
||||||
|
- 服务端:`/v1`(API Key 认证)
|
||||||
|
- 根路由健康检查:`GET /` 返回运行状态
|
||||||
|
- Swagger 文档:`/docs`
|
||||||
|
|
||||||
|
|
||||||
|
## 部署建议
|
||||||
|
|
||||||
|
- 使用 `gunicorn` + `uvicorn.workers.UvicornWorker` 作为生产入口
|
||||||
|
- 配置 `LOG_LEVEL=WARNING` 并启用文件日志
|
||||||
|
- 数据库与缓存请使用托管服务或独立实例
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gunicorn app.main:app -w 4 -k uvicorn.workers.UvicornWorker
|
||||||
|
```
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
本项目采用 Apache License 2.0 开源协议,详情见 `LICENSE`。
|
||||||
|
|
||||||
|
## 致谢与交流
|
||||||
|
|
||||||
|
- 问题反馈与讨论:请提交 Issue 到代码仓库
|
||||||
|
- 欢迎贡献:提交 PR 前请先创建功能分支并遵循常规提交信息格式
|
||||||
116
alembic.ini
Normal file
116
alembic.ini
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
script_location = migrations
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||||
|
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the
|
||||||
|
# "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to migrations/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||||
|
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||||
|
# Valid values for version_path_separator are:
|
||||||
|
#
|
||||||
|
# version_path_separator = :
|
||||||
|
# version_path_separator = ;
|
||||||
|
# version_path_separator = space
|
||||||
|
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||||
|
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
201
app/aioRedis.py
Normal file
201
app/aioRedis.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from redis.asyncio import ConnectionPool
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# 设置日志记录器
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 创建连接池
|
||||||
|
pool = ConnectionPool.from_url(
|
||||||
|
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
max_connections=30
|
||||||
|
)
|
||||||
|
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||||
|
|
||||||
|
async def get_redis_connection():
|
||||||
|
"""获取Redis连接"""
|
||||||
|
try:
|
||||||
|
return redis.StrictRedis(connection_pool=pool)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis连接失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||||
|
"""设置Redis键值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Redis键
|
||||||
|
val: 要存储的值(字符串或字典)
|
||||||
|
expire: 过期时间(秒),None表示永不过期
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(val, dict):
|
||||||
|
val = json.dumps(val, ensure_ascii=False)
|
||||||
|
|
||||||
|
if expire is not None:
|
||||||
|
# 设置带过期时间的键值
|
||||||
|
await aio_redis.set(key, val, ex=expire)
|
||||||
|
else:
|
||||||
|
# 设置永久键值
|
||||||
|
await aio_redis.set(key, val)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis set错误: {str(e)}")
|
||||||
|
|
||||||
|
async def aio_redis_get(key: str):
|
||||||
|
"""获取Redis键值"""
|
||||||
|
try:
|
||||||
|
return await aio_redis.get(key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis get错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def aio_redis_delete(key: str):
|
||||||
|
"""删除Redis键"""
|
||||||
|
try:
|
||||||
|
return await aio_redis.delete(key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis delete错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||||
|
"""发布消息到Redis频道"""
|
||||||
|
try:
|
||||||
|
conn = await get_redis_connection()
|
||||||
|
if not conn:
|
||||||
|
return False
|
||||||
|
await conn.publish(channel, json.dumps(message, ensure_ascii=False))
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis发布错误: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
class RedisSubscriber:
|
||||||
|
"""Redis订阅器"""
|
||||||
|
|
||||||
|
def __init__(self, channel: str):
|
||||||
|
self.channel = channel
|
||||||
|
self.conn = None
|
||||||
|
self.pubsub = None
|
||||||
|
self.is_closed = False
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""开始订阅"""
|
||||||
|
if self.is_closed or self._task:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._task = asyncio.create_task(self._receive_messages())
|
||||||
|
logger.info(f"开始订阅: {self.channel}")
|
||||||
|
|
||||||
|
async def _receive_messages(self):
|
||||||
|
"""接收消息"""
|
||||||
|
try:
|
||||||
|
self.conn = await get_redis_connection()
|
||||||
|
if not self.conn:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.pubsub = self.conn.pubsub()
|
||||||
|
await self.pubsub.subscribe(self.channel)
|
||||||
|
|
||||||
|
while not self.is_closed:
|
||||||
|
try:
|
||||||
|
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
|
||||||
|
if message and isinstance(message.get("data"), str):
|
||||||
|
try:
|
||||||
|
await self._queue.put(json.loads(message["data"]))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"消息解析失败: {message['data']}")
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
except Exception as e:
|
||||||
|
if "closed" in str(e).lower():
|
||||||
|
break
|
||||||
|
logger.warning(f"接收消息错误: {str(e)}")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"订阅错误: {str(e)}")
|
||||||
|
await self._queue.put({"type": "error", "data": {"message": str(e), "status": "error"}})
|
||||||
|
finally:
|
||||||
|
await self._queue.put(None)
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
|
async def _cleanup(self):
|
||||||
|
"""清理资源"""
|
||||||
|
if self.pubsub:
|
||||||
|
try:
|
||||||
|
await self.pubsub.unsubscribe(self.channel)
|
||||||
|
await self.pubsub.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if self.conn:
|
||||||
|
try:
|
||||||
|
await self.conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_message(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取消息"""
|
||||||
|
if self.is_closed:
|
||||||
|
return None
|
||||||
|
if not self._task:
|
||||||
|
await self.start()
|
||||||
|
try:
|
||||||
|
return await self._queue.get()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取消息错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭订阅器"""
|
||||||
|
if self.is_closed:
|
||||||
|
return
|
||||||
|
self.is_closed = True
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
|
class RedisPubSubManager:
|
||||||
|
"""Redis发布订阅管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.subscribers = {}
|
||||||
|
|
||||||
|
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
|
||||||
|
return await aio_redis_publish(channel, message)
|
||||||
|
|
||||||
|
def get_subscriber(self, channel: str) -> RedisSubscriber:
|
||||||
|
if channel in self.subscribers:
|
||||||
|
subscriber = self.subscribers[channel]
|
||||||
|
if not subscriber.is_closed:
|
||||||
|
return subscriber
|
||||||
|
|
||||||
|
subscriber = RedisSubscriber(channel)
|
||||||
|
self.subscribers[channel] = subscriber
|
||||||
|
return subscriber
|
||||||
|
|
||||||
|
def cancel_subscription(self, channel: str) -> bool:
|
||||||
|
if channel in self.subscribers:
|
||||||
|
asyncio.create_task(self.subscribers[channel].close())
|
||||||
|
del self.subscribers[channel]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cancel_all_subscriptions(self) -> int:
|
||||||
|
count = len(self.subscribers)
|
||||||
|
for subscriber in self.subscribers.values():
|
||||||
|
asyncio.create_task(subscriber.close())
|
||||||
|
self.subscribers.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
# 全局实例
|
||||||
|
pubsub_manager = RedisPubSubManager()
|
||||||
|
|
||||||
109
app/celery_app.py
Normal file
109
app/celery_app.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import os
|
||||||
|
from datetime import timedelta
|
||||||
|
from urllib.parse import quote
|
||||||
|
from celery import Celery
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# 创建 Celery 应用实例
|
||||||
|
# broker: 任务队列(使用 Redis DB 0)
|
||||||
|
# backend: 结果存储(使用 Redis DB 10)
|
||||||
|
celery_app = Celery(
|
||||||
|
"redbear_tasks",
|
||||||
|
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||||
|
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置使用本地队列,避免与远程 worker 冲突
|
||||||
|
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
||||||
|
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
|
||||||
|
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
||||||
|
|
||||||
|
# macOS 兼容性配置
|
||||||
|
import platform
|
||||||
|
if platform.system() == 'Darwin': # macOS
|
||||||
|
# 设置环境变量解决 fork 问题
|
||||||
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
|
# 使用 solo 池避免多进程问题
|
||||||
|
celery_app.conf.worker_pool = 'solo'
|
||||||
|
|
||||||
|
# 设置唯一的节点名称
|
||||||
|
import socket
|
||||||
|
import time
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
timestamp = int(time.time())
|
||||||
|
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
|
||||||
|
|
||||||
|
# Celery 配置
|
||||||
|
celery_app.conf.update(
|
||||||
|
# 序列化
|
||||||
|
task_serializer='json',
|
||||||
|
accept_content=['json'],
|
||||||
|
result_serializer='json',
|
||||||
|
|
||||||
|
# 时区
|
||||||
|
timezone='Asia/Shanghai',
|
||||||
|
enable_utc=True,
|
||||||
|
|
||||||
|
# 任务追踪
|
||||||
|
task_track_started=True,
|
||||||
|
task_ignore_result=False,
|
||||||
|
|
||||||
|
# 超时设置
|
||||||
|
task_time_limit=30 * 60, # 30 分钟硬超时
|
||||||
|
task_soft_time_limit=25 * 60, # 25 分钟软超时
|
||||||
|
|
||||||
|
# Worker 设置 - 针对 macOS 优化
|
||||||
|
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
|
||||||
|
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
|
||||||
|
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
|
||||||
|
|
||||||
|
# 结果过期时间
|
||||||
|
result_expires=3600, # 结果保存 1 小时
|
||||||
|
|
||||||
|
# 任务确认设置
|
||||||
|
task_acks_late=True, # 任务完成后才确认,避免任务丢失
|
||||||
|
worker_disable_rate_limits=True, # 禁用速率限制
|
||||||
|
|
||||||
|
# 任务路由(可选,用于不同队列)
|
||||||
|
# task_routes={
|
||||||
|
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
|
||||||
|
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
|
||||||
|
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
|
||||||
|
# 'tasks.process_item': {'queue': 'default'},
|
||||||
|
# },
|
||||||
|
)
|
||||||
|
|
||||||
|
# 自动发现任务模块
|
||||||
|
celery_app.autodiscover_tasks(['app'])
|
||||||
|
|
||||||
|
# Celery Beat schedule for periodic tasks
|
||||||
|
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
||||||
|
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||||
|
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||||
|
|
||||||
|
# 构建定时任务配置
|
||||||
|
beat_schedule_config = {
|
||||||
|
"run-reflection-engine": {
|
||||||
|
"task": "app.core.memory.agent.reflection.timer",
|
||||||
|
"schedule": reflection_schedule,
|
||||||
|
"args": (),
|
||||||
|
},
|
||||||
|
"check-read-service": {
|
||||||
|
"task": "app.core.memory.agent.health.check_read_service",
|
||||||
|
"schedule": health_schedule,
|
||||||
|
"args": (),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||||
|
if settings.DEFAULT_WORKSPACE_ID:
|
||||||
|
beat_schedule_config["write-total-memory"] = {
|
||||||
|
"task": "app.controllers.memory_storage_controller.search_all",
|
||||||
|
"schedule": memory_increment_schedule,
|
||||||
|
"kwargs": {
|
||||||
|
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
celery_app.conf.beat_schedule = beat_schedule_config
|
||||||
10
app/celery_worker.py
Normal file
10
app/celery_worker.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Celery Worker 入口点
|
||||||
|
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||||
|
"""
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
# 导入任务模块以注册任务
|
||||||
|
import app.tasks
|
||||||
|
|
||||||
|
__all__ = ['celery_app']
|
||||||
60
app/controllers/__init__.py
Normal file
60
app/controllers/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""管理端接口 - 基于 JWT 认证
|
||||||
|
|
||||||
|
路由前缀: /
|
||||||
|
认证方式: JWT Token
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from . import (
|
||||||
|
model_controller,
|
||||||
|
task_controller,
|
||||||
|
test_controller,
|
||||||
|
user_controller,
|
||||||
|
auth_controller,
|
||||||
|
workspace_controller,
|
||||||
|
setup_controller,
|
||||||
|
file_controller,
|
||||||
|
document_controller,
|
||||||
|
knowledge_controller,
|
||||||
|
chunk_controller,
|
||||||
|
knowledgeshare_controller,
|
||||||
|
app_controller,
|
||||||
|
upload_controller,
|
||||||
|
memory_agent_controller,
|
||||||
|
memory_dashboard_controller,
|
||||||
|
memory_storage_controller,
|
||||||
|
memory_dashboard_controller,
|
||||||
|
api_key_controller,
|
||||||
|
release_share_controller,
|
||||||
|
public_share_controller,
|
||||||
|
multi_agent_controller,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建管理端 API 路由器
|
||||||
|
manager_router = APIRouter()
|
||||||
|
|
||||||
|
# 注册所有管理端路由
|
||||||
|
manager_router.include_router(task_controller.router)
|
||||||
|
manager_router.include_router(user_controller.router)
|
||||||
|
manager_router.include_router(auth_controller.router)
|
||||||
|
manager_router.include_router(workspace_controller.router)
|
||||||
|
manager_router.include_router(workspace_controller.public_router) # 公开路由(无需认证)
|
||||||
|
manager_router.include_router(setup_controller.router)
|
||||||
|
manager_router.include_router(model_controller.router)
|
||||||
|
manager_router.include_router(file_controller.router)
|
||||||
|
manager_router.include_router(document_controller.router)
|
||||||
|
manager_router.include_router(knowledge_controller.router)
|
||||||
|
manager_router.include_router(chunk_controller.router)
|
||||||
|
manager_router.include_router(test_controller.router)
|
||||||
|
manager_router.include_router(knowledgeshare_controller.router)
|
||||||
|
manager_router.include_router(app_controller.router)
|
||||||
|
manager_router.include_router(upload_controller.router)
|
||||||
|
manager_router.include_router(memory_agent_controller.router)
|
||||||
|
manager_router.include_router(memory_dashboard_controller.router)
|
||||||
|
manager_router.include_router(memory_storage_controller.router)
|
||||||
|
manager_router.include_router(api_key_controller.router)
|
||||||
|
manager_router.include_router(release_share_controller.router)
|
||||||
|
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||||
|
manager_router.include_router(memory_dashboard_controller.router)
|
||||||
|
manager_router.include_router(multi_agent_controller.router)
|
||||||
|
|
||||||
|
__all__ = ["manager_router"]
|
||||||
151
app/controllers/api_key_controller.py
Normal file
151
app/controllers/api_key_controller.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""API Key 管理接口 - 基于 JWT 认证"""
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.schemas import api_key_schema
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.services.api_key_service import ApiKeyService
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/apikeys", tags=["API Keys"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def create_api_key(
|
||||||
|
data: api_key_schema.ApiKeyCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""创建 API Key
|
||||||
|
|
||||||
|
- 创建后返回明文 API Key(仅此一次)
|
||||||
|
- 支持设置权限范围、速率限制、配额等
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_key_obj, api_key = ApiKeyService.create_api_key(
|
||||||
|
db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回包含明文 Key 的响应
|
||||||
|
response_data = api_key_schema.ApiKeyResponse(
|
||||||
|
**api_key_obj.__dict__,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=response_data, msg="API Key 创建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_api_keys(
|
||||||
|
type: api_key_schema.ApiKeyType = Query(None),
|
||||||
|
is_active: bool = Query(None),
|
||||||
|
resource_id: uuid.UUID = Query(None),
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
pagesize: int = Query(10, ge=1, le=100),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""列出 API Keys"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
query = api_key_schema.ApiKeyQuery(
|
||||||
|
type=type,
|
||||||
|
is_active=is_active,
|
||||||
|
resource_id=resource_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
result = ApiKeyService.list_api_keys(db, workspace_id, query)
|
||||||
|
return success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{api_key_id}", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_api_key(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取 API Key 详情"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
return success(data=api_key_schema.ApiKey.model_validate(api_key))
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{api_key_id}", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_api_key(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
data: api_key_schema.ApiKeyUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""更新 API Key"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
|
||||||
|
|
||||||
|
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{api_key_id}", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def delete_api_key(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""删除 API Key"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
return success(msg="API Key 删除成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{api_key_id}/regenerate", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def regenerate_api_key(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""重新生成 API Key
|
||||||
|
|
||||||
|
- 生成新的 API Key 并返回明文(仅此一次)
|
||||||
|
- 旧的 API Key 立即失效
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
# 返回包含明文 Key 的响应
|
||||||
|
response_data = api_key_schema.ApiKeyResponse(
|
||||||
|
**api_key_obj.__dict__,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=response_data, msg="API Key 重新生成成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{api_key_id}/stats", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_api_key_stats(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取 API Key 使用统计"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
return success(data=stats)
|
||||||
716
app/controllers/app_controller.py
Normal file
716
app/controllers/app_controller.py
Normal file
@@ -0,0 +1,716 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.models import User
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
|
from app.schemas import app_schema
|
||||||
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
|
from app.services import app_service, workspace_service
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
from app.services.agent_config_helper import enrich_agent_config
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from app.models.app_model import AppType
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def create_app(
|
||||||
|
payload: app_schema.AppCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
|
||||||
|
return success(data=app_schema.App.model_validate(app))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", summary="应用列表(分页)")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_apps(
|
||||||
|
type: str | None = None,
|
||||||
|
visibility: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
include_shared: bool = True,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""列出应用
|
||||||
|
|
||||||
|
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||||
|
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
items_orm, total = app_service.list_apps(
|
||||||
|
db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
type=type,
|
||||||
|
visibility=visibility,
|
||||||
|
status=status,
|
||||||
|
search=search,
|
||||||
|
include_shared=include_shared,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 AppService 的转换方法来设置 is_shared 字段
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||||
|
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||||
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
@router.get("/{app_id}", summary="获取应用详情")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取应用详细信息
|
||||||
|
|
||||||
|
- 支持获取本工作空间的应用
|
||||||
|
- 支持获取分享给本工作空间的应用
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
app = service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
|
# 转换为 Schema 并设置 is_shared 字段
|
||||||
|
app_schema_obj = service._convert_to_schema(app, workspace_id)
|
||||||
|
return success(data=app_schema_obj)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{app_id}", summary="更新应用基本信息")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: app_schema.AppUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
|
return success(data=app_schema.App.model_validate(app))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{app_id}", summary="删除应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def delete_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""删除应用
|
||||||
|
|
||||||
|
会级联删除:
|
||||||
|
- Agent 配置
|
||||||
|
- 发布版本
|
||||||
|
- 会话和消息
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
logger.info(
|
||||||
|
f"用户请求删除应用",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"user_id": str(current_user.id),
|
||||||
|
"workspace_id": str(workspace_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
|
||||||
|
return success(msg="应用删除成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{app_id}/copy", summary="复制应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def copy_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
new_name: Optional[str] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""复制应用(包括基础信息和配置)
|
||||||
|
|
||||||
|
- 复制应用的基础信息(名称、描述、图标等)
|
||||||
|
- 复制 Agent 配置(如果是 agent 类型)
|
||||||
|
- 新应用默认为草稿状态
|
||||||
|
- 不影响原应用
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
logger.info(
|
||||||
|
f"用户请求复制应用",
|
||||||
|
extra={
|
||||||
|
"source_app_id": str(app_id),
|
||||||
|
"user_id": str(current_user.id),
|
||||||
|
"workspace_id": str(workspace_id),
|
||||||
|
"new_name": new_name
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
service = AppService(db)
|
||||||
|
new_app = service.copy_app(
|
||||||
|
app_id=app_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
new_name=new_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{app_id}/config", summary="更新 Agent 配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_agent_config(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: app_schema.AgentConfigUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
|
cfg = enrich_agent_config(cfg)
|
||||||
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_agent_config(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
# 配置总是存在(不存在时返回默认模板)
|
||||||
|
cfg = enrich_agent_config(cfg)
|
||||||
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def publish_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: app_schema.PublishRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
release = app_service.publish(
|
||||||
|
db,
|
||||||
|
app_id=app_id,
|
||||||
|
publisher_id=current_user.id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
version_name = payload.version_name,
|
||||||
|
release_notes=payload.release_notes
|
||||||
|
)
|
||||||
|
return success(data=app_schema.AppRelease.model_validate(release))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/release", summary="获取当前发布版本")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_current_release(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
if not release:
|
||||||
|
return success(data=None)
|
||||||
|
return success(data=app_schema.AppRelease.model_validate(release))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_releases(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
data = [app_schema.AppRelease.model_validate(r) for r in releases]
|
||||||
|
return success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def rollback(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
version: int,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
|
||||||
|
return success(data=app_schema.AppRelease.model_validate(release))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{app_id}/share", summary="分享应用到其他工作空间")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def share_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: app_schema.AppShareCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""分享应用到其他工作空间
|
||||||
|
|
||||||
|
- 只能分享自己工作空间的应用
|
||||||
|
- 不能分享到自己的工作空间
|
||||||
|
- 同一个应用不能重复分享到同一个工作空间
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
shares = service.share_app(
|
||||||
|
app_id=app_id,
|
||||||
|
target_workspace_ids=payload.target_workspace_ids,
|
||||||
|
user_id=current_user.id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
|
return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def unshare_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
target_workspace_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""取消应用分享
|
||||||
|
|
||||||
|
- 只能取消自己工作空间应用的分享
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
service.unshare_app(
|
||||||
|
app_id=app_id,
|
||||||
|
target_workspace_id=target_workspace_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg="应用分享已取消")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_app_shares(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""列出应用的所有分享记录
|
||||||
|
|
||||||
|
- 只能查看自己工作空间应用的分享记录
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
shares = service.list_app_shares(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
|
return success(data=data)
|
||||||
|
|
||||||
|
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def draft_run(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: app_schema.DraftRunRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
试运行 Agent,使用当前的草稿配置(未发布的配置)
|
||||||
|
|
||||||
|
- 不需要发布应用即可测试
|
||||||
|
- 使用当前的 AgentConfig 配置
|
||||||
|
- 支持流式和非流式返回
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
|
user_rag_memory_id = ''
|
||||||
|
if workspace_id:
|
||||||
|
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
|
|
||||||
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
|
from app.models import AgentConfig, ModelConfig
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
|
||||||
|
|
||||||
|
service = AppService(db)
|
||||||
|
|
||||||
|
# 1. 验证应用
|
||||||
|
app = service._get_app_or_404(app_id)
|
||||||
|
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
|
||||||
|
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
# 只读操作,允许访问共享应用
|
||||||
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
if app.type == AppType.AGENT:
|
||||||
|
service._check_agent_config(app_id)
|
||||||
|
|
||||||
|
# 2. 获取 Agent 配置
|
||||||
|
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||||
|
agent_cfg = db.scalars(stmt).first()
|
||||||
|
if not agent_cfg:
|
||||||
|
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
# 3. 获取模型配置
|
||||||
|
model_config = None
|
||||||
|
if agent_cfg.default_model_config_id:
|
||||||
|
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||||
|
if not model_config:
|
||||||
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
|
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||||
|
|
||||||
|
# 流式返回
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
from app.services.draft_run_service import DraftRunService
|
||||||
|
draft_service = DraftRunService(db)
|
||||||
|
async for event in draft_service.run_stream(
|
||||||
|
agent_config=agent_cfg,
|
||||||
|
model_config=model_config,
|
||||||
|
message=payload.message,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
user_id=payload.user_id or str(current_user.id),
|
||||||
|
variables=payload.variables,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 非流式返回
|
||||||
|
logger.debug(
|
||||||
|
f"开始非流式试运行",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"message_length": len(payload.message),
|
||||||
|
"has_conversation_id": bool(payload.conversation_id),
|
||||||
|
"has_variables": bool(payload.variables)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.services.draft_run_service import DraftRunService
|
||||||
|
draft_service = DraftRunService(db)
|
||||||
|
result = await draft_service.run(
|
||||||
|
agent_config=agent_cfg,
|
||||||
|
model_config=model_config,
|
||||||
|
message=payload.message,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
user_id=payload.user_id or str(current_user.id),
|
||||||
|
variables=payload.variables,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"试运行返回结果",
|
||||||
|
extra={
|
||||||
|
"result_type": str(type(result)),
|
||||||
|
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证结果
|
||||||
|
try:
|
||||||
|
validated_result = app_schema.DraftRunResponse.model_validate(result)
|
||||||
|
logger.debug(f"结果验证成功")
|
||||||
|
return success(data=validated_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"结果验证失败",
|
||||||
|
extra={
|
||||||
|
"error": str(e),
|
||||||
|
"error_type": str(type(e)),
|
||||||
|
"result": str(result)[:200]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
elif app.type == AppType.MULTI_AGENT:
|
||||||
|
# 1. 检查多智能体配置完整性
|
||||||
|
service._check_multi_agent_config(app_id)
|
||||||
|
|
||||||
|
# 2. 构建多智能体运行请求
|
||||||
|
from app.schemas.multi_agent_schema import MultiAgentRunRequest
|
||||||
|
|
||||||
|
multi_agent_request = MultiAgentRunRequest(
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
variables=payload.variables or {},
|
||||||
|
use_llm_routing=True # 默认启用 LLM 路由
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 流式返回
|
||||||
|
if payload.stream:
|
||||||
|
logger.debug(
|
||||||
|
f"开始多智能体流式试运行",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"message_length": len(payload.message),
|
||||||
|
"has_conversation_id": bool(payload.conversation_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
"""多智能体流式事件生成器"""
|
||||||
|
multiservice = MultiAgentService(db)
|
||||||
|
|
||||||
|
# 调用多智能体服务的流式方法
|
||||||
|
async for event in multiservice.run_stream(
|
||||||
|
app_id=app_id,
|
||||||
|
request=multi_agent_request,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 非流式返回
|
||||||
|
logger.debug(
|
||||||
|
f"开始多智能体非流式试运行",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"message_length": len(payload.message),
|
||||||
|
"has_conversation_id": bool(payload.conversation_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
multiservice = MultiAgentService(db)
|
||||||
|
result = await multiservice.run(app_id, multi_agent_request)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"多智能体试运行返回结果",
|
||||||
|
extra={
|
||||||
|
"result_type": str(type(result)),
|
||||||
|
"has_response": "response" in result if isinstance(result, dict) else False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=result,
|
||||||
|
msg="多 Agent 任务执行成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def draft_run_compare(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: app_schema.DraftRunCompareRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
多模型对比试运行
|
||||||
|
|
||||||
|
- 支持对比 1-5 个模型
|
||||||
|
- 可以是不同的模型,也可以是同一模型的不同参数配置
|
||||||
|
- 通过 model_parameters 覆盖默认参数
|
||||||
|
- 支持并行或串行执行(非流式)
|
||||||
|
- 支持流式返回(串行执行)
|
||||||
|
- 返回每个模型的运行结果和性能对比
|
||||||
|
|
||||||
|
使用场景:
|
||||||
|
1. 对比不同模型的效果(GPT-4 vs Claude vs Gemini)
|
||||||
|
2. 调优模型参数(不同 temperature 的效果对比)
|
||||||
|
3. 性能和成本分析
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
|
user_rag_memory_id = ''
|
||||||
|
if workspace_id:
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"多模型对比试运行",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"model_count": len(payload.models),
|
||||||
|
"parallel": payload.parallel,
|
||||||
|
"stream": payload.stream
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
from app.models import ModelConfig
|
||||||
|
|
||||||
|
service = AppService(db)
|
||||||
|
|
||||||
|
# 1. 验证应用和权限
|
||||||
|
app = service._get_app_or_404(app_id)
|
||||||
|
if app.type != "agent":
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
|
# 2. 获取 Agent 配置
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import AgentConfig
|
||||||
|
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||||
|
agent_cfg = db.scalars(stmt).first()
|
||||||
|
if not agent_cfg:
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
# 3. 验证所有模型配置
|
||||||
|
model_configs = []
|
||||||
|
for model_item in payload.models:
|
||||||
|
model_config = db.get(ModelConfig, model_item.model_config_id)
|
||||||
|
if not model_config:
|
||||||
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
|
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||||
|
|
||||||
|
merged_parameters = {
|
||||||
|
**(agent_cfg.model_parameters or {}),
|
||||||
|
**(model_item.model_parameters or {})
|
||||||
|
}
|
||||||
|
|
||||||
|
model_configs.append({
|
||||||
|
"model_config": model_config,
|
||||||
|
"parameters": merged_parameters,
|
||||||
|
"label": model_item.label or model_config.name,
|
||||||
|
"model_config_id": model_item.model_config_id,
|
||||||
|
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||||
|
})
|
||||||
|
|
||||||
|
# 流式返回
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
from app.services.draft_run_service import DraftRunService
|
||||||
|
draft_service = DraftRunService(db)
|
||||||
|
async for event in draft_service.run_compare_stream(
|
||||||
|
agent_config=agent_cfg,
|
||||||
|
models=model_configs,
|
||||||
|
message=payload.message,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
user_id=payload.user_id or str(current_user.id),
|
||||||
|
variables=payload.variables,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
web_search=True,
|
||||||
|
memory=True,
|
||||||
|
parallel=payload.parallel,
|
||||||
|
timeout=payload.timeout or 60
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 非流式返回
|
||||||
|
from app.services.draft_run_service import DraftRunService
|
||||||
|
draft_service = DraftRunService(db)
|
||||||
|
result = await draft_service.run_compare(
|
||||||
|
agent_config=agent_cfg,
|
||||||
|
models=model_configs,
|
||||||
|
message=payload.message,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
user_id=payload.user_id or str(current_user.id),
|
||||||
|
variables=payload.variables,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
web_search=True,
|
||||||
|
memory=True,
|
||||||
|
parallel=payload.parallel,
|
||||||
|
timeout=payload.timeout or 60
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"多模型对比完成",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"successful": result["successful_count"],
|
||||||
|
"failed": result["failed_count"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=app_schema.DraftRunCompareResponse(**result))
|
||||||
195
app/controllers/auth_controller.py
Normal file
195
app/controllers/auth_controller.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.schemas.token_schema import Token, RefreshTokenRequest, TokenRequest
|
||||||
|
from app.schemas.workspace_schema import InviteAcceptRequest
|
||||||
|
from app.services import auth_service, user_service, workspace_service
|
||||||
|
from app.core import security
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.services.session_service import SessionService
|
||||||
|
from app.core.logging_config import get_auth_logger, get_security_logger
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.dependencies import get_current_user, oauth2_scheme
|
||||||
|
from app.models.user_model import User
|
||||||
|
|
||||||
|
# 获取专用日志器
|
||||||
|
auth_logger = get_auth_logger()
|
||||||
|
security_logger = get_security_logger()
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Authentication"])
|
||||||
|
|
||||||
|
@router.post("/token", response_model=ApiResponse)
|
||||||
|
async def login_for_access_token(
|
||||||
|
form_data: TokenRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""用户登录获取token"""
|
||||||
|
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||||
|
|
||||||
|
# 验证邀请码(如果提供)
|
||||||
|
invite_info = None
|
||||||
|
# 验证用户凭据或注册新用户
|
||||||
|
user = None
|
||||||
|
if form_data.invite:
|
||||||
|
auth_logger.info(f"检测到邀请码: {form_data.invite[:8]}...")
|
||||||
|
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||||
|
|
||||||
|
if not invite_info.is_valid:
|
||||||
|
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
if invite_info.email != form_data.email:
|
||||||
|
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||||
|
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||||
|
try:
|
||||||
|
# 尝试认证用户
|
||||||
|
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||||
|
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||||
|
if form_data.invite:
|
||||||
|
auth_service.bind_workspace_with_invite(db=db,
|
||||||
|
user=user,
|
||||||
|
invite_token=form_data.invite,
|
||||||
|
workspace_id=invite_info.workspace_id)
|
||||||
|
except BusinessException as e:
|
||||||
|
# 用户不存在且有邀请码,尝试注册
|
||||||
|
if e.code == BizCode.USER_NOT_FOUND:
|
||||||
|
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||||
|
user = auth_service.register_user_with_invite(
|
||||||
|
db=db,
|
||||||
|
email=form_data.email,
|
||||||
|
password=form_data.password,
|
||||||
|
invite_token=form_data.invite,
|
||||||
|
workspace_id=invite_info.workspace_id
|
||||||
|
)
|
||||||
|
elif e.code == BizCode.PASSWORD_ERROR:
|
||||||
|
# 用户存在但密码错误
|
||||||
|
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||||
|
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||||
|
else:
|
||||||
|
# 其他认证失败情况,直接抛出
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# 尝试认证用户
|
||||||
|
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||||
|
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
|
||||||
|
# 其他认证失败情况,直接抛出
|
||||||
|
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||||
|
|
||||||
|
# 创建 tokens
|
||||||
|
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||||
|
refresh_token, refresh_token_id = security.create_refresh_token(subject=user.id)
|
||||||
|
|
||||||
|
# 计算过期时间
|
||||||
|
access_expires_at = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
# 单点登录会话管理
|
||||||
|
if settings.ENABLE_SINGLE_SESSION:
|
||||||
|
await SessionService.invalidate_old_session(user.id, access_token_id)
|
||||||
|
await SessionService.set_user_active_session(user.id, access_token_id, access_expires_at)
|
||||||
|
|
||||||
|
# 更新最后登录时间
|
||||||
|
user_service.update_last_login_time(db, user.id)
|
||||||
|
|
||||||
|
auth_logger.info(f"用户 {user.username} 登录成功")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=Token(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_at=access_expires_at,
|
||||||
|
refresh_expires_at=refresh_expires_at
|
||||||
|
),
|
||||||
|
msg="登录成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=ApiResponse)
|
||||||
|
async def refresh_token(
|
||||||
|
refresh_request: RefreshTokenRequest,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""刷新token"""
|
||||||
|
auth_logger.info("收到token刷新请求")
|
||||||
|
|
||||||
|
# 验证 refresh token
|
||||||
|
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||||
|
if not userId:
|
||||||
|
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
|
# 检查用户是否存在
|
||||||
|
user = auth_service.get_user_by_id(db, userId)
|
||||||
|
if not user:
|
||||||
|
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||||
|
|
||||||
|
# 检查 refresh token 黑名单
|
||||||
|
if settings.ENABLE_SINGLE_SESSION:
|
||||||
|
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||||
|
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||||
|
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||||
|
|
||||||
|
# 生成新 tokens
|
||||||
|
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||||
|
new_refresh_token, new_refresh_token_id = security.create_refresh_token(subject=user.id)
|
||||||
|
|
||||||
|
# 计算过期时间
|
||||||
|
access_expires_at = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
refresh_expires_at = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
# 单点登录会话管理
|
||||||
|
if settings.ENABLE_SINGLE_SESSION:
|
||||||
|
# 将旧 refresh token 加入黑名单
|
||||||
|
old_refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||||
|
if old_refresh_token_id:
|
||||||
|
await SessionService.blacklist_token(old_refresh_token_id)
|
||||||
|
|
||||||
|
# 更新会话
|
||||||
|
await SessionService.invalidate_old_session(user.id, new_access_token_id)
|
||||||
|
await SessionService.set_user_active_session(user.id, new_access_token_id, access_expires_at)
|
||||||
|
|
||||||
|
auth_logger.info(f"用户 {user.id} token刷新成功")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=Token(
|
||||||
|
access_token=new_access_token,
|
||||||
|
refresh_token=new_refresh_token,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_at=access_expires_at,
|
||||||
|
refresh_expires_at=refresh_expires_at
|
||||||
|
),
|
||||||
|
msg="token刷新成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", response_model=ApiResponse)
|
||||||
|
async def logout(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""登出当前用户:加入token黑名单并清理会话"""
|
||||||
|
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||||
|
|
||||||
|
token_id = security.get_token_id(token)
|
||||||
|
if not token_id:
|
||||||
|
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
|
# 加入黑名单
|
||||||
|
await SessionService.blacklist_token(token_id)
|
||||||
|
|
||||||
|
# 清理会话
|
||||||
|
if settings.ENABLE_SINGLE_SESSION:
|
||||||
|
await SessionService.clear_user_session(current_user.username)
|
||||||
|
|
||||||
|
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||||
|
return success(msg="登出成功")
|
||||||
|
|
||||||
447
app/controllers/chunk_controller.py
Normal file
447
app/controllers/chunk_controller.py
Normal file
@@ -0,0 +1,447 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models.document_model import Document
|
||||||
|
from app.models import knowledge_model, knowledgeshare_model
|
||||||
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
|
from app.schemas import chunk_schema
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||||
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# Obtain a dedicated API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/chunks",
|
||||||
|
tags=["chunks"],
|
||||||
|
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
|
||||||
|
async def get_preview_chunks(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
|
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Paged query document block preview list
|
||||||
|
- Support filtering by document_id
|
||||||
|
- Support keyword search for segmented content
|
||||||
|
- Return paging metadata + file list
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Paged query document block preview list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||||
|
# 1. parameter validation
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Obtain knowledge base information
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
# 3. Check if the document exists
|
||||||
|
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||||
|
if not db_document:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Check if the file exists
|
||||||
|
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
|
||||||
|
if not db_file:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Check if the file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="File not found (possibly deleted)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. Document parsing & segmentation
|
||||||
|
def progress_callback(prog=None, msg=None):
|
||||||
|
print(f"prog: {prog} msg: {msg}\n")
|
||||||
|
# Prepare to configure vision_model information
|
||||||
|
vision_model = QWenCV(
|
||||||
|
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||||
|
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||||
|
lang="Chinese", # Default to Chinese
|
||||||
|
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||||
|
)
|
||||||
|
from app.core.rag.app.naive import chunk
|
||||||
|
res = chunk(filename=file_path,
|
||||||
|
from_page=0,
|
||||||
|
to_page=5,
|
||||||
|
callback=progress_callback,
|
||||||
|
vision_model=vision_model,
|
||||||
|
parser_config=db_document.parser_config,
|
||||||
|
is_root=False)
|
||||||
|
|
||||||
|
start_index = (page - 1) * pagesize
|
||||||
|
end_index = start_index + pagesize
|
||||||
|
# Use slicing to obtain the data of the current page
|
||||||
|
paginated_chunk_str_list = res[start_index:end_index]
|
||||||
|
chunks = []
|
||||||
|
for idx, item in enumerate(paginated_chunk_str_list):
|
||||||
|
metadata = {
|
||||||
|
"doc_id": uuid.uuid4().hex,
|
||||||
|
"file_id": str(db_document.file_id),
|
||||||
|
"file_name": db_document.file_name,
|
||||||
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
|
"document_id": str(db_document.id),
|
||||||
|
"knowledge_id": str(db_document.kb_id),
|
||||||
|
"sort_id": idx,
|
||||||
|
"status": 1,
|
||||||
|
}
|
||||||
|
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
||||||
|
|
||||||
|
# 8. Return structured response
|
||||||
|
total = len(res)
|
||||||
|
result = {
|
||||||
|
"items": chunks,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page * pagesize < total else False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
|
||||||
|
return success(data=result, msg="Querying the document block preview list succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
|
||||||
|
async def get_chunks(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
|
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Paged query document chunk list
|
||||||
|
- Support filtering by document_id
|
||||||
|
- Support keyword search for segmented content
|
||||||
|
- Return paging metadata + file list
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Paged query document chunk list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||||
|
# 1. parameter validation
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Obtain knowledge base information
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Execute paged query
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start executing document chunk query")
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True)
|
||||||
|
api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Document chunk query failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Query failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return structured response
|
||||||
|
result = {
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page * pagesize < total else False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(data=result, msg="Query of document chunk list succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
|
||||||
|
async def create_chunk(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
create_data: chunk_schema.ChunkCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
create chunk
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# 1. Obtain knowledge base information
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
# 1. Obtain document information
|
||||||
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
|
if not db_document:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
# 2. Get the sort ID
|
||||||
|
sort_id = 0
|
||||||
|
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
||||||
|
if items:
|
||||||
|
sort_id = items[0].metadata["sort_id"]
|
||||||
|
sort_id = sort_id + 1
|
||||||
|
|
||||||
|
doc_id = uuid.uuid4().hex
|
||||||
|
metadata = {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"file_id": str(db_document.file_id),
|
||||||
|
"file_name": db_document.file_name,
|
||||||
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
|
"document_id": str(document_id),
|
||||||
|
"knowledge_id": str(kb_id),
|
||||||
|
"sort_id": sort_id,
|
||||||
|
"status": 1,
|
||||||
|
}
|
||||||
|
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
|
||||||
|
# 3. Segmented vector storage
|
||||||
|
vector_service.add_chunks([chunk])
|
||||||
|
|
||||||
|
# 4.update chunk_num
|
||||||
|
db_document.chunk_num += 1
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return success(data=chunk, msg="Document chunk creation successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
|
async def get_chunk(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
doc_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieve document chunk information based on doc_id
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Obtain document chunk information: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# 1. Obtain knowledge base information
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||||
|
if total:
|
||||||
|
return success(data=items[0], msg="Document chunk query successful")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document chunk does not exist or you do not have access"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
|
async def update_chunk(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
doc_id: str,
|
||||||
|
update_data: chunk_schema.ChunkUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update document chunk content
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={update_data.content}, username: {current_user.username}")
|
||||||
|
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||||
|
if total:
|
||||||
|
chunk = items[0]
|
||||||
|
chunk.page_content = update_data.content
|
||||||
|
vector_service.update_by_segment(chunk)
|
||||||
|
return success(data=chunk, msg="The document chunk has been successfully updated")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document chunk does not exist or you do not have access to it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
|
async def delete_chunk(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
doc_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
delete document chunk
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Request to delete document chunk: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
if vector_service.text_exists(doc_id):
|
||||||
|
vector_service.delete_by_ids([doc_id])
|
||||||
|
# 更新 chunk_num
|
||||||
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
|
db_document.chunk_num -= 1
|
||||||
|
db.commit()
|
||||||
|
return success(msg="The document chunk has been successfully deleted")
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document chunk does not exist or you do not have access to it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/retrieve_type", response_model=ApiResponse)
|
||||||
|
def get_retrieve_types():
|
||||||
|
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
|
||||||
|
async def retrieve_chunks(
|
||||||
|
retrieve_data: chunk_schema.ChunkRetrieve,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
retrieve chunk
|
||||||
|
"""
|
||||||
|
api_logger.info(f"retrieve chunk: query={retrieve_data.query}, username: {current_user.username}")
|
||||||
|
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
existing_ids = knowledge_service.get_chunded_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
share_ids = knowledge_service.get_chunded_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
if share_ids:
|
||||||
|
filters = [
|
||||||
|
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
|
||||||
|
]
|
||||||
|
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||||
|
db=db,
|
||||||
|
filters=filters,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
existing_ids.extend(items)
|
||||||
|
if not existing_ids:
|
||||||
|
return success(data=[], msg="retrieval successful")
|
||||||
|
kb_id = existing_ids[0]
|
||||||
|
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||||
|
indices = ",".join(uuid_strs)
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||||
|
match retrieve_data.retrieve_type:
|
||||||
|
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||||
|
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||||
|
return success(data=rs, msg="retrieval successful")
|
||||||
|
case chunk_schema.RetrieveType.SEMANTIC:
|
||||||
|
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||||
|
return success(data=rs, msg="retrieval successful")
|
||||||
|
case _:
|
||||||
|
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||||
|
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||||
|
# Efficient deduplication
|
||||||
|
seen_ids = set()
|
||||||
|
unique_rs = []
|
||||||
|
for doc in rs1 + rs2:
|
||||||
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
|
unique_rs.append(doc)
|
||||||
|
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||||
|
return success(data=rs, msg="retrieval successful")
|
||||||
341
app/controllers/document_controller.py
Normal file
341
app/controllers/document_controller.py
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models import document_model
|
||||||
|
from app.schemas import document_schema
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.services import document_service, file_service, knowledge_service
|
||||||
|
from app.controllers import file_controller
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# Obtain a dedicated API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/documents",
|
||||||
|
tags=["documents"],
|
||||||
|
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
|
||||||
|
async def get_documents(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
parent_id: uuid.UUID,
|
||||||
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||||
|
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||||
|
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||||
|
document_ids: Optional[str] = Query(None, description="document ids, separated by commas"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Paged query document list
|
||||||
|
- Support filtering by kb_id and parent_id
|
||||||
|
- Support keyword search for file names
|
||||||
|
- Support dynamic sorting
|
||||||
|
- Return paging metadata + file list
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Query document list: kb_id={kb_id}, page={page}, pagesize={pagesize}, keywords={keywords}, document_ids={document_ids}, username: {current_user.username}")
|
||||||
|
# 1. parameter validation
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Construct query conditions
|
||||||
|
filters = [
|
||||||
|
document_model.Document.kb_id == kb_id,
|
||||||
|
document_model.Document.status == 1
|
||||||
|
]
|
||||||
|
|
||||||
|
if parent_id:
|
||||||
|
files = file_service.get_files_by_parent_id(db=db, parent_id=parent_id, current_user=current_user)
|
||||||
|
files_ids = [item.id for item in files]
|
||||||
|
filters.append(document_model.Document.file_id.in_(files_ids))
|
||||||
|
|
||||||
|
# Keyword search (fuzzy matching of file name)
|
||||||
|
if keywords:
|
||||||
|
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||||
|
filters.append(document_model.Document.file_name.ilike(f"%{keywords}%"))
|
||||||
|
# document ids
|
||||||
|
if document_ids:
|
||||||
|
filters.append(document_model.Document.id.in_(document_ids.split(',')))
|
||||||
|
|
||||||
|
# 3. Execute paged query
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start executing document paging query")
|
||||||
|
total, items = document_service.get_documents_paginated(
|
||||||
|
db=db,
|
||||||
|
filters=filters,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
orderby=orderby,
|
||||||
|
desc=desc,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"Document query successful: total={total}, returned={len(items)} records")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Document query failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Query failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return structured response
|
||||||
|
result = {
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page * pagesize < total else False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(data=result, msg="Query of document list succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/document", response_model=ApiResponse)
|
||||||
|
async def create_document(
|
||||||
|
create_data: document_schema.DocumentCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
create document
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Create document request: file_name={create_data.file_name}, kb_id={create_data.kb_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start creating a document: {create_data.file_name}")
|
||||||
|
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||||
|
api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})")
|
||||||
|
return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{document_id}", response_model=ApiResponse)
|
||||||
|
async def get_document(
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieve document information based on document_id
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Obtain document information: document_id={document_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Query document information from the database
|
||||||
|
api_logger.debug(f"query documentation: {document_id}")
|
||||||
|
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||||
|
if not db_document:
|
||||||
|
api_logger.warning(f"The document does not exist or you do not have access: document_id={document_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document does not exist or you do not have access"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})")
|
||||||
|
return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Document query failed: document_id={document_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{document_id}", response_model=ApiResponse)
|
||||||
|
async def update_document(
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
update_data: document_schema.DocumentUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update document information
|
||||||
|
"""
|
||||||
|
# 1. Check if the document exists
|
||||||
|
api_logger.debug(f"Query the document to be updated: {document_id}")
|
||||||
|
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||||
|
|
||||||
|
if not db_document:
|
||||||
|
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. If updating the status, synchronize the document status switch to whether it can be retrieved from the vector database
|
||||||
|
update_dict = update_data.dict(exclude_unset=True)
|
||||||
|
if "status" in update_dict:
|
||||||
|
new_status = update_dict["status"]
|
||||||
|
if new_status != db_document.status:
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
vector_service.change_status_by_document_id(document_id=str(document_id), status=new_status)
|
||||||
|
|
||||||
|
# 3. Update fields (only update non-null fields)
|
||||||
|
api_logger.debug(f"Start updating the document fields: {document_id}")
|
||||||
|
updated_fields = []
|
||||||
|
for field, value in update_dict.items():
|
||||||
|
if hasattr(db_document, field):
|
||||||
|
old_value = getattr(db_document, field)
|
||||||
|
if old_value != value:
|
||||||
|
# update value
|
||||||
|
setattr(db_document, field, value)
|
||||||
|
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||||
|
|
||||||
|
if updated_fields:
|
||||||
|
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||||
|
|
||||||
|
db_document.updated_at = datetime.datetime.now()
|
||||||
|
|
||||||
|
# 4. Save to database
|
||||||
|
try:
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_document)
|
||||||
|
api_logger.info(f"The document has been successfully updated: {db_document.file_name} (ID: {db_document.id})")
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
api_logger.error(f"Document update failed: document_id={document_id} - {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Document update failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Return the updated document
|
||||||
|
return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{document_id}", response_model=ApiResponse)
|
||||||
|
async def delete_document(
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete document
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Request to delete document: document_id={document_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Check if the document exists
|
||||||
|
api_logger.debug(f"Check whether the document exists: {document_id}")
|
||||||
|
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||||
|
|
||||||
|
if not db_document:
|
||||||
|
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
file_id = db_document.file_id
|
||||||
|
|
||||||
|
# 2. Delete document
|
||||||
|
api_logger.debug(f"Perform document delete: {db_document.file_name} (ID: {document_id})")
|
||||||
|
db.delete(db_document)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# 3. Delete file
|
||||||
|
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||||
|
|
||||||
|
# 4. Delete vector index
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||||
|
|
||||||
|
api_logger.info(f"The document has been successfully deleted: {db_document.file_name} (ID: {document_id})")
|
||||||
|
return success(msg="The document has been successfully deleted")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to delete from the document: document_id={document_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{document_id}/chunks", response_model=ApiResponse)
|
||||||
|
async def parse_documents(
|
||||||
|
document_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
parse document
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Request to parse document: document_id={document_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Check if the document exists
|
||||||
|
api_logger.debug(f"Check whether the document exists: {document_id}")
|
||||||
|
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||||
|
|
||||||
|
if not db_document:
|
||||||
|
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The document does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Check if the file exists
|
||||||
|
api_logger.debug(f"Check whether the file exists: {db_document.file_id}")
|
||||||
|
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
|
||||||
|
|
||||||
|
if not db_file:
|
||||||
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={db_document.file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Check if the file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="File not found (possibly deleted)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Obtain knowledge base information
|
||||||
|
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Task: Document parsing, vectorization, and storage
|
||||||
|
# from app.tasks import parse_document
|
||||||
|
# parse_document(file_path, document_id)
|
||||||
|
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
||||||
|
result = {
|
||||||
|
"task_id": task.id
|
||||||
|
}
|
||||||
|
return success(data=result, msg="Task accepted. The document is being processed in the background.")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to parse document: document_id={document_id} - {str(e)}")
|
||||||
|
raise
|
||||||
453
app/controllers/file_controller.py
Normal file
453
app/controllers/file_controller.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models import file_model
|
||||||
|
from app.schemas import file_schema, document_schema
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.services import file_service, document_service
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# Obtain a dedicated API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/files",
|
||||||
|
tags=["files"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse)
|
||||||
|
async def get_files(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
parent_id: uuid.UUID,
|
||||||
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||||
|
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||||
|
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Paged query file list
|
||||||
|
- Support filtering by kb_id and parent_id
|
||||||
|
- Support keyword search for file names
|
||||||
|
- Support dynamic sorting
|
||||||
|
- Return paging metadata + file list
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||||
|
# 1. parameter validation
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Construct query conditions
|
||||||
|
filters = [
|
||||||
|
file_model.File.kb_id == kb_id
|
||||||
|
]
|
||||||
|
if parent_id:
|
||||||
|
filters.append(file_model.File.parent_id == parent_id)
|
||||||
|
# Keyword search (fuzzy matching of file name)
|
||||||
|
if keywords:
|
||||||
|
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
||||||
|
|
||||||
|
# 3. Execute paged query
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start executing file paging query")
|
||||||
|
total, items = file_service.get_files_paginated(
|
||||||
|
db=db,
|
||||||
|
filters=filters,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
orderby=orderby,
|
||||||
|
desc=desc,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Query failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return structured response
|
||||||
|
result = {
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page * pagesize < total else False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(data=result, msg="Query of file list succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/folder", response_model=ApiResponse)
|
||||||
|
def create_folder(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
parent_id: uuid.UUID,
|
||||||
|
folder_name: str = '/',
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new folder
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start creating a folder: {folder_name}")
|
||||||
|
create_folder = file_schema.FileCreate(
|
||||||
|
kb_id=kb_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
file_name=folder_name,
|
||||||
|
file_ext='folder',
|
||||||
|
file_size=0,
|
||||||
|
)
|
||||||
|
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
||||||
|
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
||||||
|
return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/file", response_model=ApiResponse)
|
||||||
|
async def upload_file(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
parent_id: uuid.UUID,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
upload file
|
||||||
|
"""
|
||||||
|
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# Read the contents of the file
|
||||||
|
contents = await file.read()
|
||||||
|
# Check file size
|
||||||
|
file_size = len(contents)
|
||||||
|
print(f"file size: {file_size} byte")
|
||||||
|
if file_size == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The file is empty."
|
||||||
|
)
|
||||||
|
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||||
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the extension using `os.path.splitext`
|
||||||
|
_, file_extension = os.path.splitext(file.filename)
|
||||||
|
upload_file = file_schema.FileCreate(
|
||||||
|
kb_id=kb_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
file_name=file.filename,
|
||||||
|
file_ext=file_extension.lower(),
|
||||||
|
file_size=file_size,
|
||||||
|
)
|
||||||
|
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||||
|
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}{file_extension}")
|
||||||
|
|
||||||
|
# Save file
|
||||||
|
with open(save_path, "wb") as f:
|
||||||
|
f.write(contents)
|
||||||
|
|
||||||
|
# Verify whether the file has been saved successfully
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="File save failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a document
|
||||||
|
create_data = document_schema.DocumentCreate(
|
||||||
|
kb_id=kb_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
file_id=db_file.id,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||||
|
|
||||||
|
api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||||
|
return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/customtext", response_model=ApiResponse)
|
||||||
|
async def custom_text(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
parent_id: uuid.UUID,
|
||||||
|
create_data: file_schema.CustomTextFileCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
custom text
|
||||||
|
"""
|
||||||
|
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# Check file content size
|
||||||
|
# 将内容编码为字节(UTF-8)
|
||||||
|
content_bytes = create_data.content.encode('utf-8')
|
||||||
|
file_size = len(content_bytes)
|
||||||
|
print(f"file size: {file_size} byte")
|
||||||
|
if file_size == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The content is empty."
|
||||||
|
)
|
||||||
|
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||||
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
upload_file = file_schema.FileCreate(
|
||||||
|
kb_id=kb_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
file_name=f"{create_data.title}.txt",
|
||||||
|
file_ext=".txt",
|
||||||
|
file_size=file_size,
|
||||||
|
)
|
||||||
|
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||||
|
|
||||||
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
|
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
||||||
|
|
||||||
|
# Save file
|
||||||
|
with open(save_path, "wb") as f:
|
||||||
|
f.write(content_bytes)
|
||||||
|
|
||||||
|
# Verify whether the file has been saved successfully
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="File save failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a document
|
||||||
|
create_document_data = document_schema.DocumentCreate(
|
||||||
|
kb_id=kb_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
file_id=db_file.id,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||||
|
|
||||||
|
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||||
|
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{file_id}", response_model=Any)
|
||||||
|
async def get_file(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Download the file based on the file_id
|
||||||
|
- Query file information from the database
|
||||||
|
- Construct the file path and check if it exists
|
||||||
|
- Return a FileResponse to download the file
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
|
||||||
|
|
||||||
|
# 1. Query file information from the database
|
||||||
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
if not db_file:
|
||||||
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Check if the file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="File not found (possibly deleted)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4.Return FileResponse (automatically handle download)
|
||||||
|
return FileResponse(
|
||||||
|
path=file_path,
|
||||||
|
filename=db_file.file_name, # Use original file name
|
||||||
|
media_type="application/octet-stream" # Universal binary stream type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{file_id}", response_model=ApiResponse)
|
||||||
|
async def update_file(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
update_data: file_schema.FileUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update file information (such as file name)
|
||||||
|
- Only specified fields such as file_name are allowed to be modified
|
||||||
|
"""
|
||||||
|
api_logger.debug(f"Query the file to be updated: {file_id}")
|
||||||
|
|
||||||
|
# 1. Check if the file exists
|
||||||
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
|
if not db_file:
|
||||||
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Update fields (only update non-null fields)
|
||||||
|
api_logger.debug(f"Start updating the file fields: {file_id}")
|
||||||
|
updated_fields = []
|
||||||
|
for field, value in update_data.items():
|
||||||
|
if hasattr(db_file, field):
|
||||||
|
old_value = getattr(db_file, field)
|
||||||
|
if old_value != value:
|
||||||
|
# update value
|
||||||
|
setattr(db_file, field, value)
|
||||||
|
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||||
|
|
||||||
|
if updated_fields:
|
||||||
|
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||||
|
|
||||||
|
# 3. Save to database
|
||||||
|
try:
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_file)
|
||||||
|
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"File update failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return the updated file
|
||||||
|
return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{file_id}", response_model=ApiResponse)
|
||||||
|
async def delete_file(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete a file or folder
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
|
||||||
|
await _delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||||
|
return success(msg="File deleted successfully")
|
||||||
|
|
||||||
|
async def _delete_file(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Delete a file or folder
|
||||||
|
"""
|
||||||
|
# 1. Check if the file exists
|
||||||
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
|
if not db_file:
|
||||||
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Construct physical path
|
||||||
|
file_path = Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.id)
|
||||||
|
) if db_file.file_ext == 'folder' else Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Delete physical files/folders
|
||||||
|
try:
|
||||||
|
if file_path.exists():
|
||||||
|
if db_file.file_ext == 'folder':
|
||||||
|
shutil.rmtree(file_path) # Recursively delete folders
|
||||||
|
else:
|
||||||
|
file_path.unlink() # Delete a single file
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete physical file/folder: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4.Delete db_file
|
||||||
|
if db_file.file_ext == 'folder':
|
||||||
|
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
||||||
|
db.delete(db_file)
|
||||||
|
db.commit()
|
||||||
305
app/controllers/knowledge_controller.py
Normal file
305
app/controllers/knowledge_controller.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy import or_
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models import knowledge_model, document_model, file_model
|
||||||
|
from app.schemas import knowledge_schema
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.services import knowledge_service, document_service
|
||||||
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# Obtain a dedicated API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/knowledges",
|
||||||
|
tags=["knowledges"],
|
||||||
|
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/knowledgetype", response_model=ApiResponse)
|
||||||
|
def get_knowledge_types():
|
||||||
|
return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/permissiontype", response_model=ApiResponse)
|
||||||
|
def get_permission_types():
|
||||||
|
return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/parsertype", response_model=ApiResponse)
|
||||||
|
def get_parser_types():
|
||||||
|
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/knowledges", response_model=ApiResponse)
|
||||||
|
async def get_knowledges(
|
||||||
|
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
|
||||||
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||||
|
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||||
|
keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"),
|
||||||
|
kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Query the knowledge base list in pages
|
||||||
|
- Support filtering by parent_id
|
||||||
|
- Support keyword search for knowledge base names
|
||||||
|
- Support dynamic sorting
|
||||||
|
- Return paging metadata + file list
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Query knowledge base list: workspace_id={current_user.current_workspace_id}, page={page}, pagesize={pagesize}, keywords={keywords}, kb_ids={kb_ids}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# 1. parameter validation
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Construct query conditions
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.workspace_id == current_user.current_workspace_id
|
||||||
|
]
|
||||||
|
if parent_id:
|
||||||
|
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
|
||||||
|
|
||||||
|
# Keyword search (fuzzy matching of knowledge base name)
|
||||||
|
if keywords:
|
||||||
|
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||||
|
filters.append(
|
||||||
|
or_(
|
||||||
|
knowledge_model.Knowledge.name.ilike(f"%{keywords}%"),
|
||||||
|
knowledge_model.Knowledge.description.ilike(f"%{keywords}%")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Knowledge base ids
|
||||||
|
if kb_ids:
|
||||||
|
filters.append(knowledge_model.Knowledge.id.in_(kb_ids.split(',')))
|
||||||
|
else:
|
||||||
|
filters.append(knowledge_model.Knowledge.status != 2)
|
||||||
|
# 3. Execute paged query
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start executing knowledge base paging query")
|
||||||
|
total, items = knowledge_service.get_knowledges_paginated(
|
||||||
|
db=db,
|
||||||
|
filters=filters,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
orderby=orderby,
|
||||||
|
desc=desc,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"Knowledge base query successful: total={total}, returned={len(items)} records")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Knowledge base query failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Query failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return structured response
|
||||||
|
result = {
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page*pagesize < total else False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(data=result, msg="Query of knowledge base list successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/knowledge", response_model=ApiResponse)
|
||||||
|
async def create_knowledge(
|
||||||
|
create_data: knowledge_schema.KnowledgeCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
create knowledge
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Request to create a knowledge base: name={create_data.name}, workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start creating the knowledge base: {create_data.name}")
|
||||||
|
# 1. Check if the knowledge base name already exists
|
||||||
|
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=create_data.name, current_user=current_user)
|
||||||
|
if db_knowledge_exist:
|
||||||
|
api_logger.warning(f"The knowledge base name already exists: {create_data.name}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The knowledge base name already exists: {create_data.name}"
|
||||||
|
)
|
||||||
|
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user)
|
||||||
|
api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||||
|
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{knowledge_id}", response_model=ApiResponse)
|
||||||
|
async def get_knowledge(
|
||||||
|
knowledge_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieve knowledge base information based on knowledge_id
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Query knowledge base information from the database
|
||||||
|
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||||
|
if not db_knowledge:
|
||||||
|
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||||
|
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Knowledge base query failed: knowledge_id={knowledge_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{knowledge_id}", response_model=ApiResponse)
|
||||||
|
async def update_knowledge(
|
||||||
|
knowledge_id: uuid.UUID,
|
||||||
|
update_data: knowledge_schema.KnowledgeUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||||
|
db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user)
|
||||||
|
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated")
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_knowledge(
|
||||||
|
knowledge_id: uuid.UUID,
|
||||||
|
update_data: knowledge_schema.KnowledgeUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> knowledge_schema.Knowledge:
|
||||||
|
"""
|
||||||
|
Update knowledge base information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. Check whether the knowledge base exists
|
||||||
|
api_logger.debug(f"Query the knowledge base to be updated: {knowledge_id}")
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||||
|
|
||||||
|
if not db_knowledge:
|
||||||
|
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. If updating the embedding_id, delete the knowledge base vector index, reset all document parsing progress to 0, and set chunk_num to 0
|
||||||
|
update_dict = update_data.dict(exclude_unset=True)
|
||||||
|
if "name" in update_dict:
|
||||||
|
name = update_dict["name"]
|
||||||
|
if name != db_knowledge.name:
|
||||||
|
# Check if the knowledge base name already exists
|
||||||
|
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=name, current_user=current_user)
|
||||||
|
if db_knowledge_exist:
|
||||||
|
api_logger.warning(f"The knowledge base name already exists: {name}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The knowledge base name already exists: {name}"
|
||||||
|
)
|
||||||
|
if "embedding_id" in update_dict:
|
||||||
|
embedding_id = update_dict["embedding_id"]
|
||||||
|
if embedding_id != db_knowledge.embedding_id:
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
vector_service.delete()
|
||||||
|
document_service.reset_documents_progress_by_kb_id(db, kb_id=db_knowledge.id, current_user=current_user)
|
||||||
|
|
||||||
|
# 2. Update fields (only update non-null fields)
|
||||||
|
api_logger.debug(f"Start updating the knowledge base fields: {knowledge_id}")
|
||||||
|
updated_fields = []
|
||||||
|
for field, value in update_data.dict(exclude_unset=True).items():
|
||||||
|
if hasattr(db_knowledge, field):
|
||||||
|
old_value = getattr(db_knowledge, field)
|
||||||
|
if old_value != value:
|
||||||
|
# update value
|
||||||
|
setattr(db_knowledge, field, value)
|
||||||
|
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||||
|
|
||||||
|
if updated_fields:
|
||||||
|
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||||
|
|
||||||
|
db_knowledge.updated_at = datetime.datetime.now()
|
||||||
|
|
||||||
|
# 3. Save to database
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_knowledge)
|
||||||
|
api_logger.info(f"The knowledge base has been successfully updated: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||||
|
|
||||||
|
# 4. Return the updated knowledge base
|
||||||
|
return db_knowledge
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
api_logger.error(f"Knowledge base update failed: knowledge_id={knowledge_id} - {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Knowledge base update failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{knowledge_id}", response_model=ApiResponse)
|
||||||
|
async def delete_knowledge(
|
||||||
|
knowledge_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Soft-delete knowledge base
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Request to delete knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Check whether the knowledge base exists
|
||||||
|
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||||
|
|
||||||
|
if not db_knowledge:
|
||||||
|
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Soft-delete knowledge base
|
||||||
|
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||||
|
db_knowledge.status = 2
|
||||||
|
db.commit()
|
||||||
|
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||||
|
return success(msg="The knowledge base has been successfully deleted")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to delete from the knowledge base: knowledge_id={knowledge_id} - {str(e)}")
|
||||||
|
raise
|
||||||
199
app/controllers/knowledgeshare_controller.py
Normal file
199
app/controllers/knowledgeshare_controller.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models import knowledgeshare_model, knowledge_model
|
||||||
|
from app.schemas import knowledgeshare_schema, knowledge_schema
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.services import knowledgeshare_service, knowledge_service
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# Obtain a dedicated API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/knowledgeshares",
|
||||||
|
tags=["knowledgeshares"],
|
||||||
|
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/knowledgeshares", response_model=ApiResponse)
|
||||||
|
async def get_knowledgeshares(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||||
|
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Paged query knowledge base sharing list
|
||||||
|
- Support filtering by kb_id
|
||||||
|
- Support dynamic sorting
|
||||||
|
- Return paging metadata + share list
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Query knowledge base sharing list: workspace_id={current_user.current_workspace_id}, kb_id={kb_id}, page={page}, pagesize={pagesize}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# 1. parameter validation
|
||||||
|
if page < 1 or pagesize < 1:
|
||||||
|
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Construct query conditions
|
||||||
|
filters = [
|
||||||
|
knowledgeshare_model.KnowledgeShare.source_workspace_id == current_user.current_workspace_id,
|
||||||
|
knowledgeshare_model.KnowledgeShare.source_kb_id == kb_id
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. Execute paged query
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"Start executing knowledge base sharing and paging query")
|
||||||
|
total, items = knowledgeshare_service.get_knowledgeshares_paginated(
|
||||||
|
db=db,
|
||||||
|
filters=filters,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
orderby=orderby,
|
||||||
|
desc=desc,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"Knowledge base sharing query successful: total={total}, returned={len(items)} records")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Knowledge base sharing query failed: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Query failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return structured response
|
||||||
|
result = {
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page * pagesize < total else False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success(data=result, msg="Query of knowledge base sharing list successful")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/knowledgeshare", response_model=ApiResponse)
|
||||||
|
async def create_knowledgeshare(
|
||||||
|
create_data: knowledgeshare_schema.KnowledgeShareCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
create knowledgeshare
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Create a knowledge base sharing request: source_kb_id={create_data.source_kb_id}, source_workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1.Create a knowledge base with permission_id=knowledge_model.PermissionType.Share
|
||||||
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=create_data.source_kb_id, current_user=current_user)
|
||||||
|
knowledge = knowledge_schema.KnowledgeCreate(
|
||||||
|
workspace_id=create_data.target_workspace_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
parent_id=create_data.target_workspace_id,
|
||||||
|
name=db_knowledge.name,
|
||||||
|
description=db_knowledge.description,
|
||||||
|
avatar=db_knowledge.avatar,
|
||||||
|
type=db_knowledge.type,
|
||||||
|
permission_id=knowledge_model.PermissionType.Share,
|
||||||
|
embedding_id=db_knowledge.embedding_id,
|
||||||
|
reranker_id=db_knowledge.reranker_id,
|
||||||
|
llm_id=db_knowledge.llm_id,
|
||||||
|
image2text_id=db_knowledge.image2text_id,
|
||||||
|
doc_num=db_knowledge.doc_num,
|
||||||
|
chunk_num=db_knowledge.chunk_num,
|
||||||
|
parser_id=db_knowledge.parser_id,
|
||||||
|
parser_config=db_knowledge.parser_config
|
||||||
|
)
|
||||||
|
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=knowledge, current_user=current_user)
|
||||||
|
# 2. Create a knowledge base for sharing
|
||||||
|
api_logger.debug(f"Start creating the knowledge base sharing: {db_knowledge.name}")
|
||||||
|
create_data.target_kb_id = db_knowledge.id
|
||||||
|
db_knowledgeshare = knowledgeshare_service.create_knowledgeshare(db=db, knowledgeshare=create_data, current_user=current_user)
|
||||||
|
api_logger.info(f"The knowledge base sharing has been successfully created: (ID: {db_knowledgeshare.id})")
|
||||||
|
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="The knowledge base sharing has been successfully created")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"The creation of the knowledge base sharing failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{knowledgeshare_id}", response_model=ApiResponse)
|
||||||
|
async def get_knowledgeshare(
|
||||||
|
knowledgeshare_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieve knowledge base sharing information based on knowledgeshare_id
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Obtain details of the knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Query knowledge base sharing information from the database
|
||||||
|
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
|
||||||
|
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
|
||||||
|
if not db_knowledgeshare:
|
||||||
|
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base sharing does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Knowledge base sharing query successful: (ID: {db_knowledgeshare.id})")
|
||||||
|
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="Successfully obtained knowledge base sharing information")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Knowledge base sharing query failed: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{knowledgeshare_id}", response_model=ApiResponse)
|
||||||
|
async def delete_knowledgeshare(
|
||||||
|
knowledgeshare_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete knowledge base sharing
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Delete knowledge base sharing request: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Query knowledge base sharing information from the database
|
||||||
|
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
|
||||||
|
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
|
||||||
|
if not db_knowledgeshare:
|
||||||
|
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base sharing does not exist or access is denied"
|
||||||
|
)
|
||||||
|
# 2. Deleting shared knowledge base
|
||||||
|
knowledge_service.delete_knowledge_by_id(db, knowledge_id=db_knowledgeshare.target_kb_id ,current_user=current_user)
|
||||||
|
# 3. Delete knowledge base sharing
|
||||||
|
api_logger.debug(f"perform knowledge base sharing delete: (ID: {knowledgeshare_id})")
|
||||||
|
|
||||||
|
knowledgeshare_service.delete_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
|
||||||
|
api_logger.info(f"The knowledge base sharing has been successfully deleted: (ID: {knowledgeshare_id})")
|
||||||
|
return success(msg="The knowledge base sharing has been successfully deleted")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to delete from the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||||
|
raise
|
||||||
802
app/controllers/memory_agent_controller.py
Normal file
802
app/controllers/memory_agent_controller.py
Normal file
@@ -0,0 +1,802 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Optional, List
|
||||||
|
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.memory.utils.config.config_utils import get_model_config
|
||||||
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
|
from app.models import ModelApiKey, Knowledge
|
||||||
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.services import task_service, workspace_service
|
||||||
|
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from fastapi import APIRouter, Depends, File, UploadFile, Form
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 加载.env文件
|
||||||
|
load_dotenv()
|
||||||
|
# Get API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
memory_agent_service = MemoryAgentService()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/memory",
|
||||||
|
tags=["Memory"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config_id(config_id: int, db: Session) -> int:
|
||||||
|
"""
|
||||||
|
Validate and ensure config_id is available, valid, and exists in database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_id: Configuration ID to validate
|
||||||
|
db: Database session for checking existence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Validated config_id
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If config_id is None, invalid, or doesn't exist in database
|
||||||
|
"""
|
||||||
|
if config_id is None:
|
||||||
|
api_logger.info(f"config_id is required but was not provided")
|
||||||
|
config_id = os.getenv('config_id')
|
||||||
|
if config_id is None:
|
||||||
|
raise ValueError("config_id is required but was not provided")
|
||||||
|
|
||||||
|
|
||||||
|
# Check if config exists in database
|
||||||
|
try:
|
||||||
|
from app.models.data_config_model import DataConfig
|
||||||
|
from app.models.models_model import ModelConfig
|
||||||
|
|
||||||
|
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||||
|
if config is None:
|
||||||
|
error_msg = f"Configuration with config_id={config_id} does not exist in database"
|
||||||
|
api_logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
# Validate llm_id exists and is usable
|
||||||
|
if config.llm_id:
|
||||||
|
try:
|
||||||
|
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
|
||||||
|
if llm_config is None:
|
||||||
|
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
|
||||||
|
api_logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
if not llm_config.is_active:
|
||||||
|
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
|
||||||
|
api_logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error validating LLM model: {str(e)}"
|
||||||
|
api_logger.error(error_msg, exc_info=True)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
else:
|
||||||
|
api_logger.error(f"Config {config_id} has no llm_id set")
|
||||||
|
raise ValueError(f"Config {config_id} has no llm_id set")
|
||||||
|
|
||||||
|
# Validate embedding_id exists and is usable
|
||||||
|
if config.embedding_id:
|
||||||
|
try:
|
||||||
|
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
|
||||||
|
if embedding_config is None:
|
||||||
|
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
|
||||||
|
api_logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
if not embedding_config.is_active:
|
||||||
|
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
|
||||||
|
api_logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error validating embedding model: {str(e)}"
|
||||||
|
api_logger.error(error_msg, exc_info=True)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
else:
|
||||||
|
api_logger.error(f"Config {config_id} has no embedding_id set")
|
||||||
|
raise ValueError(f"Config {config_id} has no embedding_id set")
|
||||||
|
|
||||||
|
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
|
||||||
|
return config_id
|
||||||
|
except ValueError:
|
||||||
|
# Re-raise ValueError from above
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
|
||||||
|
api_logger.error(error_msg, exc_info=True)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health/status", response_model=ApiResponse)
|
||||||
|
async def get_health_status(
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get latest health status written by Celery periodic task
|
||||||
|
|
||||||
|
Returns health status information from Redis cache
|
||||||
|
"""
|
||||||
|
api_logger.info("Health status check requested")
|
||||||
|
try:
|
||||||
|
result = await memory_agent_service.get_health_status()
|
||||||
|
return success(data=result["status"])
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Health status check failed: {str(e)}")
|
||||||
|
return fail(BizCode.SERVICE_UNAVAILABLE, "健康状态查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/download_log")
|
||||||
|
async def download_log(
|
||||||
|
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Download or stream agent service log file
|
||||||
|
|
||||||
|
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||||
|
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_type: Log retrieval mode
|
||||||
|
- "file": Returns complete log file content in single response (default)
|
||||||
|
- "transmission": Real-time streaming of log content using Server-Sent Events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- file mode: ApiResponse with log content
|
||||||
|
- transmission mode: StreamingResponse with SSE
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||||
|
|
||||||
|
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||||
|
if log_type not in ["file", "transmission"]:
|
||||||
|
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||||
|
return fail(
|
||||||
|
BizCode.BAD_REQUEST,
|
||||||
|
"无效的log_type参数",
|
||||||
|
"log_type必须是'file'或'transmission'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Route to appropriate mode
|
||||||
|
if log_type == "file":
|
||||||
|
# File mode: Return complete log file content
|
||||||
|
try:
|
||||||
|
log_content = memory_agent_service.get_log_content()
|
||||||
|
return success(data=log_content)
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"Log content issue: {str(e)}")
|
||||||
|
return fail(BizCode.FILE_NOT_FOUND, str(e))
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Log reading failed: {str(e)}")
|
||||||
|
return fail(BizCode.FILE_READ_ERROR, "日志读取失败", str(e))
|
||||||
|
|
||||||
|
else: # log_type == "transmission"
|
||||||
|
# Transmission mode: Stream log content using SSE
|
||||||
|
try:
|
||||||
|
api_logger.info("Starting SSE log streaming")
|
||||||
|
return StreamingResponse(
|
||||||
|
memory_agent_service.stream_log_content(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no" # Disable nginx buffering
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to start log streaming: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/writer_service", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def write_server(
|
||||||
|
user_input: Write_UserInput,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Write service endpoint - processes write operations synchronously
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: Write request containing message and group_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response with write operation status
|
||||||
|
"""
|
||||||
|
# Validate config_id
|
||||||
|
try:
|
||||||
|
config_id = validate_config_id(user_input.config_id, db)
|
||||||
|
except ValueError as e:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||||
|
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||||
|
|
||||||
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
|
user_rag_memory_id = ''
|
||||||
|
|
||||||
|
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||||
|
if storage_type == 'rag':
|
||||||
|
if workspace_id:
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge:
|
||||||
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
else:
|
||||||
|
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
else:
|
||||||
|
api_logger.warning(f"workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
|
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
|
try:
|
||||||
|
result = await memory_agent_service.write_memory(
|
||||||
|
user_input.group_id,
|
||||||
|
user_input.message,
|
||||||
|
config_id,
|
||||||
|
storage_type,
|
||||||
|
user_rag_memory_id
|
||||||
|
)
|
||||||
|
return success(data=result, msg="写入成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Write operation error: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def write_server_async(
|
||||||
|
user_input: Write_UserInput,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async write service endpoint - enqueues write processing to Celery
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: Write request containing message and group_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task ID for tracking async operation
|
||||||
|
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||||
|
"""
|
||||||
|
# Validate config_id
|
||||||
|
try:
|
||||||
|
config_id = validate_config_id(user_input.config_id, db)
|
||||||
|
except ValueError as e:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||||
|
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||||
|
|
||||||
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
|
user_rag_memory_id = ''
|
||||||
|
if workspace_id:
|
||||||
|
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
|
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
try:
|
||||||
|
task = celery_app.send_task(
|
||||||
|
"app.core.memory.agent.write_message",
|
||||||
|
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
||||||
|
)
|
||||||
|
api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
|
||||||
|
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def read_server(
|
||||||
|
user_input: UserInput,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Read service endpoint - processes read operations synchronously
|
||||||
|
|
||||||
|
search_switch values:
|
||||||
|
- "0": Requires verification
|
||||||
|
- "1": No verification, direct split
|
||||||
|
- "2": Direct answer based on context
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: Read request with message, history, search_switch, and group_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response with query answer
|
||||||
|
"""
|
||||||
|
# Validate config_id
|
||||||
|
try:
|
||||||
|
config_id = validate_config_id(user_input.config_id, db)
|
||||||
|
except ValueError as e:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||||
|
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||||
|
|
||||||
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
|
user_rag_memory_id = ''
|
||||||
|
if workspace_id:
|
||||||
|
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
|
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
|
try:
|
||||||
|
result = await memory_agent_service.read_memory(
|
||||||
|
user_input.group_id,
|
||||||
|
user_input.message,
|
||||||
|
user_input.history,
|
||||||
|
user_input.search_switch,
|
||||||
|
config_id,
|
||||||
|
storage_type,
|
||||||
|
user_rag_memory_id
|
||||||
|
)
|
||||||
|
return success(data=result, msg="回复对话消息成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Read operation error: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/file", response_model=ApiResponse)
|
||||||
|
async def file_update(
|
||||||
|
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||||
|
model_id:str = Form(..., description="模型ID"),
|
||||||
|
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
文件上传接口 - 支持图片识别
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 上传的文件列表
|
||||||
|
metadata: 文件元数据(可选)
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文件处理结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
db_gen = get_db() # get_db 通常是一个生成器
|
||||||
|
db = next(db_gen)
|
||||||
|
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||||
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
|
file_content = []
|
||||||
|
try:
|
||||||
|
for file in files:
|
||||||
|
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||||
|
content = await file.read()
|
||||||
|
|
||||||
|
if file.content_type and file.content_type.startswith("image/"):
|
||||||
|
vision_model = QWenCV(
|
||||||
|
key=apiConfig.api_key,
|
||||||
|
model_name=apiConfig.model_name,
|
||||||
|
lang="Chinese",
|
||||||
|
base_url=apiConfig.api_base
|
||||||
|
)
|
||||||
|
description, token_count = vision_model.describe(content)
|
||||||
|
file_content.append(description)
|
||||||
|
api_logger.info(f"Image processed: {file.filename}, tokens: {token_count}")
|
||||||
|
else:
|
||||||
|
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||||
|
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||||
|
|
||||||
|
result_text = ';'.join(file_content)
|
||||||
|
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||||
|
|
||||||
|
return success(data=result_text, msg="转换文本成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/read_service_async", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def read_server_async(
|
||||||
|
user_input: UserInput,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
# Validate config_id
|
||||||
|
try:
|
||||||
|
config_id = validate_config_id(user_input.config_id, db)
|
||||||
|
except ValueError as e:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||||
|
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||||
|
|
||||||
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
if storage_type is None: storage_type = 'neo4j'
|
||||||
|
user_rag_memory_id = ''
|
||||||
|
if workspace_id:
|
||||||
|
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
|
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
try:
|
||||||
|
task = celery_app.send_task(
|
||||||
|
"app.core.memory.agent.read_message",
|
||||||
|
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
||||||
|
config_id, storage_type, user_rag_memory_id]
|
||||||
|
)
|
||||||
|
api_logger.info(f"Read task queued: {task.id}")
|
||||||
|
|
||||||
|
return success(data={"task_id": task.id}, msg="查询任务已提交")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Async read operation failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/read_result/", response_model=ApiResponse)
|
||||||
|
async def get_read_task_result(
|
||||||
|
task_id: str,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the status and result of an async read task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Celery task ID returned from /read_service_async
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task status and result if completed
|
||||||
|
|
||||||
|
Response format:
|
||||||
|
- PENDING: Task is waiting to be executed
|
||||||
|
- STARTED: Task has started
|
||||||
|
- SUCCESS: Task completed successfully, returns result
|
||||||
|
- FAILURE: Task failed, returns error message
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Read task status check requested for task {task_id}")
|
||||||
|
try:
|
||||||
|
result = task_service.get_task_memory_read_result(task_id)
|
||||||
|
status = result.get("status")
|
||||||
|
|
||||||
|
if status == "SUCCESS":
|
||||||
|
# 任务成功完成
|
||||||
|
task_result = result.get("result", {})
|
||||||
|
if isinstance(task_result, dict):
|
||||||
|
# 新格式:包含详细信息
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"result": task_result.get("result"),
|
||||||
|
"group_id": task_result.get("group_id"),
|
||||||
|
"elapsed_time": task_result.get("elapsed_time"),
|
||||||
|
"task_id": task_id
|
||||||
|
},
|
||||||
|
msg="查询任务已完成"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 旧格式:直接返回结果
|
||||||
|
return success(data=task_result, msg="查询任务已完成")
|
||||||
|
|
||||||
|
elif status == "FAILURE":
|
||||||
|
# 任务失败
|
||||||
|
error_info = result.get("result", "Unknown error")
|
||||||
|
if isinstance(error_info, dict):
|
||||||
|
error_msg = error_info.get("error", str(error_info))
|
||||||
|
else:
|
||||||
|
error_msg = str(error_info)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||||
|
|
||||||
|
elif status in ["PENDING", "STARTED"]:
|
||||||
|
# 任务进行中
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"status": status,
|
||||||
|
"task_id": task_id,
|
||||||
|
"message": "任务处理中,请稍后查询"
|
||||||
|
},
|
||||||
|
msg="查询任务处理中"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 未知状态
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"status": status,
|
||||||
|
"task_id": task_id
|
||||||
|
},
|
||||||
|
msg=f"任务状态: {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/write_result/", response_model=ApiResponse)
|
||||||
|
async def get_write_task_result(
|
||||||
|
task_id: str,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the status and result of an async write task
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Celery task ID returned from /writer_service_async
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task status and result if completed
|
||||||
|
|
||||||
|
Response format:
|
||||||
|
- PENDING: Task is waiting to be executed
|
||||||
|
- STARTED: Task has started
|
||||||
|
- SUCCESS: Task completed successfully, returns result
|
||||||
|
- FAILURE: Task failed, returns error message
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Write task status check requested for task {task_id}")
|
||||||
|
try:
|
||||||
|
result = task_service.get_task_memory_write_result(task_id)
|
||||||
|
status = result.get("status")
|
||||||
|
|
||||||
|
if status == "SUCCESS":
|
||||||
|
# 任务成功完成
|
||||||
|
task_result = result.get("result", {})
|
||||||
|
if isinstance(task_result, dict):
|
||||||
|
# 新格式:包含详细信息
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"result": task_result.get("result"),
|
||||||
|
"group_id": task_result.get("group_id"),
|
||||||
|
"elapsed_time": task_result.get("elapsed_time"),
|
||||||
|
"task_id": task_id
|
||||||
|
},
|
||||||
|
msg="写入任务已完成"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 旧格式:直接返回结果
|
||||||
|
return success(data=task_result, msg="写入任务已完成")
|
||||||
|
|
||||||
|
elif status == "FAILURE":
|
||||||
|
# 任务失败
|
||||||
|
error_info = result.get("result", "Unknown error")
|
||||||
|
if isinstance(error_info, dict):
|
||||||
|
error_msg = error_info.get("error", str(error_info))
|
||||||
|
else:
|
||||||
|
error_msg = str(error_info)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||||
|
|
||||||
|
elif status in ["PENDING", "STARTED"]:
|
||||||
|
# 任务进行中
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"status": status,
|
||||||
|
"task_id": task_id,
|
||||||
|
"message": "任务处理中,请稍后查询"
|
||||||
|
},
|
||||||
|
msg="写入任务处理中"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 未知状态
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"status": status,
|
||||||
|
"task_id": task_id
|
||||||
|
},
|
||||||
|
msg=f"任务状态: {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/status_type", response_model=ApiResponse)
|
||||||
|
async def status_type(
|
||||||
|
user_input: Write_UserInput,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Determine the type of user message (read or write)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: Request containing user message and group_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type classification result
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||||
|
try:
|
||||||
|
result = await memory_agent_service.classify_message_type(user_input.message)
|
||||||
|
return success(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Message type classification failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "类型判断失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 新增的三个接口路由 ====================
|
||||||
|
|
||||||
|
@router.get("/stats/types", response_model=ApiResponse)
|
||||||
|
async def get_knowledge_type_stats_api(
|
||||||
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
|
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||||
|
会对缺失类型补 0,返回字典形式。
|
||||||
|
可选按状态过滤。
|
||||||
|
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||||
|
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
||||||
|
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
# 获取数据库会话
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
|
||||||
|
# 调用service层函数
|
||||||
|
result = await memory_agent_service.get_knowledge_type_stats(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
only_active=only_active,
|
||||||
|
current_workspace_id=current_user.current_workspace_id,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=result, msg="获取知识库类型统计成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||||
|
async def get_hot_memory_tags_by_user_api(
|
||||||
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
|
limit: int = Query(20, description="返回标签数量限制"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取指定用户的热门记忆标签
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
[
|
||||||
|
{"name": "标签名", "frequency": 频次},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
return success(data=result, msg="获取热门记忆标签成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||||
|
async def get_user_profile_api(
|
||||||
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户详情,包含:
|
||||||
|
- name: 用户名字(直接使用 end_user_id)
|
||||||
|
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||||
|
- hot_tags: 4个热门记忆标签
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
{
|
||||||
|
"name": "用户名",
|
||||||
|
"tags": ["产品设计师", "旅行爱好者", "摄影发烧友"],
|
||||||
|
"hot_tags": [
|
||||||
|
{"name": "标签1", "frequency": 10},
|
||||||
|
{"name": "标签2", "frequency": 8},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
api_logger.info(f"User profile requested: end_user_id={end_user_id}, current_user={current_user.id}")
|
||||||
|
try:
|
||||||
|
result = await memory_agent_service.get_user_profile(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
current_user_id=str(current_user.id)
|
||||||
|
)
|
||||||
|
return success(data=result, msg="获取用户详情成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"User profile failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取用户详情失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# @router.get("/docs/api", response_model=ApiResponse)
|
||||||
|
# async def get_api_docs_api(
|
||||||
|
# file_path: Optional[str] = Query(None, description="API文档文件路径,不传则使用默认路径")
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# file_path: Optional path to API docs file. If None, uses default path.
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# Parsed API documentation including title, meta info, and sections
|
||||||
|
# """
|
||||||
|
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||||
|
# try:
|
||||||
|
# result = await memory_agent_service.get_api_docs(file_path)
|
||||||
|
|
||||||
|
# if result.get("success"):
|
||||||
|
# return success(msg=result["msg"], data=result["data"])
|
||||||
|
# else:
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.BAD_REQUEST,
|
||||||
|
# msg=result["msg"],
|
||||||
|
# error=result.get("data", {}).get("error", result.get("error_code", ""))
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# api_logger.error(f"API docs retrieval failed: {str(e)}")
|
||||||
|
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
|
||||||
516
app/controllers/memory_dashboard_controller.py
Normal file
516
app/controllers/memory_dashboard_controller.py
Normal file
@@ -0,0 +1,516 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.schemas.app_schema import App as AppSchema
|
||||||
|
|
||||||
|
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# 获取API专用日志器
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/dashboard",
|
||||||
|
tags=["Dashboard"],
|
||||||
|
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/total_end_users", response_model=ApiResponse)
|
||||||
|
def get_workspace_total_end_users(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户列表的总用户数
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||||
|
total_end_users = memory_dashboard_service.get_workspace_total_end_users(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
||||||
|
return success(data=total_end_users, msg="用户数量获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/end_users", response_model=ApiResponse)
|
||||||
|
async def get_workspace_end_users(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取工作空间的宿主列表
|
||||||
|
|
||||||
|
返回格式与原 memory_list 接口中的 end_users 字段相同
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||||
|
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
result = []
|
||||||
|
for end_user in end_users:
|
||||||
|
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||||
|
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||||
|
result.append(
|
||||||
|
{
|
||||||
|
'end_user':end_user,
|
||||||
|
'memory_num':memory_num
|
||||||
|
}
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||||
|
return success(data=result, msg="宿主列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/memory_increment", response_model=ApiResponse)
|
||||||
|
def get_workspace_memory_increment(
|
||||||
|
limit: int = Query(7, description="返回记录数"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取工作空间的记忆增量"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆增量")
|
||||||
|
memory_increment = memory_dashboard_service.get_workspace_memory_increment(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
|
||||||
|
return success(data=memory_increment, msg="记忆增量获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api_increment", response_model=ApiResponse)
|
||||||
|
def get_workspace_api_increment(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取API调用趋势"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的API调用增量")
|
||||||
|
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取 {api_increment} API调用增量")
|
||||||
|
return success(data=api_increment, msg="API调用增量获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/total_memory", response_model=ApiResponse)
|
||||||
|
def write_workspace_total_memory(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""工作空间记忆总量的写入(异步任务)"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求写入工作空间 {workspace_id} 的记忆总量")
|
||||||
|
|
||||||
|
# 触发 Celery 异步任务
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
task = celery_app.send_task(
|
||||||
|
"app.controllers.memory_storage_controller.search_all",
|
||||||
|
kwargs={"workspace_id": str(workspace_id)}
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"已触发记忆总量统计任务,task_id: {task.id}")
|
||||||
|
return success(
|
||||||
|
data={"task_id": task.id, "workspace_id": str(workspace_id)},
|
||||||
|
msg="记忆总量统计任务已启动"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/task_status/{task_id}", response_model=ApiResponse)
|
||||||
|
def get_task_status(
|
||||||
|
task_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""查询异步任务的执行状态和结果"""
|
||||||
|
api_logger.info(f"用户 {current_user.username} 查询任务状态: task_id={task_id}")
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
|
||||||
|
# 获取任务结果
|
||||||
|
task_result = AsyncResult(task_id, app=celery_app)
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": task_result.state, # PENDING, STARTED, SUCCESS, FAILURE, RETRY, REVOKED
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果任务完成,返回结果
|
||||||
|
if task_result.ready():
|
||||||
|
if task_result.successful():
|
||||||
|
response_data["result"] = task_result.result
|
||||||
|
api_logger.info(f"任务 {task_id} 执行成功")
|
||||||
|
return success(data=response_data, msg="任务执行成功")
|
||||||
|
else:
|
||||||
|
# 任务失败
|
||||||
|
response_data["error"] = str(task_result.result)
|
||||||
|
api_logger.error(f"任务 {task_id} 执行失败: {task_result.result}")
|
||||||
|
return success(data=response_data, msg="任务执行失败")
|
||||||
|
else:
|
||||||
|
# 任务还在执行中
|
||||||
|
api_logger.info(f"任务 {task_id} 状态: {task_result.state}")
|
||||||
|
return success(data=response_data, msg=f"任务状态: {task_result.state}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/memory_list", response_model=ApiResponse)
|
||||||
|
def get_workspace_memory_list(
|
||||||
|
limit: int = Query(7, description="记忆增量返回记录数"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
用户记忆列表整合接口
|
||||||
|
|
||||||
|
整合以下三个接口的数据:
|
||||||
|
1. total_memory - 工作空间记忆总量
|
||||||
|
2. memory_increment - 工作空间记忆增量
|
||||||
|
3. hosts - 工作空间宿主列表
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
{
|
||||||
|
"total_memory": float,
|
||||||
|
"memory_increment": [
|
||||||
|
{"date": "2024-01-01", "count": 100},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"hosts": [
|
||||||
|
{"id": "uuid", "name": "宿主名", ...},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆列表")
|
||||||
|
memory_list = memory_dashboard_service.get_workspace_memory_list(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取记忆列表")
|
||||||
|
return success(data=memory_list, msg="记忆列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/total_memory_count", response_model=ApiResponse)
|
||||||
|
async def get_workspace_total_memory_count(
|
||||||
|
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取工作空间的记忆总量(通过聚合所有host的记忆数)
|
||||||
|
|
||||||
|
逻辑:
|
||||||
|
1. 从 memory_list 获取所有 host_id
|
||||||
|
2. 对每个 host_id 调用 search_all 获取 total
|
||||||
|
3. 将所有 total 求和返回
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
{
|
||||||
|
"total_memory_count": int,
|
||||||
|
"host_count": int,
|
||||||
|
"details": [
|
||||||
|
{"host_id": "uuid", "count": 100},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆总量")
|
||||||
|
total_memory_count = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user,
|
||||||
|
end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取记忆总量: {total_memory_count.get('total_memory_count', 0)}")
|
||||||
|
return success(data=total_memory_count, msg="记忆总量获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
# ======== RAG 数据统计 ========
|
||||||
|
@router.get("/total_rag_count", response_model=ApiResponse)
|
||||||
|
def get_workspace_total_rag_count(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取 rag 的总文档数、总chunk数、总知识库数量、总api调用数量
|
||||||
|
"""
|
||||||
|
total_documents = memory_dashboard_service.get_rag_total_doc(db, current_user)
|
||||||
|
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||||
|
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||||
|
data = {
|
||||||
|
'total_documents':total_documents,
|
||||||
|
'total_chunk':total_chunk,
|
||||||
|
'total_kb':total_kb,
|
||||||
|
'total_api':1024
|
||||||
|
}
|
||||||
|
return success(data=data, msg="RAG相关数据获取成功")
|
||||||
|
|
||||||
|
@router.get("/current_user_rag_total_num", response_model=ApiResponse)
|
||||||
|
def get_current_user_rag_total_num(
|
||||||
|
end_user_id: str = Query(..., description="宿主ID"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前宿主的 RAG 的总chunk数量
|
||||||
|
"""
|
||||||
|
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
|
||||||
|
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
|
||||||
|
|
||||||
|
@router.get("/rag_content", response_model=ApiResponse)
|
||||||
|
def get_rag_content(
|
||||||
|
end_user_id: str = Query(..., description="宿主ID"),
|
||||||
|
limit: int = Query(15, description="返回记录数"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前宿主知识库中的chunk内容
|
||||||
|
"""
|
||||||
|
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||||
|
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chunk_summary_tag", response_model=ApiResponse)
|
||||||
|
async def get_chunk_summary_tag(
|
||||||
|
end_user_id: str = Query(..., description="宿主ID"),
|
||||||
|
limit: int = Query(15, description="返回记录数"),
|
||||||
|
max_tags: int = Query(10, description="最大标签数量"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取chunk总结、提取的标签和人物形象
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
{
|
||||||
|
"summary": "chunk内容的总结",
|
||||||
|
"tags": [
|
||||||
|
{"tag": "标签1", "frequency": 5},
|
||||||
|
{"tag": "标签2", "frequency": 3},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"personas": [
|
||||||
|
"产品设计师",
|
||||||
|
"旅行爱好者",
|
||||||
|
"摄影发烧友",
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||||
|
|
||||||
|
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
max_tags=max_tags,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||||
|
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||||
|
async def get_chunk_insight(
|
||||||
|
end_user_id: str = Query(..., description="宿主ID"),
|
||||||
|
limit: int = Query(15, description="返回记录数"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取chunk的洞察内容
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
{
|
||||||
|
"insight": "对chunk内容的深度洞察分析"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||||
|
|
||||||
|
data = await memory_dashboard_service.get_chunk_insight(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"成功获取chunk洞察")
|
||||||
|
return success(data=data, msg="chunk洞察获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||||
|
async def dashboard_data(
|
||||||
|
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
整合dashboard数据接口
|
||||||
|
|
||||||
|
整合以下接口的数据:
|
||||||
|
1. /dashboard/total_memory_count - 记忆总量
|
||||||
|
2. /dashboard/api_increment - API调用增量
|
||||||
|
3. /memory/stats/types - 知识库类型统计(只要total数据)
|
||||||
|
4. /dashboard/total_rag_count - RAG相关数据
|
||||||
|
|
||||||
|
根据 storage_type 判断调用不同的接口
|
||||||
|
|
||||||
|
返回格式:
|
||||||
|
{
|
||||||
|
"storage_type": str,
|
||||||
|
"neo4j_data": {
|
||||||
|
"total_memory": int,
|
||||||
|
"total_app": int,
|
||||||
|
"total_knowledge": int,
|
||||||
|
"total_api_call": int
|
||||||
|
} | null,
|
||||||
|
"rag_data": {
|
||||||
|
"total_memory": int,
|
||||||
|
"total_app": int,
|
||||||
|
"total_knowledge": int,
|
||||||
|
"total_api_call": int
|
||||||
|
} | null
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||||
|
|
||||||
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
if storage_type is None:
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
|
user_rag_memory_id = None
|
||||||
|
|
||||||
|
# 根据 storage_type 决定返回哪个数据对象
|
||||||
|
# 如果是 'rag',neo4j_data 为 null;否则 rag_data 为 null
|
||||||
|
result = {
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"neo4j_data": None,
|
||||||
|
"rag_data": None
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 如果 storage_type 为 'neo4j' 或空,获取 neo4j_data
|
||||||
|
if storage_type == 'neo4j':
|
||||||
|
neo4j_data = {
|
||||||
|
"total_memory": None,
|
||||||
|
"total_app": None,
|
||||||
|
"total_knowledge": None,
|
||||||
|
"total_api_call": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. 获取记忆总量(total_memory)
|
||||||
|
try:
|
||||||
|
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user,
|
||||||
|
end_user_id=end_user_id
|
||||||
|
)
|
||||||
|
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||||
|
# total_app: 统计当前空间下的所有app数量
|
||||||
|
from app.repositories import app_repository
|
||||||
|
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||||
|
neo4j_data["total_app"] = len(apps_orm)
|
||||||
|
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||||
|
|
||||||
|
# 2. 获取知识库类型统计(total_knowledge)
|
||||||
|
try:
|
||||||
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
memory_agent_service = MemoryAgentService()
|
||||||
|
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
only_active=True,
|
||||||
|
current_workspace_id=workspace_id,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
||||||
|
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||||
|
|
||||||
|
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||||
|
try:
|
||||||
|
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
neo4j_data["total_api_call"] = api_increment
|
||||||
|
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||||
|
|
||||||
|
result["neo4j_data"] = neo4j_data
|
||||||
|
api_logger.info(f"成功获取neo4j_data")
|
||||||
|
|
||||||
|
# 如果 storage_type 为 'rag',获取 rag_data
|
||||||
|
elif storage_type == 'rag':
|
||||||
|
rag_data = {
|
||||||
|
"total_memory": None,
|
||||||
|
"total_app": None,
|
||||||
|
"total_knowledge": None,
|
||||||
|
"total_api_call": None
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取RAG相关数据
|
||||||
|
try:
|
||||||
|
# total_memory: 使用 total_chunk(总chunk数)
|
||||||
|
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||||
|
rag_data["total_memory"] = total_chunk
|
||||||
|
|
||||||
|
# total_app: 统计当前空间下的所有app数量
|
||||||
|
from app.repositories import app_repository
|
||||||
|
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||||
|
rag_data["total_app"] = len(apps_orm)
|
||||||
|
|
||||||
|
# total_knowledge: 使用 total_kb(总知识库数)
|
||||||
|
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||||
|
rag_data["total_knowledge"] = total_kb
|
||||||
|
|
||||||
|
# total_api_call: 固定值
|
||||||
|
rag_data["total_api_call"] = 1024
|
||||||
|
|
||||||
|
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||||
|
|
||||||
|
result["rag_data"] = rag_data
|
||||||
|
api_logger.info(f"成功获取rag_data")
|
||||||
|
|
||||||
|
api_logger.info(f"成功获取dashboard整合数据")
|
||||||
|
return success(data=result, msg="Dashboard数据获取成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取dashboard整合数据失败: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"获取dashboard整合数据失败: {str(e)}"
|
||||||
|
)
|
||||||
542
app/controllers/memory_storage_controller.py
Normal file
542
app/controllers/memory_storage_controller.py
Normal file
@@ -0,0 +1,542 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.services.memory_storage_service import (
|
||||||
|
MemoryStorageService,
|
||||||
|
DataConfigService,
|
||||||
|
kb_type_distribution,
|
||||||
|
search_dialogue,
|
||||||
|
search_chunk,
|
||||||
|
search_statement,
|
||||||
|
search_entity,
|
||||||
|
search_all,
|
||||||
|
search_detials,
|
||||||
|
search_edges,
|
||||||
|
search_entity_graph,
|
||||||
|
analytics_hot_memory_tags,
|
||||||
|
analytics_memory_insight_report,
|
||||||
|
analytics_recent_activity_stats,
|
||||||
|
analytics_user_summary,
|
||||||
|
)
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.schemas.memory_storage_schema import (
|
||||||
|
ConfigParamsCreate,
|
||||||
|
ConfigParamsDelete,
|
||||||
|
ConfigUpdate,
|
||||||
|
ConfigUpdateExtracted,
|
||||||
|
ConfigUpdateForget,
|
||||||
|
ConfigKey,
|
||||||
|
ConfigPilotRun,
|
||||||
|
)
|
||||||
|
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
# Get API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
memory_storage_service = MemoryStorageService()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/memory-storage",
|
||||||
|
tags=["Memory Storage"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/info", response_model=ApiResponse)
|
||||||
|
async def get_storage_info(
|
||||||
|
storage_id: str,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Example wrapper endpoint - retrieves storage information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_id: Storage identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage information
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Storage info requested ")
|
||||||
|
try:
|
||||||
|
result = await memory_storage_service.get_storage_info()
|
||||||
|
return success(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Storage info retrieval failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# --- DB connection dependency ---
|
||||||
|
_CONN: Optional[object] = None
|
||||||
|
|
||||||
|
|
||||||
|
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||||
|
# 这个可以转移,可能是已经有的
|
||||||
|
# PostgreSQL 数据库连接
|
||||||
|
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||||
|
host = os.getenv("DB_HOST")
|
||||||
|
user = os.getenv("DB_USER")
|
||||||
|
password = os.getenv("DB_PASSWORD")
|
||||||
|
database = os.getenv("DB_NAME")
|
||||||
|
port_str = os.getenv("DB_PORT")
|
||||||
|
try:
|
||||||
|
import psycopg2 # type: ignore
|
||||||
|
port = int(port_str) if port_str else 5432
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=host or "localhost",
|
||||||
|
port=port,
|
||||||
|
user=user,
|
||||||
|
password=password,
|
||||||
|
dbname=database,
|
||||||
|
)
|
||||||
|
# 设置自动提交,避免显式事务管理
|
||||||
|
conn.autocommit = True
|
||||||
|
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||||
|
try:
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||||
|
cur.close()
|
||||||
|
except Exception:
|
||||||
|
# 时区设置失败不影响连接,仅记录但不抛出
|
||||||
|
pass
|
||||||
|
return conn
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
print(f"[PostgreSQL] 连接失败: {e}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||||
|
global _CONN
|
||||||
|
if _CONN is None:
|
||||||
|
_CONN = _make_pgsql_conn()
|
||||||
|
return _CONN
|
||||||
|
|
||||||
|
|
||||||
|
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||||
|
"""Close and recreate the global DB connection."""
|
||||||
|
global _CONN
|
||||||
|
try:
|
||||||
|
if _CONN:
|
||||||
|
try:
|
||||||
|
_CONN.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_CONN = _make_pgsql_conn()
|
||||||
|
return _CONN is not None
|
||||||
|
except Exception:
|
||||||
|
_CONN = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||||
|
def create_config(
|
||||||
|
payload: ConfigParamsCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
|
||||||
|
try:
|
||||||
|
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||||
|
payload.workspace_id = workspace_id
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = svc.create(payload)
|
||||||
|
return success(data=result, msg="创建成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Create config failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
|
def delete_config(
|
||||||
|
config_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||||
|
return success(data=result, msg="删除成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Delete config failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||||
|
|
||||||
|
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||||
|
def update_config(
|
||||||
|
payload: ConfigUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = svc.update(payload)
|
||||||
|
return success(data=result, msg="更新成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Update config failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||||
|
def update_config_extracted(
|
||||||
|
payload: ConfigUpdateExtracted,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = svc.update_extracted(payload)
|
||||||
|
return success(data=result, msg="更新成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Update config extracted failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# --- Forget config params ---
|
||||||
|
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
|
||||||
|
def update_config_forget(
|
||||||
|
payload: ConfigUpdateForget,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = svc.update_forget(payload)
|
||||||
|
return success(data=result, msg="更新成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Update config forget failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
|
def read_config_extracted(
|
||||||
|
config_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = svc.get_extracted(ConfigKey(config_id=config_id))
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||||
|
|
||||||
|
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
|
def read_config_forget(
|
||||||
|
config_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = svc.get_forget(ConfigKey(config_id=config_id))
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Read config forget failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||||
|
|
||||||
|
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||||
|
def read_all_config(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
# 传递 workspace_id 进行过滤(保持为 UUID 类型)
|
||||||
|
result = svc.get_all(workspace_id=workspace_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Read all config failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
|
||||||
|
async def pilot_run(
|
||||||
|
payload: ConfigPilotRun,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||||
|
|
||||||
|
# 先尝试从数据库加载配置
|
||||||
|
try:
|
||||||
|
config_loaded = reload_configuration_from_database(str(payload.config_id))
|
||||||
|
if not config_loaded:
|
||||||
|
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
|
||||||
|
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Exception while loading configuration: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
|
||||||
|
|
||||||
|
try:
|
||||||
|
svc = DataConfigService(get_db_conn())
|
||||||
|
result = await svc.pilot_run(payload)
|
||||||
|
return success(data=result, msg="试运行完成")
|
||||||
|
except ValueError as e:
|
||||||
|
# 捕获参数验证错误
|
||||||
|
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Pilot run failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
|
||||||
|
|
||||||
|
"""
|
||||||
|
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||||
|
async def get_kb_type_distribution(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await kb_type_distribution(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"KB type distribution failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||||
|
async def search_dialogues_num(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_dialogue(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search dialogue failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "对话查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search/chunk", response_model=ApiResponse)
|
||||||
|
async def search_chunks_num(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_chunk(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search chunk failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "分块查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search/statement", response_model=ApiResponse)
|
||||||
|
async def search_statements_num(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_statement(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search statement failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "语句查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search/entity", response_model=ApiResponse)
|
||||||
|
async def search_entities_num(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_entity(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search entity failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "实体查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search", response_model=ApiResponse)
|
||||||
|
async def search_all_num(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_all(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search all failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "全部查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search/detials", response_model=ApiResponse)
|
||||||
|
async def search_entities_detials(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_detials(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search details failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "详情查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search/edges", response_model=ApiResponse)
|
||||||
|
async def search_entity_edges(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_edges(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search edges failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||||
|
|
||||||
|
@router.get("/search/entity_graph", response_model=ApiResponse)
|
||||||
|
async def search_for_entity_graph(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
搜索所有实体之间的关系网络
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await search_entity_graph(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Search entity graph failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||||
|
async def get_hot_memory_tags_api(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Hot memory tags requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await analytics_hot_memory_tags(end_user_id, limit)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||||
|
async def get_memory_insight_report_api(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await analytics_memory_insight_report(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Memory insight report failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||||
|
async def get_recent_activity_stats_api(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info("Recent activity stats requested")
|
||||||
|
try:
|
||||||
|
result = await analytics_recent_activity_stats()
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||||
|
async def get_user_summary_api(
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
api_logger.info(f"User summary requested for end_user_id: {end_user_id}")
|
||||||
|
try:
|
||||||
|
result = await analytics_user_summary(end_user_id)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"User summary failed: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
|
||||||
|
|
||||||
|
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||||
|
@router.get("/self_reflexion")
|
||||||
|
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||||
|
"""
|
||||||
|
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
Returns:
|
||||||
|
自我反思结果。
|
||||||
|
"""
|
||||||
|
return await self_reflexion(host_id)
|
||||||
332
app/controllers/model_controller.py
Normal file
332
app/controllers/model_controller.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
from fastapi import APIRouter, Depends, status, Query
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
from app.core.models import RedBearLLM
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.models_model import ModelProvider, ModelType
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas import model_schema
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.schemas.response_schema import ApiResponse, PageData
|
||||||
|
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# 获取API专用日志器
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/models",
|
||||||
|
tags=["Models"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/type", response_model=ApiResponse)
|
||||||
|
def get_model_types():
|
||||||
|
|
||||||
|
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/provider", response_model=ApiResponse)
|
||||||
|
def get_model_providers():
|
||||||
|
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ApiResponse)
|
||||||
|
def get_model_list(
|
||||||
|
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING)"),
|
||||||
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取模型配置列表
|
||||||
|
|
||||||
|
支持多个 type 参数:
|
||||||
|
- 单个:?type=LLM
|
||||||
|
- 多个:?type=LLM&type=EMBEDDING
|
||||||
|
"""
|
||||||
|
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = model_schema.ModelConfigQuery(
|
||||||
|
type=type,
|
||||||
|
provider=provider,
|
||||||
|
is_active=is_active,
|
||||||
|
is_public=is_public,
|
||||||
|
search=search,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||||
|
result_orm = ModelConfigService.get_model_list(db=db, query=query)
|
||||||
|
result = PageData.model_validate(result_orm)
|
||||||
|
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||||
|
return success(data=result, msg="模型配置列表获取成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{model_id}", response_model=ApiResponse)
|
||||||
|
def get_model_by_id(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据ID获取模型配置
|
||||||
|
"""
|
||||||
|
api_logger.info(f"获取模型配置请求: model_id={model_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始获取模型配置: model_id={model_id}")
|
||||||
|
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
api_logger.info(f"模型配置获取成功: {result_orm.name}")
|
||||||
|
|
||||||
|
# 将ORM对象转换为Pydantic模型
|
||||||
|
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
|
|
||||||
|
return success(data=result_pydantic, msg="模型配置获取成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取模型配置失败: model_id={model_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ApiResponse)
|
||||||
|
async def create_model(
|
||||||
|
model_data: model_schema.ModelConfigCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建模型配置
|
||||||
|
|
||||||
|
- 创建模型配置基础信息
|
||||||
|
- 如果包含 API Key,会先验证配置有效性,然后创建
|
||||||
|
- 验证失败时会抛出异常,不会创建配置
|
||||||
|
- 可通过 skip_validation=true 跳过验证
|
||||||
|
"""
|
||||||
|
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始创建模型配置: {model_data.name}")
|
||||||
|
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data)
|
||||||
|
api_logger.info(f"模型配置创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||||
|
|
||||||
|
# 将ORM对象转换为Pydantic模型
|
||||||
|
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
|
|
||||||
|
return success(data=result, msg="模型配置创建成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"创建模型配置失败: {model_data.name} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{model_id}", response_model=ApiResponse)
|
||||||
|
def update_model(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
model_data: model_schema.ModelConfigUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新模型配置
|
||||||
|
"""
|
||||||
|
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||||
|
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data)
|
||||||
|
api_logger.info(f"模型配置更新成功: {result_orm.name} (ID: {model_id})")
|
||||||
|
|
||||||
|
# 将ORM对象转换为Pydantic模型
|
||||||
|
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
|
|
||||||
|
return success(data=result_pydantic, msg="模型配置更新成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{model_id}", response_model=ApiResponse)
|
||||||
|
def delete_model(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除模型配置
|
||||||
|
"""
|
||||||
|
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始删除模型配置: model_id={model_id}")
|
||||||
|
ModelConfigService.delete_model(db=db, model_id=model_id)
|
||||||
|
api_logger.info(f"模型配置删除成功: model_id={model_id}")
|
||||||
|
return success(msg="模型配置删除成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# API Key 相关接口
|
||||||
|
@router.get("/{model_id}/apikeys", response_model=ApiResponse)
|
||||||
|
def get_model_api_keys(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
is_active: bool = Query(True, description="是否只获取活跃的API Key"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取模型的API Key列表
|
||||||
|
"""
|
||||||
|
api_logger.info(f"获取模型API Key列表请求: model_id={model_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始获取模型API Key列表: model_id={model_id}")
|
||||||
|
result_orm = ModelApiKeyService.get_api_keys_by_model(
|
||||||
|
db=db, model_config_id=model_id, is_active=is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将ORM对象列表转换为Pydantic模型列表
|
||||||
|
result_pydantic = [model_schema.ModelApiKey.model_validate(item) for item in result_orm]
|
||||||
|
|
||||||
|
api_logger.info(f"模型API Key列表获取成功: 数量={len(result_pydantic)}")
|
||||||
|
return success(data=result_pydantic, msg="模型API Key列表获取成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取模型API Key列表失败: model_id={model_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_model_api_key(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
api_key_data: model_schema.ModelApiKeyCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
为模型创建API Key
|
||||||
|
"""
|
||||||
|
api_logger.info(f"创建模型API Key请求: model_id={model_id}, model_name={api_key_data.model_name}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 设置模型配置ID
|
||||||
|
api_key_data.model_config_id = model_id
|
||||||
|
|
||||||
|
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||||
|
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||||
|
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||||
|
return success(data=result, msg="模型API Key创建成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/apikeys/{api_key_id}", response_model=ApiResponse)
|
||||||
|
def get_api_key_by_id(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据ID获取API Key
|
||||||
|
"""
|
||||||
|
api_logger.info(f"获取API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始获取API Key: api_key_id={api_key_id}")
|
||||||
|
result = ModelApiKeyService.get_api_key_by_id(db=db, api_key_id=api_key_id)
|
||||||
|
api_logger.info(f"API Key获取成功: {result.model_name}")
|
||||||
|
return success(data=result, msg="API Key获取成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/apikeys/{api_key_id}", response_model=ApiResponse)
|
||||||
|
async def update_api_key(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
api_key_data: model_schema.ModelApiKeyUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新API Key
|
||||||
|
"""
|
||||||
|
api_logger.info(f"更新API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始更新API Key: api_key_id={api_key_id}")
|
||||||
|
result = await ModelApiKeyService.update_api_key(db=db, api_key_id=api_key_id, api_key_data=api_key_data)
|
||||||
|
api_logger.info(f"API Key更新成功: {result.model_name} (ID: {api_key_id})")
|
||||||
|
result_pydantic = model_schema.ModelApiKey.model_validate(result)
|
||||||
|
return success(data=result_pydantic, msg="API Key更新成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/apikeys/{api_key_id}", response_model=ApiResponse)
|
||||||
|
def delete_api_key(
|
||||||
|
api_key_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除API Key
|
||||||
|
"""
|
||||||
|
api_logger.info(f"删除API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.debug(f"开始删除API Key: api_key_id={api_key_id}")
|
||||||
|
ModelApiKeyService.delete_api_key(db=db, api_key_id=api_key_id)
|
||||||
|
api_logger.info(f"API Key删除成功: api_key_id={api_key_id}")
|
||||||
|
return success(msg="API Key删除成功")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/validate", response_model=ApiResponse)
|
||||||
|
async def validate_model_config(
|
||||||
|
validate_data: model_schema.ModelValidateRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
验证模型配置是否有效
|
||||||
|
|
||||||
|
支持验证不同类型的模型:
|
||||||
|
- llm: 大语言模型
|
||||||
|
- chat: 对话模型
|
||||||
|
- embedding: 向量模型
|
||||||
|
- rerank: 重排序模型
|
||||||
|
"""
|
||||||
|
api_logger.info(f"验证模型配置请求: {validate_data.model_name} ({validate_data.model_type}), 用户: {current_user.username}")
|
||||||
|
|
||||||
|
result = await ModelConfigService.validate_model_config(
|
||||||
|
db=db,
|
||||||
|
model_name=validate_data.model_name,
|
||||||
|
provider=validate_data.provider,
|
||||||
|
api_key=validate_data.api_key,
|
||||||
|
api_base=validate_data.api_base,
|
||||||
|
model_type=validate_data.model_type,
|
||||||
|
test_message=validate_data.test_message
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
404
app/controllers/multi_agent_controller.py
Normal file
404
app/controllers/multi_agent_controller.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
"""多 Agent 控制器"""
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, Query, Path
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.schemas import multi_agent_schema
|
||||||
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
|
from app.models import User
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/apps", tags=["Multi-Agent"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 多 Agent 配置管理 ====================
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{app_id}/multi-agent",
|
||||||
|
summary="创建多 Agent 配置"
|
||||||
|
)
|
||||||
|
def create_multi_agent_config(
|
||||||
|
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||||
|
data: multi_agent_schema.MultiAgentConfigCreate = ...,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""创建多 Agent 配置
|
||||||
|
|
||||||
|
支持四种编排模式:
|
||||||
|
- sequential: 顺序执行
|
||||||
|
- parallel: 并行执行
|
||||||
|
- conditional: 条件路由
|
||||||
|
- loop: 循环执行
|
||||||
|
"""
|
||||||
|
service = MultiAgentService(db)
|
||||||
|
config = service.create_config(
|
||||||
|
app_id=app_id,
|
||||||
|
data=data,
|
||||||
|
created_by=current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
|
||||||
|
msg="多 Agent 配置创建成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{app_id}/multi-agent",
|
||||||
|
summary="获取当前应用的最新有效多 Agent 配置"
|
||||||
|
)
|
||||||
|
def get_multi_agent_configs(
|
||||||
|
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""获取指定应用的最新有效多 Agent 配置,如果不存在则返回默认模板"""
|
||||||
|
service = MultiAgentService(db)
|
||||||
|
|
||||||
|
# 通过 app_id 获取最新有效配置(已转换 agent_id 为 app_id)
|
||||||
|
config = service.get_multi_agent_configs(app_id)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
# 返回默认模板
|
||||||
|
default_template = {
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"master_agent_id": None,
|
||||||
|
"master_agent_name": None,
|
||||||
|
"orchestration_mode": "conditional",
|
||||||
|
"sub_agents": [],
|
||||||
|
"routing_rules": [],
|
||||||
|
"execution_config": {
|
||||||
|
"max_iterations": 10,
|
||||||
|
"timeout": 300,
|
||||||
|
"enable_parallel": False,
|
||||||
|
"error_handling": "stop"
|
||||||
|
},
|
||||||
|
"aggregation_strategy": "merge",
|
||||||
|
}
|
||||||
|
return success(
|
||||||
|
data=default_template,
|
||||||
|
msg="该应用暂无配置,返回默认模板"
|
||||||
|
)
|
||||||
|
|
||||||
|
# config 已经是字典格式,直接返回
|
||||||
|
return success(data=config)
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/{app_id}/multi-agent",
|
||||||
|
summary="更新多 Agent 配置"
|
||||||
|
)
|
||||||
|
def update_multi_agent_config(
|
||||||
|
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||||
|
data: multi_agent_schema.MultiAgentConfigUpdate = ...,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""更新多 Agent 配置"""
|
||||||
|
service = MultiAgentService(db)
|
||||||
|
config = service.update_config(app_id, data)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
|
||||||
|
msg="多 Agent 配置更新成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{app_id}/multi-agent",
|
||||||
|
summary="删除多 Agent 配置"
|
||||||
|
)
|
||||||
|
def delete_multi_agent_config(
|
||||||
|
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""删除多 Agent 配置"""
|
||||||
|
service = MultiAgentService(db)
|
||||||
|
service.delete_config(app_id)
|
||||||
|
|
||||||
|
return success(msg="多 Agent 配置删除成功")
|
||||||
|
|
||||||
|
# ==================== 多 Agent 运行 ====================
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{app_id}/multi-agent/run",
|
||||||
|
summary="运行多 Agent 任务"
|
||||||
|
)
|
||||||
|
async def run_multi_agent(
|
||||||
|
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||||
|
request: multi_agent_schema.MultiAgentRunRequest = ...,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""运行多 Agent 任务
|
||||||
|
|
||||||
|
根据配置的编排模式执行多个 Agent:
|
||||||
|
- sequential: 按优先级顺序执行
|
||||||
|
- parallel: 并行执行所有 Agent
|
||||||
|
- conditional: 根据条件选择 Agent
|
||||||
|
- loop: 循环执行直到满足条件
|
||||||
|
"""
|
||||||
|
service = MultiAgentService(db)
|
||||||
|
result = await service.run(app_id, request)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=multi_agent_schema.MultiAgentRunResponse(**result),
|
||||||
|
msg="多 Agent 任务执行成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 智能路由测试 ====================
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{app_id}/multi-agent/test-routing",
|
||||||
|
summary="测试智能路由"
|
||||||
|
)
|
||||||
|
async def test_routing(
|
||||||
|
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||||
|
request: multi_agent_schema.RoutingTestRequest = ...,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""测试智能路由功能
|
||||||
|
|
||||||
|
支持三种路由模式:
|
||||||
|
- keyword: 仅使用关键词路由
|
||||||
|
- llm: 使用 LLM 路由(需要提供 routing_model_id)
|
||||||
|
- hybrid: 混合路由(关键词 + LLM)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- message: 测试消息
|
||||||
|
- conversation_id: 会话 ID(可选)
|
||||||
|
- routing_model_id: 路由模型 ID(可选,用于 LLM 路由)
|
||||||
|
- use_llm: 是否启用 LLM(默认 False)
|
||||||
|
- keyword_threshold: 关键词置信度阈值(默认 0.8)
|
||||||
|
"""
|
||||||
|
from app.services.conversation_state_manager import ConversationStateManager
|
||||||
|
from app.services.llm_router import LLMRouter
|
||||||
|
from app.models import ModelConfig
|
||||||
|
|
||||||
|
# 1. 获取多 Agent 配置
|
||||||
|
service = MultiAgentService(db)
|
||||||
|
config = service.get_config(app_id)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
return success(
|
||||||
|
data=None,
|
||||||
|
msg="应用未配置多 Agent,无法测试路由"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 准备子 Agent 信息
|
||||||
|
sub_agents = {}
|
||||||
|
for sub_agent_info in config.sub_agents:
|
||||||
|
agent_id = sub_agent_info["agent_id"]
|
||||||
|
sub_agents[agent_id] = {
|
||||||
|
"name": sub_agent_info.get("name", agent_id),
|
||||||
|
"role": sub_agent_info.get("role", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 获取路由模型(如果指定)
|
||||||
|
routing_model = None
|
||||||
|
if request.routing_model_id:
|
||||||
|
routing_model = db.get(ModelConfig, request.routing_model_id)
|
||||||
|
if not routing_model:
|
||||||
|
return success(
|
||||||
|
data=None,
|
||||||
|
msg=f"路由模型不存在: {request.routing_model_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 初始化路由器
|
||||||
|
state_manager = ConversationStateManager()
|
||||||
|
router = LLMRouter(
|
||||||
|
db=db,
|
||||||
|
state_manager=state_manager,
|
||||||
|
routing_rules=config.routing_rules or [],
|
||||||
|
sub_agents=sub_agents,
|
||||||
|
routing_model_config=routing_model,
|
||||||
|
use_llm=request.use_llm and routing_model is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 设置阈值
|
||||||
|
if request.keyword_threshold:
|
||||||
|
router.keyword_high_confidence_threshold = request.keyword_threshold
|
||||||
|
|
||||||
|
# 6. 执行路由
|
||||||
|
try:
|
||||||
|
routing_result = await router.route(
|
||||||
|
message=request.message,
|
||||||
|
conversation_id=str(request.conversation_id) if request.conversation_id else None,
|
||||||
|
force_new=request.force_new
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. 获取 Agent 信息
|
||||||
|
agent_id = routing_result["agent_id"]
|
||||||
|
agent_info = sub_agents.get(agent_id, {})
|
||||||
|
|
||||||
|
# 8. 构建响应
|
||||||
|
response_data = {
|
||||||
|
"message": request.message,
|
||||||
|
"routing_result": {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"agent_name": agent_info.get("name", agent_id),
|
||||||
|
"agent_role": agent_info.get("role", ""),
|
||||||
|
"confidence": routing_result["confidence"],
|
||||||
|
"strategy": routing_result["strategy"],
|
||||||
|
"topic": routing_result["topic"],
|
||||||
|
"topic_changed": routing_result["topic_changed"],
|
||||||
|
"reason": routing_result["reason"],
|
||||||
|
"routing_method": routing_result["routing_method"]
|
||||||
|
},
|
||||||
|
"cmulti-agent/batch-test-routingonfig_info": {
|
||||||
|
"use_llm": request.use_llm and routing_model is not None,
|
||||||
|
"routing_model": routing_model.name if routing_model else None,
|
||||||
|
"keyword_threshold": router.keyword_high_confidence_threshold,
|
||||||
|
"total_sub_agents": len(sub_agents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=response_data,
|
||||||
|
msg="路由测试成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"路由测试失败: {str(e)}")
|
||||||
|
return success(
|
||||||
|
data=None,
|
||||||
|
msg=f"路由测试失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{app_id}/",
|
||||||
|
summary="批量测试智能路由"
|
||||||
|
)
|
||||||
|
async def batch_test_routing(
|
||||||
|
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||||
|
request: multi_agent_schema.BatchRoutingTestRequest = ...,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""批量测试智能路由功能
|
||||||
|
|
||||||
|
用于测试多条消息的路由效果,并统计准确率
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- test_cases: 测试用例列表
|
||||||
|
- routing_model_id: 路由模型 ID(可选)
|
||||||
|
- use_llm: 是否启用 LLM
|
||||||
|
- keyword_threshold: 关键词置信度阈值
|
||||||
|
"""
|
||||||
|
from app.services.conversation_state_manager import ConversationStateManager
|
||||||
|
from app.services.llm_router import LLMRouter
|
||||||
|
from app.models import ModelConfig
|
||||||
|
|
||||||
|
# 1. 获取多 Agent 配置
|
||||||
|
service = MultiAgentService(db)
|
||||||
|
config = service.get_config(app_id)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
return success(
|
||||||
|
data=None,
|
||||||
|
msg="应用未配置多 Agent,无法测试路由"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 准备子 Agent 信息
|
||||||
|
sub_agents = {}
|
||||||
|
for sub_agent_info in config.sub_agents:
|
||||||
|
agent_id = sub_agent_info["agent_id"]
|
||||||
|
sub_agents[agent_id] = {
|
||||||
|
"name": sub_agent_info.get("name", agent_id),
|
||||||
|
"role": sub_agent_info.get("role", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 获取路由模型
|
||||||
|
routing_model = None
|
||||||
|
if request.routing_model_id:
|
||||||
|
routing_model = db.get(ModelConfig, request.routing_model_id)
|
||||||
|
|
||||||
|
# 4. 初始化路由器
|
||||||
|
state_manager = ConversationStateManager()
|
||||||
|
router = LLMRouter(
|
||||||
|
db=db,
|
||||||
|
state_manager=state_manager,
|
||||||
|
routing_rules=config.routing_rules or [],
|
||||||
|
sub_agents=sub_agents,
|
||||||
|
routing_model_config=routing_model,
|
||||||
|
use_llm=request.use_llm and routing_model is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.keyword_threshold:
|
||||||
|
router.keyword_high_confidence_threshold = request.keyword_threshold
|
||||||
|
|
||||||
|
# 5. 批量测试
|
||||||
|
results = []
|
||||||
|
correct_count = 0
|
||||||
|
total_count = len(request.test_cases)
|
||||||
|
|
||||||
|
for test_case in request.test_cases:
|
||||||
|
try:
|
||||||
|
routing_result = await router.route(
|
||||||
|
message=test_case.message,
|
||||||
|
conversation_id=str(uuid.uuid4()) # 每个测试用例使用独立会话
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_id = routing_result["agent_id"]
|
||||||
|
agent_info = sub_agents.get(agent_id, {})
|
||||||
|
|
||||||
|
# 判断是否正确
|
||||||
|
is_correct = None
|
||||||
|
if test_case.expected_agent_id:
|
||||||
|
is_correct = (agent_id == str(test_case.expected_agent_id))
|
||||||
|
if is_correct:
|
||||||
|
correct_count += 1
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"message": test_case.message,
|
||||||
|
"description": test_case.description,
|
||||||
|
"routed_agent_id": agent_id,
|
||||||
|
"routed_agent_name": agent_info.get("name"),
|
||||||
|
"expected_agent_id": str(test_case.expected_agent_id) if test_case.expected_agent_id else None,
|
||||||
|
"is_correct": is_correct,
|
||||||
|
"confidence": routing_result["confidence"],
|
||||||
|
"routing_method": routing_result["routing_method"],
|
||||||
|
"strategy": routing_result["strategy"]
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"测试用例失败: {test_case.message}, 错误: {str(e)}")
|
||||||
|
results.append({
|
||||||
|
"message": test_case.message,
|
||||||
|
"description": test_case.description,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
# 6. 统计
|
||||||
|
accuracy = None
|
||||||
|
if correct_count > 0:
|
||||||
|
total_with_expected = sum(1 for r in results if r.get("expected_agent_id"))
|
||||||
|
if total_with_expected > 0:
|
||||||
|
accuracy = correct_count / total_with_expected * 100
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"total_count": total_count,
|
||||||
|
"correct_count": correct_count,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"results": results,
|
||||||
|
"config_info": {
|
||||||
|
"use_llm": request.use_llm and routing_model is not None,
|
||||||
|
"routing_model": routing_model.name if routing_model else None,
|
||||||
|
"keyword_threshold": router.keyword_high_confidence_threshold
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=response_data,
|
||||||
|
msg=f"批量测试完成,准确率: {accuracy:.1f}%" if accuracy else "批量测试完成"
|
||||||
|
)
|
||||||
437
app/controllers/public_share_controller.py
Normal file
437
app/controllers/public_share_controller.py
Normal file
@@ -0,0 +1,437 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Query, Request, Header
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
import uuid
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import jwt
|
||||||
|
from typing import Optional, Dict
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.schemas import release_share_schema, conversation_schema
|
||||||
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
|
from app.services.release_share_service import ReleaseShareService
|
||||||
|
from app.services.shared_chat_service import SharedChatService
|
||||||
|
from app.services.conversation_service import ConversationService
|
||||||
|
from app.services.auth_service import create_access_token
|
||||||
|
from app.dependencies import get_share_user_id, ShareTokenData
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_url(request: Request) -> str:
|
||||||
|
"""从请求中获取基础 URL"""
|
||||||
|
return f"{request.url.scheme}://{request.url.netloc}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
|
||||||
|
"""获取或生成用户 ID
|
||||||
|
|
||||||
|
优先级:
|
||||||
|
1. 使用前端传递的 user_id
|
||||||
|
2. 基于 IP + User-Agent 生成唯一 ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload_user_id: 前端传递的 user_id
|
||||||
|
request: FastAPI Request 对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户 ID
|
||||||
|
"""
|
||||||
|
if payload_user_id:
|
||||||
|
return payload_user_id
|
||||||
|
|
||||||
|
# 获取客户端 IP
|
||||||
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
# 获取 User-Agent
|
||||||
|
user_agent = request.headers.get("user-agent", "unknown")
|
||||||
|
|
||||||
|
# 生成唯一 ID:基于 IP + User-Agent 的哈希
|
||||||
|
unique_string = f"{client_ip}_{user_agent}"
|
||||||
|
hash_value = hashlib.md5(unique_string.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
return f"guest_{hash_value}"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{share_token}/token",
|
||||||
|
summary="获取访问 token"
|
||||||
|
)
|
||||||
|
def get_access_token(
|
||||||
|
share_token: str,
|
||||||
|
payload: release_share_schema.TokenRequest,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""获取访问 token
|
||||||
|
|
||||||
|
- 用户通过 user_id + share_token 换取访问 token
|
||||||
|
- 后续请求需要携带此 token
|
||||||
|
"""
|
||||||
|
# 获取或生成 user_id
|
||||||
|
user_id = get_or_generate_user_id(payload.user_id, request)
|
||||||
|
|
||||||
|
# 验证分享链接(可选:验证密码)
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
try:
|
||||||
|
service.get_shared_release_info(
|
||||||
|
share_token=share_token,
|
||||||
|
password=payload.password
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取分享信息失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 生成 token
|
||||||
|
access_token = create_access_token(user_id, share_token)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"生成访问 token",
|
||||||
|
extra={
|
||||||
|
"share_token": share_token,
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data={
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"user_id": user_id
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"",
|
||||||
|
summary="获取公开分享的应用信息",
|
||||||
|
response_model=None
|
||||||
|
)
|
||||||
|
def get_shared_release(
|
||||||
|
password: str = Query(None, description="访问密码(如果需要)"),
|
||||||
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""获取公开分享的发布版本信息
|
||||||
|
|
||||||
|
- 无需认证即可访问
|
||||||
|
- 如果设置了密码保护,需要提供正确的密码
|
||||||
|
- 如果密码错误或未提供密码,返回基本信息(不含配置详情)
|
||||||
|
"""
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
info = service.get_shared_release_info(
|
||||||
|
share_token=share_data.share_token,
|
||||||
|
password=password
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=info)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/verify",
|
||||||
|
summary="验证访问密码"
|
||||||
|
)
|
||||||
|
def verify_password(
|
||||||
|
payload: release_share_schema.PasswordVerifyRequest,
|
||||||
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""验证分享的访问密码
|
||||||
|
|
||||||
|
- 用于前端先验证密码,再获取完整信息
|
||||||
|
"""
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
is_valid = service.verify_password(
|
||||||
|
share_token=share_data.share_token,
|
||||||
|
password=payload.password
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data={"valid": is_valid})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/embed",
|
||||||
|
summary="获取嵌入代码"
|
||||||
|
)
|
||||||
|
def get_embed_code(
|
||||||
|
width: str = Query("100%", description="iframe 宽度"),
|
||||||
|
height: str = Query("600px", description="iframe 高度"),
|
||||||
|
request: Request = None,
|
||||||
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""获取嵌入代码
|
||||||
|
|
||||||
|
- 返回 iframe 嵌入代码
|
||||||
|
- 可以自定义宽度和高度
|
||||||
|
"""
|
||||||
|
base_url = get_base_url(request) if request else None
|
||||||
|
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
embed_code = service.get_embed_code(
|
||||||
|
share_token=share_data.share_token,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=embed_code)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- 会话管理接口 ----------
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/conversations",
|
||||||
|
summary="获取会话列表"
|
||||||
|
)
|
||||||
|
def list_conversations(
|
||||||
|
password: str = Query(None, description="访问密码"),
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
pagesize: int = Query(20, ge=1, le=100),
|
||||||
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""获取分享应用的会话列表
|
||||||
|
|
||||||
|
- 可以按 user_id 筛选
|
||||||
|
- 支持分页
|
||||||
|
"""
|
||||||
|
logger.debug(f"share_data:{share_data.user_id}")
|
||||||
|
other_id = share_data.user_id
|
||||||
|
service = SharedChatService(db)
|
||||||
|
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=share.app_id,
|
||||||
|
other_id=other_id
|
||||||
|
)
|
||||||
|
logger.debug(new_end_user.id)
|
||||||
|
service = SharedChatService(db)
|
||||||
|
conversations, total = service.list_conversations(
|
||||||
|
share_token=share_data.share_token,
|
||||||
|
user_id=str(new_end_user.id),
|
||||||
|
password=password,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [conversation_schema.Conversation.model_validate(c) for c in conversations]
|
||||||
|
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||||
|
|
||||||
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/conversations/{conversation_id}",
|
||||||
|
summary="获取会话详情(含消息)"
|
||||||
|
)
|
||||||
|
def get_conversation(
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
password: str = Query(None, description="访问密码"),
|
||||||
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""获取会话详情和消息历史"""
|
||||||
|
chat_service = SharedChatService(db)
|
||||||
|
conversation = chat_service.get_conversation_messages(
|
||||||
|
share_token=share_data.share_token,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
password=password
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取消息
|
||||||
|
conv_service = ConversationService(db)
|
||||||
|
messages = conv_service.get_messages(conversation_id)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||||
|
conv_dict["messages"] = [
|
||||||
|
conversation_schema.Message.model_validate(m) for m in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
return success(data=conv_dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- 聊天接口 ----------
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/chat",
|
||||||
|
summary="发送消息(支持流式和非流式)"
|
||||||
|
)
|
||||||
|
async def chat(
|
||||||
|
payload: conversation_schema.ChatRequest,
|
||||||
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""发送消息并获取回复
|
||||||
|
|
||||||
|
使用 Bearer token 认证:
|
||||||
|
- Header: Authorization: Bearer {token}
|
||||||
|
- user_id 和 share_token 从 token 中解码
|
||||||
|
|
||||||
|
- 支持多轮对话(提供 conversation_id)
|
||||||
|
- 支持流式返回(设置 stream=true)
|
||||||
|
- 如果不提供 conversation_id,会自动创建新会话
|
||||||
|
"""
|
||||||
|
service = SharedChatService(db)
|
||||||
|
|
||||||
|
# 从依赖中获取 user_id 和 share_token
|
||||||
|
user_id = share_data.user_id
|
||||||
|
share_token = share_data.share_token
|
||||||
|
password = None # Token 认证不需要密码
|
||||||
|
# end_user_id = user_id
|
||||||
|
other_id = user_id
|
||||||
|
|
||||||
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
|
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||||
|
from app.models.app_model import AppType
|
||||||
|
try:
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
# 验证分享链接和密码
|
||||||
|
share, release = service._get_release_by_share_token(share_token, password)
|
||||||
|
|
||||||
|
# # Create end_user_id by concatenating app_id with user_id
|
||||||
|
# end_user_id = f"{share.app_id}_{user_id}"
|
||||||
|
|
||||||
|
# Store end_user_id in database with original user_id
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=share.app_id,
|
||||||
|
other_id=other_id,
|
||||||
|
original_user_id=user_id # Save original user_id to other_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取应用类型
|
||||||
|
app_type = release.app.type if release.app else None
|
||||||
|
|
||||||
|
# 根据应用类型验证配置
|
||||||
|
if app_type == "agent":
|
||||||
|
# Agent 类型:验证模型配置
|
||||||
|
model_config_id = release.default_model_config_id
|
||||||
|
if not model_config_id:
|
||||||
|
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
elif app_type == "multi_agent":
|
||||||
|
# Multi-Agent 类型:验证多 Agent 配置
|
||||||
|
config = release.config or {}
|
||||||
|
if not config.get("sub_agents"):
|
||||||
|
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
else:
|
||||||
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
# 获取或创建会话(提前验证)
|
||||||
|
conversation = service.create_or_get_conversation(
|
||||||
|
share_token=share_data.share_token,
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
password=password
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"参数验证完成",
|
||||||
|
extra={
|
||||||
|
"share_token": share_token,
|
||||||
|
"app_type": app_type,
|
||||||
|
"conversation_id": str(conversation.id),
|
||||||
|
"stream": payload.stream
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 验证失败,直接抛出异常(会被 FastAPI 的异常处理器捕获)
|
||||||
|
logger.error(f"参数验证失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if app_type == AppType.AGENT:
|
||||||
|
# 流式返回
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
async for event in service.chat_stream(
|
||||||
|
share_token=share_token,
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
password=password,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 非流式返回
|
||||||
|
result = await service.chat(
|
||||||
|
share_token=share_token,
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
password=password,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory
|
||||||
|
)
|
||||||
|
return success(data=conversation_schema.ChatResponse(**result))
|
||||||
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
# 多 Agent 流式返回
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
async for event in service.multi_agent_chat_stream(
|
||||||
|
share_token=share_token,
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
password=password,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 多 Agent 非流式返回
|
||||||
|
result = await service.multi_agent_chat(
|
||||||
|
share_token=share_token,
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
password=password,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=conversation_schema.ChatResponse(**result))
|
||||||
|
else:
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
pass
|
||||||
170
app/controllers/release_share_controller.py
Normal file
170
app/controllers/release_share_controller.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.schemas import release_share_schema
|
||||||
|
from app.services.release_share_service import ReleaseShareService
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Release Share"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_url(request: Request) -> str:
|
||||||
|
"""从请求中获取基础 URL"""
|
||||||
|
return f"{request.url.scheme}://{request.url.netloc}"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/apps/{app_id}/releases/{release_id}/share",
|
||||||
|
summary="创建/启用分享配置"
|
||||||
|
)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def create_share(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
release_id: uuid.UUID,
|
||||||
|
payload: release_share_schema.ReleaseShareCreate,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""创建或更新发布版本的分享配置
|
||||||
|
|
||||||
|
- 如果已存在分享配置,则更新
|
||||||
|
- 自动生成唯一的分享 token
|
||||||
|
- 返回完整的分享 URL
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
base_url = get_base_url(request)
|
||||||
|
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
share = service.create_or_update_share(
|
||||||
|
release_id=release_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
data=payload,
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
share_schema = service._convert_to_schema(share, base_url)
|
||||||
|
return success(data=share_schema, msg="分享配置已创建")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/apps/{app_id}/releases/{release_id}/share",
|
||||||
|
summary="更新分享配置"
|
||||||
|
)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_share(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
release_id: uuid.UUID,
|
||||||
|
payload: release_share_schema.ReleaseShareUpdate,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""更新分享配置
|
||||||
|
|
||||||
|
- 可以更新启用状态、密码、嵌入设置等
|
||||||
|
- 不会改变 share_token
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
base_url = get_base_url(request)
|
||||||
|
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
share = service.update_share(
|
||||||
|
release_id=release_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
data=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
share_schema = service._convert_to_schema(share, base_url)
|
||||||
|
return success(data=share_schema, msg="分享配置已更新")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/apps/{app_id}/releases/{release_id}/share",
|
||||||
|
summary="获取分享配置"
|
||||||
|
)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_share(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
release_id: uuid.UUID,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取发布版本的分享配置
|
||||||
|
|
||||||
|
- 如果不存在分享配置,返回 null
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
base_url = get_base_url(request)
|
||||||
|
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
share = service.get_share(
|
||||||
|
release_id=release_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=share)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/apps/{app_id}/releases/{release_id}/share",
|
||||||
|
summary="删除分享配置"
|
||||||
|
)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def delete_share(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
release_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""删除分享配置
|
||||||
|
|
||||||
|
- 删除后,公开访问链接将失效
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
service.delete_share(
|
||||||
|
release_id=release_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg="分享配置已删除")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/apps/{app_id}/releases/{release_id}/share/regenerate-token",
|
||||||
|
summary="重新生成分享链接"
|
||||||
|
)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def regenerate_token(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
release_id: uuid.UUID,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""重新生成分享 token
|
||||||
|
|
||||||
|
- 旧的分享链接将失效
|
||||||
|
- 生成新的唯一 token
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
base_url = get_base_url(request)
|
||||||
|
|
||||||
|
service = ReleaseShareService(db)
|
||||||
|
share = service.regenerate_token(
|
||||||
|
release_id=release_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
share_schema = service._convert_to_schema(share, base_url)
|
||||||
|
return success(data=share_schema, msg="分享链接已重新生成")
|
||||||
17
app/controllers/service/__init__.py
Normal file
17
app/controllers/service/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Service API Controllers - 基于 API Key 认证的服务接口
|
||||||
|
|
||||||
|
路由前缀: /v1
|
||||||
|
认证方式: API Key
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from . import app_api_controller, rag_api_controller, memory_api_controller
|
||||||
|
|
||||||
|
# 创建 V1 API 路由器
|
||||||
|
service_router = APIRouter()
|
||||||
|
|
||||||
|
# 注册子路由
|
||||||
|
service_router.include_router(app_api_controller.router)
|
||||||
|
service_router.include_router(rag_api_controller.router)
|
||||||
|
service_router.include_router(memory_api_controller.router)
|
||||||
|
|
||||||
|
__all__ = ["service_router"]
|
||||||
16
app/controllers/service/app_api_controller.py
Normal file
16
app/controllers/service/app_api_controller.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""App 服务接口 - 基于 API Key 认证"""
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/v1/apps", tags=["V1 - App API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_apps():
|
||||||
|
"""列出可访问的应用(占位)"""
|
||||||
|
return success(data=[], msg="App API - Coming Soon")
|
||||||
16
app/controllers/service/memory_api_controller.py
Normal file
16
app/controllers/service/memory_api_controller.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def get_memory_info():
|
||||||
|
"""获取记忆服务信息(占位)"""
|
||||||
|
return success(data={}, msg="Memory API - Coming Soon")
|
||||||
16
app/controllers/service/rag_api_controller.py
Normal file
16
app/controllers/service/rag_api_controller.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_knowledge():
|
||||||
|
"""列出可访问的知识库(占位)"""
|
||||||
|
return success(data=[], msg="RAG API - Coming Soon")
|
||||||
23
app/controllers/setup_controller.py
Normal file
23
app/controllers/setup_controller.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.services import user_service
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/setup",
|
||||||
|
tags=["Setup"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("", summary="Create the first superuser", response_model=ApiResponse)
|
||||||
|
def setup_initial_user(db: Session = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
Create the initial superuser. This can only be run once.
|
||||||
|
Reads credentials from environment variables.
|
||||||
|
"""
|
||||||
|
user = user_service.create_initial_superuser(db)
|
||||||
|
if not user:
|
||||||
|
return success(msg="Superuser already exists.")
|
||||||
|
return success(msg="Superuser created successfully.")
|
||||||
25
app/controllers/task_controller.py
Normal file
25
app/controllers/task_controller.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from fastapi import APIRouter, status
|
||||||
|
from app.schemas.item_schema import Item
|
||||||
|
from app.services import task_service
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/tasks",
|
||||||
|
tags=["Tasks"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/process_item", status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
def process_item_task(item: Item):
|
||||||
|
"""
|
||||||
|
This endpoint receives an item, and instead of processing it directly,
|
||||||
|
it sends a task to the Celery queue via the task service.
|
||||||
|
"""
|
||||||
|
task_id = task_service.create_processing_task(item.dict())
|
||||||
|
return {"message": "Task accepted. The item is being processed in the background.", "task_id": task_id}
|
||||||
|
|
||||||
|
@router.get("/result/{task_id}")
|
||||||
|
def get_task_result_controller(task_id: str):
|
||||||
|
"""
|
||||||
|
This endpoint allows clients to check the status and result of a
|
||||||
|
previously submitted task using its ID, by calling the task service.
|
||||||
|
"""
|
||||||
|
return task_service.get_task_result(task_id)
|
||||||
126
app/controllers/test_controller.py
Normal file
126
app/controllers/test_controller.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from fastapi import APIRouter, Depends, status, Query, HTTPException
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
from app.core.models import RedBearLLM, RedBearRerank
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.core.models.embedding import RedBearEmbeddings
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas import model_schema
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.schemas.response_schema import ApiResponse, PageData
|
||||||
|
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# 获取API专用日志器
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/test",
|
||||||
|
tags=["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(f"/llm/{{model_id}}", response_model=ApiResponse)
|
||||||
|
def test_llm(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
if not config:
|
||||||
|
api_logger.error(f"模型ID {model_id} 不存在")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||||
|
try:
|
||||||
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
|
llm = RedBearLLM(RedBearModelConfig(
|
||||||
|
model_name=apiConfig.model_name,
|
||||||
|
provider=apiConfig.provider,
|
||||||
|
api_key=apiConfig.api_key,
|
||||||
|
base_url=apiConfig.api_base
|
||||||
|
), type=config.type)
|
||||||
|
print(llm.dict())
|
||||||
|
|
||||||
|
template = """Question: {question}
|
||||||
|
|
||||||
|
Answer: Let's think step by step."""
|
||||||
|
# ChatPromptTemplate
|
||||||
|
prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
chain = prompt | llm
|
||||||
|
answer = chain.invoke({"question": "What is LangChain?"})
|
||||||
|
print("Answer:", answer)
|
||||||
|
return success(msg="测试LLM成功", data={"question": "What is LangChain?", "answer": answer})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"测试LLM失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(f"/embedding/{{model_id}}", response_model=ApiResponse)
|
||||||
|
def test_embedding(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
if not config:
|
||||||
|
api_logger.error(f"模型ID {model_id} 不存在")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||||
|
|
||||||
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
|
model = RedBearEmbeddings(RedBearModelConfig(
|
||||||
|
model_name=apiConfig.model_name,
|
||||||
|
provider=apiConfig.provider,
|
||||||
|
api_key=apiConfig.api_key,
|
||||||
|
base_url=apiConfig.api_base
|
||||||
|
))
|
||||||
|
|
||||||
|
data = [
|
||||||
|
"最近哪家咖啡店评价最好?",
|
||||||
|
"附近有没有推荐的咖啡厅?",
|
||||||
|
"明天天气预报说会下雨。",
|
||||||
|
"北京是中国的首都。",
|
||||||
|
"我想找一个适合学习的地方。"
|
||||||
|
]
|
||||||
|
embeddings = model.embed_documents(data)
|
||||||
|
print(embeddings)
|
||||||
|
query = "我想找一个适合学习的地方。"
|
||||||
|
query_embedding = model.embed_query(query)
|
||||||
|
print(query_embedding)
|
||||||
|
|
||||||
|
return success(msg="测试LLM成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(f"/rerank/{{model_id}}", response_model=ApiResponse)
|
||||||
|
def test_rerank(
|
||||||
|
model_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
if not config:
|
||||||
|
api_logger.error(f"模型ID {model_id} 不存在")
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||||
|
|
||||||
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
|
model = RedBearRerank(RedBearModelConfig(
|
||||||
|
model_name=apiConfig.model_name,
|
||||||
|
provider=apiConfig.provider,
|
||||||
|
api_key=apiConfig.api_key,
|
||||||
|
base_url=apiConfig.api_base
|
||||||
|
))
|
||||||
|
query = "最近哪家咖啡店评价最好?"
|
||||||
|
data = [
|
||||||
|
"最近哪家咖啡店评价最好?",
|
||||||
|
"附近有没有推荐的咖啡厅?",
|
||||||
|
"明天天气预报说会下雨。",
|
||||||
|
"北京是中国的首都。",
|
||||||
|
"我想找一个适合学习的地方。"
|
||||||
|
]
|
||||||
|
scores = model.rerank(query=query, documents=data, top_n=3)
|
||||||
|
print(scores)
|
||||||
|
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
|
||||||
376
app/controllers/upload_controller.py
Normal file
376
app/controllers/upload_controller.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""
|
||||||
|
Upload Controller for Generic File Upload System
|
||||||
|
Handles HTTP requests for file upload, download, deletion, and metadata updates.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import List, Optional, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, Depends, File, UploadFile, Form
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.schemas.generic_file_schema import (
|
||||||
|
GenericFileResponse,
|
||||||
|
FileMetadataUpdate,
|
||||||
|
UploadResultSchema,
|
||||||
|
BatchUploadResponse
|
||||||
|
)
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.core.upload_enums import UploadContext
|
||||||
|
from app.services.upload_service import UploadService
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
from app.core.exceptions import (
|
||||||
|
ValidationException,
|
||||||
|
ResourceNotFoundException,
|
||||||
|
FileUploadException,
|
||||||
|
BusinessException
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Create router
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api",
|
||||||
|
tags=["upload"],
|
||||||
|
dependencies=[Depends(get_current_user)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize upload service
|
||||||
|
upload_service = UploadService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload", response_model=ApiResponse)
|
||||||
|
async def upload_file(
|
||||||
|
file: UploadFile = File(..., description="要上传的文件"),
|
||||||
|
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
|
||||||
|
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> ApiResponse:
|
||||||
|
"""
|
||||||
|
单文件上传接口
|
||||||
|
|
||||||
|
- **file**: 要上传的文件
|
||||||
|
- **context**: 上传上下文,决定文件存储位置和验证规则
|
||||||
|
- **metadata**: 可选的文件元数据,JSON格式字符串
|
||||||
|
|
||||||
|
返回上传成功的文件信息
|
||||||
|
"""
|
||||||
|
logger.info(f"Upload request: filename={file.filename}, context={context}, user={current_user.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate and parse context
|
||||||
|
try:
|
||||||
|
upload_context = UploadContext(context)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid upload context: {context}")
|
||||||
|
raise ValidationException(
|
||||||
|
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
|
||||||
|
field="context"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse metadata if provided
|
||||||
|
file_metadata = {}
|
||||||
|
if metadata:
|
||||||
|
try:
|
||||||
|
file_metadata = json.loads(metadata)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Invalid metadata JSON: {metadata}")
|
||||||
|
raise ValidationException(
|
||||||
|
"元数据必须是有效的JSON格式",
|
||||||
|
field="metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload file
|
||||||
|
db_file = upload_service.upload_file(
|
||||||
|
file=file,
|
||||||
|
context=upload_context,
|
||||||
|
metadata=file_metadata,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to response schema
|
||||||
|
file_response = GenericFileResponse.model_validate(db_file)
|
||||||
|
|
||||||
|
logger.info(f"Upload successful: {file.filename} (ID: {db_file.id})")
|
||||||
|
return success(data=file_response.dict(), msg="文件上传成功")
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
# Business exceptions are handled by global exception handlers
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Upload failed: {str(e)}")
|
||||||
|
# Wrap unknown exceptions as FileUploadException
|
||||||
|
raise FileUploadException(
|
||||||
|
f"文件上传失败: {str(e)}",
|
||||||
|
cause=e
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload/batch", response_model=ApiResponse)
|
||||||
|
async def upload_files_batch(
|
||||||
|
files: List[UploadFile] = File(..., description="要上传的文件列表"),
|
||||||
|
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
|
||||||
|
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> ApiResponse:
|
||||||
|
"""
|
||||||
|
批量文件上传接口
|
||||||
|
|
||||||
|
- **files**: 要上传的文件列表(最多20个)
|
||||||
|
- **context**: 上传上下文,决定文件存储位置和验证规则
|
||||||
|
- **metadata**: 可选的文件元数据,JSON格式字符串,应用于所有文件
|
||||||
|
|
||||||
|
返回每个文件的上传结果
|
||||||
|
"""
|
||||||
|
logger.info(f"Batch upload request: {len(files)} files, context={context}, user={current_user.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate and parse context
|
||||||
|
try:
|
||||||
|
upload_context = UploadContext(context)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid upload context: {context}")
|
||||||
|
raise ValidationException(
|
||||||
|
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
|
||||||
|
field="context"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse metadata if provided
|
||||||
|
file_metadata = {}
|
||||||
|
if metadata:
|
||||||
|
try:
|
||||||
|
file_metadata = json.loads(metadata)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Invalid metadata JSON: {metadata}")
|
||||||
|
raise ValidationException(
|
||||||
|
"元数据必须是有效的JSON格式",
|
||||||
|
field="metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload files in batch
|
||||||
|
upload_results = upload_service.upload_files_batch(
|
||||||
|
files=files,
|
||||||
|
context=upload_context,
|
||||||
|
metadata=file_metadata,
|
||||||
|
current_user=current_user,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert results to response schemas
|
||||||
|
result_schemas = []
|
||||||
|
for result in upload_results:
|
||||||
|
result_schema = UploadResultSchema(
|
||||||
|
success=result.success,
|
||||||
|
file_id=result.file_id,
|
||||||
|
file_name=result.file_name,
|
||||||
|
error=result.error,
|
||||||
|
file_info=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# If upload was successful, get file info
|
||||||
|
if result.success and result.file_id:
|
||||||
|
try:
|
||||||
|
db_file = upload_service.get_file(result.file_id, current_user, db)
|
||||||
|
result_schema.file_info = GenericFileResponse.model_validate(db_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get file info for {result.file_id}: {str(e)}")
|
||||||
|
|
||||||
|
result_schemas.append(result_schema)
|
||||||
|
|
||||||
|
# Create batch response
|
||||||
|
batch_response = BatchUploadResponse(
|
||||||
|
total=len(files),
|
||||||
|
success_count=sum(1 for r in upload_results if r.success),
|
||||||
|
failed_count=sum(1 for r in upload_results if not r.success),
|
||||||
|
results=result_schemas
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Batch upload completed: {batch_response.success_count}/{batch_response.total} successful")
|
||||||
|
return success(data=batch_response.dict(), msg="批量上传完成")
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
# Business exceptions are handled by global exception handlers
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch upload failed: {str(e)}")
|
||||||
|
# Wrap unknown exceptions as FileUploadException
|
||||||
|
raise FileUploadException(
|
||||||
|
f"批量上传失败: {str(e)}",
|
||||||
|
cause=e
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
|
async def download_file(
|
||||||
|
file_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
文件下载接口
|
||||||
|
|
||||||
|
- **file_id**: 文件ID
|
||||||
|
|
||||||
|
返回文件内容供下载
|
||||||
|
"""
|
||||||
|
logger.info(f"Download request: file_id={file_id}, user={current_user.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse file_id
|
||||||
|
import uuid
|
||||||
|
try:
|
||||||
|
file_uuid = uuid.UUID(file_id)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid file ID format: {file_id}")
|
||||||
|
raise ValidationException(
|
||||||
|
"无效的文件ID格式",
|
||||||
|
field="file_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get file from database
|
||||||
|
db_file = upload_service.get_file(file_uuid, current_user, db)
|
||||||
|
|
||||||
|
# Check if physical file exists
|
||||||
|
storage_path = Path(db_file.storage_path)
|
||||||
|
if not storage_path.exists():
|
||||||
|
logger.error(f"Physical file not found: {storage_path}")
|
||||||
|
raise ResourceNotFoundException(
|
||||||
|
"文件",
|
||||||
|
str(file_uuid),
|
||||||
|
context={"detail": "文件未找到(可能已被删除)"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return file response
|
||||||
|
logger.info(f"Download successful: {db_file.file_name} (ID: {file_id})")
|
||||||
|
return FileResponse(
|
||||||
|
path=str(storage_path),
|
||||||
|
filename=db_file.file_name,
|
||||||
|
media_type=db_file.mime_type or "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
# Business exceptions are handled by global exception handlers
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Download failed: {str(e)}")
|
||||||
|
# Wrap unknown exceptions
|
||||||
|
raise FileUploadException(
|
||||||
|
f"文件下载失败: {str(e)}",
|
||||||
|
cause=e
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/files/{file_id}", response_model=ApiResponse)
|
||||||
|
async def delete_file(
|
||||||
|
file_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> ApiResponse:
|
||||||
|
"""
|
||||||
|
文件删除接口
|
||||||
|
|
||||||
|
- **file_id**: 文件ID
|
||||||
|
|
||||||
|
删除文件(包括物理文件和数据库记录)
|
||||||
|
"""
|
||||||
|
logger.info(f"Delete request: file_id={file_id}, user={current_user.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse file_id
|
||||||
|
import uuid
|
||||||
|
try:
|
||||||
|
file_uuid = uuid.UUID(file_id)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid file ID format: {file_id}")
|
||||||
|
raise ValidationException(
|
||||||
|
"无效的文件ID格式",
|
||||||
|
field="file_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete file
|
||||||
|
upload_service.delete_file(file_uuid, current_user, db)
|
||||||
|
|
||||||
|
logger.info(f"Delete successful: file_id={file_id}")
|
||||||
|
return success(msg="文件删除成功")
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
# Business exceptions are handled by global exception handlers
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Delete failed: {str(e)}")
|
||||||
|
# Wrap unknown exceptions
|
||||||
|
raise FileUploadException(
|
||||||
|
f"文件删除失败: {str(e)}",
|
||||||
|
cause=e
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/files/{file_id}", response_model=ApiResponse)
|
||||||
|
async def update_file_metadata(
|
||||||
|
file_id: str,
|
||||||
|
update_data: FileMetadataUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
) -> ApiResponse:
|
||||||
|
"""
|
||||||
|
文件元数据更新接口
|
||||||
|
|
||||||
|
- **file_id**: 文件ID
|
||||||
|
- **update_data**: 要更新的元数据
|
||||||
|
|
||||||
|
更新文件的元数据(文件名、自定义元数据、公开状态)
|
||||||
|
"""
|
||||||
|
logger.info(f"Update metadata request: file_id={file_id}, user={current_user.id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse file_id
|
||||||
|
import uuid
|
||||||
|
try:
|
||||||
|
file_uuid = uuid.UUID(file_id)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Invalid file ID format: {file_id}")
|
||||||
|
raise ValidationException(
|
||||||
|
"无效的文件ID格式",
|
||||||
|
field="file_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert update data to dict, excluding unset fields
|
||||||
|
update_dict = update_data.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
if not update_dict:
|
||||||
|
logger.warning(f"No fields to update for file: {file_id}")
|
||||||
|
raise ValidationException(
|
||||||
|
"没有提供要更新的字段",
|
||||||
|
field="update_data"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update file metadata
|
||||||
|
updated_file = upload_service.update_file_metadata(
|
||||||
|
file_uuid, update_dict, current_user, db
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to response schema
|
||||||
|
file_response = GenericFileResponse.model_validate(updated_file)
|
||||||
|
|
||||||
|
logger.info(f"Update metadata successful: file_id={file_id}")
|
||||||
|
return success(data=file_response.dict(), msg="文件元数据更新成功")
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
# Business exceptions are handled by global exception handlers
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Update metadata failed: {str(e)}")
|
||||||
|
# Wrap unknown exceptions
|
||||||
|
raise FileUploadException(
|
||||||
|
f"文件元数据更新失败: {str(e)}",
|
||||||
|
cause=e
|
||||||
|
)
|
||||||
183
app/controllers/user_controller.py
Normal file
183
app/controllers/user_controller.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, get_current_superuser
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas import user_schema
|
||||||
|
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.services import user_service
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
|
||||||
|
# 获取API专用日志器
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/users",
|
||||||
|
tags=["Users"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/superuser", response_model=ApiResponse)
|
||||||
|
def create_superuser(
|
||||||
|
user: user_schema.UserCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_superuser: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""创建超级管理员(仅超级管理员可访问)"""
|
||||||
|
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||||
|
|
||||||
|
result = user_service.create_superuser(db=db, user=user, current_user=current_superuser)
|
||||||
|
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
|
result_schema = user_schema.User.model_validate(result)
|
||||||
|
return success(data=result_schema, msg="超级管理员创建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||||
|
def delete_user(
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""停用用户(软删除)"""
|
||||||
|
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
|
result = user_service.deactivate_user(
|
||||||
|
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||||
|
return success(msg="用户停用成功")
|
||||||
|
|
||||||
|
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||||
|
def activate_user(
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""激活用户"""
|
||||||
|
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
|
result = user_service.activate_user(
|
||||||
|
db=db, user_id_to_activate=user_id, current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
|
result_schema = user_schema.User.model_validate(result)
|
||||||
|
return success(data=result_schema, msg="用户激活成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ApiResponse)
|
||||||
|
def get_current_user_info(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取当前用户信息"""
|
||||||
|
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||||
|
|
||||||
|
result = user_service.get_user(
|
||||||
|
db=db, user_id=current_user.id, current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
result_schema = user_schema.User.model_validate(result)
|
||||||
|
|
||||||
|
# 设置当前工作空间的角色和名称
|
||||||
|
if current_user.current_workspace_id:
|
||||||
|
from app.repositories.workspace_repository import WorkspaceRepository
|
||||||
|
workspace_repo = WorkspaceRepository(db)
|
||||||
|
current_workspace = workspace_repo.get_workspace_by_id(current_user.current_workspace_id)
|
||||||
|
if current_workspace:
|
||||||
|
result_schema.current_workspace_name = current_workspace.name
|
||||||
|
|
||||||
|
for ws in result.workspaces:
|
||||||
|
if ws.workspace_id == current_user.current_workspace_id:
|
||||||
|
result_schema.role = ws.role
|
||||||
|
break
|
||||||
|
|
||||||
|
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||||
|
return success(data=result_schema, msg="用户信息获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/superusers", response_model=ApiResponse)
|
||||||
|
def get_tenant_superusers(
|
||||||
|
include_inactive: bool = False,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
):
|
||||||
|
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||||
|
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||||
|
|
||||||
|
superusers = user_service.get_tenant_superusers(
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
include_inactive=include_inactive
|
||||||
|
)
|
||||||
|
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||||
|
|
||||||
|
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||||
|
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=ApiResponse)
|
||||||
|
def get_user_info_by_id(
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""根据用户ID获取用户信息"""
|
||||||
|
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
|
result = user_service.get_user(
|
||||||
|
db=db, user_id=user_id, current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||||
|
|
||||||
|
result_schema = user_schema.User.model_validate(result)
|
||||||
|
return success(data=result_schema, msg="用户信息获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/change-password", response_model=ApiResponse)
|
||||||
|
async def change_password(
|
||||||
|
request: ChangePasswordRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""修改当前用户密码"""
|
||||||
|
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||||
|
|
||||||
|
await user_service.change_password(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
old_password=request.old_password,
|
||||||
|
new_password=request.new_password,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||||
|
return success(msg="密码修改成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||||
|
async def admin_change_password(
|
||||||
|
request: AdminChangePasswordRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
):
|
||||||
|
"""超级管理员修改指定用户的密码"""
|
||||||
|
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||||
|
|
||||||
|
user, generated_password = await user_service.admin_change_password(
|
||||||
|
db=db,
|
||||||
|
target_user_id=request.user_id,
|
||||||
|
new_password=request.new_password,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
# 根据是否生成了随机密码来构造响应
|
||||||
|
if request.new_password:
|
||||||
|
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||||
|
return success(msg="密码修改成功")
|
||||||
|
else:
|
||||||
|
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||||
|
return success(data=generated_password, msg="密码重置成功")
|
||||||
342
app/controllers/workspace_controller.py
Normal file
342
app/controllers/workspace_controller.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models.tenant_model import Tenants
|
||||||
|
from app.models.workspace_model import Workspace, InviteStatus
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.schemas.workspace_schema import (
|
||||||
|
WorkspaceCreate, WorkspaceUpdate, WorkspaceResponse,
|
||||||
|
WorkspaceInviteCreate, WorkspaceInviteResponse,
|
||||||
|
InviteValidateResponse, InviteAcceptRequest,
|
||||||
|
WorkspaceMemberUpdate, WorkspaceMemberItem
|
||||||
|
)
|
||||||
|
from app.schemas import knowledge_schema
|
||||||
|
from app.services import workspace_service
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.services import knowledge_service, document_service
|
||||||
|
# 获取API专用日志器
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
# 需要认证的路由器
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/workspaces",
|
||||||
|
tags=["Workspaces"],
|
||||||
|
dependencies=[Depends(get_current_user)]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 公开路由器(不需要认证)
|
||||||
|
public_router = APIRouter(
|
||||||
|
prefix="/workspaces",
|
||||||
|
tags=["Workspaces"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_members_to_table_items(members):
|
||||||
|
"""将工作空间成员列表转换为表格项"""
|
||||||
|
return [
|
||||||
|
WorkspaceMemberItem(
|
||||||
|
id=m.id,
|
||||||
|
username=m.user.username,
|
||||||
|
account=m.user.email,
|
||||||
|
role=m.role,
|
||||||
|
last_login_at=m.user.last_login_at
|
||||||
|
)
|
||||||
|
for m in members
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ApiResponse)
|
||||||
|
def get_workspaces(
|
||||||
|
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
current_tenant: Tenants = Depends(get_current_tenant)
|
||||||
|
):
|
||||||
|
"""获取当前租户下用户参与的所有工作空间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_current: 是否包含当前工作空间(默认 True)
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"用户 {current_user.username} 在租户 {current_tenant.name} 中请求获取工作空间列表",
|
||||||
|
extra={"include_current": include_current}
|
||||||
|
)
|
||||||
|
|
||||||
|
workspaces = workspace_service.get_user_workspaces(db, current_user)
|
||||||
|
|
||||||
|
# 如果不包含当前工作空间,则过滤掉
|
||||||
|
if not include_current and current_user.current_workspace_id:
|
||||||
|
workspaces = [w for w in workspaces if w.id != current_user.current_workspace_id]
|
||||||
|
api_logger.debug(
|
||||||
|
f"过滤掉当前工作空间",
|
||||||
|
extra={"current_workspace_id": str(current_user.current_workspace_id)}
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||||
|
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||||
|
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ApiResponse)
|
||||||
|
def create_workspace(
|
||||||
|
workspace: WorkspaceCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
):
|
||||||
|
"""创建新的工作空间"""
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
|
||||||
|
|
||||||
|
result = workspace_service.create_workspace(
|
||||||
|
db=db, workspace=workspace, user=current_user)
|
||||||
|
|
||||||
|
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
|
||||||
|
result_schema = WorkspaceResponse.model_validate(result)
|
||||||
|
return success(data=result_schema, msg="工作空间创建成功")
|
||||||
|
|
||||||
|
@router.put("", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_workspace(
|
||||||
|
workspace: WorkspaceUpdate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""更新工作空间"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 ID: {workspace_id}")
|
||||||
|
|
||||||
|
result = workspace_service.update_workspace(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
workspace_in=workspace,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||||
|
result_schema = WorkspaceResponse.model_validate(result)
|
||||||
|
return success(data=result_schema, msg="工作空间更新成功")
|
||||||
|
|
||||||
|
@router.get("/members", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_cur_workspace_members(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取工作空间成员列表(关系序列化)"""
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||||
|
|
||||||
|
members = workspace_service.get_workspace_members(
|
||||||
|
db=db,
|
||||||
|
workspace_id=current_user.current_workspace_id,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||||
|
table_items = _convert_members_to_table_items(members)
|
||||||
|
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/members", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_workspace_members(
|
||||||
|
|
||||||
|
updates: List[WorkspaceMemberUpdate],
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||||
|
members = workspace_service.update_workspace_member_roles(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
updates=updates,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||||
|
return success(msg="成员角色更新成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def delete_workspace_member(
|
||||||
|
member_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
|
workspace_service.delete_workspace_member(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
member_id=member_id,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||||
|
return success(msg="成员删除成功")
|
||||||
|
|
||||||
|
|
||||||
|
# 创建空间协作邀请
|
||||||
|
@router.post("/invites", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def create_workspace_invite(
|
||||||
|
invite_data: WorkspaceInviteCreate,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""创建工作空间邀请"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求为工作空间 {workspace_id} 创建邀请: {invite_data.email}")
|
||||||
|
|
||||||
|
result = workspace_service.create_workspace_invite(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
invite_data=invite_data,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||||
|
return success(data=result, msg="邀请创建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/invites", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_workspace_invites(
|
||||||
|
|
||||||
|
status_filter: Optional[InviteStatus] = Query(None, alias="status"),
|
||||||
|
limit: int = Query(50, ge=1, le=100),
|
||||||
|
offset: int = Query(0, ge=0),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取工作空间邀请列表"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的邀请列表")
|
||||||
|
|
||||||
|
invites = workspace_service.get_workspace_invites(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user,
|
||||||
|
status=status_filter,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||||
|
return success(data=invites, msg="邀请列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||||
|
def get_workspace_invite_info(
|
||||||
|
token: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""获取工作空间邀请用户信息(无需认证)"""
|
||||||
|
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||||
|
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||||
|
return success(data=result, msg="邀请验证成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def revoke_workspace_invite(
|
||||||
|
|
||||||
|
invite_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""撤销工作空间邀请"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求撤销工作空间 {workspace_id} 的邀请 {invite_id}")
|
||||||
|
|
||||||
|
result = workspace_service.revoke_workspace_invite(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
invite_id=invite_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||||
|
return success(data=result, msg="邀请撤销成功")
|
||||||
|
|
||||||
|
# ==================== 公开邀请接口(无需认证) ====================
|
||||||
|
|
||||||
|
# # 创建一个新的路由器用于公开接口
|
||||||
|
# public_router = APIRouter(
|
||||||
|
# prefix="/invites",
|
||||||
|
# tags=["Public Invites"]
|
||||||
|
# )
|
||||||
|
|
||||||
|
# @public_router.get("/validate", response_model=ApiResponse)
|
||||||
|
# def validate_invite_token(
|
||||||
|
# token: str = Query(..., description="邀请令牌"),
|
||||||
|
# db: Session = Depends(get_db),
|
||||||
|
# ):
|
||||||
|
# """验证邀请令牌(公开接口)"""
|
||||||
|
# api_logger.info(f"验证邀请令牌请求")
|
||||||
|
@router.put("/{workspace_id}/switch", response_model=ApiResponse)
|
||||||
|
@workspace_access_guard()
|
||||||
|
def switch_workspace(
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""切换工作空间"""
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||||
|
|
||||||
|
workspace_service.switch_workspace(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user,
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||||
|
return success(msg="工作空间切换成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/storage", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_workspace_storage_type(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取当前工作空间的存储类型"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的存储类型")
|
||||||
|
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||||
|
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/workspace_models", response_model=ApiResponse)
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def workspace_models_configs(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的模型配置")
|
||||||
|
|
||||||
|
configs = workspace_service.get_workspace_models_configs(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
if configs is None:
|
||||||
|
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="工作空间不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||||
|
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||||
|
)
|
||||||
|
return success(data=configs, msg="模型配置获取成功")
|
||||||
|
|
||||||
0
app/core/agent/__init__.py
Normal file
0
app/core/agent/__init__.py
Normal file
35
app/core/agent/agent_api_text.py
Normal file
35
app/core/agent/agent_api_text.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.core.agent.agent_chat import Agent_chat
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
|
from app.dependencies import workspace_access_guard
|
||||||
|
from app.services.agent_server import config,ChatRequest
|
||||||
|
router = APIRouter(prefix="/Test", tags=["Apps"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
class CombinedRequest(BaseModel):
|
||||||
|
config_base: config
|
||||||
|
agent_config: ChatRequest
|
||||||
|
|
||||||
|
@router.post("", summary="uuid")
|
||||||
|
async def agent_chat(
|
||||||
|
config_base: CombinedRequest
|
||||||
|
):
|
||||||
|
chat_config=config_base.agent_config
|
||||||
|
chat_base=config_base.config_base
|
||||||
|
request = ChatRequest(
|
||||||
|
end_user_id=chat_config.end_user_id,
|
||||||
|
message=chat_config.message,
|
||||||
|
search_switch=chat_config.search_switch,
|
||||||
|
kb_ids=chat_config.kb_ids,
|
||||||
|
similarity_threshold=chat_config.similarity_threshold,
|
||||||
|
vector_similarity_weight=chat_config.vector_similarity_weight,
|
||||||
|
top_k=chat_config.top_k,
|
||||||
|
hybrid=chat_config.hybrid,
|
||||||
|
token=chat_config.token
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_result=await Agent_chat(chat_base).chat(request)
|
||||||
|
|
||||||
|
return chat_result
|
||||||
109
app/core/agent/agent_chat.py
Normal file
109
app/core/agent/agent_chat.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
|
from app.services.api_resquests_server import messages_type, write_messages
|
||||||
|
from app.services.agent_server import ChatRequest, tool_memory, create_dynamic_agent, tool_Retrieval
|
||||||
|
|
||||||
|
logger = get_business_logger()
|
||||||
|
class Agent_chat:
|
||||||
|
def __init__(self,config_data: dict):
|
||||||
|
self.prompt_message = render_prompt_message(
|
||||||
|
config_data.template_str,
|
||||||
|
PromptMessageRole.USER,
|
||||||
|
config_data.params
|
||||||
|
)
|
||||||
|
self.prompt = self.prompt_message.get_text_content()
|
||||||
|
self.model_configs = config_data.model_configs
|
||||||
|
self.history_memory = config_data.history_memory
|
||||||
|
self.knowledge_base = config_data.knowledge_base
|
||||||
|
logger.info(f"渲染结果:{self.prompt_message.get_text_content()}" )
|
||||||
|
|
||||||
|
async def run_agent(self,agent, end_user_id:str, user_prompt:str, model_name:str):
|
||||||
|
response = agent.invoke(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"configurable": {"thread_id": f'{model_name}_{end_user_id}'}},
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
for msg in response["messages"]:
|
||||||
|
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||||
|
outputs.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"name": t["name"], "arguments": t["args"]}
|
||||||
|
for t in msg.tool_calls
|
||||||
|
]
|
||||||
|
})
|
||||||
|
elif hasattr(msg, "content") and msg.content:
|
||||||
|
outputs.append({
|
||||||
|
"role": msg.__class__.__name__.lower().replace("message", ""),
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
ai_messages=[msg['content'] for msg in outputs if msg["role"] == "ai"]
|
||||||
|
return {"model_name": model_name, "end_user_id": end_user_id, "response": ai_messages}
|
||||||
|
|
||||||
|
async def chat(self,req: ChatRequest) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
end_user_id = req.end_user_id # 用 user_id 作为对话线程标识
|
||||||
|
start=time.time()
|
||||||
|
user_prompt = req.message
|
||||||
|
|
||||||
|
'''判断是都写入redis数据库'''
|
||||||
|
messags_type = await messages_type(req.message,end_user_id)
|
||||||
|
messags_type=messags_type['data']
|
||||||
|
if messags_type=='question':
|
||||||
|
writer_result=await write_messages(f'{end_user_id}', req.message)
|
||||||
|
logger.info(f'判断类型写入耗时:{time.time() - start},{writer_result}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''history_memory'''
|
||||||
|
|
||||||
|
if self.history_memory==True:
|
||||||
|
tool_result =await tool_memory(req)
|
||||||
|
if tool_result!='' :tool_result=tool_result['data']
|
||||||
|
if tool_result!='' :self.prompt=self.prompt+f''',历史消息:{tool_result},结合历史消息'''
|
||||||
|
logger.info(f"记忆科学消耗时间:{time.time()-start},工具调用结果:{tool_result}")
|
||||||
|
|
||||||
|
'''baidu'''
|
||||||
|
|
||||||
|
|
||||||
|
'''knowledge_base'''
|
||||||
|
if self.knowledge_base == True:
|
||||||
|
retrieval_result=await tool_Retrieval(req)
|
||||||
|
retrieval_knowledge = [i['page_content'] for i in retrieval_result['data']]
|
||||||
|
retrieval_knowledge=','.join(retrieval_knowledge)
|
||||||
|
logger.info(f"检索消耗时间:{time.time()-start},{retrieval_knowledge}")
|
||||||
|
if retrieval_knowledge!='' :self.prompt=self.prompt+f",知识库检索内容:{retrieval_knowledge},结合检索结果"
|
||||||
|
self.prompt=self.prompt+f'给出最合适的答案,确保答案的完整性,只保留用户的问题的回答,不额外输出提示语'
|
||||||
|
logger.info(f"用户输入:{user_prompt}")
|
||||||
|
logger.info(f"系统prompt:{self.prompt}")
|
||||||
|
|
||||||
|
AGENTS = {
|
||||||
|
cfg["name"]: await create_dynamic_agent(cfg["name"], cfg["moder_id"], self.prompt, req.token)
|
||||||
|
for cfg in self.model_configs
|
||||||
|
}
|
||||||
|
tasks=[
|
||||||
|
self.run_agent(agent, end_user_id, user_prompt, model_name)
|
||||||
|
for model_name, agent in AGENTS.items()
|
||||||
|
]
|
||||||
|
# 并行运行
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
result=[]
|
||||||
|
|
||||||
|
for i in results:
|
||||||
|
result.append(i)
|
||||||
|
chat_result=(f"最终耗时:{time.time()-start},{result}")
|
||||||
|
return chat_result
|
||||||
347
app/core/agent/langchain_agent.py
Normal file
347
app/core/agent/langchain_agent.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""
|
||||||
|
LangChain Agent 封装
|
||||||
|
|
||||||
|
使用 LangChain 1.x 标准方式
|
||||||
|
- 使用 create_agent 创建 agent graph
|
||||||
|
- 支持工具调用循环
|
||||||
|
- 支持流式输出
|
||||||
|
- 使用 RedBearLLM 支持多提供商
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.models.models_model import ModelType
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainAgent:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
api_key: str,
|
||||||
|
provider: str = "openai",
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
|
streaming: bool = False
|
||||||
|
):
|
||||||
|
"""初始化 LangChain Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
api_key: API Key
|
||||||
|
provider: 提供商(openai, xinference, gpustack, ollama, dashscope)
|
||||||
|
api_base: API 基础 URL
|
||||||
|
temperature: 温度参数
|
||||||
|
max_tokens: 最大 token 数
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||||
|
streaming: 是否启用流式输出(默认 True)
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.provider = provider
|
||||||
|
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||||
|
self.tools = tools or []
|
||||||
|
self.streaming = streaming
|
||||||
|
|
||||||
|
# 创建 RedBearLLM(支持多提供商)
|
||||||
|
model_config = RedBearModelConfig(
|
||||||
|
model_name=model_name,
|
||||||
|
provider=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
extra_params={
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"streaming": streaming # 使用参数控制流式
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||||
|
|
||||||
|
# 获取底层模型用于真正的流式调用
|
||||||
|
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
|
||||||
|
|
||||||
|
# 确保底层模型也启用流式
|
||||||
|
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||||
|
self._underlying_llm.streaming = True
|
||||||
|
|
||||||
|
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||||
|
# 无论是否有工具,都使用 agent 统一处理
|
||||||
|
self.agent = create_agent(
|
||||||
|
model=self.llm,
|
||||||
|
tools=self.tools if self.tools else None,
|
||||||
|
system_prompt=self.system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"LangChain Agent 初始化完成",
|
||||||
|
extra={
|
||||||
|
"model": model_name,
|
||||||
|
"provider": provider,
|
||||||
|
"has_api_base": bool(api_base),
|
||||||
|
"temperature": temperature,
|
||||||
|
"streaming": streaming,
|
||||||
|
"tool_count": len(self.tools),
|
||||||
|
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||||
|
"tool_count": len(self.tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_messages(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
|
context: Optional[str] = None
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""准备消息列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 用户消息
|
||||||
|
history: 历史消息列表
|
||||||
|
context: 上下文信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: 消息列表
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# 添加系统提示词
|
||||||
|
messages.append(SystemMessage(content=self.system_prompt))
|
||||||
|
|
||||||
|
# 添加历史消息
|
||||||
|
if history:
|
||||||
|
for msg in history:
|
||||||
|
if msg["role"] == "user":
|
||||||
|
messages.append(HumanMessage(content=msg["content"]))
|
||||||
|
elif msg["role"] == "assistant":
|
||||||
|
messages.append(AIMessage(content=msg["content"]))
|
||||||
|
|
||||||
|
# 添加当前用户消息
|
||||||
|
user_content = message
|
||||||
|
if context:
|
||||||
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
config_id: Optional[str] = None, # 添加这个参数
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""执行对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 用户消息
|
||||||
|
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||||
|
context: 上下文信息(如知识库检索结果)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 包含 content 和元数据的字典
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
|
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
|
if storage_type == "rag":
|
||||||
|
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
else:
|
||||||
|
if config_id==None:
|
||||||
|
actual_config_id = os.getenv("config_id")
|
||||||
|
else:actual_config_id=config_id
|
||||||
|
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||||
|
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 准备消息列表
|
||||||
|
messages = self._prepare_messages(message, history, context)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"准备调用 LangChain Agent",
|
||||||
|
extra={
|
||||||
|
"has_context": bool(context),
|
||||||
|
"has_history": bool(history),
|
||||||
|
"has_tools": bool(self.tools),
|
||||||
|
"message_count": len(messages)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 统一使用 agent.invoke 调用
|
||||||
|
result = await self.agent.ainvoke({"messages": messages})
|
||||||
|
|
||||||
|
# 获取最后的 AI 消息
|
||||||
|
output_messages = result.get("messages", [])
|
||||||
|
content = ""
|
||||||
|
for msg in reversed(output_messages):
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
content = msg.content
|
||||||
|
break
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
if storage_type == "rag":
|
||||||
|
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
else:
|
||||||
|
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, user_rag_memory_id)
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"content": content,
|
||||||
|
"model": self.model_name,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Agent 调用完成",
|
||||||
|
extra={
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"content_length": len(response["content"])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent 调用失败", extra={"error": str(e)})
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
end_user_id:Optional[str] = None,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
storage_type:Optional[str] = None,
|
||||||
|
user_rag_memory_id:Optional[str] = None,
|
||||||
|
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""执行流式对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 用户消息
|
||||||
|
history: 历史消息列表
|
||||||
|
context: 上下文信息
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: 消息内容块
|
||||||
|
"""
|
||||||
|
logger.info("=" * 80)
|
||||||
|
logger.info(f" chat_stream 方法开始执行")
|
||||||
|
logger.info(f" Message: {message[:100]}")
|
||||||
|
logger.info(f" Has tools: {bool(self.tools)}")
|
||||||
|
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
if storage_type == "rag":
|
||||||
|
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||||
|
else:
|
||||||
|
if config_id==None:
|
||||||
|
actual_config_id = os.getenv("config_id")
|
||||||
|
else:actual_config_id=config_id
|
||||||
|
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||||
|
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent 记忆用户输入出错", extra={"error": str(e)})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 准备消息列表
|
||||||
|
messages = self._prepare_messages(message, history, context)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_count = 0
|
||||||
|
yielded_content = False
|
||||||
|
|
||||||
|
# 统一使用 agent 的 astream_events 实现流式输出
|
||||||
|
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for event in self.agent.astream_events(
|
||||||
|
{"messages": messages},
|
||||||
|
version="v2"
|
||||||
|
):
|
||||||
|
chunk_count += 1
|
||||||
|
kind = event.get("event")
|
||||||
|
|
||||||
|
# 处理所有可能的流式事件
|
||||||
|
if kind == "on_chat_model_stream":
|
||||||
|
# LLM 流式输出
|
||||||
|
chunk = event.get("data", {}).get("chunk")
|
||||||
|
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||||
|
yield chunk.content
|
||||||
|
yielded_content = True
|
||||||
|
|
||||||
|
elif kind == "on_llm_stream":
|
||||||
|
# 另一种 LLM 流式事件
|
||||||
|
chunk = event.get("data", {}).get("chunk")
|
||||||
|
if chunk:
|
||||||
|
if hasattr(chunk, "content") and chunk.content:
|
||||||
|
yield chunk.content
|
||||||
|
yielded_content = True
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
yield chunk
|
||||||
|
yielded_content = True
|
||||||
|
|
||||||
|
# 记录工具调用(可选)
|
||||||
|
elif kind == "on_tool_start":
|
||||||
|
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||||
|
elif kind == "on_tool_end":
|
||||||
|
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||||
|
|
||||||
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("=" * 80)
|
||||||
|
logger.error(f"chat_stream 异常: {str(e)}")
|
||||||
|
logger.error("=" * 80, exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
logger.info("=" * 80)
|
||||||
|
logger.info(f"chat_stream 方法执行结束")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
56
app/core/api_key_utils.py
Normal file
56
app/core/api_key_utils.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""API Key 工具函数"""
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
|
from app.models.api_key_model import ApiKeyType
|
||||||
|
|
||||||
|
|
||||||
|
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
||||||
|
"""生成 API Key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_type: API Key 类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (api_key, key_hash, key_prefix)
|
||||||
|
"""
|
||||||
|
# 前缀映射
|
||||||
|
prefix_map = {
|
||||||
|
ApiKeyType.APP: "sk-app-",
|
||||||
|
ApiKeyType.RAG: "sk-rag-",
|
||||||
|
ApiKeyType.MEMORY: "sk-mem-",
|
||||||
|
ApiKeyType.GENERAL: "sk-gen-",
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix = prefix_map[key_type]
|
||||||
|
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
|
||||||
|
api_key = f"{prefix}{random_string}"
|
||||||
|
|
||||||
|
# 生成哈希值存储
|
||||||
|
key_hash = hash_api_key(api_key)
|
||||||
|
|
||||||
|
return api_key, key_hash, prefix
|
||||||
|
|
||||||
|
|
||||||
|
def hash_api_key(api_key: str) -> str:
|
||||||
|
"""对 API Key 进行哈希
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API Key 明文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 哈希值
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_api_key(api_key: str, key_hash: str) -> bool:
|
||||||
|
"""验证 API Key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API Key 明文
|
||||||
|
key_hash: 存储的哈希值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否匹配
|
||||||
|
"""
|
||||||
|
return hash_api_key(api_key) == key_hash
|
||||||
47
app/core/compensation.py
Normal file
47
app/core/compensation.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
Compensation Transaction Handler
|
||||||
|
Handles operations that cannot be rolled back (like file system operations).
|
||||||
|
"""
|
||||||
|
from typing import List, Callable
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CompensationHandler:
|
||||||
|
"""补偿事务处理器,用于处理无法回滚的操作"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._compensations: List[Callable] = []
|
||||||
|
|
||||||
|
def register(self, compensation: Callable):
|
||||||
|
"""
|
||||||
|
注册补偿操作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
compensation: 补偿操作的可调用对象
|
||||||
|
"""
|
||||||
|
self._compensations.append(compensation)
|
||||||
|
logger.debug(f"Registered compensation operation: {compensation.__name__ if hasattr(compensation, '__name__') else 'lambda'}")
|
||||||
|
|
||||||
|
def execute(self):
|
||||||
|
"""执行所有补偿操作(按注册的逆序执行)"""
|
||||||
|
if not self._compensations:
|
||||||
|
logger.debug("No compensation operations to execute")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Executing {len(self._compensations)} compensation operations")
|
||||||
|
|
||||||
|
for compensation in reversed(self._compensations):
|
||||||
|
try:
|
||||||
|
compensation()
|
||||||
|
logger.debug(f"Compensation operation executed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"补偿操作失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""清空补偿操作"""
|
||||||
|
count = len(self._compensations)
|
||||||
|
self._compensations.clear()
|
||||||
|
if count > 0:
|
||||||
|
logger.debug(f"Cleared {count} compensation operations")
|
||||||
237
app/core/config.py
Normal file
237
app/core/config.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class Settings:
|
||||||
|
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||||
|
# API Keys Configuration
|
||||||
|
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||||
|
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||||
|
|
||||||
|
# Neo4j Configuration (记忆系统数据库)
|
||||||
|
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
|
||||||
|
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
|
||||||
|
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
|
||||||
|
|
||||||
|
# Database configuration (Postgres)
|
||||||
|
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
|
||||||
|
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
|
||||||
|
DB_USER: str = os.getenv("DB_USER", "postgres")
|
||||||
|
DB_PASSWORD: str = os.getenv("DB_PASSWORD", "password")
|
||||||
|
DB_NAME: str = os.getenv("DB_NAME", "redbear-mem")
|
||||||
|
|
||||||
|
DB_AUTO_UPGRADE = os.getenv("DB_AUTO_UPGRADE", "false").lower() == "true"
|
||||||
|
|
||||||
|
# Redis configuration
|
||||||
|
REDIS_HOST: str = os.getenv("REDIS_HOST", "127.0.0.1")
|
||||||
|
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||||
|
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||||
|
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||||
|
|
||||||
|
# ElasticSearch configuration
|
||||||
|
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||||
|
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
|
||||||
|
ELASTICSEARCH_USERNAME: str = os.getenv("ELASTICSEARCH_USERNAME", "elastic")
|
||||||
|
ELASTICSEARCH_PASSWORD: str = os.getenv("ELASTICSEARCH_PASSWORD", "")
|
||||||
|
ELASTICSEARCH_VERIFY_CERTS: bool = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "False").lower() == "true"
|
||||||
|
ELASTICSEARCH_CA_CERTS: str = os.getenv("ELASTICSEARCH_CA_CERTS", "")
|
||||||
|
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
|
||||||
|
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
|
||||||
|
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
|
||||||
|
|
||||||
|
# Xinference configuration
|
||||||
|
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
|
||||||
|
|
||||||
|
# LangSmith configuration
|
||||||
|
LANGCHAIN_TRACING_V2: bool = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
|
||||||
|
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
|
||||||
|
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
|
||||||
|
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
|
||||||
|
|
||||||
|
# LLM Request Configuration
|
||||||
|
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||||
|
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
|
||||||
|
|
||||||
|
# JWT Token Configuration
|
||||||
|
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
|
||||||
|
ALGORITHM: str = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||||
|
|
||||||
|
# Single Sign-On configuration
|
||||||
|
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||||
|
|
||||||
|
# File Upload
|
||||||
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
|
|
||||||
|
# VOLC ASR settings
|
||||||
|
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||||
|
VOLC_ACCESS_KEY: str = os.getenv("VOLC_ACCESS_KEY", "")
|
||||||
|
VOLC_SUBMIT_URL: str = os.getenv("VOLC_SUBMIT_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/submit")
|
||||||
|
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
|
||||||
|
|
||||||
|
# Langfuse configuration
|
||||||
|
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
|
||||||
|
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
|
||||||
|
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
|
||||||
|
|
||||||
|
# Server Configuration
|
||||||
|
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Internal Configuration (not in .env, used by application code)
|
||||||
|
# ========================================================================
|
||||||
|
|
||||||
|
# Superuser settings (internal defaults)
|
||||||
|
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
|
||||||
|
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
|
||||||
|
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
|
||||||
|
|
||||||
|
# Generic File Upload (internal)
|
||||||
|
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
|
||||||
|
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
|
||||||
|
ENABLE_VIRUS_SCAN: bool = os.getenv("ENABLE_VIRUS_SCAN", "false").lower() == "true"
|
||||||
|
FILE_ACCESS_URL_PREFIX: str = os.getenv("FILE_ACCESS_URL_PREFIX", "http://localhost:8000/api/files")
|
||||||
|
|
||||||
|
# Frontend URL for workspace invitations (internal)
|
||||||
|
WEB_URL: str = os.getenv("WEB_URL", "http://localhost:3000")
|
||||||
|
|
||||||
|
# CORS configuration (internal)
|
||||||
|
CORS_ORIGINS: list[str] = [
|
||||||
|
origin.strip()
|
||||||
|
for origin in os.getenv("CORS_ORIGINS", "").split(",")
|
||||||
|
if origin.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Logging settings
|
||||||
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
LOG_FILE_PATH: str = os.getenv("LOG_FILE_PATH", "logs/app.log")
|
||||||
|
LOG_MAX_SIZE: int = int(os.getenv("LOG_MAX_SIZE", "10485760")) # 10MB
|
||||||
|
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
|
||||||
|
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
|
||||||
|
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
|
||||||
|
|
||||||
|
# Sensitive Data Filtering
|
||||||
|
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
|
||||||
|
|
||||||
|
# Memory Module Logging
|
||||||
|
PROMPT_LOG_LEVEL: str = os.getenv("PROMPT_LOG_LEVEL", "INFO")
|
||||||
|
ENABLE_TEMPLATE_LOGGING: bool = os.getenv("ENABLE_TEMPLATE_LOGGING", "false").lower() == "true"
|
||||||
|
TIMING_LOG_FILE: str = os.getenv("TIMING_LOG_FILE", "logs/time.log")
|
||||||
|
TIMING_LOG_TO_CONSOLE: bool = os.getenv("TIMING_LOG_TO_CONSOLE", "true").lower() == "true"
|
||||||
|
AGENT_LOG_FILE: str = os.getenv("AGENT_LOG_FILE", "logs/agent_service.log")
|
||||||
|
AGENT_LOG_MAX_SIZE: int = int(os.getenv("AGENT_LOG_MAX_SIZE", "5242880")) # 5MB
|
||||||
|
AGENT_LOG_BACKUP_COUNT: int = int(os.getenv("AGENT_LOG_BACKUP_COUNT", "20"))
|
||||||
|
|
||||||
|
# Log Streaming Configuration
|
||||||
|
LOG_STREAM_KEEPALIVE_INTERVAL: int = int(os.getenv("LOG_STREAM_KEEPALIVE_INTERVAL", "300")) # 5 minutes
|
||||||
|
LOG_STREAM_MAX_CONNECTIONS: int = int(os.getenv("LOG_STREAM_MAX_CONNECTIONS", "10"))
|
||||||
|
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
|
||||||
|
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||||
|
|
||||||
|
|
||||||
|
# Celery configuration (internal)
|
||||||
|
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||||
|
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||||
|
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||||
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
|
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||||
|
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||||
|
|
||||||
|
# Memory Module Configuration (internal)
|
||||||
|
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||||
|
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||||
|
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
|
||||||
|
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
|
||||||
|
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
|
||||||
|
|
||||||
|
def get_memory_output_path(self, filename: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Get the full path for memory module output files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Optional filename to append to the output directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the output file or directory
|
||||||
|
"""
|
||||||
|
base_path = Path(self.MEMORY_OUTPUT_DIR)
|
||||||
|
if filename:
|
||||||
|
return str(base_path / filename)
|
||||||
|
return str(base_path)
|
||||||
|
|
||||||
|
def get_memory_config_path(self, config_file: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Get the full path for memory module configuration files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Full path to the config file
|
||||||
|
"""
|
||||||
|
if not config_file:
|
||||||
|
config_file = self.MEMORY_CONFIG_FILE
|
||||||
|
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
|
||||||
|
|
||||||
|
def load_memory_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load memory module configuration from config.json.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing memory configuration
|
||||||
|
"""
|
||||||
|
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
|
||||||
|
try:
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||||
|
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def load_memory_runtime_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load memory module runtime configuration from runtime.json.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing runtime configuration
|
||||||
|
"""
|
||||||
|
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
|
||||||
|
try:
|
||||||
|
with open(runtime_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||||
|
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
|
||||||
|
return {"selections": {}}
|
||||||
|
|
||||||
|
def load_memory_dbrun_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load memory module database run configuration from dbrun.json.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing dbrun configuration
|
||||||
|
"""
|
||||||
|
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
|
||||||
|
try:
|
||||||
|
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||||
|
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
|
||||||
|
return {"selections": {}}
|
||||||
|
|
||||||
|
def ensure_memory_output_dir(self) -> None:
|
||||||
|
"""
|
||||||
|
Ensure the memory output directory exists.
|
||||||
|
Creates the directory if it doesn't exist.
|
||||||
|
"""
|
||||||
|
output_dir = Path(self.MEMORY_OUTPUT_DIR)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
130
app/core/error_codes.py
Normal file
130
app/core/error_codes.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
class BizCode(IntEnum):
|
||||||
|
# 通用(1xxx)
|
||||||
|
OK = 0
|
||||||
|
BAD_REQUEST = 1000
|
||||||
|
VALIDATION_FAILED = 1001
|
||||||
|
MISSING_PARAMETER = 1002
|
||||||
|
INVALID_PARAMETER = 1003
|
||||||
|
# 认证/鉴权(2xxx/3xxx)
|
||||||
|
UNAUTHORIZED = 2001
|
||||||
|
TOKEN_INVALID = 2002
|
||||||
|
TOKEN_EXPIRED = 2003
|
||||||
|
TOKEN_BLACKLISTED = 2004
|
||||||
|
PASSWORD_ERROR = 2005
|
||||||
|
LOGIN_FAILED = 2006
|
||||||
|
FORBIDDEN = 3001
|
||||||
|
TENANT_NOT_FOUND = 3002
|
||||||
|
WORKSPACE_NO_ACCESS = 3003
|
||||||
|
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||||
|
# 资源(4xxx)
|
||||||
|
NOT_FOUND = 4000
|
||||||
|
USER_NOT_FOUND = 4001
|
||||||
|
WORKSPACE_NOT_FOUND = 4002
|
||||||
|
MODEL_NOT_FOUND = 4003
|
||||||
|
KNOWLEDGE_NOT_FOUND = 4004
|
||||||
|
DOCUMENT_NOT_FOUND = 4005
|
||||||
|
FILE_NOT_FOUND = 4006
|
||||||
|
APP_NOT_FOUND = 4007
|
||||||
|
RELEASE_NOT_FOUND = 4008
|
||||||
|
|
||||||
|
# 冲突/状态(5xxx)
|
||||||
|
DUPLICATE_NAME = 5001
|
||||||
|
RESOURCE_ALREADY_EXISTS = 5002
|
||||||
|
VERSION_ALREADY_EXISTS = 5003
|
||||||
|
STATE_CONFLICT = 5004
|
||||||
|
|
||||||
|
# 应用发布(6xxx)
|
||||||
|
PUBLISH_FAILED = 6001
|
||||||
|
NO_DRAFT_TO_PUBLISH = 6002
|
||||||
|
ROLLBACK_TARGET_NOT_FOUND = 6003
|
||||||
|
APP_TYPE_NOT_SUPPORTED = 6004
|
||||||
|
AGENT_CONFIG_MISSING = 6005
|
||||||
|
SHARE_DISABLED = 6006
|
||||||
|
INVALID_PASSWORD = 6007
|
||||||
|
PASSWORD_REQUIRED = 6008
|
||||||
|
EMBED_NOT_ALLOWED = 6009
|
||||||
|
PERMISSION_DENIED = 6010
|
||||||
|
INVALID_CONVERSATION = 6011
|
||||||
|
|
||||||
|
# 模型(7xxx)
|
||||||
|
MODEL_CONFIG_INVALID = 7001
|
||||||
|
API_KEY_MISSING = 7002
|
||||||
|
PROVIDER_NOT_SUPPORTED = 7003
|
||||||
|
LLM_ERROR = 7004
|
||||||
|
EMBEDDING_ERROR = 7005
|
||||||
|
|
||||||
|
# 文件/解析(8xxx)
|
||||||
|
FILE_READ_ERROR = 8001
|
||||||
|
PARSER_NOT_SUPPORTED = 8002
|
||||||
|
CHUNKING_FAILED = 8003
|
||||||
|
|
||||||
|
# RAG/知识(9xxx)
|
||||||
|
INDEX_BUILD_FAILED = 9001
|
||||||
|
EMBEDDING_FAILED = 9002
|
||||||
|
SEARCH_FAILED = 9003
|
||||||
|
|
||||||
|
# 系统(100xx)
|
||||||
|
INTERNAL_ERROR = 10001
|
||||||
|
DB_ERROR = 10002
|
||||||
|
SERVICE_UNAVAILABLE = 10003
|
||||||
|
RATE_LIMITED = 10004
|
||||||
|
|
||||||
|
|
||||||
|
# 建议的HTTP状态映射(如需在异常处理器中使用)
|
||||||
|
HTTP_MAPPING = {
|
||||||
|
BizCode.OK: 200,
|
||||||
|
BizCode.LOGIN_FAILED: 200,
|
||||||
|
BizCode.BAD_REQUEST: 400,
|
||||||
|
BizCode.VALIDATION_FAILED: 400,
|
||||||
|
BizCode.MISSING_PARAMETER: 400,
|
||||||
|
BizCode.INVALID_PARAMETER: 400,
|
||||||
|
BizCode.UNAUTHORIZED: 401,
|
||||||
|
BizCode.TOKEN_INVALID: 401,
|
||||||
|
BizCode.TOKEN_EXPIRED: 401,
|
||||||
|
BizCode.TOKEN_BLACKLISTED: 401,
|
||||||
|
BizCode.FORBIDDEN: 403,
|
||||||
|
BizCode.TENANT_NOT_FOUND: 404,
|
||||||
|
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||||
|
BizCode.NOT_FOUND: 404,
|
||||||
|
BizCode.USER_NOT_FOUND: 200,
|
||||||
|
BizCode.WORKSPACE_NOT_FOUND: 404,
|
||||||
|
BizCode.MODEL_NOT_FOUND: 404,
|
||||||
|
BizCode.KNOWLEDGE_NOT_FOUND: 404,
|
||||||
|
BizCode.DOCUMENT_NOT_FOUND: 404,
|
||||||
|
BizCode.FILE_NOT_FOUND: 404,
|
||||||
|
BizCode.APP_NOT_FOUND: 404,
|
||||||
|
BizCode.RELEASE_NOT_FOUND: 404,
|
||||||
|
BizCode.DUPLICATE_NAME: 409,
|
||||||
|
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||||
|
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||||
|
BizCode.STATE_CONFLICT: 409,
|
||||||
|
BizCode.PUBLISH_FAILED: 500,
|
||||||
|
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||||
|
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
|
||||||
|
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
|
||||||
|
BizCode.AGENT_CONFIG_MISSING: 400,
|
||||||
|
BizCode.SHARE_DISABLED: 403,
|
||||||
|
BizCode.INVALID_PASSWORD: 401,
|
||||||
|
BizCode.PASSWORD_REQUIRED: 401,
|
||||||
|
BizCode.EMBED_NOT_ALLOWED: 403,
|
||||||
|
BizCode.PERMISSION_DENIED: 403,
|
||||||
|
BizCode.INVALID_CONVERSATION: 400,
|
||||||
|
BizCode.MODEL_CONFIG_INVALID: 400,
|
||||||
|
BizCode.API_KEY_MISSING: 400,
|
||||||
|
BizCode.PROVIDER_NOT_SUPPORTED: 400,
|
||||||
|
BizCode.LLM_ERROR: 500,
|
||||||
|
BizCode.EMBEDDING_ERROR: 500,
|
||||||
|
BizCode.FILE_READ_ERROR: 500,
|
||||||
|
BizCode.PARSER_NOT_SUPPORTED: 400,
|
||||||
|
BizCode.CHUNKING_FAILED: 500,
|
||||||
|
BizCode.INDEX_BUILD_FAILED: 500,
|
||||||
|
BizCode.EMBEDDING_FAILED: 500,
|
||||||
|
BizCode.SEARCH_FAILED: 500,
|
||||||
|
BizCode.INTERNAL_ERROR: 500,
|
||||||
|
BizCode.DB_ERROR: 500,
|
||||||
|
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||||
|
BizCode.RATE_LIMITED: 429,
|
||||||
|
}
|
||||||
86
app/core/exceptions.py
Normal file
86
app/core/exceptions.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
业务异常定义
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
|
|
||||||
|
class BusinessException(Exception):
|
||||||
|
"""业务逻辑异常基类"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: BizCode | int | None = None,
|
||||||
|
context: Optional[Dict[str, Any]] = None,
|
||||||
|
cause: Optional[Exception] = None
|
||||||
|
):
|
||||||
|
self.message = message
|
||||||
|
self.code = code if code is not None else BizCode.BAD_REQUEST
|
||||||
|
# Make a copy of context to avoid modifying the original dict
|
||||||
|
self.context = dict(context) if context else {}
|
||||||
|
self.cause = cause
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
ctx = f", context={self.context}" if self.context else ""
|
||||||
|
code_name = self.code.name if isinstance(self.code, BizCode) else str(self.code)
|
||||||
|
return f"{code_name}: {self.message}{ctx}"
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationException(BusinessException):
|
||||||
|
"""数据验证异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, field: str = None, **kwargs):
|
||||||
|
context = {"field": field} if field else {}
|
||||||
|
if "context" in kwargs:
|
||||||
|
context.update(kwargs.pop("context"))
|
||||||
|
super().__init__(message, BizCode.VALIDATION_FAILED, context, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationException(BusinessException):
|
||||||
|
"""认证异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "认证失败", **kwargs):
|
||||||
|
super().__init__(message, BizCode.UNAUTHORIZED, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationException(BusinessException):
|
||||||
|
"""授权异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "权限不足", **kwargs):
|
||||||
|
super().__init__(message, BizCode.FORBIDDEN, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceNotFoundException(BusinessException):
|
||||||
|
"""资源未找到异常"""
|
||||||
|
|
||||||
|
def __init__(self, resource_type: str, resource_id: str = None, **kwargs):
|
||||||
|
message = f"{resource_type} 不存在"
|
||||||
|
context = {"resource_type": resource_type}
|
||||||
|
if resource_id:
|
||||||
|
context["resource_id"] = resource_id
|
||||||
|
if "context" in kwargs:
|
||||||
|
context.update(kwargs.pop("context"))
|
||||||
|
super().__init__(message, BizCode.FILE_NOT_FOUND, context, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateResourceException(BusinessException):
|
||||||
|
"""资源重复异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "资源已存在", **kwargs):
|
||||||
|
super().__init__(message, BizCode.DUPLICATE_NAME, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FileUploadException(BusinessException):
|
||||||
|
"""文件上传异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, **kwargs):
|
||||||
|
super().__init__(message, BizCode.FILE_READ_ERROR, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionDeniedException(BusinessException):
|
||||||
|
"""权限拒绝异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "权限不足", **kwargs):
|
||||||
|
super().__init__(message, BizCode.FORBIDDEN, **kwargs)
|
||||||
633
app/core/logging_config.py
Normal file
633
app/core/logging_config.py
Normal file
@@ -0,0 +1,633 @@
|
|||||||
|
import logging
|
||||||
|
import logging.handlers
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.sensitive_filter import SensitiveDataFilter
|
||||||
|
|
||||||
|
|
||||||
|
class SensitiveDataLoggingFilter(logging.Filter):
|
||||||
|
"""日志过滤器:自动过滤敏感信息"""
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
"""
|
||||||
|
过滤日志记录中的敏感信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: 日志记录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True表示允许记录,False表示拒绝
|
||||||
|
"""
|
||||||
|
# 过滤消息中的敏感信息
|
||||||
|
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
||||||
|
record.msg = SensitiveDataFilter.filter_string(record.msg)
|
||||||
|
|
||||||
|
# 过滤参数中的敏感信息
|
||||||
|
if hasattr(record, 'args') and record.args:
|
||||||
|
if isinstance(record.args, dict):
|
||||||
|
record.args = SensitiveDataFilter.filter_dict(record.args)
|
||||||
|
elif isinstance(record.args, (list, tuple)):
|
||||||
|
record.args = tuple(
|
||||||
|
SensitiveDataFilter.filter_string(str(arg)) if isinstance(arg, str) else arg
|
||||||
|
for arg in record.args
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingConfig:
|
||||||
|
"""全局日志配置类"""
|
||||||
|
|
||||||
|
_initialized = False
|
||||||
|
_memory_loggers_initialized = False
|
||||||
|
_prompt_logger = None
|
||||||
|
_template_logger = None
|
||||||
|
_timing_logger = None
|
||||||
|
_agent_loggers = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_logging(cls) -> None:
|
||||||
|
"""初始化全局日志配置"""
|
||||||
|
if cls._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建日志目录
|
||||||
|
log_dir = Path(settings.LOG_FILE_PATH).parent
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 配置根日志器
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||||
|
|
||||||
|
# 清除现有处理器
|
||||||
|
root_logger.handlers.clear()
|
||||||
|
|
||||||
|
# 创建格式化器
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt=settings.LOG_FORMAT,
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建敏感信息过滤器
|
||||||
|
sensitive_filter = SensitiveDataLoggingFilter()
|
||||||
|
|
||||||
|
# 控制台处理器
|
||||||
|
if settings.LOG_TO_CONSOLE:
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||||
|
console_handler.addFilter(sensitive_filter)
|
||||||
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 文件处理器(带轮转)
|
||||||
|
if settings.LOG_TO_FILE:
|
||||||
|
file_handler = logging.handlers.RotatingFileHandler(
|
||||||
|
filename=settings.LOG_FILE_PATH,
|
||||||
|
maxBytes=settings.LOG_MAX_SIZE,
|
||||||
|
backupCount=5,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||||
|
file_handler.addFilter(sensitive_filter)
|
||||||
|
root_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
cls._initialized = True
|
||||||
|
|
||||||
|
# Initialize memory module logging
|
||||||
|
cls.setup_memory_logging()
|
||||||
|
|
||||||
|
# 记录初始化完成
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info("全局日志系统初始化完成")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_memory_logging(cls) -> None:
|
||||||
|
"""Initialize memory module specific loggers.
|
||||||
|
|
||||||
|
Called automatically by setup_logging() or can be called independently.
|
||||||
|
Sets up:
|
||||||
|
- Prompt logger with timestamped files
|
||||||
|
- Template logger with conditional file output
|
||||||
|
- Timing logger with dual output (file + console)
|
||||||
|
- Agent logger factory with concurrent handlers
|
||||||
|
"""
|
||||||
|
if cls._memory_loggers_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
log_dir = Path("logs")
|
||||||
|
try:
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
except OSError as e:
|
||||||
|
print(f"Warning: Could not create log directory: {e}")
|
||||||
|
# Continue with console-only logging
|
||||||
|
|
||||||
|
# Initialize memory-specific loggers
|
||||||
|
# These will be created lazily when first requested via factory functions
|
||||||
|
# This method just marks the system as ready for memory logging
|
||||||
|
|
||||||
|
cls._memory_loggers_initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||||
|
"""获取日志器实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 日志器名称,默认为调用模块名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的日志器实例
|
||||||
|
"""
|
||||||
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_logger() -> logging.Logger:
|
||||||
|
"""获取认证专用日志器"""
|
||||||
|
return logging.getLogger("auth")
|
||||||
|
|
||||||
|
|
||||||
|
def get_security_logger() -> logging.Logger:
|
||||||
|
"""获取安全专用日志器"""
|
||||||
|
return logging.getLogger("security")
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_logger() -> logging.Logger:
|
||||||
|
"""获取API专用日志器"""
|
||||||
|
return logging.getLogger("api")
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_logger() -> logging.Logger:
|
||||||
|
"""获取数据库专用日志器"""
|
||||||
|
return logging.getLogger("database")
|
||||||
|
|
||||||
|
|
||||||
|
def get_business_logger() -> logging.Logger:
|
||||||
|
"""获取业务逻辑专用日志器"""
|
||||||
|
return logging.getLogger("business")
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_logger() -> logging.Logger:
|
||||||
|
"""Get the prompt logger for memory module.
|
||||||
|
|
||||||
|
Returns a logger configured for prompt rendering output with:
|
||||||
|
- Logger name: memory.prompts
|
||||||
|
- Output: logs/prompt_logs-{timestamp}.log
|
||||||
|
- Level: Configurable via PROMPT_LOG_LEVEL setting (default: INFO)
|
||||||
|
- Handler: FileHandler (no console output)
|
||||||
|
|
||||||
|
The logger is cached after first creation for performance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger configured for prompt rendering output
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> logger = get_prompt_logger()
|
||||||
|
>>> logger.info("=== RENDERED EXTRACTION PROMPT ===\\n%s", prompt_content)
|
||||||
|
"""
|
||||||
|
# Return cached logger if already initialized
|
||||||
|
if LoggingConfig._prompt_logger is not None:
|
||||||
|
return LoggingConfig._prompt_logger
|
||||||
|
|
||||||
|
# Ensure memory logging is initialized
|
||||||
|
if not LoggingConfig._memory_loggers_initialized:
|
||||||
|
LoggingConfig.setup_memory_logging()
|
||||||
|
|
||||||
|
# Create prompt logger
|
||||||
|
logger = logging.getLogger("memory.prompts")
|
||||||
|
logger.setLevel(getattr(logging, settings.PROMPT_LOG_LEVEL.upper()))
|
||||||
|
logger.propagate = False # Don't propagate to root logger (no console output)
|
||||||
|
|
||||||
|
# Create timestamped log file
|
||||||
|
from datetime import datetime
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
log_file = Path("logs/prompts/") / f"prompt_logs-{timestamp}.log"
|
||||||
|
|
||||||
|
# Ensure log directory exists
|
||||||
|
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create file handler
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
filename=str(log_file),
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Add handler to logger
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Cache the logger
|
||||||
|
LoggingConfig._prompt_logger = logger
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_logger() -> logging.Logger:
|
||||||
|
"""Get the template logger for memory module.
|
||||||
|
|
||||||
|
Returns a logger configured for template rendering information with:
|
||||||
|
- Logger name: memory.templates
|
||||||
|
- Output: logs/prompt_templates.log (only when ENABLE_TEMPLATE_LOGGING is True)
|
||||||
|
- Level: INFO
|
||||||
|
- Handler: FileHandler when enabled, NullHandler when disabled
|
||||||
|
|
||||||
|
The logger is cached after first creation for performance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger configured for template rendering info
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> logger = get_template_logger()
|
||||||
|
>>> logger.info("Rendering template: %s with context keys: %s",
|
||||||
|
... template_name, list(context.keys()))
|
||||||
|
"""
|
||||||
|
# Return cached logger if already initialized
|
||||||
|
if LoggingConfig._template_logger is not None:
|
||||||
|
return LoggingConfig._template_logger
|
||||||
|
|
||||||
|
# Ensure memory logging is initialized
|
||||||
|
if not LoggingConfig._memory_loggers_initialized:
|
||||||
|
LoggingConfig.setup_memory_logging()
|
||||||
|
|
||||||
|
# Create template logger
|
||||||
|
logger = logging.getLogger("memory.templates")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.propagate = False # Don't propagate to root logger
|
||||||
|
|
||||||
|
# Add appropriate handler based on configuration
|
||||||
|
if settings.ENABLE_TEMPLATE_LOGGING:
|
||||||
|
# Create log file path
|
||||||
|
log_file = Path("logs") / "prompt_templates.log"
|
||||||
|
|
||||||
|
# Ensure log directory exists
|
||||||
|
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create file handler
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
filename=str(log_file),
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Add handler to logger
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
else:
|
||||||
|
# Use NullHandler when template logging is disabled
|
||||||
|
null_handler = logging.NullHandler()
|
||||||
|
logger.addHandler(null_handler)
|
||||||
|
|
||||||
|
# Cache the logger
|
||||||
|
LoggingConfig._template_logger = logger
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def log_prompt_rendering(prompt_type: str, content: str) -> None:
|
||||||
|
"""Log rendered prompt content.
|
||||||
|
|
||||||
|
Logs the rendered prompt with a formatted header and separator for easy
|
||||||
|
identification in log files. This is useful for debugging LLM interactions
|
||||||
|
and understanding what prompts are being sent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_type: Type of prompt (e.g., 'statement_extraction', 'triplet_extraction')
|
||||||
|
content: The rendered prompt text
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> log_prompt_rendering("extraction", "Extract entities from: Hello world")
|
||||||
|
# Logs:
|
||||||
|
# === RENDERED EXTRACTION PROMPT ===
|
||||||
|
# Extract entities from: Hello world
|
||||||
|
# =====================================
|
||||||
|
"""
|
||||||
|
logger = get_prompt_logger()
|
||||||
|
|
||||||
|
# Format the log entry with header and separator
|
||||||
|
separator = "=" * 50
|
||||||
|
header = f"=== RENDERED {prompt_type.upper()} PROMPT ==="
|
||||||
|
|
||||||
|
log_message = f"\n{header}\n{content}\n{separator}\n"
|
||||||
|
|
||||||
|
logger.info(log_message)
|
||||||
|
|
||||||
|
|
||||||
|
def log_template_rendering(template_name: str, context: dict | None = None) -> None:
|
||||||
|
"""Log template rendering information.
|
||||||
|
|
||||||
|
Logs the template name and context keys for debugging template rendering.
|
||||||
|
This function is wrapped in try-except to ensure it never breaks application
|
||||||
|
flow, even if logging fails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_name: Name of the Jinja2 template being rendered
|
||||||
|
context: Optional context dictionary with template variables
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> log_template_rendering("extract_triplet.jinja2", {"text": "...", "ontology": "..."})
|
||||||
|
# Logs: Rendering template: extract_triplet.jinja2 with context keys: ['text', 'ontology']
|
||||||
|
|
||||||
|
>>> log_template_rendering("system.jinja2")
|
||||||
|
# Logs: Rendering template: system.jinja2 with no context
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger = get_template_logger()
|
||||||
|
|
||||||
|
if context is not None:
|
||||||
|
context_keys = list(context.keys())
|
||||||
|
logger.info(f"Rendering template: {template_name} with context keys: {context_keys}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Rendering template: {template_name} with no context")
|
||||||
|
except Exception:
|
||||||
|
# Never break application flow due to logging issues
|
||||||
|
# Silently ignore any logging errors
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_timing_logger() -> logging.Logger:
|
||||||
|
"""Get the timing logger for memory module.
|
||||||
|
|
||||||
|
Returns a logger configured for performance timing with:
|
||||||
|
- Logger name: memory.timing
|
||||||
|
- Output: Configurable via TIMING_LOG_FILE setting (default: logs/time.log)
|
||||||
|
- Level: INFO
|
||||||
|
- Handlers: FileHandler + optional StreamHandler for console output
|
||||||
|
- Console output: Controlled by TIMING_LOG_TO_CONSOLE setting (default: True)
|
||||||
|
|
||||||
|
The logger is cached after first creation for performance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger configured for performance timing
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> logger = get_timing_logger()
|
||||||
|
>>> logger.info("[2025-11-18 10:30:45] Extraction: 2.34 seconds")
|
||||||
|
"""
|
||||||
|
# Return cached logger if already initialized
|
||||||
|
if LoggingConfig._timing_logger is not None:
|
||||||
|
return LoggingConfig._timing_logger
|
||||||
|
|
||||||
|
# Ensure memory logging is initialized
|
||||||
|
if not LoggingConfig._memory_loggers_initialized:
|
||||||
|
LoggingConfig.setup_memory_logging()
|
||||||
|
|
||||||
|
# Create timing logger
|
||||||
|
logger = logging.getLogger("memory.timing")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.propagate = False # Don't propagate to root logger
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add file handler
|
||||||
|
log_file = Path(settings.TIMING_LOG_FILE)
|
||||||
|
|
||||||
|
# Ensure log directory exists
|
||||||
|
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
filename=str(log_file),
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Add console handler if enabled
|
||||||
|
if settings.TIMING_LOG_TO_CONSOLE:
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# Cache the logger
|
||||||
|
LoggingConfig._timing_logger = logger
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -> None:
|
||||||
|
"""Log timing information for performance tracking.
|
||||||
|
|
||||||
|
Logs timing information to both file and console (console output is always shown
|
||||||
|
for backward compatibility). The file output includes a timestamp and full details,
|
||||||
|
while console output shows a concise checkmark format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_name: Name of the operation being timed
|
||||||
|
duration: Duration in seconds
|
||||||
|
log_file: Optional custom log file path (default: logs/time.log)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> log_time("Knowledge Extraction", 2.34)
|
||||||
|
# File logs: [2025-11-18 10:30:45] Knowledge Extraction: 2.34 seconds
|
||||||
|
# Console prints: ✓ Knowledge Extraction: 2.34s
|
||||||
|
|
||||||
|
>>> log_time("Database Query", 0.15, "logs/custom_time.log")
|
||||||
|
# Logs to custom file and console
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Format timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# Format timing entry for file
|
||||||
|
log_entry = f"[{timestamp}] {step_name}: {duration:.2f} seconds\n"
|
||||||
|
|
||||||
|
# Write to file with error handling
|
||||||
|
try:
|
||||||
|
log_path = Path(log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(log_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(log_entry)
|
||||||
|
except IOError as e:
|
||||||
|
# Fallback to console only if file write fails
|
||||||
|
print(f"Warning: Could not write to timing log: {e}")
|
||||||
|
|
||||||
|
# Always print to console (backward compatible behavior)
|
||||||
|
print(f"✓ {step_name}: {duration:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_logger(name: str = "agent_service",
|
||||||
|
console_level: str = "INFO",
|
||||||
|
file_level: str = "DEBUG") -> logging.Logger:
|
||||||
|
"""Get an agent logger with concurrent file handling.
|
||||||
|
|
||||||
|
Returns a logger configured for agent operations with:
|
||||||
|
- Logger name: memory.agent.{name}
|
||||||
|
- Output: Configurable via AGENT_LOG_FILE setting (default: logs/agent_service.log)
|
||||||
|
- Console level: Configurable (default: INFO)
|
||||||
|
- File level: Configurable (default: DEBUG)
|
||||||
|
- Handler: ConcurrentRotatingFileHandler for multi-process support
|
||||||
|
- Rotation: Configurable via AGENT_LOG_MAX_SIZE (default: 5MB) and
|
||||||
|
AGENT_LOG_BACKUP_COUNT (default: 20)
|
||||||
|
|
||||||
|
The logger is cached by name after first creation for performance.
|
||||||
|
Supports concurrent writes from multiple processes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name for namespacing (default: "agent_service")
|
||||||
|
console_level: Log level for console output (default: "INFO")
|
||||||
|
file_level: Log level for file output (default: "DEBUG")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger configured for agent operations
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> logger = get_agent_logger("my_agent")
|
||||||
|
>>> logger.info("Agent operation started")
|
||||||
|
>>> logger.debug("Detailed agent state information")
|
||||||
|
|
||||||
|
>>> logger = get_agent_logger("custom_agent", console_level="WARNING", file_level="INFO")
|
||||||
|
>>> logger.warning("This appears in console and file")
|
||||||
|
>>> logger.info("This only appears in file")
|
||||||
|
"""
|
||||||
|
# Return cached logger if already initialized
|
||||||
|
if name in LoggingConfig._agent_loggers:
|
||||||
|
return LoggingConfig._agent_loggers[name]
|
||||||
|
|
||||||
|
# Ensure memory logging is initialized
|
||||||
|
if not LoggingConfig._memory_loggers_initialized:
|
||||||
|
LoggingConfig.setup_memory_logging()
|
||||||
|
|
||||||
|
# Create agent logger with namespaced name
|
||||||
|
logger_name = f"memory.agent.{name}"
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.DEBUG) # Set to DEBUG to allow both handlers to filter
|
||||||
|
logger.propagate = False # Don't propagate to root logger
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add console handler
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(getattr(logging, console_level.upper()))
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# Add concurrent rotating file handler
|
||||||
|
try:
|
||||||
|
from concurrent_log_handler import ConcurrentRotatingFileHandler
|
||||||
|
except ImportError:
|
||||||
|
# Fall back to standard RotatingFileHandler if concurrent handler not available
|
||||||
|
from logging.handlers import RotatingFileHandler as ConcurrentRotatingFileHandler
|
||||||
|
print("Warning: concurrent-log-handler not available, using standard RotatingFileHandler")
|
||||||
|
|
||||||
|
# Create log file path
|
||||||
|
log_file = Path(settings.AGENT_LOG_FILE)
|
||||||
|
|
||||||
|
# Ensure log directory exists
|
||||||
|
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create file handler with rotation
|
||||||
|
file_handler = ConcurrentRotatingFileHandler(
|
||||||
|
filename=str(log_file),
|
||||||
|
maxBytes=settings.AGENT_LOG_MAX_SIZE,
|
||||||
|
backupCount=settings.AGENT_LOG_BACKUP_COUNT,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
file_handler.setLevel(getattr(logging, file_level.upper()))
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Cache the logger
|
||||||
|
LoggingConfig._agent_loggers[name] = logger
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_logger(name: str) -> logging.Logger:
|
||||||
|
"""Backward compatible alias for get_agent_logger.
|
||||||
|
|
||||||
|
This function maintains backward compatibility with existing code that uses
|
||||||
|
the get_named_logger pattern from the agent logger module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name for namespacing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger configured for agent operations
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> logger = get_named_logger("my_agent")
|
||||||
|
>>> logger.info("Agent operation started")
|
||||||
|
"""
|
||||||
|
return get_agent_logger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_logger(name: Optional[str] = None) -> logging.Logger:
|
||||||
|
"""Get a standard logger for memory module components.
|
||||||
|
|
||||||
|
Returns a logger configured for memory module components that inherits
|
||||||
|
the root logger's configuration (handlers, formatters, and level). This
|
||||||
|
provides consistent logging behavior across the memory module while
|
||||||
|
maintaining the ability to filter and identify memory-specific logs.
|
||||||
|
|
||||||
|
The logger uses the 'memory' namespace:
|
||||||
|
- If name is provided: logger name is 'memory.{module_name}'
|
||||||
|
- If name is None: logger name is 'memory'
|
||||||
|
|
||||||
|
The logger inherits all handlers and formatters from the root logger,
|
||||||
|
ensuring consistent output format and destinations (console, file, etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Optional logger name, typically __name__ from the calling module.
|
||||||
|
If provided, creates a namespaced logger under 'memory.{name}'.
|
||||||
|
If None, returns the base 'memory' logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger configured for memory module operations with root logger inheritance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # In app/core/memory/src/search.py
|
||||||
|
>>> logger = get_memory_logger(__name__)
|
||||||
|
>>> logger.info("Starting search operation")
|
||||||
|
# Logs: [timestamp] - memory.app.core.memory.src.search - INFO - Starting search operation
|
||||||
|
|
||||||
|
>>> # Get base memory logger
|
||||||
|
>>> logger = get_memory_logger()
|
||||||
|
>>> logger.debug("Memory module initialized")
|
||||||
|
# Logs: [timestamp] - memory - DEBUG - Memory module initialized
|
||||||
|
|
||||||
|
>>> # In app/core/memory/src/knowledge_extraction/triplet_extraction.py
|
||||||
|
>>> logger = get_memory_logger(__name__)
|
||||||
|
>>> logger.error("Extraction failed", exc_info=True)
|
||||||
|
# Logs error with full traceback
|
||||||
|
"""
|
||||||
|
# Ensure memory logging is initialized
|
||||||
|
if not LoggingConfig._memory_loggers_initialized:
|
||||||
|
LoggingConfig.setup_memory_logging()
|
||||||
|
|
||||||
|
# Construct logger name with memory namespace
|
||||||
|
if name is not None:
|
||||||
|
logger_name = f"memory.{name}"
|
||||||
|
else:
|
||||||
|
logger_name = "memory"
|
||||||
|
|
||||||
|
# Get logger - it will inherit from root logger configuration
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
|
||||||
|
# The logger automatically inherits handlers, formatters, and level from root logger
|
||||||
|
# through Python's logging hierarchy, so no additional configuration is needed
|
||||||
|
|
||||||
|
return logger
|
||||||
0
app/core/memory/__init__.py
Normal file
0
app/core/memory/__init__.py
Normal file
0
app/core/memory/agent/__init__.py
Normal file
0
app/core/memory/agent/__init__.py
Normal file
16
app/core/memory/agent/langgraph_graph/__init__.py
Normal file
16
app/core/memory/agent/langgraph_graph/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
LangGraph Graph package for memory agent.
|
||||||
|
|
||||||
|
This package provides the LangGraph workflow orchestrator with modular
|
||||||
|
node implementations, routing logic, and state management.
|
||||||
|
|
||||||
|
Package structure:
|
||||||
|
- read_graph: Main graph factory for read operations
|
||||||
|
- write_graph: Main graph factory for write operations
|
||||||
|
- nodes: LangGraph node implementations
|
||||||
|
- routing: State routing logic
|
||||||
|
- state: State management utilities
|
||||||
|
"""
|
||||||
|
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||||
|
|
||||||
|
__all__ = ['make_read_graph']
|
||||||
10
app/core/memory/agent/langgraph_graph/nodes/__init__.py
Normal file
10
app/core/memory/agent/langgraph_graph/nodes/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
LangGraph node implementations.
|
||||||
|
|
||||||
|
This module contains custom node implementations for the LangGraph workflow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
|
||||||
|
|
||||||
|
__all__ = ["ToolExecutionNode", "create_input_message"]
|
||||||
144
app/core/memory/agent/langgraph_graph/nodes/input_node.py
Normal file
144
app/core/memory/agent/langgraph_graph/nodes/input_node.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
Input node for LangGraph workflow entry point.
|
||||||
|
|
||||||
|
This module provides the create_input_message function which processes initial
|
||||||
|
user input with multimodal support and creates the first tool call message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_input_message(
|
||||||
|
state: Dict[str, Any],
|
||||||
|
tool_name: str,
|
||||||
|
session_id: str,
|
||||||
|
search_switch: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
multimodal_processor: MultimodalProcessor
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create initial tool call message from user input.
|
||||||
|
|
||||||
|
This function:
|
||||||
|
1. Extracts the last message content from state
|
||||||
|
2. Processes multimodal inputs (images/audio) using the multimodal processor
|
||||||
|
3. Generates a unique message ID
|
||||||
|
4. Extracts namespace from session_id
|
||||||
|
5. Handles verified_data extraction for backward compatibility
|
||||||
|
6. Returns AIMessage with complete tool_calls structure
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: LangGraph state dictionary containing messages
|
||||||
|
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
|
||||||
|
session_id: Session identifier (format: "call_id_{namespace}")
|
||||||
|
search_switch: Search routing parameter
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
multimodal_processor: Processor for handling image/audio inputs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
State update with AIMessage containing tool_call
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
|
||||||
|
>>> result = await create_input_message(
|
||||||
|
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
|
||||||
|
... )
|
||||||
|
>>> result["messages"][0].tool_calls[0]["name"]
|
||||||
|
'Split_The_Problem'
|
||||||
|
"""
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
|
||||||
|
# Extract last message content
|
||||||
|
if messages:
|
||||||
|
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
|
||||||
|
else:
|
||||||
|
logger.warning("[create_input_message] No messages in state, using empty string")
|
||||||
|
last_message = ""
|
||||||
|
|
||||||
|
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
|
||||||
|
|
||||||
|
# Process multimodal input (images/audio)
|
||||||
|
try:
|
||||||
|
processed_content = await multimodal_processor.process_input(last_message)
|
||||||
|
if processed_content != last_message:
|
||||||
|
logger.info(
|
||||||
|
f"[create_input_message] Multimodal processing converted input "
|
||||||
|
f"from {len(last_message)} to {len(processed_content)} chars"
|
||||||
|
)
|
||||||
|
last_message = processed_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[create_input_message] Multimodal processing failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Continue with original content
|
||||||
|
|
||||||
|
# Generate unique message ID
|
||||||
|
uuid_str = uuid.uuid4()
|
||||||
|
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# Extract namespace from session_id
|
||||||
|
# Expected format: "call_id_{namespace}" or similar
|
||||||
|
try:
|
||||||
|
namespace = str(session_id).split('_id_')[1]
|
||||||
|
except (IndexError, AttributeError):
|
||||||
|
logger.warning(
|
||||||
|
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
|
||||||
|
)
|
||||||
|
namespace = "unknown"
|
||||||
|
|
||||||
|
# Handle verified_data extraction (backward compatibility)
|
||||||
|
# This regex-based extraction is kept for compatibility with existing data formats
|
||||||
|
if 'verified_data' in str(last_message):
|
||||||
|
try:
|
||||||
|
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||||
|
query_match = re.findall(r'"query": "(.*?)",', messages_last)
|
||||||
|
if query_match:
|
||||||
|
last_message = query_match[0]
|
||||||
|
logger.debug(
|
||||||
|
f"[create_input_message] Extracted query from verified_data: {last_message}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[create_input_message] Failed to extract query from verified_data: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct tool call message
|
||||||
|
tool_call_id = f"{session_id}_{uuid_str}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[create_input_message] Creating tool call for '{tool_name}' "
|
||||||
|
f"with ID: {tool_call_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{
|
||||||
|
"name": tool_name,
|
||||||
|
"args": {
|
||||||
|
"sentence": last_message,
|
||||||
|
"sessionid": session_id,
|
||||||
|
"messages_id": str(uuid_str),
|
||||||
|
"search_switch": search_switch,
|
||||||
|
"apply_id": apply_id,
|
||||||
|
"group_id": group_id
|
||||||
|
},
|
||||||
|
"id": tool_call_id
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
199
app/core/memory/agent/langgraph_graph/nodes/tool_node.py
Normal file
199
app/core/memory/agent/langgraph_graph/nodes/tool_node.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
Tool execution node for LangGraph workflow.
|
||||||
|
|
||||||
|
This module provides the ToolExecutionNode class which wraps tool execution
|
||||||
|
with parameter transformation logic using the ParameterBuilder service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
|
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||||
|
extract_tool_call_id,
|
||||||
|
extract_content_payload
|
||||||
|
)
|
||||||
|
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutionNode:
|
||||||
|
"""
|
||||||
|
Custom LangGraph node that wraps tool execution with parameter transformation.
|
||||||
|
|
||||||
|
This node extracts content from previous tool results, transforms parameters
|
||||||
|
based on tool type using ParameterBuilder, and invokes the tool with the
|
||||||
|
correct argument structure.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tool_node: LangGraph ToolNode wrapping the actual tool
|
||||||
|
id: Node identifier for message IDs
|
||||||
|
tool_name: Name of the tool being executed
|
||||||
|
namespace: Namespace for session management
|
||||||
|
search_switch: Search routing parameter
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
parameter_builder: Service for building tool-specific arguments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tool: Callable,
|
||||||
|
node_id: str,
|
||||||
|
namespace: str,
|
||||||
|
search_switch: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
parameter_builder: ParameterBuilder,
|
||||||
|
storage_type:str,
|
||||||
|
user_rag_memory_id:str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the tool execution node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: The tool function to execute
|
||||||
|
node_id: Identifier for this node (used in message IDs)
|
||||||
|
namespace: Namespace for session management
|
||||||
|
search_switch: Search routing parameter
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
parameter_builder: Service for building tool-specific arguments
|
||||||
|
"""
|
||||||
|
self.tool_node = ToolNode([tool])
|
||||||
|
self.id = node_id
|
||||||
|
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||||
|
self.namespace = namespace
|
||||||
|
self.search_switch = search_switch
|
||||||
|
self.apply_id = apply_id
|
||||||
|
self.group_id = group_id
|
||||||
|
self.parameter_builder = parameter_builder
|
||||||
|
self.storage_type=storage_type
|
||||||
|
self.user_rag_memory_id=user_rag_memory_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute the tool with transformed parameters.
|
||||||
|
|
||||||
|
This method:
|
||||||
|
1. Extracts the last message from state
|
||||||
|
2. Extracts tool call ID using state extractors
|
||||||
|
3. Extracts content payload using state extractors
|
||||||
|
4. Builds tool arguments using parameter builder
|
||||||
|
5. Constructs AIMessage with tool_calls
|
||||||
|
6. Invokes the tool and returns the result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: LangGraph state dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated state with tool result in messages
|
||||||
|
"""
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
logger.debug( self.tool_name)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
|
||||||
|
return {"messages": [AIMessage(content="Error: No messages in state")]}
|
||||||
|
|
||||||
|
last_message = messages[-1]
|
||||||
|
logger.debug(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract tool call ID using state extractors
|
||||||
|
tool_call_id = extract_tool_call_id(last_message)
|
||||||
|
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
|
||||||
|
)
|
||||||
|
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract content payload using state extractors
|
||||||
|
content = extract_content_payload(last_message)
|
||||||
|
logger.debug(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
content = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build tool arguments using parameter builder
|
||||||
|
tool_args = self.parameter_builder.build_tool_args(
|
||||||
|
tool_name=self.tool_name,
|
||||||
|
content=content,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
search_switch=self.search_switch,
|
||||||
|
apply_id=self.apply_id,
|
||||||
|
group_id=self.group_id,
|
||||||
|
storage_type=self.storage_type,
|
||||||
|
user_rag_memory_id=self.user_rag_memory_id
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
|
||||||
|
|
||||||
|
# Construct tool input message
|
||||||
|
tool_input = {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{
|
||||||
|
"name": self.tool_name,
|
||||||
|
"args": tool_args,
|
||||||
|
"id": f"{self.id}_{tool_call_id}",
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Invoke the tool
|
||||||
|
result = await self.tool_node.ainvoke(tool_input)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Tool execution completed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the result directly - it already contains the messages list
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Return error as ToolMessage to maintain message chain consistency
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Error executing tool: {str(e)}",
|
||||||
|
tool_call_id=f"{self.id}_{tool_call_id}"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
508
app/core/memory/agent/langgraph_graph/read_graph.py
Normal file
508
app/core/memory/agent/langgraph_graph/read_graph.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from langgraph.constants import START, END
|
||||||
|
from langgraph.graph import StateGraph
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
# Import new modular components
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||||
|
Verify_continue,
|
||||||
|
Retrieve_continue,
|
||||||
|
Split_continue
|
||||||
|
)
|
||||||
|
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||||
|
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
|
load_dotenv()
|
||||||
|
redishost=os.getenv("REDISHOST")
|
||||||
|
redisport=os.getenv('REDISPORT')
|
||||||
|
redisdb=os.getenv('REDISDB')
|
||||||
|
redispassword=os.getenv('REDISPASSWORD')
|
||||||
|
counter = COUNTState(limit=3)
|
||||||
|
|
||||||
|
# 在工作流中添加循环计数更新
|
||||||
|
async def update_loop_count(state):
|
||||||
|
"""更新循环计数器"""
|
||||||
|
current_count = state.get("loop_count", 0)
|
||||||
|
return {"loop_count": current_count + 1}
|
||||||
|
|
||||||
|
|
||||||
|
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||||
|
messages = state["messages"]
|
||||||
|
|
||||||
|
# 添加边界检查
|
||||||
|
if not messages:
|
||||||
|
return END
|
||||||
|
counter.add(1) # 累加 1
|
||||||
|
|
||||||
|
loop_count = counter.get_total()
|
||||||
|
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
|
||||||
|
|
||||||
|
last_message = messages[-1]
|
||||||
|
last_message_str = str(last_message).replace('\\', '')
|
||||||
|
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||||
|
logger.debug(f"Status tools: {status_tools}")
|
||||||
|
|
||||||
|
if "success" in status_tools:
|
||||||
|
counter.reset()
|
||||||
|
return "Summary"
|
||||||
|
elif "failed" in status_tools:
|
||||||
|
if loop_count < 2: # 最大循环次数 3
|
||||||
|
return "content_input"
|
||||||
|
else:
|
||||||
|
counter.reset()
|
||||||
|
return "Summary_fails"
|
||||||
|
else:
|
||||||
|
# 添加默认返回值,避免返回 None
|
||||||
|
counter.reset()
|
||||||
|
return "Summary" # 或根据业务需求选择合适的默认值
|
||||||
|
|
||||||
|
|
||||||
|
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||||
|
"""
|
||||||
|
Determine routing based on search_switch value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: State dictionary containing search_switch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next node to execute
|
||||||
|
"""
|
||||||
|
# Direct dictionary access instead of regex parsing
|
||||||
|
search_switch = state.get("search_switch")
|
||||||
|
|
||||||
|
# Handle case where search_switch might be in messages
|
||||||
|
if search_switch is None and "messages" in state:
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
if messages:
|
||||||
|
last_message = messages[-1]
|
||||||
|
# Try to extract from tool_calls args
|
||||||
|
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||||
|
for tool_call in last_message.tool_calls:
|
||||||
|
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||||
|
search_switch = tool_call["args"].get("search_switch")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert to string for comparison if needed
|
||||||
|
if search_switch is not None:
|
||||||
|
search_switch = str(search_switch)
|
||||||
|
if search_switch == '0':
|
||||||
|
return 'Verify'
|
||||||
|
elif search_switch == '1':
|
||||||
|
return 'Retrieve_Summary'
|
||||||
|
|
||||||
|
# 添加默认返回值,避免返回 None
|
||||||
|
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
|
||||||
|
|
||||||
|
|
||||||
|
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||||
|
"""
|
||||||
|
Determine routing based on search_switch value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: State dictionary containing search_switch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next node to execute
|
||||||
|
"""
|
||||||
|
logger.debug(f"Split_continue state: {state}")
|
||||||
|
|
||||||
|
# Direct dictionary access instead of regex parsing
|
||||||
|
search_switch = state.get("search_switch")
|
||||||
|
|
||||||
|
# Handle case where search_switch might be in messages
|
||||||
|
if search_switch is None and "messages" in state:
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
if messages:
|
||||||
|
last_message = messages[-1]
|
||||||
|
# Try to extract from tool_calls args
|
||||||
|
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||||
|
for tool_call in last_message.tool_calls:
|
||||||
|
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||||
|
search_switch = tool_call["args"].get("search_switch")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert to string for comparison if needed
|
||||||
|
if search_switch is not None:
|
||||||
|
search_switch = str(search_switch)
|
||||||
|
if search_switch == '2':
|
||||||
|
return 'Input_Summary'
|
||||||
|
return 'Split_The_Problem' # 默认情况
|
||||||
|
|
||||||
|
# 在 input_sentence 函数中修改参数名称
|
||||||
|
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
|
||||||
|
messages = state["messages"]
|
||||||
|
last_message = messages[-1].content if messages else ""
|
||||||
|
|
||||||
|
if last_message.endswith('.jpg') or last_message.endswith('.png'):
|
||||||
|
last_message=await picture_model_requests(last_message)
|
||||||
|
if any(last_message.endswith(ext) for ext in audio_extensions):
|
||||||
|
last_message=await Vico_recognition([last_message]).run()
|
||||||
|
logger.debug(f"Audio recognition result: {last_message}")
|
||||||
|
|
||||||
|
|
||||||
|
uuid_str = uuid.uuid4()
|
||||||
|
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
namespace = str(id).split('_id_')[1]
|
||||||
|
if 'verified_data' in str(last_message):
|
||||||
|
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||||
|
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{
|
||||||
|
"name": name,
|
||||||
|
"args": {
|
||||||
|
"sentence": last_message,
|
||||||
|
'sessionid': id,
|
||||||
|
'messages_id': str(uuid_str),
|
||||||
|
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
|
||||||
|
"apply_id":apply_id,
|
||||||
|
"group_id":group_id
|
||||||
|
},
|
||||||
|
"id": id + f'_{uuid_str}'
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemExtensionNode:
|
||||||
|
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
|
||||||
|
self.tool_node = ToolNode([tool])
|
||||||
|
self.id = id
|
||||||
|
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||||
|
self.namespace = namespace
|
||||||
|
self.search_switch = search_switch
|
||||||
|
self.apply_id = apply_id
|
||||||
|
self.group_id = group_id
|
||||||
|
self.storage_type = storage_type
|
||||||
|
self.user_rag_memory_id = user_rag_memory_id
|
||||||
|
|
||||||
|
async def __call__(self, state):
|
||||||
|
messages = state["messages"]
|
||||||
|
last_message = messages[-1] if messages else ""
|
||||||
|
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
|
||||||
|
if self.tool_name=='Input_Summary':
|
||||||
|
tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0]
|
||||||
|
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||||
|
# try:
|
||||||
|
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
|
||||||
|
# except:
|
||||||
|
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||||
|
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
|
||||||
|
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||||
|
extracted_payload = None
|
||||||
|
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
|
||||||
|
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
|
||||||
|
if m:
|
||||||
|
extracted_payload = m.group(1)
|
||||||
|
else:
|
||||||
|
# 回退:直接尝试使用原始字符串
|
||||||
|
extracted_payload = raw_msg
|
||||||
|
|
||||||
|
# 优先尝试将内容解析为 JSON
|
||||||
|
try:
|
||||||
|
content = json.loads(extracted_payload)
|
||||||
|
except Exception:
|
||||||
|
# 尝试从文本中提取 JSON 片段再解析
|
||||||
|
parsed = None
|
||||||
|
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
|
||||||
|
for cand in candidates:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(cand)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
# 如果仍然失败,则以原始字符串作为内容
|
||||||
|
content = parsed if parsed is not None else extracted_payload
|
||||||
|
|
||||||
|
# 根据工具名称构建正确的参数
|
||||||
|
tool_args = {}
|
||||||
|
|
||||||
|
if self.tool_name == "Verify":
|
||||||
|
# Verify工具需要context和usermessages参数
|
||||||
|
if isinstance(content, dict):
|
||||||
|
tool_args["context"] = content
|
||||||
|
else:
|
||||||
|
tool_args["context"] = {"content": content}
|
||||||
|
tool_args["usermessages"] = str(tool_call)
|
||||||
|
tool_args["apply_id"] = str(self.apply_id)
|
||||||
|
tool_args["group_id"] = str(self.group_id)
|
||||||
|
elif self.tool_name == "Retrieve":
|
||||||
|
# Retrieve工具需要context和usermessages参数
|
||||||
|
if isinstance(content, dict):
|
||||||
|
tool_args["context"] = content
|
||||||
|
else:
|
||||||
|
tool_args["context"] = {"content": content}
|
||||||
|
tool_args["usermessages"] = str(tool_call)
|
||||||
|
tool_args["search_switch"] = str(self.search_switch)
|
||||||
|
tool_args["apply_id"] = str(self.apply_id)
|
||||||
|
tool_args["group_id"] = str(self.group_id)
|
||||||
|
elif self.tool_name == "Summary":
|
||||||
|
# Summary工具需要字符串类型的context参数
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# 将字典转换为JSON字符串
|
||||||
|
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
tool_args["context"] = str(content)
|
||||||
|
tool_args["usermessages"] = str(tool_call)
|
||||||
|
tool_args["apply_id"] = str(self.apply_id)
|
||||||
|
tool_args["group_id"] = str(self.group_id)
|
||||||
|
elif self.tool_name == "Summary_fails":
|
||||||
|
# Summary工具需要字符串类型的context参数
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# 将字典转换为JSON字符串
|
||||||
|
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
tool_args["context"] = str(content)
|
||||||
|
tool_args["usermessages"] = str(tool_call)
|
||||||
|
tool_args["apply_id"] = str(self.apply_id)
|
||||||
|
tool_args["group_id"] = str(self.group_id)
|
||||||
|
elif self.tool_name=='Input_Summary':
|
||||||
|
tool_args["context"] =str(last_message)
|
||||||
|
tool_args["usermessages"] = str(tool_call)
|
||||||
|
tool_args["search_switch"] = str(self.search_switch)
|
||||||
|
tool_args["apply_id"] = str(self.apply_id)
|
||||||
|
tool_args["group_id"] = str(self.group_id)
|
||||||
|
tool_args["storage_type"] = getattr(self, 'storage_type', "")
|
||||||
|
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
|
||||||
|
elif self.tool_name=='Retrieve_Summary' :
|
||||||
|
# Retrieve_Summary expects dict directly, not JSON string
|
||||||
|
# content might be a JSON string, try to parse it
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
# Check if it has a "context" key
|
||||||
|
if isinstance(parsed_content, dict) and "context" in parsed_content:
|
||||||
|
tool_args["context"] = parsed_content["context"]
|
||||||
|
else:
|
||||||
|
tool_args["context"] = parsed_content
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If parsing fails, wrap the string
|
||||||
|
tool_args["context"] = {"content": content}
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
# Check if content has a "context" key that needs unwrapping
|
||||||
|
if "context" in content:
|
||||||
|
tool_args["context"] = content["context"]
|
||||||
|
else:
|
||||||
|
tool_args["context"] = content
|
||||||
|
else:
|
||||||
|
tool_args["context"] = {"content": str(content)}
|
||||||
|
|
||||||
|
tool_args["usermessages"] = str(tool_call)
|
||||||
|
tool_args["apply_id"] = str(self.apply_id)
|
||||||
|
tool_args["group_id"] = str(self.group_id)
|
||||||
|
else:
|
||||||
|
# 其他工具使用context参数
|
||||||
|
if isinstance(content, dict):
|
||||||
|
tool_args["context"] = content
|
||||||
|
else:
|
||||||
|
tool_args["context"] = {"content": content}
|
||||||
|
tool_args["usermessages"] = str(tool_call)
|
||||||
|
tool_args["apply_id"] = str(self.apply_id)
|
||||||
|
tool_args["group_id"] = str(self.group_id)
|
||||||
|
|
||||||
|
|
||||||
|
tool_input = {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{
|
||||||
|
"name": self.tool_name,
|
||||||
|
"args": tool_args,
|
||||||
|
"id": self.id + f"{tool_call}",
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = await self.tool_node.ainvoke(tool_input)
|
||||||
|
result_text = str(result)
|
||||||
|
|
||||||
|
return {"messages": [AIMessage(content=result_text)]}
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
|
||||||
|
memory = InMemorySaver()
|
||||||
|
tool=[i.name for i in tools ]
|
||||||
|
logger.info(f"Initializing read graph with tools: {tool}")
|
||||||
|
if config_id:
|
||||||
|
logger.info(f"使用配置 ID: {config_id}")
|
||||||
|
|
||||||
|
# Extract tool functions
|
||||||
|
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
|
||||||
|
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
|
||||||
|
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
|
||||||
|
Verify_ = next((t for t in tools if t.name == "Verify"), None)
|
||||||
|
Summary_ = next((t for t in tools if t.name == "Summary"), None)
|
||||||
|
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
|
||||||
|
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
|
||||||
|
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
|
||||||
|
|
||||||
|
# Instantiate services
|
||||||
|
parameter_builder = ParameterBuilder()
|
||||||
|
multimodal_processor = MultimodalProcessor()
|
||||||
|
|
||||||
|
# Create nodes using new modular components
|
||||||
|
Split_The_Problem_node = ToolNode([Split_The_Problem_])
|
||||||
|
|
||||||
|
Problem_Extension_node = ToolExecutionNode(
|
||||||
|
tool=Problem_Extension_,
|
||||||
|
node_id="Problem_Extension_id",
|
||||||
|
namespace=namespace,
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
parameter_builder=parameter_builder,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
Retrieve_node = ToolExecutionNode(
|
||||||
|
tool=Retrieve_,
|
||||||
|
node_id="Retrieve_id",
|
||||||
|
namespace=namespace,
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
parameter_builder=parameter_builder,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
Verify_node = ToolExecutionNode(
|
||||||
|
tool=Verify_,
|
||||||
|
node_id="Verify_id",
|
||||||
|
namespace=namespace,
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
parameter_builder=parameter_builder,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
Summary_node = ToolExecutionNode(
|
||||||
|
tool=Summary_,
|
||||||
|
node_id="Summary_id",
|
||||||
|
namespace=namespace,
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
parameter_builder=parameter_builder,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
Summary_fails_node = ToolExecutionNode(
|
||||||
|
tool=Summary_fails_,
|
||||||
|
node_id="Summary_fails_id",
|
||||||
|
namespace=namespace,
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
parameter_builder=parameter_builder,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
Retrieve_Summary_node = ToolExecutionNode(
|
||||||
|
tool=Retrieve_Summary_,
|
||||||
|
node_id="Retrieve_Summary_id",
|
||||||
|
namespace=namespace,
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
parameter_builder=parameter_builder,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
Input_Summary_node = ToolExecutionNode(
|
||||||
|
tool=Input_Summary_,
|
||||||
|
node_id="Input_Summary_id",
|
||||||
|
namespace=namespace,
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
parameter_builder=parameter_builder,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def content_input_node(state):
|
||||||
|
state_search_switch = state.get("search_switch", search_switch)
|
||||||
|
|
||||||
|
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
|
||||||
|
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
|
||||||
|
|
||||||
|
return await create_input_message(
|
||||||
|
state=state,
|
||||||
|
tool_name=tool_name,
|
||||||
|
session_id=f"{session_prefix}_{namespace}",
|
||||||
|
search_switch=search_switch,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
multimodal_processor=multimodal_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Build workflow graph
|
||||||
|
workflow = StateGraph(ReadState)
|
||||||
|
workflow.add_node("content_input", content_input_node)
|
||||||
|
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
|
||||||
|
workflow.add_node("Problem_Extension", Problem_Extension_node)
|
||||||
|
workflow.add_node("Retrieve", Retrieve_node)
|
||||||
|
workflow.add_node("Verify", Verify_node)
|
||||||
|
workflow.add_node("Summary", Summary_node)
|
||||||
|
workflow.add_node("Summary_fails", Summary_fails_node)
|
||||||
|
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
|
||||||
|
workflow.add_node("Input_Summary", Input_Summary_node)
|
||||||
|
|
||||||
|
# Add edges using imported routers
|
||||||
|
workflow.add_edge(START, "content_input")
|
||||||
|
workflow.add_conditional_edges("content_input", Split_continue)
|
||||||
|
workflow.add_edge("Input_Summary", END)
|
||||||
|
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||||
|
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||||
|
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||||
|
workflow.add_edge("Retrieve_Summary", END)
|
||||||
|
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||||
|
workflow.add_edge("Summary_fails", END)
|
||||||
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
|
graph = workflow.compile(checkpointer=memory)
|
||||||
|
yield graph
|
||||||
|
|
||||||
|
|
||||||
|
# 添加到文件末尾或创建新的执行脚本
|
||||||
|
# 在 memory_agent_service.py 文件中添加以下函数
|
||||||
|
|
||||||
13
app/core/memory/agent/langgraph_graph/routing/__init__.py
Normal file
13
app/core/memory/agent/langgraph_graph/routing/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""LangGraph routing logic."""
|
||||||
|
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||||
|
Verify_continue,
|
||||||
|
Retrieve_continue,
|
||||||
|
Split_continue,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Verify_continue",
|
||||||
|
"Retrieve_continue",
|
||||||
|
"Split_continue",
|
||||||
|
]
|
||||||
123
app/core/memory/agent/langgraph_graph/routing/routers.py
Normal file
123
app/core/memory/agent/langgraph_graph/routing/routers.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Routing functions for LangGraph conditional edges.
|
||||||
|
|
||||||
|
This module provides routing functions that determine the next node to execute
|
||||||
|
based on state values. All functions return Literal types for type safety.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
|
||||||
|
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global counter for Verify routing
|
||||||
|
counter = COUNTState(limit=3)
|
||||||
|
|
||||||
|
|
||||||
|
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||||
|
"""
|
||||||
|
Determine routing after Verify node based on verification result.
|
||||||
|
|
||||||
|
This function checks the verification result in the last message and routes to:
|
||||||
|
- Summary: if verification succeeded
|
||||||
|
- content_input: if verification failed and retry limit not reached
|
||||||
|
- Summary_fails: if verification failed and retry limit reached
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: LangGraph state containing messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next node name as Literal type
|
||||||
|
"""
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
|
||||||
|
# Boundary check
|
||||||
|
if not messages:
|
||||||
|
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
|
||||||
|
counter.reset()
|
||||||
|
return "Summary"
|
||||||
|
|
||||||
|
# Increment counter
|
||||||
|
counter.add(1)
|
||||||
|
loop_count = counter.get_total()
|
||||||
|
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
|
||||||
|
|
||||||
|
# Extract verification result from last message
|
||||||
|
last_message = messages[-1]
|
||||||
|
last_message_str = str(last_message).replace('\\', '')
|
||||||
|
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||||
|
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
|
||||||
|
|
||||||
|
# Route based on verification result
|
||||||
|
if "success" in status_tools:
|
||||||
|
counter.reset()
|
||||||
|
return "Summary"
|
||||||
|
elif "failed" in status_tools:
|
||||||
|
if loop_count < 2: # Max retry count is 2
|
||||||
|
return "content_input"
|
||||||
|
else:
|
||||||
|
counter.reset()
|
||||||
|
return "Summary_fails"
|
||||||
|
else:
|
||||||
|
# Default to Summary if status is unclear
|
||||||
|
counter.reset()
|
||||||
|
return "Summary"
|
||||||
|
|
||||||
|
|
||||||
|
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
|
||||||
|
"""
|
||||||
|
Determine routing after Retrieve node based on search_switch value.
|
||||||
|
|
||||||
|
This function routes based on the search_switch parameter:
|
||||||
|
- search_switch == '0': Route to Verify (verification needed)
|
||||||
|
- search_switch == '1': Route to Retrieve_Summary (direct summary)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: LangGraph state dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next node name as Literal type
|
||||||
|
"""
|
||||||
|
search_switch = extract_search_switch(state)
|
||||||
|
|
||||||
|
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
|
||||||
|
|
||||||
|
if search_switch == '0':
|
||||||
|
return 'Verify'
|
||||||
|
elif search_switch == '1':
|
||||||
|
return 'Retrieve_Summary'
|
||||||
|
|
||||||
|
# Default to Retrieve_Summary
|
||||||
|
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
|
||||||
|
return 'Retrieve_Summary'
|
||||||
|
|
||||||
|
|
||||||
|
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||||
|
"""
|
||||||
|
Determine routing after content_input node based on search_switch value.
|
||||||
|
|
||||||
|
This function routes based on the search_switch parameter:
|
||||||
|
- search_switch == '2': Route to Input_Summary (direct input summary)
|
||||||
|
- Otherwise: Route to Split_The_Problem (problem decomposition)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: LangGraph state dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Next node name as Literal type
|
||||||
|
"""
|
||||||
|
logger.debug(f"[Split_continue] state keys: {state.keys()}")
|
||||||
|
|
||||||
|
search_switch = extract_search_switch(state)
|
||||||
|
|
||||||
|
logger.debug(f"[Split_continue] search_switch: {search_switch}")
|
||||||
|
|
||||||
|
if search_switch == '2':
|
||||||
|
return 'Input_Summary'
|
||||||
|
|
||||||
|
# Default to Split_The_Problem
|
||||||
|
return 'Split_The_Problem'
|
||||||
13
app/core/memory/agent/langgraph_graph/state/__init__.py
Normal file
13
app/core/memory/agent/langgraph_graph/state/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""LangGraph state management utilities."""
|
||||||
|
|
||||||
|
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||||
|
extract_search_switch,
|
||||||
|
extract_tool_call_id,
|
||||||
|
extract_content_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"extract_search_switch",
|
||||||
|
"extract_tool_call_id",
|
||||||
|
"extract_content_payload",
|
||||||
|
]
|
||||||
164
app/core/memory/agent/langgraph_graph/state/extractors.py
Normal file
164
app/core/memory/agent/langgraph_graph/state/extractors.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""
|
||||||
|
State extraction utilities for type-safe access to LangGraph state values.
|
||||||
|
|
||||||
|
This module provides utility functions for extracting values from LangGraph state
|
||||||
|
dictionaries with proper error handling and sensible defaults.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def extract_search_switch(state: dict) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract search_switch from state or messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
search_switch = state.get("search_switch")
|
||||||
|
|
||||||
|
if search_switch is not None:
|
||||||
|
return str(search_switch)
|
||||||
|
|
||||||
|
# Try to extract from messages
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 从最新的消息开始查找
|
||||||
|
for message in reversed(messages):
|
||||||
|
# 尝试从 tool_calls 中提取
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
if isinstance(tool_call, dict):
|
||||||
|
# 从 tool_call 的 args 中提取
|
||||||
|
if "args" in tool_call and isinstance(tool_call["args"], dict):
|
||||||
|
search_switch = tool_call["args"].get("search_switch")
|
||||||
|
if search_switch is not None:
|
||||||
|
return str(search_switch)
|
||||||
|
# 直接从 tool_call 中提取
|
||||||
|
search_switch = tool_call.get("search_switch")
|
||||||
|
if search_switch is not None:
|
||||||
|
return str(search_switch)
|
||||||
|
|
||||||
|
# 尝试从 content 中提取(如果是 JSON 格式)
|
||||||
|
if hasattr(message, "content"):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
content_data = json.loads(message.content)
|
||||||
|
if isinstance(content_data, dict):
|
||||||
|
search_switch = content_data.get("search_switch")
|
||||||
|
if search_switch is not None:
|
||||||
|
return str(search_switch)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tool_call_id(message: Any) -> str:
|
||||||
|
"""
|
||||||
|
Extract tool call ID from message using structured attributes.
|
||||||
|
|
||||||
|
This function extracts the tool call ID from a message object, handling both
|
||||||
|
direct attribute access and tool_calls list structures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message object (typically ToolMessage or AIMessage)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool call ID as string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If tool call ID cannot be extracted
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> message = ToolMessage(content="...", tool_call_id="call_123")
|
||||||
|
>>> extract_tool_call_id(message)
|
||||||
|
'call_123'
|
||||||
|
"""
|
||||||
|
# Try direct attribute access for ToolMessage
|
||||||
|
if hasattr(message, "tool_call_id"):
|
||||||
|
tool_call_id = message.tool_call_id
|
||||||
|
if tool_call_id:
|
||||||
|
return str(tool_call_id)
|
||||||
|
|
||||||
|
# Try extracting from tool_calls list for AIMessage
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
|
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||||
|
return str(tool_call["id"])
|
||||||
|
|
||||||
|
# Try extracting from id attribute
|
||||||
|
if hasattr(message, "id"):
|
||||||
|
message_id = message.id
|
||||||
|
if message_id:
|
||||||
|
return str(message_id)
|
||||||
|
|
||||||
|
# If all else fails, raise an error
|
||||||
|
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_content_payload(message: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Extract content payload from ToolMessage, parsing JSON if needed.
|
||||||
|
|
||||||
|
This function extracts the content from a message and attempts to parse it as JSON
|
||||||
|
if it appears to be a JSON string. It handles various message formats and provides
|
||||||
|
sensible fallbacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message object (typically ToolMessage)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed content (dict, list, or str)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> message = ToolMessage(content='{"key": "value"}')
|
||||||
|
>>> extract_content_payload(message)
|
||||||
|
{'key': 'value'}
|
||||||
|
|
||||||
|
>>> message = ToolMessage(content='plain text')
|
||||||
|
>>> extract_content_payload(message)
|
||||||
|
'plain text'
|
||||||
|
"""
|
||||||
|
# Extract raw content
|
||||||
|
# For ToolMessages (responses from tools), extract from content
|
||||||
|
if hasattr(message, "content"):
|
||||||
|
raw_content = message.content
|
||||||
|
|
||||||
|
# If content is empty and this is an AIMessage with tool_calls,
|
||||||
|
# extract from args (this handles the initial tool call from content_input)
|
||||||
|
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
tool_call = message.tool_calls[0]
|
||||||
|
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||||
|
return tool_call["args"]
|
||||||
|
else:
|
||||||
|
raw_content = str(message)
|
||||||
|
|
||||||
|
# If content is already a dict or list, return it directly
|
||||||
|
if isinstance(raw_content, (dict, list)):
|
||||||
|
return raw_content
|
||||||
|
|
||||||
|
# Try to parse as JSON
|
||||||
|
if isinstance(raw_content, str):
|
||||||
|
# First, try direct JSON parsing
|
||||||
|
try:
|
||||||
|
return json.loads(raw_content)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If that fails, try to extract JSON from the string
|
||||||
|
# This handles cases where the content is embedded in a larger string
|
||||||
|
import re
|
||||||
|
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
|
||||||
|
for candidate in json_candidates:
|
||||||
|
try:
|
||||||
|
return json.loads(candidate)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If all parsing attempts fail, return the raw content
|
||||||
|
return raw_content
|
||||||
78
app/core/memory/agent/langgraph_graph/write_graph.py
Normal file
78
app/core/memory/agent/langgraph_graph/write_graph.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from langgraph.constants import START, END
|
||||||
|
from langgraph.graph import add_messages, StateGraph
|
||||||
|
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
|
import warnings
|
||||||
|
import sys
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
import asyncio
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
@asynccontextmanager
|
||||||
|
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
|
||||||
|
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
|
||||||
|
if config_id:
|
||||||
|
logger.info(f"使用配置 ID: {config_id}")
|
||||||
|
|
||||||
|
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
|
||||||
|
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
|
||||||
|
|
||||||
|
if not data_type_tool or not data_write_tool:
|
||||||
|
logger.error('不存在数据存储工具', exc_info=True)
|
||||||
|
raise ValueError('不存在数据存储工具')
|
||||||
|
# ToolNode
|
||||||
|
write_node = ToolNode([data_write_tool])
|
||||||
|
|
||||||
|
|
||||||
|
async def call_model(state):
|
||||||
|
messages = state["messages"]
|
||||||
|
last_message = messages[-1]
|
||||||
|
|
||||||
|
result = await data_type_tool.ainvoke({
|
||||||
|
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||||
|
})
|
||||||
|
result=json.loads( result)
|
||||||
|
|
||||||
|
# 调用 Data_write,传递 config_id
|
||||||
|
write_params = {
|
||||||
|
"content": result["context"],
|
||||||
|
"apply_id": apply_id,
|
||||||
|
"group_id": group_id,
|
||||||
|
"user_id": user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果提供了 config_id,添加到参数中
|
||||||
|
if config_id:
|
||||||
|
write_params["config_id"] = config_id
|
||||||
|
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
|
||||||
|
|
||||||
|
write_result = await data_write_tool.ainvoke(write_params)
|
||||||
|
|
||||||
|
if isinstance(write_result, dict):
|
||||||
|
content = write_result.get("data", str(write_result))
|
||||||
|
else:
|
||||||
|
content = str(write_result)
|
||||||
|
logger.info("写入内容: %s", content)
|
||||||
|
return {"messages": [AIMessage(content=content)]}
|
||||||
|
|
||||||
|
workflow = StateGraph(WriteState)
|
||||||
|
workflow.add_node("content_input", call_model)
|
||||||
|
workflow.add_node("save_neo4j", write_node)
|
||||||
|
workflow.add_edge(START, "content_input")
|
||||||
|
workflow.add_edge("content_input", "save_neo4j")
|
||||||
|
workflow.add_edge("save_neo4j", END)
|
||||||
|
|
||||||
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
|
||||||
|
yield graph
|
||||||
285
app/core/memory/agent/logger_file/log_streamer.py
Normal file
285
app/core/memory/agent/logger_file/log_streamer.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Log Streamer Module
|
||||||
|
|
||||||
|
Manages streaming of log file content with file watching and real-time transmission.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LogStreamer:
|
||||||
|
"""Manages log file streaming with file watching and content transmission"""
|
||||||
|
|
||||||
|
def __init__(self, log_path: str, keepalive_interval: int = 300):
|
||||||
|
"""
|
||||||
|
Initialize LogStreamer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_path: Path to the log file to stream
|
||||||
|
keepalive_interval: Interval in seconds for sending keepalive messages (default: 300)
|
||||||
|
"""
|
||||||
|
self.log_path = log_path
|
||||||
|
self.keepalive_interval = keepalive_interval
|
||||||
|
self.last_position = 0
|
||||||
|
|
||||||
|
# Pattern to match and remove timestamp and log level prefix
|
||||||
|
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
|
||||||
|
# This pattern is comprehensive to handle various log formats
|
||||||
|
self.pattern = re.compile(
|
||||||
|
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"LogStreamer initialized for {log_path}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_log_line(line: str) -> str:
|
||||||
|
"""
|
||||||
|
Static method to clean log entry by removing timestamp and log level prefix.
|
||||||
|
This is the canonical log cleaning method used by both file mode and transmission mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line: Raw log line
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned log line without timestamp and log level prefix
|
||||||
|
"""
|
||||||
|
# Pattern to match and remove timestamp and log level prefix
|
||||||
|
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
|
||||||
|
pattern = re.compile(
|
||||||
|
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
|
||||||
|
)
|
||||||
|
cleaned = re.sub(pattern, '', line)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def clean_log_entry(self, line: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean log entry by removing timestamp and log level prefix.
|
||||||
|
This instance method delegates to the static method for consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line: Raw log line
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned log line without timestamp and log level prefix
|
||||||
|
"""
|
||||||
|
return LogStreamer.clean_log_line(line)
|
||||||
|
|
||||||
|
async def send_keepalive(self) -> dict:
|
||||||
|
"""
|
||||||
|
Generate keepalive message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Keepalive message dict with timestamp
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"event": "keepalive",
|
||||||
|
"data": {
|
||||||
|
"timestamp": int(time.time())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def read_existing_and_stream(self) -> AsyncGenerator[dict, None]:
|
||||||
|
"""
|
||||||
|
Read existing log content first, then watch for new content
|
||||||
|
|
||||||
|
This method reads all existing content in the file first,
|
||||||
|
then continues to watch for new content as it's written.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Dict messages with event type and data:
|
||||||
|
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
|
||||||
|
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
|
||||||
|
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
|
||||||
|
- done events: {"event": "done", "data": {"message": "..."}}
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting log stream (read existing) for {self.log_path}")
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not os.path.exists(self.log_path):
|
||||||
|
logger.error(f"Log file not found: {self.log_path}")
|
||||||
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": {
|
||||||
|
"code": 4006,
|
||||||
|
"message": "日志文件不存在",
|
||||||
|
"error": f"File not found: {self.log_path}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||||||
|
# First, read all existing content
|
||||||
|
for line in f:
|
||||||
|
if line.strip(): # Skip empty lines
|
||||||
|
cleaned_line = self.clean_log_entry(line)
|
||||||
|
yield {
|
||||||
|
"event": "log",
|
||||||
|
"data": {
|
||||||
|
"content": cleaned_line.rstrip('\n'),
|
||||||
|
"timestamp": int(time.time())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Now watch for new content
|
||||||
|
self.last_position = f.tell()
|
||||||
|
last_keepalive = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if line:
|
||||||
|
cleaned_line = self.clean_log_entry(line)
|
||||||
|
yield {
|
||||||
|
"event": "log",
|
||||||
|
"data": {
|
||||||
|
"content": cleaned_line.rstrip('\n'),
|
||||||
|
"timestamp": int(time.time())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last_keepalive = time.time()
|
||||||
|
else:
|
||||||
|
# No new content, check if we need to send keepalive
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - last_keepalive >= self.keepalive_interval:
|
||||||
|
keepalive_msg = await self.send_keepalive()
|
||||||
|
yield keepalive_msg
|
||||||
|
last_keepalive = current_time
|
||||||
|
|
||||||
|
# Sleep briefly before checking again
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Log file disappeared during streaming: {self.log_path}")
|
||||||
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": {
|
||||||
|
"code": 4006,
|
||||||
|
"message": "日志文件在流式传输期间变得不可用",
|
||||||
|
"error": "File not found during streaming"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during log streaming: {e}", exc_info=True)
|
||||||
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": {
|
||||||
|
"code": 8001,
|
||||||
|
"message": "流式传输期间发生错误",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
logger.info(f"Log stream ended for {self.log_path}")
|
||||||
|
yield {
|
||||||
|
"event": "done",
|
||||||
|
"data": {
|
||||||
|
"message": "流式传输完成"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def watch_and_stream(self) -> AsyncGenerator[dict, None]:
|
||||||
|
"""
|
||||||
|
Watch log file and stream only new content as it's written
|
||||||
|
|
||||||
|
This method starts from the end of the file and only streams
|
||||||
|
new content that is written after the stream starts.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Dict messages with event type and data:
|
||||||
|
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
|
||||||
|
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
|
||||||
|
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
|
||||||
|
- done events: {"event": "done", "data": {"message": "..."}}
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting log stream (new content only) for {self.log_path}")
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not os.path.exists(self.log_path):
|
||||||
|
logger.error(f"Log file not found: {self.log_path}")
|
||||||
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": {
|
||||||
|
"code": 4006,
|
||||||
|
"message": "日志文件不存在",
|
||||||
|
"error": f"File not found: {self.log_path}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Open file and seek to end to start streaming new content
|
||||||
|
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||||||
|
# Move to end of file
|
||||||
|
f.seek(0, os.SEEK_END)
|
||||||
|
self.last_position = f.tell()
|
||||||
|
|
||||||
|
last_keepalive = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Check if file has new content
|
||||||
|
current_position = f.tell()
|
||||||
|
|
||||||
|
# Read new lines if available
|
||||||
|
line = f.readline()
|
||||||
|
if line:
|
||||||
|
# Clean the log entry
|
||||||
|
cleaned_line = self.clean_log_entry(line)
|
||||||
|
|
||||||
|
# Yield log event
|
||||||
|
yield {
|
||||||
|
"event": "log",
|
||||||
|
"data": {
|
||||||
|
"content": cleaned_line.rstrip('\n'),
|
||||||
|
"timestamp": int(time.time())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update last keepalive time since we sent data
|
||||||
|
last_keepalive = time.time()
|
||||||
|
else:
|
||||||
|
# No new content, check if we need to send keepalive
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - last_keepalive >= self.keepalive_interval:
|
||||||
|
keepalive_msg = await self.send_keepalive()
|
||||||
|
yield keepalive_msg
|
||||||
|
last_keepalive = current_time
|
||||||
|
|
||||||
|
# Sleep briefly before checking again
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Log file disappeared during streaming: {self.log_path}")
|
||||||
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": {
|
||||||
|
"code": 4006,
|
||||||
|
"message": "日志文件在流式传输期间变得不可用",
|
||||||
|
"error": "File not found during streaming"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during log streaming: {e}", exc_info=True)
|
||||||
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": {
|
||||||
|
"code": 8001,
|
||||||
|
"message": "流式传输期间发生错误",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
logger.info(f"Log stream ended for {self.log_path}")
|
||||||
|
yield {
|
||||||
|
"event": "done",
|
||||||
|
"data": {
|
||||||
|
"message": "流式传输完成"
|
||||||
|
}
|
||||||
|
}
|
||||||
32
app/core/memory/agent/logger_file/logger_data.py
Normal file
32
app/core/memory/agent/logger_file/logger_data.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Agent logger module for backward compatibility.
|
||||||
|
|
||||||
|
This module maintains the get_named_logger() function for backward compatibility
|
||||||
|
while delegating to the centralized logging configuration.
|
||||||
|
|
||||||
|
All new code should import directly from app.core.logging_config instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
__author__ = "RED_BEAR"
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_named_logger(name):
|
||||||
|
"""Get a named logger for agent operations.
|
||||||
|
|
||||||
|
This function maintains backward compatibility with existing code.
|
||||||
|
It delegates to the centralized get_agent_logger() function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Logger name for namespacing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger configured for agent operations
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> logger = get_named_logger("my_agent")
|
||||||
|
>>> logger.info("Agent operation started")
|
||||||
|
"""
|
||||||
|
return get_agent_logger(name)
|
||||||
28
app/core/memory/agent/mcp_server/__init__.py
Normal file
28
app/core/memory/agent/mcp_server/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
MCP Server package for memory agent.
|
||||||
|
|
||||||
|
This package provides the FastMCP server implementation with context-based
|
||||||
|
dependency injection for tool functions.
|
||||||
|
|
||||||
|
Package structure:
|
||||||
|
- server: FastMCP server initialization and context setup
|
||||||
|
- tools: MCP tool implementations
|
||||||
|
- models: Pydantic response models
|
||||||
|
- services: Business logic services
|
||||||
|
"""
|
||||||
|
from app.core.memory.agent.mcp_server.server import (
|
||||||
|
mcp,
|
||||||
|
initialize_context,
|
||||||
|
main,
|
||||||
|
get_context_resource
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import tools to register them (but don't export them)
|
||||||
|
from app.core.memory.agent.mcp_server import tools
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'mcp',
|
||||||
|
'initialize_context',
|
||||||
|
'main',
|
||||||
|
'get_context_resource',
|
||||||
|
]
|
||||||
11
app/core/memory/agent/mcp_server/mcp_instance.py
Normal file
11
app/core/memory/agent/mcp_server/mcp_instance.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
MCP Server Instance
|
||||||
|
|
||||||
|
This module contains the FastMCP server instance that is shared across all modules.
|
||||||
|
It's in a separate file to avoid circular import issues.
|
||||||
|
"""
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
# Initialize FastMCP server instance
|
||||||
|
# This instance is shared across all tool modules
|
||||||
|
mcp = FastMCP('data_flow')
|
||||||
30
app/core/memory/agent/mcp_server/models/__init__.py
Normal file
30
app/core/memory/agent/mcp_server/models/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Pydantic models for MCP server responses."""
|
||||||
|
|
||||||
|
from .problem_models import (
|
||||||
|
ProblemBreakdownItem,
|
||||||
|
ProblemBreakdownResponse,
|
||||||
|
ExtendedQuestionItem,
|
||||||
|
ProblemExtensionResponse,
|
||||||
|
)
|
||||||
|
from .summary_models import (
|
||||||
|
SummaryData,
|
||||||
|
SummaryResponse,
|
||||||
|
RetrieveSummaryData,
|
||||||
|
RetrieveSummaryResponse,
|
||||||
|
)
|
||||||
|
from .verification_models import VerificationResult
|
||||||
|
from .retrieval_models import RetrievalResult, DistinguishTypeResponse
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ProblemBreakdownItem",
|
||||||
|
"ProblemBreakdownResponse",
|
||||||
|
"ExtendedQuestionItem",
|
||||||
|
"ProblemExtensionResponse",
|
||||||
|
"SummaryData",
|
||||||
|
"SummaryResponse",
|
||||||
|
"RetrieveSummaryData",
|
||||||
|
"RetrieveSummaryResponse",
|
||||||
|
"VerificationResult",
|
||||||
|
"RetrievalResult",
|
||||||
|
"DistinguishTypeResponse",
|
||||||
|
]
|
||||||
34
app/core/memory/agent/mcp_server/models/problem_models.py
Normal file
34
app/core/memory/agent/mcp_server/models/problem_models.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Pydantic models for problem breakdown and extension operations."""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemBreakdownItem(BaseModel):
|
||||||
|
"""Individual item in problem breakdown response."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
question: str
|
||||||
|
type: str
|
||||||
|
reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemBreakdownResponse(RootModel[List[ProblemBreakdownItem]]):
|
||||||
|
"""Response model for problem breakdown containing list of breakdown items."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedQuestionItem(BaseModel):
|
||||||
|
"""Individual extended question item with reasoning."""
|
||||||
|
|
||||||
|
original_question: str = Field(..., description="原始初步问题")
|
||||||
|
extended_question: str = Field(..., description="扩展后的问题")
|
||||||
|
type: str = Field(..., description="类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)")
|
||||||
|
reason: str = Field(..., description="生成该扩展问题的理由")
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemExtensionResponse(RootModel[List[ExtendedQuestionItem]]):
|
||||||
|
"""Response model for problem extension containing list of extended questions."""
|
||||||
|
|
||||||
|
pass
|
||||||
17
app/core/memory/agent/mcp_server/models/retrieval_models.py
Normal file
17
app/core/memory/agent/mcp_server/models/retrieval_models.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Pydantic models for retrieval operations."""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalResult(BaseModel):
|
||||||
|
"""Result model for retrieval operation."""
|
||||||
|
|
||||||
|
Query: str
|
||||||
|
Expansion_issue: List[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class DistinguishTypeResponse(BaseModel):
|
||||||
|
"""Response model for data type differentiation."""
|
||||||
|
|
||||||
|
type: str
|
||||||
31
app/core/memory/agent/mcp_server/models/summary_models.py
Normal file
31
app/core/memory/agent/mcp_server/models/summary_models.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Pydantic models for summary operations."""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryData(BaseModel):
|
||||||
|
"""Data structure for summary input."""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
history: List[str] = Field(default_factory=list)
|
||||||
|
retrieve_info: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryResponse(BaseModel):
|
||||||
|
"""Response model for summary operation."""
|
||||||
|
|
||||||
|
data: SummaryData
|
||||||
|
query_answer: str
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveSummaryData(BaseModel):
|
||||||
|
"""Data structure for retrieve summary response."""
|
||||||
|
|
||||||
|
query_answer: str = Field(default="")
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveSummaryResponse(BaseModel):
|
||||||
|
"""Response model for retrieve summary operation."""
|
||||||
|
|
||||||
|
data: RetrieveSummaryData
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
"""Pydantic models for verification operations."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationResult(BaseModel):
|
||||||
|
"""Result model for verification operation."""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
expansion_issue: List[Dict[str, Any]]
|
||||||
|
split_result: str
|
||||||
|
reason: Optional[str] = None
|
||||||
|
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
161
app/core/memory/agent/mcp_server/server.py
Normal file
161
app/core/memory/agent/mcp_server/server.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""
|
||||||
|
MCP Server initialization with FastMCP context setup.
|
||||||
|
|
||||||
|
This module initializes the FastMCP server and registers shared resources
|
||||||
|
in the context for dependency injection into tool functions.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
|
||||||
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
|
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
|
||||||
|
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||||
|
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||||
|
from app.core.memory.agent.mcp_server.services.search_service import SearchService
|
||||||
|
from app.core.memory.agent.mcp_server.services.session_service import SessionService
|
||||||
|
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_resource(ctx, resource_name: str):
|
||||||
|
"""
|
||||||
|
Helper function to retrieve a resource from the FastMCP context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP Context object (passed to tool functions)
|
||||||
|
resource_name: Name of the resource to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The requested resource
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: If the resource doesn't exist
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@mcp.tool()
|
||||||
|
async def my_tool(ctx: Context):
|
||||||
|
template_service = get_context_resource(ctx, 'template_service')
|
||||||
|
llm_client = get_context_resource(ctx, 'llm_client')
|
||||||
|
"""
|
||||||
|
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
|
||||||
|
raise RuntimeError("Context does not have fastmcp attribute")
|
||||||
|
|
||||||
|
if not hasattr(ctx.fastmcp, resource_name):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Resource '{resource_name}' not found in context. "
|
||||||
|
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return getattr(ctx.fastmcp, resource_name)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_context():
|
||||||
|
"""
|
||||||
|
Initialize and register shared resources in FastMCP context.
|
||||||
|
|
||||||
|
This function sets up all shared resources that will be available
|
||||||
|
to tool functions via dependency injection through the context parameter.
|
||||||
|
|
||||||
|
Resources are stored as attributes on the FastMCP instance and can be
|
||||||
|
accessed via ctx.fastmcp in tool functions.
|
||||||
|
|
||||||
|
Resources registered:
|
||||||
|
- session_store: RedisSessionStore for session management
|
||||||
|
- llm_client: LLM client for structured API calls
|
||||||
|
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
|
||||||
|
- template_service: Service for template rendering
|
||||||
|
- search_service: Service for hybrid search
|
||||||
|
- session_service: Service for session operations
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Register Redis session store
|
||||||
|
logger.info("Registering session_store in context")
|
||||||
|
mcp.session_store = store
|
||||||
|
|
||||||
|
# Register LLM client
|
||||||
|
try:
|
||||||
|
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
|
||||||
|
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||||
|
mcp.llm_client = llm_client
|
||||||
|
logger.info("llm_client registered successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
|
||||||
|
# 注册一个 None 值,避免工具调用时找不到资源
|
||||||
|
mcp.llm_client = None
|
||||||
|
logger.warning("llm_client set to None due to initialization failure")
|
||||||
|
|
||||||
|
# Register application settings (renamed to avoid conflict with FastMCP's settings)
|
||||||
|
logger.info("Registering app_settings in context")
|
||||||
|
mcp.app_settings = settings
|
||||||
|
|
||||||
|
# Register template service
|
||||||
|
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||||
|
# logger.info(f"Registering template_service in context with root: {template_root}")
|
||||||
|
template_service = TemplateService(template_root)
|
||||||
|
mcp.template_service = template_service
|
||||||
|
|
||||||
|
# Register search service
|
||||||
|
# logger.info("Registering search_service in context")
|
||||||
|
search_service = SearchService()
|
||||||
|
mcp.search_service = search_service
|
||||||
|
|
||||||
|
# Register session service
|
||||||
|
# logger.info("Registering session_service in context")
|
||||||
|
session_service = SessionService(store)
|
||||||
|
mcp.session_service = session_service
|
||||||
|
|
||||||
|
# logger.info("All context resources registered successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize context: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main entry point for the MCP server.
|
||||||
|
|
||||||
|
Initializes context and starts the server with SSE transport.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# logger.info("Starting MCP server initialization")
|
||||||
|
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
|
||||||
|
# Initialize context resources
|
||||||
|
initialize_context()
|
||||||
|
|
||||||
|
# Import and register tools
|
||||||
|
# logger.info("Importing MCP tools")
|
||||||
|
from app.core.memory.agent.mcp_server.tools import (
|
||||||
|
problem_tools,
|
||||||
|
retrieval_tools,
|
||||||
|
verification_tools,
|
||||||
|
summary_tools,
|
||||||
|
data_tools
|
||||||
|
)
|
||||||
|
# logger.info("All MCP tools imported and registered")
|
||||||
|
|
||||||
|
# Log registered tools for debugging
|
||||||
|
import asyncio
|
||||||
|
tools_list = asyncio.run(mcp.list_tools())
|
||||||
|
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
|
||||||
|
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport")
|
||||||
|
|
||||||
|
# Run the server with SSE transport for HTTP connections
|
||||||
|
# The server will be available at http://127.0.0.1:8081
|
||||||
|
import uvicorn
|
||||||
|
app = mcp.sse_app()
|
||||||
|
uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
app/core/memory/agent/mcp_server/services/__init__.py
Normal file
23
app/core/memory/agent/mcp_server/services/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
MCP Server Services
|
||||||
|
|
||||||
|
This module provides business logic services for the MCP server:
|
||||||
|
- TemplateService: Template loading and rendering
|
||||||
|
- SearchService: Search result processing
|
||||||
|
- SessionService: Session and history management
|
||||||
|
- ParameterBuilder: Tool parameter construction
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .template_service import TemplateService, TemplateRenderError
|
||||||
|
from .search_service import SearchService
|
||||||
|
from .session_service import SessionService
|
||||||
|
from .parameter_builder import ParameterBuilder
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TemplateService",
|
||||||
|
"TemplateRenderError",
|
||||||
|
"SearchService",
|
||||||
|
"SessionService",
|
||||||
|
"ParameterBuilder",
|
||||||
|
]
|
||||||
157
app/core/memory/agent/mcp_server/services/parameter_builder.py
Normal file
157
app/core/memory/agent/mcp_server/services/parameter_builder.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
Parameter Builder for constructing tool call arguments.
|
||||||
|
|
||||||
|
This service provides tool-specific parameter transformation logic
|
||||||
|
to build correct arguments for each tool type.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterBuilder:
|
||||||
|
"""Service for building tool call arguments based on tool type."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the parameter builder."""
|
||||||
|
logger.info("ParameterBuilder initialized")
|
||||||
|
|
||||||
|
def build_tool_args(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
content: Any,
|
||||||
|
tool_call_id: str,
|
||||||
|
search_switch: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build tool arguments based on tool type.
|
||||||
|
|
||||||
|
Different tools expect different argument formats:
|
||||||
|
- Verify: dict context
|
||||||
|
- Retrieve: dict context + search_switch
|
||||||
|
- Summary/Summary_fails: JSON string context
|
||||||
|
- Retrieve_Summary: unwrap nested context structures
|
||||||
|
- Input_Summary: raw message string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool being invoked
|
||||||
|
content: Parsed content from previous tool result
|
||||||
|
tool_call_id: Extracted tool call identifier
|
||||||
|
search_switch: Search routing parameter
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (optional)
|
||||||
|
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of tool arguments ready for invocation
|
||||||
|
"""
|
||||||
|
# Base arguments common to most tools
|
||||||
|
base_args = {
|
||||||
|
"usermessages": tool_call_id,
|
||||||
|
"apply_id": apply_id,
|
||||||
|
"group_id": group_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||||
|
base_args["storage_type"] = storage_type if storage_type is not None else ""
|
||||||
|
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
|
||||||
|
|
||||||
|
# Tool-specific argument construction
|
||||||
|
if tool_name == "Verify":
|
||||||
|
# Verify expects dict context
|
||||||
|
return {
|
||||||
|
"context": content if isinstance(content, dict) else {},
|
||||||
|
**base_args
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "Retrieve":
|
||||||
|
# Retrieve expects dict context + search_switch
|
||||||
|
return {
|
||||||
|
"context": content if isinstance(content, dict) else {},
|
||||||
|
"search_switch": search_switch,
|
||||||
|
**base_args
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name in ["Summary", "Summary_fails"]:
|
||||||
|
# Summary tools expect JSON string context
|
||||||
|
if isinstance(content, dict):
|
||||||
|
context_str = json.dumps(content, ensure_ascii=False)
|
||||||
|
elif isinstance(content, str):
|
||||||
|
context_str = content
|
||||||
|
else:
|
||||||
|
context_str = json.dumps({"data": content}, ensure_ascii=False)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"context": context_str,
|
||||||
|
**base_args
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "Retrieve_Summary":
|
||||||
|
# Retrieve_Summary needs to unwrap nested context structures
|
||||||
|
# Handle both 'content' and 'context' keys
|
||||||
|
context_dict = content
|
||||||
|
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Check for nested 'content' wrapper
|
||||||
|
if "content" in content:
|
||||||
|
inner = content["content"]
|
||||||
|
|
||||||
|
# If it's a JSON string, parse it
|
||||||
|
if isinstance(inner, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(inner)
|
||||||
|
# Check if parsed has 'context' wrapper
|
||||||
|
if isinstance(parsed, dict) and "context" in parsed:
|
||||||
|
context_dict = parsed["context"]
|
||||||
|
else:
|
||||||
|
context_dict = parsed
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse JSON content for {tool_name}: {inner[:100]}"
|
||||||
|
)
|
||||||
|
context_dict = {"Query": "", "Expansion_issue": []}
|
||||||
|
elif isinstance(inner, dict):
|
||||||
|
context_dict = inner
|
||||||
|
|
||||||
|
# Check for 'context' wrapper
|
||||||
|
elif "context" in content:
|
||||||
|
context_dict = content["context"] if isinstance(content["context"], dict) else content
|
||||||
|
|
||||||
|
return {
|
||||||
|
"context": context_dict,
|
||||||
|
**base_args
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "Input_Summary":
|
||||||
|
# Input_Summary expects raw message string + search_switch
|
||||||
|
# Content should be the raw message string
|
||||||
|
if isinstance(content, dict):
|
||||||
|
# Try to extract message from dict
|
||||||
|
message_str = content.get("sentence", str(content))
|
||||||
|
else:
|
||||||
|
message_str = str(content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"context": message_str,
|
||||||
|
"search_switch": search_switch,
|
||||||
|
**base_args
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Default: pass content as context
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown tool name '{tool_name}', using default argument structure"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": content,
|
||||||
|
**base_args
|
||||||
|
}
|
||||||
193
app/core/memory/agent/mcp_server/services/search_service.py
Normal file
193
app/core/memory/agent/mcp_server/services/search_service.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""
|
||||||
|
Search Service for executing hybrid search and processing results.
|
||||||
|
|
||||||
|
This service provides clean search result processing with content extraction
|
||||||
|
and deduplication.
|
||||||
|
"""
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchService:
|
||||||
|
"""Service for executing hybrid search and processing results."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the search service."""
|
||||||
|
logger.info("SearchService initialized")
|
||||||
|
|
||||||
|
def extract_content_from_result(self, result: dict) -> str:
|
||||||
|
"""
|
||||||
|
Extract only meaningful content from search results, dropping all metadata.
|
||||||
|
|
||||||
|
Extraction rules by node type:
|
||||||
|
- Statements: extract 'statement' field
|
||||||
|
- Entities: extract 'name' and 'fact_summary' fields
|
||||||
|
- Summaries: extract 'content' field
|
||||||
|
- Chunks: extract 'content' field
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Search result dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Clean content string without metadata
|
||||||
|
"""
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
content_parts = []
|
||||||
|
|
||||||
|
# Statements: extract statement field
|
||||||
|
if 'statement' in result and result['statement']:
|
||||||
|
content_parts.append(result['statement'])
|
||||||
|
|
||||||
|
# Summaries/Chunks: extract content field
|
||||||
|
if 'content' in result and result['content']:
|
||||||
|
content_parts.append(result['content'])
|
||||||
|
|
||||||
|
# Entities: extract name and fact_summary (commented out in original)
|
||||||
|
# if 'name' in result and result['name']:
|
||||||
|
# content_parts.append(result['name'])
|
||||||
|
# if result.get('fact_summary'):
|
||||||
|
# content_parts.append(result['fact_summary'])
|
||||||
|
|
||||||
|
# Return concatenated content or empty string
|
||||||
|
return '\n'.join(content_parts) if content_parts else ""
|
||||||
|
|
||||||
|
def clean_query(self, query: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean and escape query text for Lucene.
|
||||||
|
|
||||||
|
- Removes wrapping quotes
|
||||||
|
- Removes newlines and carriage returns
|
||||||
|
- Applies Lucene escaping
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Raw query string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cleaned and escaped query string
|
||||||
|
"""
|
||||||
|
q = str(query).strip()
|
||||||
|
|
||||||
|
# Remove wrapping quotes
|
||||||
|
if (q.startswith("'") and q.endswith("'")) or (
|
||||||
|
q.startswith('"') and q.endswith('"')
|
||||||
|
):
|
||||||
|
q = q[1:-1]
|
||||||
|
|
||||||
|
# Remove newlines and carriage returns
|
||||||
|
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||||
|
|
||||||
|
# Apply Lucene escaping
|
||||||
|
q = escape_lucene_query(q)
|
||||||
|
|
||||||
|
return q
|
||||||
|
|
||||||
|
async def execute_hybrid_search(
|
||||||
|
self,
|
||||||
|
group_id: str,
|
||||||
|
question: str,
|
||||||
|
limit: int = 5,
|
||||||
|
search_type: str = "hybrid",
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
rerank_alpha: float = 0.4,
|
||||||
|
output_path: str = "search_results.json",
|
||||||
|
return_raw_results: bool = False
|
||||||
|
) -> Tuple[str, str, Optional[dict]]:
|
||||||
|
"""
|
||||||
|
Execute hybrid search and return clean content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: Group identifier for filtering results
|
||||||
|
question: Search query text
|
||||||
|
limit: Maximum number of results to return (default: 5)
|
||||||
|
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||||
|
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
|
||||||
|
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||||
|
output_path: Path to save search results (default: "search_results.json")
|
||||||
|
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (clean_content, cleaned_query, raw_results)
|
||||||
|
raw_results is None if return_raw_results=False
|
||||||
|
"""
|
||||||
|
if include is None:
|
||||||
|
include = ["statements", "chunks", "entities", "summaries"]
|
||||||
|
|
||||||
|
# Clean query
|
||||||
|
cleaned_query = self.clean_query(question)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute search
|
||||||
|
answer = await run_hybrid_search(
|
||||||
|
query_text=cleaned_query,
|
||||||
|
search_type=search_type,
|
||||||
|
group_id=group_id,
|
||||||
|
limit=limit,
|
||||||
|
include=include,
|
||||||
|
output_path=output_path,
|
||||||
|
rerank_alpha=rerank_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract results based on search type and include parameter
|
||||||
|
# Prioritize summaries as they contain synthesized contextual information
|
||||||
|
answer_list = []
|
||||||
|
|
||||||
|
# For hybrid search, use reranked_results
|
||||||
|
if search_type == "hybrid":
|
||||||
|
reranked_results = answer.get('reranked_results', {})
|
||||||
|
|
||||||
|
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||||
|
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
|
for category in priority_order:
|
||||||
|
if category in include and category in reranked_results:
|
||||||
|
category_results = reranked_results[category]
|
||||||
|
if isinstance(category_results, list):
|
||||||
|
answer_list.extend(category_results)
|
||||||
|
else:
|
||||||
|
# For keyword or embedding search, results are directly in answer dict
|
||||||
|
# Apply same priority order
|
||||||
|
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
|
for category in priority_order:
|
||||||
|
if category in include and category in answer:
|
||||||
|
category_results = answer[category]
|
||||||
|
if isinstance(category_results, list):
|
||||||
|
answer_list.extend(category_results)
|
||||||
|
|
||||||
|
# Extract clean content from all results
|
||||||
|
content_list = [
|
||||||
|
self.extract_content_from_result(ans)
|
||||||
|
for ans in answer_list
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Filter out empty strings and join with newlines
|
||||||
|
clean_content = '\n'.join([c for c in content_list if c])
|
||||||
|
|
||||||
|
# Log first 200 chars
|
||||||
|
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||||
|
|
||||||
|
# Return raw results if requested
|
||||||
|
if return_raw_results:
|
||||||
|
return clean_content, cleaned_query, answer
|
||||||
|
else:
|
||||||
|
return clean_content, cleaned_query, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Return empty results on failure
|
||||||
|
if return_raw_results:
|
||||||
|
return "", cleaned_query, {}
|
||||||
|
else:
|
||||||
|
return "", cleaned_query, None
|
||||||
169
app/core/memory/agent/mcp_server/services/session_service.py
Normal file
169
app/core/memory/agent/mcp_server/services/session_service.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Session Service for managing user sessions and conversation history.
|
||||||
|
|
||||||
|
This service provides clean Redis interactions with error handling and
|
||||||
|
session management utilities.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionService:
|
||||||
|
"""Service for managing user sessions and conversation history."""
|
||||||
|
|
||||||
|
def __init__(self, store: RedisSessionStore):
|
||||||
|
"""
|
||||||
|
Initialize the session service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: Redis session store instance
|
||||||
|
"""
|
||||||
|
self.store = store
|
||||||
|
logger.info("SessionService initialized")
|
||||||
|
|
||||||
|
def resolve_user_id(self, session_string: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract user ID from session string.
|
||||||
|
|
||||||
|
Handles formats like:
|
||||||
|
- 'call_id_user123' -> 'user123'
|
||||||
|
- 'prefix_id_user456_suffix' -> 'user456_suffix'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_string: Session identifier string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted user ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Split by '_id_' and take everything after it
|
||||||
|
parts = session_string.split('_id_')
|
||||||
|
if len(parts) > 1:
|
||||||
|
return parts[1]
|
||||||
|
|
||||||
|
# Fallback: return original string
|
||||||
|
return session_string
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse user ID from session string '{session_string}': {e}"
|
||||||
|
)
|
||||||
|
return session_string
|
||||||
|
|
||||||
|
async def get_history(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Retrieve conversation history from Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conversation history items with Query and Answer keys
|
||||||
|
Returns empty list if no history found or on error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||||
|
|
||||||
|
# Validate history structure
|
||||||
|
if not isinstance(history, list):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid history format for user {user_id}, "
|
||||||
|
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
return history
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to retrieve history for user {user_id}, "
|
||||||
|
f"apply {apply_id}, group {group_id}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Return empty list on error to allow execution to continue
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def save_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
query: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
ai_response: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Save conversation turn to Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User identifier
|
||||||
|
query: User query/message
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
ai_response: AI response/answer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session ID if successful, None on error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate required fields
|
||||||
|
if not user_id:
|
||||||
|
logger.warning("Cannot save session: user_id is empty")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
logger.warning("Cannot save session: query is empty")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Save session
|
||||||
|
session_id = self.store.save_session(
|
||||||
|
userid=user_id,
|
||||||
|
messages=query,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
aimessages=ai_response
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Session saved successfully: {session_id}")
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to save session for user {user_id}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cleanup_duplicates(self) -> int:
|
||||||
|
"""
|
||||||
|
Remove duplicate session entries.
|
||||||
|
|
||||||
|
Duplicates are identified by matching:
|
||||||
|
- sessionid
|
||||||
|
- user_id (id field)
|
||||||
|
- group_id
|
||||||
|
- messages
|
||||||
|
- aimessages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of duplicate sessions deleted
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
deleted_count = self.store.delete_duplicate_sessions()
|
||||||
|
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
|
||||||
|
return 0
|
||||||
116
app/core/memory/agent/mcp_server/services/template_service.py
Normal file
116
app/core/memory/agent/mcp_server/services/template_service.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
Template Service for loading and rendering Jinja2 templates.
|
||||||
|
|
||||||
|
This service provides centralized template management with caching and error handling.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateRenderError(Exception):
|
||||||
|
"""Exception raised when template rendering fails."""
|
||||||
|
|
||||||
|
def __init__(self, template_name: str, error: Exception, variables: dict):
|
||||||
|
self.template_name = template_name
|
||||||
|
self.error = error
|
||||||
|
self.variables = variables
|
||||||
|
super().__init__(
|
||||||
|
f"Failed to render template '{template_name}': {str(error)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateService:
|
||||||
|
"""Service for loading and rendering Jinja2 templates with caching."""
|
||||||
|
|
||||||
|
def __init__(self, template_root: str):
|
||||||
|
"""
|
||||||
|
Initialize the template service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_root: Root directory containing template files
|
||||||
|
"""
|
||||||
|
self.template_root = template_root
|
||||||
|
self.env = Environment(
|
||||||
|
loader=FileSystemLoader(template_root),
|
||||||
|
autoescape=False # Disable autoescape for prompt templates
|
||||||
|
)
|
||||||
|
logger.info(f"TemplateService initialized with root: {template_root}")
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def _load_template(self, template_name: str) -> Template:
|
||||||
|
"""
|
||||||
|
Load a template from disk with caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_name: Relative path to template file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded Jinja2 Template object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TemplateNotFound: If template file doesn't exist
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.env.get_template(template_name)
|
||||||
|
except TemplateNotFound as e:
|
||||||
|
expected_path = os.path.join(self.template_root, template_name)
|
||||||
|
logger.error(
|
||||||
|
f"Template not found: {template_name}. "
|
||||||
|
f"Expected path: {expected_path}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def render_template(
|
||||||
|
self,
|
||||||
|
template_name: str,
|
||||||
|
operation_name: str,
|
||||||
|
**variables
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Load and render a Jinja2 template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_name: Relative path to template file
|
||||||
|
operation_name: Name for logging (e.g., "split_the_problem")
|
||||||
|
**variables: Template variables to render
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered template string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TemplateRenderError: If template loading or rendering fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load template (cached)
|
||||||
|
template = self._load_template(template_name)
|
||||||
|
|
||||||
|
# Render template
|
||||||
|
rendered = template.render(**variables)
|
||||||
|
|
||||||
|
# Log rendered prompt
|
||||||
|
log_prompt_rendering(operation_name, rendered)
|
||||||
|
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
except TemplateNotFound as e:
|
||||||
|
logger.error(
|
||||||
|
f"Template rendering failed for {operation_name} "
|
||||||
|
f"({template_name}): Template not found",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise TemplateRenderError(template_name, e, variables)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Template rendering failed for {operation_name} "
|
||||||
|
f"({template_name}): {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise TemplateRenderError(template_name, e, variables)
|
||||||
27
app/core/memory/agent/mcp_server/tools/__init__.py
Normal file
27
app/core/memory/agent/mcp_server/tools/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""
|
||||||
|
MCP Tools module.
|
||||||
|
|
||||||
|
This module contains all MCP tool implementations organized by functionality.
|
||||||
|
|
||||||
|
Tools are organized into the following modules:
|
||||||
|
- problem_tools: Question segmentation and extension
|
||||||
|
- retrieval_tools: Database and context retrieval
|
||||||
|
- verification_tools: Data verification
|
||||||
|
- summary_tools: Summarization and summary retrieval
|
||||||
|
- data_tools: Data type differentiation and writing
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import all tool modules to register them with the MCP server
|
||||||
|
from . import problem_tools
|
||||||
|
from . import retrieval_tools
|
||||||
|
from . import verification_tools
|
||||||
|
from . import summary_tools
|
||||||
|
from . import data_tools
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'problem_tools',
|
||||||
|
'retrieval_tools',
|
||||||
|
'verification_tools',
|
||||||
|
'summary_tools',
|
||||||
|
'data_tools',
|
||||||
|
]
|
||||||
149
app/core/memory/agent/mcp_server/tools/data_tools.py
Normal file
149
app/core/memory/agent/mcp_server/tools/data_tools.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
Data Tools for data type differentiation and writing.
|
||||||
|
|
||||||
|
This module contains MCP tools for distinguishing data types and writing data.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import Context
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||||
|
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||||
|
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
|
||||||
|
from app.core.memory.agent.utils.write_tools import write
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Data_type_differentiation(
|
||||||
|
ctx: Context,
|
||||||
|
context: str
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Distinguish the type of data (read or write).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: Text to analyze for type differentiation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'context' with the original text and 'type' field
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
template_service = get_context_resource(ctx, 'template_service')
|
||||||
|
llm_client = get_context_resource(ctx, 'llm_client')
|
||||||
|
|
||||||
|
# Render template
|
||||||
|
try:
|
||||||
|
system_prompt = await template_service.render_template(
|
||||||
|
template_name='distinguish_types_prompt.jinja2',
|
||||||
|
operation_name='status_typle',
|
||||||
|
user_query=context
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Template rendering failed for Data_type_differentiation: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Prompt rendering failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call LLM with structured response
|
||||||
|
try:
|
||||||
|
structured = await llm_client.response_structured(
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
response_model=DistinguishTypeResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
result = structured.model_dump()
|
||||||
|
|
||||||
|
# Add context to result
|
||||||
|
result["context"] = context
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"LLM call failed for Data_type_differentiation: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": context,
|
||||||
|
"type": "error",
|
||||||
|
"message": f"LLM call failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Data_type_differentiation failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": context,
|
||||||
|
"type": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Data_write(
|
||||||
|
ctx: Context,
|
||||||
|
content: str,
|
||||||
|
user_id: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
config_id: str
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Write data to the database/file system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
content: Data content to write
|
||||||
|
user_id: User identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
config_id: Configuration ID for processing (optional, integer)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Ensure output directory exists
|
||||||
|
os.makedirs("data_output", exist_ok=True)
|
||||||
|
file_path = os.path.join("data_output", "user_data.csv")
|
||||||
|
|
||||||
|
# Write data using utility function
|
||||||
|
try:
|
||||||
|
await write(content, user_id, apply_id, group_id, config_id=config_id)
|
||||||
|
logger.info(f"写入成功!Config ID: {config_id if config_id else 'None'}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"saved_to": file_path,
|
||||||
|
"data": content,
|
||||||
|
"config_id": config_id
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"写入失败: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Data_write failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
293
app/core/memory/agent/mcp_server/tools/problem_tools.py
Normal file
293
app/core/memory/agent/mcp_server/tools/problem_tools.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
Problem Tools for question segmentation and extension.
|
||||||
|
|
||||||
|
This module contains MCP tools for breaking down and extending user questions.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
from mcp.server.fastmcp import Context
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
|
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||||
|
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||||
|
from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||||
|
ProblemBreakdownItem,
|
||||||
|
ProblemBreakdownResponse,
|
||||||
|
ExtendedQuestionItem,
|
||||||
|
ProblemExtensionResponse
|
||||||
|
)
|
||||||
|
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Split_The_Problem(
|
||||||
|
ctx: Context,
|
||||||
|
sentence: str,
|
||||||
|
sessionid: str,
|
||||||
|
messages_id: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Segment the dialogue or sentence into sub-problems.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
sentence: Original sentence to split
|
||||||
|
sessionid: Session identifier
|
||||||
|
messages_id: Message identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'context' (JSON string of split results) and 'original' sentence
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
template_service = get_context_resource(ctx, 'template_service')
|
||||||
|
session_service = get_context_resource(ctx, 'session_service')
|
||||||
|
llm_client = get_context_resource(ctx, 'llm_client')
|
||||||
|
|
||||||
|
# Extract user ID from session
|
||||||
|
user_id = session_service.resolve_user_id(sessionid)
|
||||||
|
|
||||||
|
# Get conversation history
|
||||||
|
history = await session_service.get_history(user_id, apply_id, group_id)
|
||||||
|
# Override with empty list for now (as in original)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
# Render template
|
||||||
|
try:
|
||||||
|
system_prompt = await template_service.render_template(
|
||||||
|
template_name='problem_breakdown_prompt.jinja2',
|
||||||
|
operation_name='split_the_problem',
|
||||||
|
history=history,
|
||||||
|
sentence=sentence
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Template rendering failed for Split_The_Problem: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": json.dumps([], ensure_ascii=False),
|
||||||
|
"original": sentence,
|
||||||
|
"error": f"Prompt rendering failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call LLM with structured response
|
||||||
|
try:
|
||||||
|
structured = await llm_client.response_structured(
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
response_model=ProblemBreakdownResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle RootModel response with .root attribute access
|
||||||
|
if structured is None:
|
||||||
|
# LLM returned None, use empty list as fallback
|
||||||
|
split_result = json.dumps([], ensure_ascii=False)
|
||||||
|
elif hasattr(structured, 'root') and structured.root is not None:
|
||||||
|
split_result = json.dumps(
|
||||||
|
[item.model_dump() for item in structured.root],
|
||||||
|
ensure_ascii=False
|
||||||
|
)
|
||||||
|
elif isinstance(structured, list):
|
||||||
|
# Fallback: treat structured itself as the list
|
||||||
|
split_result = json.dumps(
|
||||||
|
[item.model_dump() for item in structured],
|
||||||
|
ensure_ascii=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Last resort: use empty list
|
||||||
|
split_result = json.dumps([], ensure_ascii=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"LLM call failed for Split_The_Problem: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
split_result = json.dumps([], ensure_ascii=False)
|
||||||
|
|
||||||
|
logger.info(f"问题拆分")
|
||||||
|
logger.info(f"问题拆分结果==>>:{split_result}")
|
||||||
|
|
||||||
|
# Emit intermediate output for frontend
|
||||||
|
result = {
|
||||||
|
"context": split_result,
|
||||||
|
"original": sentence,
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "problem_split",
|
||||||
|
"data": json.loads(split_result) if split_result else [],
|
||||||
|
"original_query": sentence
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Split_The_Problem failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": json.dumps([], ensure_ascii=False),
|
||||||
|
"original": sentence,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Log execution time
|
||||||
|
end = time.time()
|
||||||
|
try:
|
||||||
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
|
log_time('问题拆分', duration)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Problem_Extension(
|
||||||
|
ctx: Context,
|
||||||
|
context: dict,
|
||||||
|
usermessages: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: str = "",
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Extend the problem with additional sub-questions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: Dictionary containing split problem results
|
||||||
|
usermessages: User messages identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (optional)
|
||||||
|
user_rag_memory_id: User RAG memory identifier (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'context' (aggregated questions) and 'original' question
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
template_service = get_context_resource(ctx, 'template_service')
|
||||||
|
session_service = get_context_resource(ctx, 'session_service')
|
||||||
|
llm_client = get_context_resource(ctx, 'llm_client')
|
||||||
|
|
||||||
|
# Resolve session ID from usermessages
|
||||||
|
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||||
|
sessionid = Resolve_username(usermessages)
|
||||||
|
|
||||||
|
# Get conversation history
|
||||||
|
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||||
|
# Override with empty list for now (as in original)
|
||||||
|
history = []
|
||||||
|
|
||||||
|
# Process context to extract questions
|
||||||
|
extent_quest, original = await Problem_Extension_messages_deal(context)
|
||||||
|
|
||||||
|
# Format questions for template rendering
|
||||||
|
questions_formatted = []
|
||||||
|
for msg in extent_quest:
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
questions_formatted.append(msg.get("content", ""))
|
||||||
|
|
||||||
|
# Render template
|
||||||
|
try:
|
||||||
|
system_prompt = await template_service.render_template(
|
||||||
|
template_name='Problem_Extension_prompt.jinja2',
|
||||||
|
operation_name='problem_extension',
|
||||||
|
history=history,
|
||||||
|
questions=questions_formatted
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Template rendering failed for Problem_Extension: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": {},
|
||||||
|
"original": original,
|
||||||
|
"error": f"Prompt rendering failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Call LLM with structured response
|
||||||
|
try:
|
||||||
|
response_content = await llm_client.response_structured(
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
response_model=ProblemExtensionResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
# Aggregate results by original question
|
||||||
|
aggregated_dict = {}
|
||||||
|
for item in response_content.root:
|
||||||
|
key = getattr(item, "original_question", None) or (
|
||||||
|
item.get("original_question") if isinstance(item, dict) else None
|
||||||
|
)
|
||||||
|
value = getattr(item, "extended_question", None) or (
|
||||||
|
item.get("extended_question") if isinstance(item, dict) else None
|
||||||
|
)
|
||||||
|
if not key or not value:
|
||||||
|
continue
|
||||||
|
aggregated_dict.setdefault(key, []).append(value)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"LLM call failed for Problem_Extension: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
aggregated_dict = {}
|
||||||
|
|
||||||
|
logger.info(f"问题扩展")
|
||||||
|
logger.info(f"问题扩展==>>:{aggregated_dict}")
|
||||||
|
|
||||||
|
# Emit intermediate output for frontend
|
||||||
|
result = {
|
||||||
|
"context": aggregated_dict,
|
||||||
|
"original": original,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "problem_extension",
|
||||||
|
"data": aggregated_dict,
|
||||||
|
"original_query": original,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Problem_Extension failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": {},
|
||||||
|
"original": context.get("original", ""),
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Log execution time
|
||||||
|
end = time.time()
|
||||||
|
try:
|
||||||
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
|
log_time('问题扩展', duration)
|
||||||
282
app/core/memory/agent/mcp_server/tools/retrieval_tools.py
Normal file
282
app/core/memory/agent/mcp_server/tools/retrieval_tools.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
"""
|
||||||
|
Retrieval Tools for database and context retrieval.
|
||||||
|
|
||||||
|
This module contains MCP tools for retrieving data using hybrid search.
|
||||||
|
"""
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
|
||||||
|
# 加载.env文件
|
||||||
|
load_dotenv()
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import Context
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
|
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||||
|
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||||
|
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
|
||||||
|
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Retrieve(
|
||||||
|
ctx: Context,
|
||||||
|
context,
|
||||||
|
usermessages: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: str = "",
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Retrieve data from the database using hybrid search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: Dictionary or string containing query information
|
||||||
|
usermessages: User messages identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||||
|
user_rag_memory_id: User RAG memory identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'context' with Query and Expansion_issue results
|
||||||
|
"""
|
||||||
|
kb_config = {
|
||||||
|
"knowledge_bases": [
|
||||||
|
{
|
||||||
|
"kb_id": user_rag_memory_id,
|
||||||
|
"similarity_threshold": 0.7,
|
||||||
|
"vector_similarity_weight": 0.5,
|
||||||
|
"top_k": 10,
|
||||||
|
"retrieve_type": "participle"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"merge_strategy": "weight",
|
||||||
|
"reranker_id": os.getenv('reranker_id'),
|
||||||
|
"reranker_top_k": 10
|
||||||
|
}
|
||||||
|
start = time.time()
|
||||||
|
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
search_service = get_context_resource(ctx, 'search_service')
|
||||||
|
|
||||||
|
databases_anser = []
|
||||||
|
|
||||||
|
# Handle both dict and string context
|
||||||
|
if isinstance(context, dict):
|
||||||
|
# Process dict context with extended questions
|
||||||
|
all_items = []
|
||||||
|
content, original = await Retriev_messages_deal(context)
|
||||||
|
|
||||||
|
# Extract all query items from content
|
||||||
|
# content is like {original_question: [extended_questions...], ...}
|
||||||
|
for key, values in content.items():
|
||||||
|
if isinstance(values, list):
|
||||||
|
all_items.extend(values)
|
||||||
|
elif isinstance(values, str):
|
||||||
|
all_items.append(values)
|
||||||
|
elif values is not None:
|
||||||
|
# Fallback: convert non-empty non-list values to string
|
||||||
|
all_items.append(str(values))
|
||||||
|
|
||||||
|
# Execute search for each question
|
||||||
|
for idx, question in enumerate(all_items):
|
||||||
|
try:
|
||||||
|
# Prepare search parameters based on storage type
|
||||||
|
search_params = {
|
||||||
|
"group_id": group_id,
|
||||||
|
"question": question,
|
||||||
|
"return_raw_results": True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add storage-specific parameters
|
||||||
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
|
||||||
|
try:
|
||||||
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
|
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||||
|
cleaned_query=question
|
||||||
|
raw_results=clean_content
|
||||||
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
|
except:
|
||||||
|
clean_content = ''
|
||||||
|
raw_results=''
|
||||||
|
cleaned_query = question
|
||||||
|
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||||
|
else:
|
||||||
|
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||||
|
|
||||||
|
databases_anser.append({
|
||||||
|
"Query_small": cleaned_query,
|
||||||
|
"Result_small": clean_content,
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "search_result",
|
||||||
|
"query": cleaned_query,
|
||||||
|
"raw_results": raw_results,
|
||||||
|
"index": idx + 1,
|
||||||
|
"total": len(all_items)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Continue with empty result for this question
|
||||||
|
databases_anser.append({
|
||||||
|
"Query_small": question,
|
||||||
|
"Result_small": ""
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build initial database data structure
|
||||||
|
databases_data = {
|
||||||
|
"Query": original,
|
||||||
|
"Expansion_issue": databases_anser
|
||||||
|
}
|
||||||
|
|
||||||
|
# Collect intermediate outputs before deduplication
|
||||||
|
intermediate_outputs = []
|
||||||
|
for item in databases_anser:
|
||||||
|
if '_intermediate' in item:
|
||||||
|
intermediate_outputs.append(item['_intermediate'])
|
||||||
|
|
||||||
|
# Deduplicate and merge results
|
||||||
|
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||||
|
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||||
|
deduplicated_data,
|
||||||
|
'Query_small',
|
||||||
|
'Result_small'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restructure for Verify/Retrieve_Summary compatibility
|
||||||
|
keys, val = [], []
|
||||||
|
for item in deduplicated_data_merged:
|
||||||
|
for items_key, items_value in item.items():
|
||||||
|
keys.append(items_key)
|
||||||
|
val.append(items_value)
|
||||||
|
|
||||||
|
send_verify = []
|
||||||
|
for i, j in zip(keys, val):
|
||||||
|
send_verify.append({
|
||||||
|
"Query_small": i,
|
||||||
|
"Answer_Small": j
|
||||||
|
})
|
||||||
|
|
||||||
|
dup_databases = {
|
||||||
|
"Query": original,
|
||||||
|
"Expansion_issue": send_verify,
|
||||||
|
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Handle string context (simple query)
|
||||||
|
query = str(context).strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare search parameters based on storage type
|
||||||
|
search_params = {
|
||||||
|
"group_id": group_id,
|
||||||
|
"question": query,
|
||||||
|
"return_raw_results": True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add storage-specific parameters
|
||||||
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
|
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||||
|
try:
|
||||||
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
|
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||||
|
cleaned_query = query
|
||||||
|
raw_results = clean_content
|
||||||
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
|
except:
|
||||||
|
clean_content = ''
|
||||||
|
raw_results = ''
|
||||||
|
cleaned_query = query
|
||||||
|
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||||
|
else:
|
||||||
|
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||||
|
# Keep structure for Verify/Retrieve_Summary compatibility
|
||||||
|
dup_databases = {
|
||||||
|
"Query": cleaned_query,
|
||||||
|
"Expansion_issue": [{
|
||||||
|
"Query_small": cleaned_query,
|
||||||
|
"Answer_Small": clean_content,
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "search_result",
|
||||||
|
"query": cleaned_query,
|
||||||
|
"raw_results": raw_results,
|
||||||
|
"index": 1,
|
||||||
|
"total": 1
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Retrieve: hybrid_search failed for query '{query}': {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Return empty results on failure
|
||||||
|
dup_databases = {
|
||||||
|
"Query": query,
|
||||||
|
"Expansion_issue": []
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||||
|
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build result with intermediate outputs
|
||||||
|
result = {
|
||||||
|
"context": dup_databases,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add intermediate outputs list if they exist
|
||||||
|
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
|
||||||
|
if intermediate_outputs:
|
||||||
|
result['_intermediates'] = intermediate_outputs
|
||||||
|
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
|
||||||
|
else:
|
||||||
|
logger.warning("No intermediate outputs found in dup_databases")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Retrieve failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"context": {
|
||||||
|
"Query": "",
|
||||||
|
"Expansion_issue": []
|
||||||
|
},
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Log execution time
|
||||||
|
end = time.time()
|
||||||
|
try:
|
||||||
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
|
log_time('检索', duration)
|
||||||
647
app/core/memory/agent/mcp_server/tools/summary_tools.py
Normal file
647
app/core/memory/agent/mcp_server/tools/summary_tools.py
Normal file
@@ -0,0 +1,647 @@
|
|||||||
|
"""
|
||||||
|
Summary Tools for data summarization.
|
||||||
|
|
||||||
|
This module contains MCP tools for summarizing retrieved data and generating responses.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from mcp.server.fastmcp import Context
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
|
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||||
|
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||||
|
from app.core.memory.agent.mcp_server.models.summary_models import (
|
||||||
|
SummaryData,
|
||||||
|
SummaryResponse,
|
||||||
|
RetrieveSummaryData,
|
||||||
|
RetrieveSummaryResponse
|
||||||
|
)
|
||||||
|
from app.core.memory.agent.utils.messages_tool import (
|
||||||
|
Summary_messages_deal,
|
||||||
|
Resolve_username
|
||||||
|
)
|
||||||
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 加载.env文件
|
||||||
|
load_dotenv()
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Summary(
|
||||||
|
ctx: Context,
|
||||||
|
context: str,
|
||||||
|
usermessages: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: str = "",
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Summarize the verified data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: JSON string containing verified data
|
||||||
|
usermessages: User messages identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (optional)
|
||||||
|
user_rag_memory_id: User RAG memory identifier (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'status' and 'summary_result'
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
template_service = get_context_resource(ctx, 'template_service')
|
||||||
|
session_service = get_context_resource(ctx, 'session_service')
|
||||||
|
llm_client = get_context_resource(ctx, 'llm_client')
|
||||||
|
|
||||||
|
# Resolve session ID
|
||||||
|
sessionid = Resolve_username(usermessages)
|
||||||
|
|
||||||
|
# Process context to extract answer and query
|
||||||
|
answer_small, query = await Summary_messages_deal(context)
|
||||||
|
|
||||||
|
|
||||||
|
# Get conversation history
|
||||||
|
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||||
|
# Override with empty list for now (as in original)
|
||||||
|
# Prepare data for template
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"history": history,
|
||||||
|
"retrieve_info": answer_small
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Summary: initialization failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"summary_result": "信息不足,无法回答"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Render template
|
||||||
|
system_prompt = await template_service.render_template(
|
||||||
|
template_name='summary_prompt.jinja2',
|
||||||
|
operation_name='summary',
|
||||||
|
data=data,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Template rendering failed for Summary: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Prompt rendering failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call LLM with structured response
|
||||||
|
structured = await llm_client.response_structured(
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
response_model=SummaryResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
aimessages = structured.query_answer or ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"LLM call failed for Summary: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
aimessages = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Save session
|
||||||
|
if aimessages != "":
|
||||||
|
await session_service.save_session(
|
||||||
|
user_id=sessionid,
|
||||||
|
query=query,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
ai_response=aimessages
|
||||||
|
)
|
||||||
|
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Cleanup duplicate sessions
|
||||||
|
await session_service.cleanup_duplicates()
|
||||||
|
|
||||||
|
# Use fallback if empty
|
||||||
|
if aimessages == '':
|
||||||
|
aimessages = '信息不足,无法回答'
|
||||||
|
|
||||||
|
logger.info(f"验证之后的总结==>>:{aimessages}")
|
||||||
|
|
||||||
|
# Log execution time
|
||||||
|
end = time.time()
|
||||||
|
try:
|
||||||
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
|
log_time('总结', duration)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"summary_result": aimessages,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Retrieve_Summary(
|
||||||
|
ctx: Context,
|
||||||
|
context: dict,
|
||||||
|
usermessages: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: str = "",
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Summarize data directly from retrieval results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: Dictionary containing Query and Expansion_issue from Retrieve
|
||||||
|
usermessages: User messages identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (optional)
|
||||||
|
user_rag_memory_id: User RAG memory identifier (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'status' and 'summary_result'
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
template_service = get_context_resource(ctx, 'template_service')
|
||||||
|
session_service = get_context_resource(ctx, 'session_service')
|
||||||
|
llm_client = get_context_resource(ctx, 'llm_client')
|
||||||
|
|
||||||
|
# Resolve session ID
|
||||||
|
sessionid = Resolve_username(usermessages)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||||
|
if isinstance(context, dict):
|
||||||
|
if "content" in context:
|
||||||
|
inner = context["content"]
|
||||||
|
# If it's a JSON string, parse it
|
||||||
|
if isinstance(inner, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(inner)
|
||||||
|
logger.info(f"Retrieve_Summary: successfully parsed JSON")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Try unescaping first
|
||||||
|
try:
|
||||||
|
unescaped = inner.encode('utf-8').decode('unicode_escape')
|
||||||
|
parsed = json.loads(unescaped)
|
||||||
|
logger.info(f"Retrieve_Summary: parsed after unescaping")
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||||
|
logger.error(
|
||||||
|
f"Retrieve_Summary: parsing failed even after unescape: {e}"
|
||||||
|
)
|
||||||
|
context_dict = {"Query": "", "Expansion_issue": []}
|
||||||
|
parsed = None
|
||||||
|
|
||||||
|
if parsed:
|
||||||
|
# Check if parsed has 'context' wrapper
|
||||||
|
if isinstance(parsed, dict) and "context" in parsed:
|
||||||
|
context_dict = parsed["context"]
|
||||||
|
else:
|
||||||
|
context_dict = parsed
|
||||||
|
elif isinstance(inner, dict):
|
||||||
|
context_dict = inner
|
||||||
|
else:
|
||||||
|
context_dict = {"Query": "", "Expansion_issue": []}
|
||||||
|
elif "context" in context:
|
||||||
|
context_dict = context["context"] if isinstance(context["context"], dict) else context
|
||||||
|
else:
|
||||||
|
context_dict = context
|
||||||
|
else:
|
||||||
|
context_dict = {"Query": "", "Expansion_issue": []}
|
||||||
|
|
||||||
|
query = context_dict.get("Query", "")
|
||||||
|
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||||
|
|
||||||
|
# Extract retrieve_info from expansion_issue
|
||||||
|
retrieve_info = []
|
||||||
|
for item in expansion_issue:
|
||||||
|
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
|
||||||
|
answer = None
|
||||||
|
if isinstance(item, dict):
|
||||||
|
if "Answer_Small" in item:
|
||||||
|
answer = item["Answer_Small"]
|
||||||
|
elif "Answer_Samll" in item:
|
||||||
|
answer = item["Answer_Samll"]
|
||||||
|
|
||||||
|
if answer is not None:
|
||||||
|
# Handle both string and list formats
|
||||||
|
if isinstance(answer, list):
|
||||||
|
# Join list of characters/strings into a single string
|
||||||
|
retrieve_info.append(''.join(str(x) for x in answer))
|
||||||
|
elif isinstance(answer, str):
|
||||||
|
retrieve_info.append(answer)
|
||||||
|
else:
|
||||||
|
retrieve_info.append(str(answer))
|
||||||
|
|
||||||
|
# Join all retrieve_info into a single string
|
||||||
|
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
|
||||||
|
|
||||||
|
# Get conversation history
|
||||||
|
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||||
|
# Override with empty list for now (as in original)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Retrieve_Summary: initialization failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"summary_result": "信息不足,无法回答"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Render template
|
||||||
|
system_prompt = await template_service.render_template(
|
||||||
|
template_name='Retrieve_Summary_prompt.jinja2',
|
||||||
|
operation_name='retrieve_summary',
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
retrieve_info=retrieve_info_str
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Template rendering failed for Retrieve_Summary: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Prompt rendering failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call LLM with structured response
|
||||||
|
structured = await llm_client.response_structured(
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
response_model=RetrieveSummaryResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle case where structured response might be None or incomplete
|
||||||
|
if structured and hasattr(structured, 'data') and structured.data:
|
||||||
|
aimessages = structured.data.query_answer or ""
|
||||||
|
else:
|
||||||
|
logger.warning("Structured response is None or incomplete, using default message")
|
||||||
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
|
|
||||||
|
# Check for insufficient information response
|
||||||
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
|
||||||
|
# Save session
|
||||||
|
await session_service.save_session(
|
||||||
|
user_id=sessionid,
|
||||||
|
query=query,
|
||||||
|
apply_id=apply_id,
|
||||||
|
group_id=group_id,
|
||||||
|
ai_response=aimessages
|
||||||
|
)
|
||||||
|
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Retrieve_Summary: LLM call failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
aimessages = ""
|
||||||
|
# Cleanup duplicate sessions
|
||||||
|
await session_service.cleanup_duplicates()
|
||||||
|
|
||||||
|
# Use fallback if empty
|
||||||
|
if aimessages == '':
|
||||||
|
aimessages = '信息不足,无法回答'
|
||||||
|
|
||||||
|
logger.info(f"检索之后的总结==>>:{aimessages}")
|
||||||
|
|
||||||
|
# Log execution time
|
||||||
|
end = time.time()
|
||||||
|
try:
|
||||||
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
|
log_time('检索总结', duration)
|
||||||
|
|
||||||
|
# Emit intermediate output for frontend
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"summary_result": aimessages,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "retrieval_summary",
|
||||||
|
"summary": aimessages,
|
||||||
|
"query": query,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Input_Summary(
|
||||||
|
ctx: Context,
|
||||||
|
context: str,
|
||||||
|
usermessages: str,
|
||||||
|
search_switch: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: str = "",
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Generate a quick summary for direct input without verification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: String containing the input sentence
|
||||||
|
usermessages: User messages identifier
|
||||||
|
search_switch: Search switch value for routing ('2' for summaries only)
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||||
|
user_rag_memory_id: User RAG memory identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'query_answer' with the summary result
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
|
||||||
|
# Initialize variables to avoid UnboundLocalError
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
template_service = get_context_resource(ctx, 'template_service')
|
||||||
|
session_service = get_context_resource(ctx, 'session_service')
|
||||||
|
llm_client = get_context_resource(ctx, 'llm_client')
|
||||||
|
search_service = get_context_resource(ctx, 'search_service')
|
||||||
|
|
||||||
|
# Check if llm_client is None
|
||||||
|
if llm_client is None:
|
||||||
|
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
|
||||||
|
logger.error(error_msg)
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
# Resolve session ID
|
||||||
|
sessionid = Resolve_username(usermessages) or ""
|
||||||
|
sessionid = sessionid.replace('call_id_', '')
|
||||||
|
|
||||||
|
# Get conversation history
|
||||||
|
history = await session_service.get_history(
|
||||||
|
str(sessionid),
|
||||||
|
str(apply_id),
|
||||||
|
str(group_id)
|
||||||
|
)
|
||||||
|
# Override with empty list for now (as in original)
|
||||||
|
|
||||||
|
# Log the raw context for debugging
|
||||||
|
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
|
||||||
|
|
||||||
|
# Extract sentence from context
|
||||||
|
# Context can be a string or might contain the sentence in various formats
|
||||||
|
try:
|
||||||
|
# Try to parse as JSON first
|
||||||
|
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
context_dict = json.loads(context)
|
||||||
|
if isinstance(context_dict, dict):
|
||||||
|
query = context_dict.get('sentence', context_dict.get('content', context))
|
||||||
|
else:
|
||||||
|
query = context
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Not valid JSON, try regex
|
||||||
|
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
|
||||||
|
query = match.group(1) if match else context
|
||||||
|
else:
|
||||||
|
query = context
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to extract query from context: {e}")
|
||||||
|
query = context
|
||||||
|
|
||||||
|
# Clean query
|
||||||
|
query = str(query).strip().strip("\"'")
|
||||||
|
|
||||||
|
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
|
||||||
|
|
||||||
|
# Execute search based on search_switch and storage_type
|
||||||
|
try:
|
||||||
|
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
|
||||||
|
|
||||||
|
# Prepare search parameters based on storage type
|
||||||
|
search_params = {
|
||||||
|
"group_id": group_id,
|
||||||
|
"question": query,
|
||||||
|
"return_raw_results": True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add storage-specific parameters
|
||||||
|
|
||||||
|
'''检索'''
|
||||||
|
if search_switch == '2':
|
||||||
|
search_params["include"] = ["summaries"]
|
||||||
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
|
raw_results = []
|
||||||
|
retrieve_info = ""
|
||||||
|
kb_config={
|
||||||
|
"knowledge_bases": [
|
||||||
|
{
|
||||||
|
"kb_id": user_rag_memory_id,
|
||||||
|
"similarity_threshold": 0.7,
|
||||||
|
"vector_similarity_weight": 0.5,
|
||||||
|
"top_k": 10,
|
||||||
|
"retrieve_type": "participle"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"merge_strategy": "weight",
|
||||||
|
"reranker_id":os.getenv('reranker_id'),
|
||||||
|
"reranker_top_k": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||||
|
try:
|
||||||
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
|
retrieve_info = '\n\n'.join(retrieval_knowledge)
|
||||||
|
raw_results=[retrieve_info]
|
||||||
|
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
|
except:
|
||||||
|
retrieve_info=''
|
||||||
|
raw_results=['']
|
||||||
|
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
|
||||||
|
else:
|
||||||
|
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||||
|
logger.info(f"Input_Summary: 使用 summary 进行检索")
|
||||||
|
else:
|
||||||
|
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Input_Summary: hybrid_search failed, using empty results: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
retrieve_info, question, raw_results = "", query, []
|
||||||
|
|
||||||
|
|
||||||
|
# Render template
|
||||||
|
system_prompt = await template_service.render_template(
|
||||||
|
template_name='Retrieve_Summary_prompt.jinja2',
|
||||||
|
operation_name='input_summary',
|
||||||
|
query=query,
|
||||||
|
history=history,
|
||||||
|
retrieve_info=retrieve_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call LLM with structured response
|
||||||
|
try:
|
||||||
|
structured = await llm_client.response_structured(
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
response_model=RetrieveSummaryResponse
|
||||||
|
)
|
||||||
|
aimessages = structured.data.query_answer or "信息不足,无法回答"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Input_Summary: response_structured failed, using default answer: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
|
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||||
|
|
||||||
|
# Emit intermediate output for frontend
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"summary_result": aimessages,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "input_summary",
|
||||||
|
"title": "快速答案",
|
||||||
|
"summary": aimessages,
|
||||||
|
"query": query,
|
||||||
|
"raw_results": raw_results,
|
||||||
|
"search_mode": "quick_search",
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Input_Summary failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "fail",
|
||||||
|
"summary_result": "信息不足,无法回答",
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Log execution time
|
||||||
|
end = time.time()
|
||||||
|
try:
|
||||||
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
|
log_time('检索', duration)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Summary_fails(
|
||||||
|
ctx: Context,
|
||||||
|
context: str,
|
||||||
|
usermessages: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: str = "",
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Handle workflow failure when summary cannot be generated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: Failure context string
|
||||||
|
usermessages: User messages identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (optional)
|
||||||
|
user_rag_memory_id: User RAG memory identifier (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'query_answer' with failure message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
session_service = get_context_resource(ctx, 'session_service')
|
||||||
|
|
||||||
|
# Parse session ID from usermessages
|
||||||
|
usermessages_parts = usermessages.split('_')[1:]
|
||||||
|
sessionid = '_'.join(usermessages_parts[:-1])
|
||||||
|
|
||||||
|
# Cleanup duplicate sessions
|
||||||
|
await session_service.cleanup_duplicates()
|
||||||
|
|
||||||
|
logger.info(f"没有相关数据")
|
||||||
|
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"summary_result": "没有相关数据",
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Summary_fails failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "fail",
|
||||||
|
"summary_result": "没有相关数据",
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
169
app/core/memory/agent/mcp_server/tools/verification_tools.py
Normal file
169
app/core/memory/agent/mcp_server/tools/verification_tools.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Verification Tools for data verification.
|
||||||
|
|
||||||
|
This module contains MCP tools for verifying retrieved data.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
from mcp.server.fastmcp import Context
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
|
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||||
|
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||||
|
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||||
|
from app.core.memory.agent.utils.messages_tool import (
|
||||||
|
Verify_messages_deal,
|
||||||
|
Retrieve_verify_tool_messages_deal,
|
||||||
|
Resolve_username
|
||||||
|
)
|
||||||
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def Verify(
|
||||||
|
ctx: Context,
|
||||||
|
context: dict,
|
||||||
|
usermessages: str,
|
||||||
|
apply_id: str,
|
||||||
|
group_id: str,
|
||||||
|
storage_type: str = "",
|
||||||
|
user_rag_memory_id: str = ""
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Verify the retrieved data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: FastMCP context for dependency injection
|
||||||
|
context: Dictionary containing query and expansion issues
|
||||||
|
usermessages: User messages identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
group_id: Group identifier
|
||||||
|
storage_type: Storage type for the workspace (optional)
|
||||||
|
user_rag_memory_id: User RAG memory identifier (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'status' and 'verified_data' with verification results
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract services from context
|
||||||
|
session_service = get_context_resource(ctx, 'session_service')
|
||||||
|
|
||||||
|
# Load verification prompt template
|
||||||
|
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
|
||||||
|
|
||||||
|
# Read template file directly (VerifyTool expects raw template content)
|
||||||
|
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||||
|
system_prompt = await read_template_file(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Resolve session ID
|
||||||
|
sessionid = Resolve_username(usermessages)
|
||||||
|
|
||||||
|
# Get conversation history
|
||||||
|
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||||
|
|
||||||
|
template = Template(system_prompt)
|
||||||
|
system_prompt = template.render(history=history, sentence=context)
|
||||||
|
|
||||||
|
# Process context to extract query and results
|
||||||
|
Query_small, Result_small, query = await Verify_messages_deal(context)
|
||||||
|
|
||||||
|
# Build query list for verification
|
||||||
|
query_list = []
|
||||||
|
for query_small, anser in zip(Query_small, Result_small):
|
||||||
|
query_list.append({
|
||||||
|
'Query_small': query_small,
|
||||||
|
'Answer_Small': anser
|
||||||
|
})
|
||||||
|
|
||||||
|
messages = {
|
||||||
|
"Query": query,
|
||||||
|
"Expansion_issue": query_list
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Call verification workflow
|
||||||
|
verify_tool = VerifyTool(system_prompt, messages)
|
||||||
|
verify_result = await verify_tool.verify()
|
||||||
|
|
||||||
|
# Parse LLM verification result with error handling
|
||||||
|
try:
|
||||||
|
messages_deal = await Retrieve_verify_tool_messages_deal(
|
||||||
|
verify_result,
|
||||||
|
history,
|
||||||
|
query
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# Fallback to avoid 500 errors
|
||||||
|
messages_deal = {
|
||||||
|
"data": {
|
||||||
|
"query": query,
|
||||||
|
"expansion_issue": []
|
||||||
|
},
|
||||||
|
"split_result": "failed",
|
||||||
|
"reason": str(e),
|
||||||
|
"history": history,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"验证==>>:{messages_deal}")
|
||||||
|
|
||||||
|
# Emit intermediate output for frontend
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"verified_data": messages_deal,
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "verification",
|
||||||
|
"title": "数据验证",
|
||||||
|
"result": messages_deal.get("split_result", "unknown"),
|
||||||
|
"reason": messages_deal.get("reason", ""),
|
||||||
|
"query": query,
|
||||||
|
"verified_count": len(query_list),
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Verify failed: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"verified_data": {
|
||||||
|
"data": {
|
||||||
|
"query": "",
|
||||||
|
"expansion_issue": []
|
||||||
|
},
|
||||||
|
"split_result": "failed",
|
||||||
|
"reason": str(e),
|
||||||
|
"history": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Log execution time
|
||||||
|
end = time.time()
|
||||||
|
try:
|
||||||
|
duration = end - start
|
||||||
|
except Exception:
|
||||||
|
duration = 0.0
|
||||||
|
log_time('验证', duration)
|
||||||
7
app/core/memory/agent/utils/__init__.py
Normal file
7
app/core/memory/agent/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Agent utilities."""
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MultimodalProcessor",
|
||||||
|
]
|
||||||
70
app/core/memory/agent/utils/get_dialogs.py
Normal file
70
app/core/memory/agent/utils/get_dialogs.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||||
|
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||||
|
|
||||||
|
|
||||||
|
async def get_chunked_dialogs(
|
||||||
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
|
group_id: str = "group_1",
|
||||||
|
user_id: str = "user1",
|
||||||
|
apply_id: str = "applyid",
|
||||||
|
content: str = "这是用户的输入",
|
||||||
|
ref_id: str = "wyl_20251027",
|
||||||
|
config_id: str = None
|
||||||
|
) -> List[DialogData]:
|
||||||
|
"""Generate chunks from all test data entries using the specified chunker strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||||
|
group_id: Group identifier
|
||||||
|
user_id: User identifier
|
||||||
|
apply_id: Application identifier
|
||||||
|
content: Dialog content
|
||||||
|
ref_id: Reference identifier
|
||||||
|
config_id: Configuration ID for processing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of DialogData objects with generated chunks for each test entry
|
||||||
|
"""
|
||||||
|
dialog_data_list = []
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
messages.append(ConversationMessage(role="用户", msg=content))
|
||||||
|
|
||||||
|
# Create DialogData
|
||||||
|
conversation_context = ConversationContext(msgs=messages)
|
||||||
|
# Create DialogData with group_id based on the entry's id for uniqueness
|
||||||
|
dialog_data = DialogData(
|
||||||
|
context=conversation_context,
|
||||||
|
ref_id=ref_id,
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=user_id,
|
||||||
|
apply_id=apply_id,
|
||||||
|
config_id=config_id
|
||||||
|
)
|
||||||
|
# Create DialogueChunker and process the dialogue
|
||||||
|
chunker = DialogueChunker(chunker_strategy)
|
||||||
|
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||||
|
dialog_data.chunks = extracted_chunks
|
||||||
|
|
||||||
|
dialog_data_list.append(dialog_data)
|
||||||
|
|
||||||
|
# Convert to dict with datetime serialized
|
||||||
|
def serialize_datetime(obj):
|
||||||
|
if isinstance(obj, datetime):
|
||||||
|
return obj.isoformat()
|
||||||
|
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||||
|
|
||||||
|
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||||
|
|
||||||
|
print(dialog_data_list)
|
||||||
|
|
||||||
|
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
|
||||||
|
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
|
||||||
|
|
||||||
|
|
||||||
|
return dialog_data_list
|
||||||
204
app/core/memory/agent/utils/llm_tools.py
Normal file
204
app/core/memory/agent/utils/llm_tools.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import TypedDict, Annotated
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
from langchain_core.messages import AnyMessage
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from langgraph.graph import add_messages
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||||
|
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
|
||||||
|
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||||
|
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||||
|
|
||||||
|
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
#TODO: Refactor entire picture/voice
|
||||||
|
# async def LLM_model_request(context,data,query):
|
||||||
|
# '''
|
||||||
|
# Agent model request
|
||||||
|
# Args:
|
||||||
|
# context:Input request
|
||||||
|
# data: template parameters
|
||||||
|
# query:request content
|
||||||
|
# Returns:
|
||||||
|
|
||||||
|
# '''
|
||||||
|
# template = Template(context)
|
||||||
|
# system_prompt = template.render(**data)
|
||||||
|
# llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||||
|
# result = await llm_client.chat(
|
||||||
|
# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}]
|
||||||
|
# )
|
||||||
|
# return result
|
||||||
|
|
||||||
|
async def picture_model_requests(image_url):
|
||||||
|
'''
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url:
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
'''
|
||||||
|
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
|
||||||
|
system_prompt = await read_template_file(file_path)
|
||||||
|
result = await Picture_recognize(image_url,system_prompt)
|
||||||
|
return (result)
|
||||||
|
class WriteState(TypedDict):
|
||||||
|
'''
|
||||||
|
Langgrapg Writing TypedDict
|
||||||
|
'''
|
||||||
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
|
user_id:str
|
||||||
|
apply_id:str
|
||||||
|
group_id:str
|
||||||
|
|
||||||
|
class ReadState(TypedDict):
|
||||||
|
'''
|
||||||
|
Langgrapg READING TypedDict
|
||||||
|
name:
|
||||||
|
id:user id
|
||||||
|
loop_count:Traverse times
|
||||||
|
search_switch:type
|
||||||
|
config_id: configuration id for filtering results
|
||||||
|
'''
|
||||||
|
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
|
||||||
|
name: str
|
||||||
|
id: str
|
||||||
|
loop_count:int
|
||||||
|
search_switch: str
|
||||||
|
user_id: str
|
||||||
|
apply_id: str
|
||||||
|
group_id: str
|
||||||
|
config_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class COUNTState:
|
||||||
|
'''
|
||||||
|
The number of times the workflow dialogue retrieval content has no correct message recall traversal
|
||||||
|
'''
|
||||||
|
def __init__(self, limit: int = 5):
|
||||||
|
self.total: int = 0 # 当前累加值
|
||||||
|
self.limit: int = limit # 最大上限
|
||||||
|
|
||||||
|
def add(self, value: int = 1):
|
||||||
|
"""累加数字,如果达到上限就保持最大值"""
|
||||||
|
self.total += value
|
||||||
|
print(f"[COUNTState] 当前值: {self.total}")
|
||||||
|
if self.total >= self.limit:
|
||||||
|
print(f"[COUNTState] 达到上限 {self.limit}")
|
||||||
|
self.total = self.limit # 达到上限不再增加
|
||||||
|
|
||||||
|
def get_total(self) -> int:
|
||||||
|
"""获取当前累加值"""
|
||||||
|
return self.total
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""手动重置累加值"""
|
||||||
|
self.total = 0
|
||||||
|
print(f"[COUNTState] 已重置为 0")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def embed(texts: list[str]) -> list[list[float]]:
|
||||||
|
# # 这里可以换成 LangChain Embeddings
|
||||||
|
# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts]
|
||||||
|
|
||||||
|
|
||||||
|
# def export_store_to_json(store, namespace):
|
||||||
|
# """Export the entire storage content to a JSON file"""
|
||||||
|
# # 搜索所有存储项
|
||||||
|
# all_items = store.search(namespace)
|
||||||
|
|
||||||
|
# # 整理数据
|
||||||
|
# export_data = {}
|
||||||
|
# for item in all_items:
|
||||||
|
# if hasattr(item, 'key') and hasattr(item, 'value'):
|
||||||
|
# export_data[item.key] = item.value
|
||||||
|
|
||||||
|
# # 保存到文件
|
||||||
|
# os.makedirs("memory_logs", exist_ok=True)
|
||||||
|
# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f:
|
||||||
|
# json.dump(export_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# print(f"{len(export_data)} 条记忆到 JSON 文件")
|
||||||
|
|
||||||
|
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||||
|
grouped = defaultdict(list)
|
||||||
|
for item in data:
|
||||||
|
grouped[item[query_key]].append(item[result_key])
|
||||||
|
return [{key: values} for key, values in grouped.items()]
|
||||||
|
|
||||||
|
def deduplicate_entries(entries):
|
||||||
|
seen = set()
|
||||||
|
deduped = []
|
||||||
|
for entry in entries:
|
||||||
|
key = (entry['Query_small'], entry['Result_small'])
|
||||||
|
if key not in seen:
|
||||||
|
seen.add(key)
|
||||||
|
deduped.append(entry)
|
||||||
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
|
||||||
|
try:
|
||||||
|
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
|
||||||
|
except Exception as e:
|
||||||
|
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||||
|
logger.error(err)
|
||||||
|
return err
|
||||||
|
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||||
|
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||||
|
api_base=model_config['api_base']
|
||||||
|
|
||||||
|
logger.info(f"model_name: {backend_model_name}")
|
||||||
|
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
|
||||||
|
logger.info(f"base_url: {model_config['api_base']}")
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=api_key, base_url=api_base,
|
||||||
|
)
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=backend_model_name,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url":image_path,
|
||||||
|
},
|
||||||
|
{"type": "text",
|
||||||
|
"text": PROMPT_TICKET_EXTRACTION}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
])
|
||||||
|
picture_text = completion.choices[0].message.content
|
||||||
|
picture_text = picture_text.replace('```json', '').replace('```', '')
|
||||||
|
picture_text = json.loads(picture_text)
|
||||||
|
return (picture_text['statement'])
|
||||||
|
|
||||||
|
async def Voice_recognize():
|
||||||
|
try:
|
||||||
|
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
|
||||||
|
except Exception as e:
|
||||||
|
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||||
|
logger.error(err)
|
||||||
|
return err
|
||||||
|
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||||
|
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||||
|
api_base = model_config['api_base']
|
||||||
|
return api_key,backend_model_name,api_base
|
||||||
|
|
||||||
|
|
||||||
15
app/core/memory/agent/utils/mcp_tools.py
Normal file
15
app/core/memory/agent/utils/mcp_tools.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
def get_mcp_server_config():
|
||||||
|
"""
|
||||||
|
Get the MCP server configuration
|
||||||
|
"""
|
||||||
|
mcp_server_config = {
|
||||||
|
"data_flow": {
|
||||||
|
"url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口
|
||||||
|
"transport": "sse",
|
||||||
|
"timeout": 15000,
|
||||||
|
"sse_read_timeout": 15000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mcp_server_config
|
||||||
239
app/core/memory/agent/utils/messages_tool.py
Normal file
239
app/core/memory/agent/utils/messages_tool.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AnyMessage
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
|
||||||
|
out = []
|
||||||
|
for m in msgs:
|
||||||
|
if hasattr(m, "content"):
|
||||||
|
out.append({"role": "user", "content": getattr(m, "content", "")})
|
||||||
|
elif isinstance(m, dict) and "role" in m and "content" in m:
|
||||||
|
out.append(m)
|
||||||
|
else:
|
||||||
|
out.append({"role": "user", "content": str(m)})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content(resp: Any) -> str:
|
||||||
|
"""Extract LLM content and sanitize to raw JSON/text.
|
||||||
|
|
||||||
|
- Supports both object and dict response shapes.
|
||||||
|
- Removes leading role labels (e.g., "Assistant:").
|
||||||
|
- Strips Markdown code fences like ```json ... ```.
|
||||||
|
- Attempts to isolate the first valid JSON array/object block when extra text is present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _to_text(r: Any) -> str:
|
||||||
|
try:
|
||||||
|
# 对象形式: resp.choices[0].message.content
|
||||||
|
if hasattr(r, "choices") and getattr(r, "choices", None):
|
||||||
|
msg = r.choices[0].message
|
||||||
|
if hasattr(msg, "content"):
|
||||||
|
return msg.content
|
||||||
|
if isinstance(msg, dict) and "content" in msg:
|
||||||
|
return msg["content"]
|
||||||
|
# 字典形式: resp["choices"][0]["message"]["content"]
|
||||||
|
if isinstance(r, dict):
|
||||||
|
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return str(r)
|
||||||
|
|
||||||
|
def _clean_text(text: str) -> str:
|
||||||
|
s = str(text).strip()
|
||||||
|
# 移除可能的角色前缀
|
||||||
|
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
|
||||||
|
# 提取 ```json ... ``` 代码块
|
||||||
|
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
|
||||||
|
if m:
|
||||||
|
s = m.group(1).strip()
|
||||||
|
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
|
||||||
|
if not (s.startswith("{") or s.startswith("[")):
|
||||||
|
left = s.find("[")
|
||||||
|
right = s.rfind("]")
|
||||||
|
if left != -1 and right != -1 and right > left:
|
||||||
|
s = s[left:right + 1].strip()
|
||||||
|
else:
|
||||||
|
left = s.find("{")
|
||||||
|
right = s.rfind("}")
|
||||||
|
if left != -1 and right != -1 and right > left:
|
||||||
|
s = s[left:right + 1].strip()
|
||||||
|
return s
|
||||||
|
|
||||||
|
raw = _to_text(resp)
|
||||||
|
return _clean_text(raw)
|
||||||
|
|
||||||
|
def Resolve_username(usermessages):
|
||||||
|
'''
|
||||||
|
Extract username
|
||||||
|
Args:
|
||||||
|
usermessages: user name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
'''
|
||||||
|
usermessages = usermessages.split('_')[1:]
|
||||||
|
sessionid = '_'.join(usermessages[:-1])
|
||||||
|
return sessionid
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: USE app.core.memory.src.utils.render_template instead
|
||||||
|
async def read_template_file(template_path: str) -> str:
|
||||||
|
"""
|
||||||
|
读取模板文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_path: 模板文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板内容字符串
|
||||||
|
|
||||||
|
Note:
|
||||||
|
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(template_path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"模板文件未找到: {template_path}")
|
||||||
|
raise
|
||||||
|
except IOError as e:
|
||||||
|
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def Problem_Extension_messages_deal(context):
|
||||||
|
'''
|
||||||
|
Extract data
|
||||||
|
Args:
|
||||||
|
context:
|
||||||
|
Returns:
|
||||||
|
'''
|
||||||
|
extent_quest = []
|
||||||
|
original = context.get('original', '')
|
||||||
|
messages = context.get('context', '')
|
||||||
|
messages = json.loads(messages)
|
||||||
|
for message in messages:
|
||||||
|
question = message.get('question', '')
|
||||||
|
type = message.get('type', '')
|
||||||
|
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||||
|
|
||||||
|
return extent_quest, original
|
||||||
|
|
||||||
|
|
||||||
|
async def Retriev_messages_deal(context):
|
||||||
|
'''
|
||||||
|
Extract data
|
||||||
|
Args:
|
||||||
|
context:
|
||||||
|
Returns:
|
||||||
|
'''
|
||||||
|
if isinstance(context, dict):
|
||||||
|
if 'context' in context or 'original' in context:
|
||||||
|
return context.get('context', {}), context.get('original', '')
|
||||||
|
return content, original_value
|
||||||
|
|
||||||
|
async def Verify_messages_deal(context):
|
||||||
|
'''
|
||||||
|
Extract data
|
||||||
|
Args:
|
||||||
|
context:
|
||||||
|
Returns:
|
||||||
|
'''
|
||||||
|
|
||||||
|
query = context['context']['Query']
|
||||||
|
Query_small_list = context['context']['Expansion_issue']
|
||||||
|
Result_small = []
|
||||||
|
Query_small = []
|
||||||
|
for i in Query_small_list:
|
||||||
|
Result_small.append(i['Answer_Small'][0])
|
||||||
|
Query_small.append(i['Query_small'])
|
||||||
|
return Query_small, Result_small, query
|
||||||
|
|
||||||
|
|
||||||
|
async def Summary_messages_deal(context):
|
||||||
|
'''
|
||||||
|
Extract data
|
||||||
|
Args:
|
||||||
|
context:
|
||||||
|
Returns:
|
||||||
|
'''
|
||||||
|
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||||
|
query = re.findall(r'"query": (.*?),', messages)[0]
|
||||||
|
query = query.replace('[', '').replace(']', '').strip()
|
||||||
|
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
|
||||||
|
answer_small_texts = []
|
||||||
|
for m in matches:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(m)
|
||||||
|
for item in parsed:
|
||||||
|
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||||
|
except Exception:
|
||||||
|
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||||
|
|
||||||
|
return answer_small_texts, query
|
||||||
|
|
||||||
|
|
||||||
|
async def VerifyTool_messages_deal(context):
|
||||||
|
'''
|
||||||
|
Extract data
|
||||||
|
Args:
|
||||||
|
context:
|
||||||
|
Returns:
|
||||||
|
'''
|
||||||
|
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||||
|
content_messages = messages.split('"context":')[1].replace('""', '"')
|
||||||
|
messages = str(content_messages).split("name='Retrieve'")[0]
|
||||||
|
query = re.findall(f'"Query": "(.*?)"', messages)[0]
|
||||||
|
Query_small = re.findall(f'"Query_small": "(.*?)"', messages)
|
||||||
|
Result_small = re.findall(f'"Result_small": "(.*?)"', messages)
|
||||||
|
return Query_small, Result_small, query
|
||||||
|
|
||||||
|
|
||||||
|
async def Retrieve_Summary_messages_deal(context):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def Retrieve_verify_tool_messages_deal(context, history, query):
|
||||||
|
'''
|
||||||
|
Extract data
|
||||||
|
Args:
|
||||||
|
context:
|
||||||
|
Returns:
|
||||||
|
'''
|
||||||
|
results = []
|
||||||
|
# 统一转为字符串,避免 None 或非字符串导致正则报错
|
||||||
|
text = str(context)
|
||||||
|
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
|
||||||
|
for block in blocks:
|
||||||
|
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
|
||||||
|
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
|
||||||
|
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
|
||||||
|
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"query_small": query_small.group(1) if query_small else None,
|
||||||
|
"answer_small": answer_small.group(1) if answer_small else None,
|
||||||
|
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
|
||||||
|
"status": status.group(1) if status else "",
|
||||||
|
"query_answer": query_answer.group(1) if query_answer else None
|
||||||
|
})
|
||||||
|
result = []
|
||||||
|
for r in results:
|
||||||
|
# 统一按字符串判定状态,兼容大小写和缺失情况
|
||||||
|
status_str = str(r.get('status', '')).strip().lower()
|
||||||
|
if status_str == 'false':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
result.append(r)
|
||||||
|
split_result = 'failed' if not result else 'success'
|
||||||
|
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
|
||||||
|
"history": history}
|
||||||
|
return result
|
||||||
38
app/core/memory/agent/utils/model_tool.py
Normal file
38
app/core/memory/agent/utils/model_tool.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
# sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
# load_dotenv()
|
||||||
|
|
||||||
|
# async def llm_client_chat(messages: List[dict]) -> str:
|
||||||
|
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
|
||||||
|
# try:
|
||||||
|
# cfg = get_model_config(SELECTED_LLM_ID)
|
||||||
|
# rb_config = RedBearModelConfig(
|
||||||
|
# model_name=cfg["model_name"],
|
||||||
|
# provider=cfg["provider"],
|
||||||
|
# api_key=cfg["api_key"],
|
||||||
|
# base_url=cfg["base_url"],
|
||||||
|
# )
|
||||||
|
# client = OpenAIClient(model_config=rb_config, type_="chat")
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"获取模型配置失败:{e}")
|
||||||
|
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
|
||||||
|
# return err
|
||||||
|
# try:
|
||||||
|
# response = await client.chat(messages)
|
||||||
|
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
|
||||||
|
# return _extract_content(response)
|
||||||
|
# # return _extract_content(result)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
|
||||||
|
# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
|
||||||
|
|
||||||
|
# async def main(image_url):
|
||||||
|
# await llm_client_chat(image_url)
|
||||||
|
#
|
||||||
|
# # 运行主函数
|
||||||
|
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
|
||||||
|
#
|
||||||
131
app/core/memory/agent/utils/multimodal.py
Normal file
131
app/core/memory/agent/utils/multimodal.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
Multimodal input processor for handling image and audio content.
|
||||||
|
|
||||||
|
This module provides utilities for detecting and processing multimodal inputs
|
||||||
|
(images and audio files) by converting them to text using appropriate models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
|
||||||
|
from app.core.memory.agent.utils.llm_tools import picture_model_requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalProcessor:
|
||||||
|
"""
|
||||||
|
Processor for handling multimodal inputs (images and audio).
|
||||||
|
|
||||||
|
This class detects image and audio file paths in input content and converts
|
||||||
|
them to text using appropriate recognition models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Supported file extensions
|
||||||
|
IMAGE_EXTENSIONS = ['.jpg', '.png']
|
||||||
|
AUDIO_EXTENSIONS = [
|
||||||
|
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
|
||||||
|
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the multimodal processor."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_image(self, content: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if content is an image file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Input string to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if content ends with a supported image extension
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> processor = MultimodalProcessor()
|
||||||
|
>>> processor.is_image("photo.jpg")
|
||||||
|
True
|
||||||
|
>>> processor.is_image("document.pdf")
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
content_lower = content.lower()
|
||||||
|
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
|
||||||
|
|
||||||
|
def is_audio(self, content: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if content is an audio file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Input string to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if content ends with a supported audio extension
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> processor = MultimodalProcessor()
|
||||||
|
>>> processor.is_audio("recording.mp3")
|
||||||
|
True
|
||||||
|
>>> processor.is_audio("video.mp4")
|
||||||
|
True
|
||||||
|
>>> processor.is_audio("document.txt")
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
content_lower = content.lower()
|
||||||
|
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
|
||||||
|
|
||||||
|
async def process_input(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Process input content, converting images/audio to text if needed.
|
||||||
|
|
||||||
|
This method detects if the input is an image or audio file and converts
|
||||||
|
it to text using the appropriate recognition model. If processing fails
|
||||||
|
or the content is not multimodal, it returns the original content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Input string (may be file path or regular text)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text content (original or converted from image/audio)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> processor = MultimodalProcessor()
|
||||||
|
>>> await processor.process_input("photo.jpg")
|
||||||
|
"Recognized text from image..."
|
||||||
|
|
||||||
|
>>> await processor.process_input("Hello world")
|
||||||
|
"Hello world"
|
||||||
|
"""
|
||||||
|
if not isinstance(content, str):
|
||||||
|
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check for image input
|
||||||
|
if self.is_image(content):
|
||||||
|
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
|
||||||
|
result = await picture_model_requests(content)
|
||||||
|
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Check for audio input
|
||||||
|
if self.is_audio(content):
|
||||||
|
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
|
||||||
|
result = await Vico_recognition([content]).run()
|
||||||
|
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
|
||||||
|
logger.info(f"[MultimodalProcessor] Falling back to original content")
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Return original content if not multimodal
|
||||||
|
return content
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
|
||||||
|
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
|
||||||
|
|
||||||
|
角色:
|
||||||
|
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
|
||||||
|
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
|
||||||
|
- 如果历史信息或上下文与当前问题无关,可忽略。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 历史信息参考
|
||||||
|
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||||
|
- 历史对话或任务的主题;
|
||||||
|
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||||
|
- 历史中已解答的问题(避免重复);
|
||||||
|
- 历史推理链(保持逻辑一致性)。
|
||||||
|
|
||||||
|
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||||
|
输入历史信息内容:{{history}}
|
||||||
|
|
||||||
|
## User Input
|
||||||
|
{% if questions is string %}
|
||||||
|
{{ questions }}
|
||||||
|
{% else %}
|
||||||
|
{% for question in questions %}
|
||||||
|
- {{ question }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
需求:
|
||||||
|
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
|
||||||
|
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
|
||||||
|
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
|
||||||
|
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
|
||||||
|
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
|
||||||
|
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
|
||||||
|
- 子问题数量不超过4个。
|
||||||
|
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||||
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
输出要求:
|
||||||
|
- 仅输出 JSON 数组,不要包含任何解释或代码块。
|
||||||
|
- 每个元素包含:
|
||||||
|
- `original_question`: 原始问题
|
||||||
|
- `extended_question`: 扩展后的问题
|
||||||
|
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
|
||||||
|
- `reason`: 生成该扩展问题的简短理由
|
||||||
|
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
输入:
|
||||||
|
[
|
||||||
|
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
|
||||||
|
]
|
||||||
|
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||||
|
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
|
||||||
|
"type": "多跳",
|
||||||
|
"reason": "输出原问题的关键要素"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||||
|
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
|
||||||
|
"type": "多跳",
|
||||||
|
"reason": "输出原问题的关键要素"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
**Output format**
|
||||||
|
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||||
|
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||||
|
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||||
|
3. Ensure all JSON strings are properly closed and comma-separated
|
||||||
|
4. Do not include line breaks within JSON string values
|
||||||
|
|
||||||
|
The output language should always be the same as the input language.{{ json_schema }}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
# 角色
|
||||||
|
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
|
||||||
|
|
||||||
|
# 任务
|
||||||
|
根据提供的上下文信息回答用户的问题。
|
||||||
|
|
||||||
|
# 输入信息
|
||||||
|
- 历史对话:{{history}}
|
||||||
|
- 检索信息:{{retrieve_info}}
|
||||||
|
|
||||||
|
## User Query
|
||||||
|
{{query}}
|
||||||
|
|
||||||
|
# 回答指南
|
||||||
|
1. 仔细分析用户的问题
|
||||||
|
2. 优先使用检索信息中的相关内容回答
|
||||||
|
3. 结合历史对话提供连贯的回复
|
||||||
|
4. 如果信息不足:
|
||||||
|
- 对于简单问候或日常对话,给出自然简短的回复
|
||||||
|
- 对于复杂问题,诚实说明信息不足
|
||||||
|
5. 保持回答简洁、相关、自然
|
||||||
|
6. 使用与问题相同的语言回答
|
||||||
|
|
||||||
|
**Output format**
|
||||||
|
- 直接回答问题,像人类对话一样自然流畅
|
||||||
|
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
|
||||||
|
- 不要解释推理过程或评论信息来源
|
||||||
|
- 如果只能部分回答问题,先回答能回答的部分,然后说明哪些方面信息不足
|
||||||
|
- 如果完全无法回答,简洁地说明:"信息不足,无法回答。"
|
||||||
|
|
||||||
|
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||||
|
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||||
|
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||||
|
3. Ensure all JSON strings are properly closed and comma-separated
|
||||||
|
4. Do not include line breaks within JSON string values
|
||||||
|
|
||||||
|
The output language should always be the same as the input language.{{ json_schema }}
|
||||||
29
app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2
Normal file
29
app/core/memory/agent/utils/prompt/Retrieve_prompt.jinja2
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
|
||||||
|
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
|
||||||
|
你是一个智能问答助手,任务如下
|
||||||
|
## 目标:
|
||||||
|
|
||||||
|
1. 接收一个字典,格式为 {'问题': [答案列表]}。
|
||||||
|
2. 接收一个问题(字典中的 key)。
|
||||||
|
3. 找到与问题匹配的答案列表。
|
||||||
|
4. 将答案列表合并成一句自然流畅的话:
|
||||||
|
- 如果答案有两条,使用“是”连接,例如:“A,是B”。
|
||||||
|
- 如果答案有三条或以上,使用“,并且”“另外”等自然连词,保证句子流畅。
|
||||||
|
5. 输出内容时只输出合并后的答案,不输出关键点或其他文字。
|
||||||
|
6. 如果问题未在字典中找到对应答案,请输出:
|
||||||
|
对不起,我没有找到相关信息。
|
||||||
|
|
||||||
|
|
||||||
|
输出要求:
|
||||||
|
- 文本形式
|
||||||
|
---
|
||||||
|
|
||||||
|
字典示例:
|
||||||
|
{
|
||||||
|
'今天的天气怎么样': ['今天天气很好', '今天是晴天']
|
||||||
|
}
|
||||||
|
|
||||||
|
问题示例:
|
||||||
|
今天的天气怎么样
|
||||||
|
输出要求:
|
||||||
|
今天天气很好,是晴天
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
请提图像内的文本
|
||||||
|
返回数据格式以json方式输出,
|
||||||
|
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||||
|
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||||
|
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||||
|
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||||
|
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||||
|
4.JSON字符串值中不包括换行符
|
||||||
|
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||||
|
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
你是一个输入分类助手,负责判断用户输入的意图类型。
|
||||||
|
|
||||||
|
## User Input
|
||||||
|
{{ user_query }}
|
||||||
|
|
||||||
|
请你根据以下规则判断:
|
||||||
|
1. 如果输入是在寻求信息、提问、请求解释、或疑问句(包括隐含的问题),则分类为 "question"。
|
||||||
|
2. 如果输入是命令、陈述、描述、感叹、或其他类型,不在寻求答案,则分类为 "other"。
|
||||||
|
只输出:
|
||||||
|
{
|
||||||
|
"type": "question"
|
||||||
|
}
|
||||||
|
或
|
||||||
|
{
|
||||||
|
"type": "other"
|
||||||
|
}
|
||||||
|
示例:
|
||||||
|
输入:"Python怎么读取文件?"
|
||||||
|
输出:{"type": "question"}
|
||||||
|
|
||||||
|
输入:"帮我写个读取文件的函数"
|
||||||
|
输出:{"type": "other"}
|
||||||
|
|
||||||
|
输入:"今天是星期几?"
|
||||||
|
输出:{"type": "question"}
|
||||||
|
返回数据格式以json方式输出,
|
||||||
|
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||||
|
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||||
|
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||||
|
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||||
|
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||||
|
4.JSON字符串值中不包括换行符
|
||||||
|
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||||
|
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
|
||||||
|
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
|
||||||
|
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||||
|
## 目标:
|
||||||
|
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
|
||||||
|
---
|
||||||
|
|
||||||
|
### 历史信息参考
|
||||||
|
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||||
|
- 历史对话或任务的主题;
|
||||||
|
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||||
|
- 历史中已解答的问题(避免重复);
|
||||||
|
- 历史推理链(保持逻辑一致性)。
|
||||||
|
|
||||||
|
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||||
|
输入历史信息内容:{{history}}
|
||||||
|
|
||||||
|
## User Input
|
||||||
|
{{ sentence }}
|
||||||
|
|
||||||
|
## 需求:
|
||||||
|
1:首先判断类型(单跳、多跳、开放域、时间)。
|
||||||
|
2:根据类型进行拆分。
|
||||||
|
3:拆分后的内容需保证信息完整且可独立处理。
|
||||||
|
4:对每个拆分条目,可附加示例或说明。
|
||||||
|
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||||
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
|
## 指令:
|
||||||
|
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||||
|
单跳(Single-hop)
|
||||||
|
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
|
||||||
|
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
|
||||||
|
示例:
|
||||||
|
输入数据:"请列出今年诺贝尔物理学奖的得主"
|
||||||
|
拆分结果:[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": "今年诺贝尔物理学奖得主是谁",
|
||||||
|
"type": "单跳’",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
注意: 当遇到上下文依赖问题时,明确指出缺失的信息类型并且,question可填写输入问题
|
||||||
|
多跳(Multi-hop):
|
||||||
|
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
|
||||||
|
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
|
||||||
|
示例:
|
||||||
|
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
|
||||||
|
拆分结果:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": 今年诺贝尔物理学奖得主是谁?",
|
||||||
|
"type": "多跳’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q2",
|
||||||
|
"question": "该得主的研究领域是什么?",
|
||||||
|
"type": "多跳’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q3",
|
||||||
|
"question": "该得主的代表性成果有哪些?",
|
||||||
|
"type": "多跳’"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
开放域(Open-domain):
|
||||||
|
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
|
||||||
|
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
|
||||||
|
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
|
||||||
|
示例:
|
||||||
|
输入数据:"介绍量子计算的最新研究进展"
|
||||||
|
拆分结果:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": 量子计算的基本概念是什么?",
|
||||||
|
"type": "开放域’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q2",
|
||||||
|
"question": "当前量子计算的主要研究方向有哪些?",
|
||||||
|
"type": "开放域’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q3",
|
||||||
|
"question": "近期在量子计算领域有哪些重大进展?",
|
||||||
|
"type": "开放域’",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
时间(Temporal):
|
||||||
|
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
|
||||||
|
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
|
||||||
|
示例:
|
||||||
|
输入数据:"列出苹果公司过去五年的重大事件"
|
||||||
|
拆分结果:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": 苹果公司2019年的重大事件有哪些?",
|
||||||
|
"type": "时间’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q2",
|
||||||
|
"question": "苹果公司2020年的重大事件有哪些?",
|
||||||
|
"type": "时间’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q3",
|
||||||
|
"question": "苹果公司2021年的重大事件有哪些?",
|
||||||
|
"type": "时间’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q3",
|
||||||
|
"question": "苹果公司2022年的重大事件有哪些?",
|
||||||
|
"type": "时间’",
|
||||||
|
}
|
||||||
|
,
|
||||||
|
{
|
||||||
|
"id": "Q4",
|
||||||
|
"question": "苹果公司2023年的重大事件有哪些?",
|
||||||
|
"type": "时间’",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
输出要求:
|
||||||
|
- 每个子问题包括:
|
||||||
|
- `id`: 子问题编号(Q1, Q2...)
|
||||||
|
- `question`: 子问题内容
|
||||||
|
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
|
||||||
|
- `reason`: 拆分的理由(为什么要这样拆)
|
||||||
|
- 格式案例:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": 量子计算的基本概念是什么?",
|
||||||
|
"type": "开放域’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q2",
|
||||||
|
"question": "当前量子计算的主要研究方向有哪些?",
|
||||||
|
"type": "开放域’",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Q3",
|
||||||
|
"question": "近期在量子计算领域有哪些重大进展?",
|
||||||
|
"type": "开放域’",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
- 必须通过json.loads()的格式支持的形式输出
|
||||||
|
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||||
|
- 关键的JSON格式要求
|
||||||
|
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||||
|
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||||
|
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||||
|
4.JSON字符串值中不包括换行符
|
||||||
|
5.正确转义的例子:“statement”:“Zhang Xinhua said:\”我非常喜欢这本书\""
|
||||||
|
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
# 角色
|
||||||
|
你是验证专家
|
||||||
|
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析,是不是回答Query_Samll这个字段的问题
|
||||||
|
|
||||||
|
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
|
||||||
|
## 工作步骤
|
||||||
|
1. 获取所有的Query_Samll字段和Answer_Samll字段
|
||||||
|
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
|
||||||
|
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
|
||||||
|
4. 如果是True保留,否则不要相对应的问题和回答
|
||||||
|
5. 输出,需要严格按照模版
|
||||||
|
输入:{{history}}
|
||||||
|
历史消息:{"history":{{sentence}}}
|
||||||
|
### 第一步 获取用户的输入
|
||||||
|
获取用户的输入提取对应的Query_Samll和Answer_Samll
|
||||||
|
### 第二步 分析验证
|
||||||
|
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容,如果有关系不是答非所问
|
||||||
|
## 核心验证标准
|
||||||
|
在评估子问题拆分时,必须严格遵循以下标准,且验证过程中完全不依赖于子问题的相关信息(Answer_Samll):
|
||||||
|
1. 合理性标准(必须全部满足):
|
||||||
|
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
|
||||||
|
- 最小化:每个不同的子问题数量应尽可能少,通常不超过原问题关键要素数量的2倍(建议2-4个),避免冗余和不必要拆分。
|
||||||
|
- 相关性:每个不同的子问题必须直接服务于原问题的解答,不引入无关内容或扩展原问题未提及的主题。
|
||||||
|
- 可操作性:每个不同的子问题应能在有限资源(如标准工具或合理时间)内独立解答,且难度适中。
|
||||||
|
- 逻辑性:每个不同的子问题间应有清晰的逻辑关系(如并列、递进、因果),共同构成原问题的解答路径。
|
||||||
|
|
||||||
|
2. 不合理拆分的特征(出现任一特征即为不合理):
|
||||||
|
- 不同的子问题数量超过5个或明显多于必要数量。
|
||||||
|
- 引入原问题未提及的新主题、人物、细节或个人看法。
|
||||||
|
- 拆分过于细碎,失去实用价值,无法高效合成原问题答案。
|
||||||
|
|
||||||
|
3. 特殊情况说明:
|
||||||
|
- 每个不同的子问题与原问题相同,需进一步判断:
|
||||||
|
- 每个不同的子问题不可进一步拆分 → success(合理,最小化拆分)
|
||||||
|
- 每个不同的子问题能够进一步拆分为更小、更合理的问题 → failed(不合理,拆分没有最小化)
|
||||||
|
- 每个不同的子问题数量=原问题核心要素数量 → success(理想情况)
|
||||||
|
- 每个不同的子问题数量=核心要素数量+1 → success(通常合理)
|
||||||
|
|
||||||
|
### 第三步 添加状态
|
||||||
|
如果有相关性并且比较高给一个状态TRUE,否则给一个FLASE的状态
|
||||||
|
### 第四步 判断
|
||||||
|
如果状态是TRUE保留这条数据,否则需不需要这条数据
|
||||||
|
### 第五步 输出格式
|
||||||
|
按照json的形式输出
|
||||||
|
{"data":"Query":原来Query的字段,"history":原来的history字段,
|
||||||
|
"expansion_issue":以为列表的形式存储验证之后的数据比如[
|
||||||
|
{"query_small": query_small,
|
||||||
|
"answer_small": answer_small,,
|
||||||
|
"status": 回答的结果是否符合query_small,填写状态,
|
||||||
|
"query_answer": answer_small},
|
||||||
|
{
|
||||||
|
"query_small": "张曼婷生日是什么时候?",
|
||||||
|
"answer_small": "张曼婷喜欢绘画。",
|
||||||
|
"status": "True",
|
||||||
|
"query_answer": "张曼 婷喜欢绘画。"
|
||||||
|
},{}......]
|
||||||
|
,
|
||||||
|
"split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success,
|
||||||
|
"reason": 为以上分析完之后的结果给一个说明
|
||||||
|
}
|
||||||
57
app/core/memory/agent/utils/prompt/summary_prompt.jinja2
Normal file
57
app/core/memory/agent/utils/prompt/summary_prompt.jinja2
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
{# 角色定义 #}
|
||||||
|
你是专业的问题解答专家,负责根据上下文信息和检索到的所有信息准确回答用户的问题。
|
||||||
|
|
||||||
|
{# 输入数据展示 #}
|
||||||
|
{% if data %}
|
||||||
|
## 输入数据
|
||||||
|
上下文信息:
|
||||||
|
{% for item in data.history %}
|
||||||
|
- {{ item }}
|
||||||
|
{% endfor %}
|
||||||
|
检索到的所有信息:
|
||||||
|
{% for item in data.retrieve_info %}
|
||||||
|
- {{ item }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
## User Query
|
||||||
|
{{ query }}
|
||||||
|
|
||||||
|
{# 问题回答标准 #}
|
||||||
|
## 问题回答核心标准
|
||||||
|
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。注意,若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
|
||||||
|
- 若能根据已有信息回答用户的问题,应根据上下文信息和检索到的所有信息提供简明扼要的答案。
|
||||||
|
- 若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
|
||||||
|
|
||||||
|
{# 重要提醒 #}
|
||||||
|
再次提醒,给出问题的答案时,仅根据已有的信息进行回答,不能自己编造答案。
|
||||||
|
|
||||||
|
{# 输出格式模板 #}
|
||||||
|
## 输出格式
|
||||||
|
严格按照以下JSON格式输出,不添加任何其他内容:
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"query": "{{ query }}",
|
||||||
|
"history": [
|
||||||
|
{% for item in data.history %}
|
||||||
|
"{{ item | replace('"', '\\"') }}"
|
||||||
|
{% if not loop.last %},{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
],
|
||||||
|
"retrieve_info": [
|
||||||
|
{% for item in data.retrieve_info %}
|
||||||
|
"{{ item | replace('"', '\\"') }}"
|
||||||
|
{% if not loop.last %},{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"query_answer": "{% if not data.history and not data.retrieve_info %}信息不足,无法回答。{% endif %}"
|
||||||
|
}
|
||||||
|
**Output format**
|
||||||
|
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||||
|
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||||
|
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||||
|
3. Ensure all JSON strings are properly closed and comma-separated
|
||||||
|
4. Do not include line breaks within JSON string values
|
||||||
|
|
||||||
|
The output language should always be the same as the input language.{{ json_schema }}
|
||||||
203
app/core/memory/agent/utils/redis_tool.py
Normal file
203
app/core/memory/agent/utils/redis_tool.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
import redis
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from app.core.config import settings
|
||||||
|
class RedisSessionStore:
|
||||||
|
def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''):
|
||||||
|
self.r = redis.Redis(host=host, port=port, db=db, password=password)
|
||||||
|
self.uudi=session_id
|
||||||
|
|
||||||
|
|
||||||
|
# 修改后的 save_session 方法
|
||||||
|
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||||
|
"""
|
||||||
|
写入一条会话数据,返回 session_id
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||||
|
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||||
|
|
||||||
|
# 使用 Hash 存储结构化数据
|
||||||
|
result = self.r.hset(key, mapping={
|
||||||
|
"id": self.uudi,
|
||||||
|
"sessionid": userid,
|
||||||
|
"apply_id": apply_id,
|
||||||
|
"group_id": group_id,
|
||||||
|
"messages": messages,
|
||||||
|
"aimessages": aimessages,
|
||||||
|
"starttime": starttime
|
||||||
|
})
|
||||||
|
print(f"保存结果: {result}, session_id: {session_id}")
|
||||||
|
return session_id # 返回新生成的 session_id
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存会话失败: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ---------------- 读取 ----------------
|
||||||
|
def get_session(self, session_id):
|
||||||
|
"""
|
||||||
|
读取一条会话数据
|
||||||
|
"""
|
||||||
|
key = f"session:{session_id}"
|
||||||
|
data = self.r.hgetall(key)
|
||||||
|
if data:
|
||||||
|
return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||||
|
"""
|
||||||
|
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||||
|
"""
|
||||||
|
result_items = []
|
||||||
|
|
||||||
|
# 遍历所有会话数据
|
||||||
|
for key_bytes in self.r.keys('session:*'):
|
||||||
|
key = key_bytes.decode('utf-8')
|
||||||
|
data = self.r.hgetall(key)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解码数据
|
||||||
|
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||||
|
|
||||||
|
# 检查三个条件是否都匹配
|
||||||
|
if (decoded_data.get('sessionid') == sessionid and
|
||||||
|
decoded_data.get('apply_id') == apply_id and
|
||||||
|
decoded_data.get('group_id') == group_id):
|
||||||
|
result_items.append(decoded_data)
|
||||||
|
|
||||||
|
return result_items
|
||||||
|
|
||||||
|
def get_all_sessions(self):
|
||||||
|
"""
|
||||||
|
获取所有会话数据
|
||||||
|
"""
|
||||||
|
sessions = {}
|
||||||
|
for key in self.r.keys('session:*'):
|
||||||
|
sid = key.decode('utf-8').split(':')[1]
|
||||||
|
sessions[sid] = self.get_session(sid)
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
# ---------------- 更新 ----------------
|
||||||
|
def update_session(self, session_id, field, value):
|
||||||
|
"""
|
||||||
|
更新单个字段
|
||||||
|
"""
|
||||||
|
key = f"session:{session_id}"
|
||||||
|
if self.r.exists(key):
|
||||||
|
self.r.hset(key, field, value)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ---------------- 删除 ----------------
|
||||||
|
def delete_session(self, session_id):
|
||||||
|
"""
|
||||||
|
删除单条会话
|
||||||
|
"""
|
||||||
|
key = f"session:{session_id}"
|
||||||
|
return self.r.delete(key)
|
||||||
|
|
||||||
|
def delete_all_sessions(self):
|
||||||
|
"""
|
||||||
|
删除所有会话
|
||||||
|
"""
|
||||||
|
keys = self.r.keys('session:*')
|
||||||
|
if keys:
|
||||||
|
return self.r.delete(*keys)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def delete_duplicate_sessions(self):
|
||||||
|
"""
|
||||||
|
删除重复会话数据,条件:
|
||||||
|
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||||
|
"""
|
||||||
|
seen = set() # 用来记录已出现的唯一组合
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
for key_bytes in self.r.keys('session:*'):
|
||||||
|
key = key_bytes.decode('utf-8')
|
||||||
|
data = self.r.hgetall(key)
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取五个字段的值并解码
|
||||||
|
sessionid = data.get(b'sessionid', b'').decode('utf-8')
|
||||||
|
user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id
|
||||||
|
group_id = data.get(b'group_id', b'').decode('utf-8')
|
||||||
|
messages = data.get(b'messages', b'').decode('utf-8')
|
||||||
|
aimessages = data.get(b'aimessages', b'').decode('utf-8')
|
||||||
|
|
||||||
|
# 用五元组作为唯一标识
|
||||||
|
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||||
|
|
||||||
|
if identifier in seen:
|
||||||
|
# 重复,删除该 key
|
||||||
|
self.r.delete(key)
|
||||||
|
deleted_count += 1
|
||||||
|
else:
|
||||||
|
# 第一次出现,加入 seen
|
||||||
|
seen.add(identifier)
|
||||||
|
|
||||||
|
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}")
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
def find_user_session(self,sessionid):
|
||||||
|
user_id = sessionid
|
||||||
|
|
||||||
|
result_items = []
|
||||||
|
for key, values in store.get_all_sessions().items():
|
||||||
|
history = {}
|
||||||
|
if user_id == str(values['sessionid']):
|
||||||
|
history["Query"] = values['messages']
|
||||||
|
history["Answer"] = values['aimessages']
|
||||||
|
result_items.append(history)
|
||||||
|
|
||||||
|
if len(result_items) <= 1:
|
||||||
|
result_items = []
|
||||||
|
return (result_items)
|
||||||
|
|
||||||
|
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||||
|
"""
|
||||||
|
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||||
|
"""
|
||||||
|
result_items = []
|
||||||
|
|
||||||
|
# 遍历所有会话数据
|
||||||
|
for key_bytes in self.r.keys('session:*'):
|
||||||
|
key = key_bytes.decode('utf-8')
|
||||||
|
data = self.r.hgetall(key)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解码数据
|
||||||
|
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# 检查三个条件是否都匹配
|
||||||
|
if (decoded_data.get('sessionid') == sessionid and
|
||||||
|
decoded_data.get('apply_id') == apply_id and
|
||||||
|
decoded_data.get('group_id') == group_id):
|
||||||
|
history = {
|
||||||
|
"Query": decoded_data.get('messages'),
|
||||||
|
"Answer": decoded_data.get('aimessages')
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
result_items.append(history)
|
||||||
|
|
||||||
|
# 如果结果少于等于1条,返回空列表
|
||||||
|
if len(result_items) <= 1:
|
||||||
|
result_items = []
|
||||||
|
|
||||||
|
return result_items
|
||||||
|
|
||||||
|
store = RedisSessionStore(
|
||||||
|
host=settings.REDIS_HOST,
|
||||||
|
port=settings.REDIS_PORT,
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||||
|
session_id=str(uuid.uuid4())
|
||||||
|
)
|
||||||
59
app/core/memory/agent/utils/type_classifier.py
Normal file
59
app/core/memory/agent/utils/type_classifier.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
Type classification utility for distinguishing read/write operations.
|
||||||
|
"""
|
||||||
|
from jinja2 import Template
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||||
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
|
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||||
|
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DistinguishTypeResponse(BaseModel):
|
||||||
|
"""Response model for type classification"""
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
async def status_typle(messages: str) -> dict:
|
||||||
|
"""
|
||||||
|
Classify message type as read or write operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: User message to classify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Contains 'type' field with classification result
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/distinguish_types_prompt.jinja2'
|
||||||
|
template_content = await read_template_file(file_path)
|
||||||
|
template = Template(template_content)
|
||||||
|
system_prompt = template.render(user_query=messages)
|
||||||
|
log_prompt_rendering("status_typle", system_prompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Template rendering failed for status_typle: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Prompt rendering failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
from app.core.memory.utils.config import definitions as config_defs
|
||||||
|
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||||
|
|
||||||
|
try:
|
||||||
|
structured = await llm_client.response_structured(
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
response_model=DistinguishTypeResponse
|
||||||
|
)
|
||||||
|
return structured.model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM call failed for status_typle: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"type": "error",
|
||||||
|
"message": f"LLM call failed: {str(e)}"
|
||||||
|
}
|
||||||
76
app/core/memory/agent/utils/verify_tool.py
Normal file
76
app/core/memory/agent/utils/verify_tool.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from typing import TypedDict, Annotated, List, Any
|
||||||
|
from langchain_core.messages import AnyMessage
|
||||||
|
from langgraph.constants import START, END
|
||||||
|
from langgraph.graph import StateGraph, add_messages
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from dotenv import load_dotenv, find_dotenv
|
||||||
|
import os
|
||||||
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
|
||||||
|
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||||
|
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
load_dotenv(find_dotenv())
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
def keep_last(_, right):
|
||||||
|
return right
|
||||||
|
class State(TypedDict):
|
||||||
|
user_input: Annotated[dict, keep_last]
|
||||||
|
messages: Annotated[List[AnyMessage], add_messages]
|
||||||
|
agent1_response: str
|
||||||
|
agent2_response: str
|
||||||
|
agent3_response: str
|
||||||
|
final_response: str
|
||||||
|
status: Annotated[str, keep_last]
|
||||||
|
|
||||||
|
|
||||||
|
class VerifyTool:
|
||||||
|
def __init__(self, system_prompt: str="", verify_data: Any=None):
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
if isinstance(verify_data, str):
|
||||||
|
self.verify_data = verify_data
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.verify_data = json.dumps(verify_data, ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
self.verify_data = str(verify_data)
|
||||||
|
|
||||||
|
async def model_1(self, state: State) -> State:
|
||||||
|
llm_client = get_llm_client(SELECTED_LLM_ID)
|
||||||
|
response_content = await llm_client.chat(
|
||||||
|
messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"])
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"agent1_response": response_content,
|
||||||
|
"status": "processed",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph(self):
|
||||||
|
graph = StateGraph(State)
|
||||||
|
graph.add_node("model_1", self.model_1)
|
||||||
|
|
||||||
|
graph.add_edge(START, "model_1")
|
||||||
|
graph.add_edge("model_1", END)
|
||||||
|
|
||||||
|
compiled_graph = graph.compile()
|
||||||
|
return compiled_graph
|
||||||
|
|
||||||
|
async def verify(self):
|
||||||
|
graph = self.get_graph()
|
||||||
|
initial_state = {
|
||||||
|
"user_input": self.verify_data,
|
||||||
|
"messages": [HumanMessage(content=self.verify_data)],
|
||||||
|
"final_response": "",
|
||||||
|
"status": ""
|
||||||
|
}
|
||||||
|
final_state = await graph.ainvoke(initial_state)
|
||||||
|
# return final_state["final_response"]
|
||||||
|
return final_state["agent1_response"]
|
||||||
|
|
||||||
49
app/core/memory/agent/utils/write_to_database.py
Normal file
49
app/core/memory/agent/utils/write_to_database.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.models.retrieval_info import RetrievalInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
|
||||||
|
"""
|
||||||
|
将数据写入数据库
|
||||||
|
:param host_id: 宿主 ID
|
||||||
|
:param data: 要写入的数据
|
||||||
|
:return: 写入数据库的结果
|
||||||
|
"""
|
||||||
|
# 从数据库会话中获取会话
|
||||||
|
db: Session = next(get_db())
|
||||||
|
try:
|
||||||
|
if isinstance(data, (dict, list)):
|
||||||
|
serialized = json.dumps(data, ensure_ascii=False)
|
||||||
|
elif isinstance(data, str):
|
||||||
|
serialized = data
|
||||||
|
else:
|
||||||
|
serialized = str(data)
|
||||||
|
|
||||||
|
new_retrieval_info = RetrievalInfo(
|
||||||
|
# host_id=host_id,
|
||||||
|
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
|
||||||
|
retrieve_info=serialized,
|
||||||
|
created_at=datetime.now()
|
||||||
|
)
|
||||||
|
db.add(new_retrieval_info)
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
|
||||||
|
return "success to write data to database"
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
db.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
183
app/core/memory/agent/utils/write_tools.py
Normal file
183
app/core/memory/agent/utils/write_tools.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import asyncio
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
# 使用新的模块化架构
|
||||||
|
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||||
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||||
|
embedding_generation_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用新的仓储层
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
# 导入配置模块(而不是直接导入变量)
|
||||||
|
from app.core.memory.utils.config import definitions as config_defs
|
||||||
|
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||||
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
|
||||||
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
|
||||||
|
"""
|
||||||
|
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 对话内容
|
||||||
|
user_id: 用户ID
|
||||||
|
apply_id: 应用ID
|
||||||
|
group_id: 组ID
|
||||||
|
ref_id: 参考ID,默认为 "wyl20251027"
|
||||||
|
config_id: 配置ID,用于标记数据处理配置
|
||||||
|
"""
|
||||||
|
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||||
|
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
|
||||||
|
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
|
||||||
|
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
|
||||||
|
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
|
||||||
|
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
|
||||||
|
logger.info(f"Config ID: {config_id if config_id else 'None'}")
|
||||||
|
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
|
||||||
|
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
|
||||||
|
|
||||||
|
# Initialize timing log
|
||||||
|
log_file = "logs/time.log"
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
|
||||||
|
|
||||||
|
pipeline_start = time.time()
|
||||||
|
|
||||||
|
# 初始化客户端
|
||||||
|
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||||
|
|
||||||
|
# 获取 embedder 配置
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||||
|
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||||
|
|
||||||
|
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||||
|
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||||
|
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||||
|
|
||||||
|
neo4j_connector = Neo4jConnector()
|
||||||
|
|
||||||
|
# Step 1: 加载和分块数据
|
||||||
|
step_start = time.time()
|
||||||
|
chunked_dialogs = await get_chunked_dialogs(
|
||||||
|
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||||
|
group_id=group_id,
|
||||||
|
user_id=user_id,
|
||||||
|
apply_id=apply_id,
|
||||||
|
content=content,
|
||||||
|
ref_id=ref_id,
|
||||||
|
config_id=config_id,
|
||||||
|
)
|
||||||
|
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 2: 初始化并运行 ExtractionOrchestrator
|
||||||
|
step_start = time.time()
|
||||||
|
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||||
|
config = get_pipeline_config()
|
||||||
|
|
||||||
|
orchestrator = ExtractionOrchestrator(
|
||||||
|
llm_client=llm_client,
|
||||||
|
embedder_client=embedder_client,
|
||||||
|
connector=neo4j_connector,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 运行完整的提取流水线
|
||||||
|
# orchestrator.run returns a flat tuple of 7 values after deduplication
|
||||||
|
(
|
||||||
|
all_dialogue_nodes,
|
||||||
|
all_chunk_nodes,
|
||||||
|
all_statement_nodes,
|
||||||
|
all_entity_nodes,
|
||||||
|
all_statement_chunk_edges,
|
||||||
|
all_statement_entity_edges,
|
||||||
|
all_entity_entity_edges,
|
||||||
|
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||||
|
|
||||||
|
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 8: Save all data to Neo4j database using graph models
|
||||||
|
step_start = time.time()
|
||||||
|
# 运行索引创建
|
||||||
|
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||||
|
try:
|
||||||
|
await create_fulltext_indexes()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await save_dialog_and_statements_to_neo4j(
|
||||||
|
dialogue_nodes=all_dialogue_nodes,
|
||||||
|
chunk_nodes=all_chunk_nodes,
|
||||||
|
statement_nodes=all_statement_nodes,
|
||||||
|
entity_nodes=all_entity_nodes,
|
||||||
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
|
entity_edges=all_entity_entity_edges,
|
||||||
|
connector=neo4j_connector
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
finally:
|
||||||
|
await neo4j_connector.close()
|
||||||
|
|
||||||
|
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
|
||||||
|
step_start = time.time()
|
||||||
|
try:
|
||||||
|
summaries = await Memory_summary_generation(
|
||||||
|
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save memory summaries to Neo4j as nodes
|
||||||
|
try:
|
||||||
|
ms_connector = Neo4jConnector()
|
||||||
|
await add_memory_summary_nodes(summaries, ms_connector)
|
||||||
|
# Link summaries to statements via chunks for summary→entity queries
|
||||||
|
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await ms_connector.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Log total pipeline time
|
||||||
|
total_time = time.time() - pipeline_start
|
||||||
|
log_time("TOTAL PIPELINE TIME", total_time, log_file)
|
||||||
|
|
||||||
|
# Add completion marker to log
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||||
|
|
||||||
|
logger.info("=== Pipeline Complete ===")
|
||||||
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
|
logger.info(f"Timing details saved to: {log_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
|
||||||
|
asyncio.run(write(content, ref_id="wyl20251027"))
|
||||||
19
app/core/memory/llm_tools/__init__.py
Normal file
19
app/core/memory/llm_tools/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
LLM 工具模块
|
||||||
|
|
||||||
|
提供 LLM 和 Embedder 客户端的抽象基类和具体实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.core.memory.llm_tools.llm_client import LLMClient
|
||||||
|
from app.core.memory.llm_tools.embedder_client import EmbedderClient
|
||||||
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
|
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||||
|
from app.core.memory.llm_tools.chunker_client import ChunkerClient
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMClient",
|
||||||
|
"EmbedderClient",
|
||||||
|
"OpenAIClient",
|
||||||
|
"OpenAIEmbedderClient",
|
||||||
|
"ChunkerClient",
|
||||||
|
]
|
||||||
330
app/core/memory/llm_tools/chunker_client.py
Normal file
330
app/core/memory/llm_tools/chunker_client.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Fix tokenizer parallelism warning
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
from chonkie import (
|
||||||
|
SemanticChunker,
|
||||||
|
RecursiveChunker,
|
||||||
|
RecursiveRules,
|
||||||
|
LateChunker,
|
||||||
|
NeuralChunker,
|
||||||
|
SentenceChunker,
|
||||||
|
TokenChunker,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.memory.models.config_models import ChunkerConfig
|
||||||
|
from app.core.memory.models.message_models import DialogData, Chunk
|
||||||
|
try:
|
||||||
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
|
except Exception:
|
||||||
|
# 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入
|
||||||
|
OpenAIClient = Any
|
||||||
|
|
||||||
|
|
||||||
|
class LLMChunker:
|
||||||
|
"""基于LLM的智能分块策略"""
|
||||||
|
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||||
|
self.llm_client = llm_client
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
async def __call__(self, text: str) -> List[Any]:
|
||||||
|
# 使用LLM分析文本结构并进行智能分块
|
||||||
|
prompt = f"""
|
||||||
|
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
|
||||||
|
请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。
|
||||||
|
|
||||||
|
文本内容:
|
||||||
|
{text[:5000]}
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用异步的 achat 方法
|
||||||
|
if hasattr(self.llm_client, 'achat'):
|
||||||
|
response = await self.llm_client.achat(messages)
|
||||||
|
else:
|
||||||
|
# 如果没有异步方法,使用同步方法并转换为异步
|
||||||
|
response = await asyncio.to_thread(self.llm_client.chat, messages)
|
||||||
|
|
||||||
|
# 检查响应格式并提取内容
|
||||||
|
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
elif hasattr(response, 'content'):
|
||||||
|
content = response.content
|
||||||
|
else:
|
||||||
|
content = str(response)
|
||||||
|
|
||||||
|
# 解析LLM响应
|
||||||
|
if "```json" in content:
|
||||||
|
json_str = content.split("```json")[1].split("```")[0].strip()
|
||||||
|
elif "```" in content:
|
||||||
|
json_str = content.split("```")[1].split("```")[0].strip()
|
||||||
|
else:
|
||||||
|
json_str = content
|
||||||
|
|
||||||
|
result = json.loads(json_str)
|
||||||
|
|
||||||
|
class SimpleChunk:
|
||||||
|
def __init__(self, text, index):
|
||||||
|
self.text = text
|
||||||
|
self.start_index = index * 100 # 近似位置
|
||||||
|
self.end_index = (index + 1) * 100
|
||||||
|
|
||||||
|
return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"LLM分块失败: {e}")
|
||||||
|
# 失败时返回空列表,外层会处理回退方案
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class HybridChunker:
|
||||||
|
"""混合分块策略:先按结构分块,再按语义合并"""
|
||||||
|
def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300):
|
||||||
|
self.semantic_threshold = semantic_threshold
|
||||||
|
self.base_chunk_size = base_chunk_size
|
||||||
|
self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size)
|
||||||
|
self.semantic_chunker = SemanticChunker(threshold=semantic_threshold)
|
||||||
|
|
||||||
|
def __call__(self, text: str) -> List[Any]:
|
||||||
|
# 先用基础分块
|
||||||
|
base_chunks = self.base_chunker(text)
|
||||||
|
|
||||||
|
# 如果文本不长,直接返回基础分块
|
||||||
|
if len(base_chunks) <= 3:
|
||||||
|
return base_chunks
|
||||||
|
|
||||||
|
# 对基础分块进行语义合并
|
||||||
|
combined_text = " ".join([chunk.text for chunk in base_chunks])
|
||||||
|
return self.semantic_chunker(combined_text)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkerClient:
|
||||||
|
def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None):
|
||||||
|
self.chunker_config = chunker_config
|
||||||
|
self.embedding_model = chunker_config.embedding_model
|
||||||
|
self.chunk_size = chunker_config.chunk_size
|
||||||
|
self.threshold = chunker_config.threshold
|
||||||
|
self.language = chunker_config.language
|
||||||
|
self.skip_window = chunker_config.skip_window
|
||||||
|
self.min_sentences = chunker_config.min_sentences
|
||||||
|
self.min_characters_per_chunk = chunker_config.min_characters_per_chunk
|
||||||
|
self.llm_client = llm_client
|
||||||
|
|
||||||
|
# 可选参数(从配置中安全获取,提供默认值)
|
||||||
|
self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0)
|
||||||
|
self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1)
|
||||||
|
self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12)
|
||||||
|
self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"])
|
||||||
|
self.include_delim = getattr(chunker_config, 'include_delim', "prev")
|
||||||
|
self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character")
|
||||||
|
|
||||||
|
# 初始化具体分块器策略
|
||||||
|
if chunker_config.chunker_strategy == "TokenChunker":
|
||||||
|
self.chunker = TokenChunker(
|
||||||
|
tokenizer=self.tokenizer_or_token_counter,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.chunk_overlap,
|
||||||
|
)
|
||||||
|
elif chunker_config.chunker_strategy == "SemanticChunker":
|
||||||
|
self.chunker = SemanticChunker(
|
||||||
|
embedding_model=self.embedding_model,
|
||||||
|
threshold=self.threshold,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
min_sentences=self.min_sentences,
|
||||||
|
)
|
||||||
|
elif chunker_config.chunker_strategy == "RecursiveChunker":
|
||||||
|
self.chunker = RecursiveChunker(
|
||||||
|
rules=RecursiveRules(),
|
||||||
|
min_characters_per_chunk=self.min_characters_per_chunk or 50,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
)
|
||||||
|
elif chunker_config.chunker_strategy == "LateChunker":
|
||||||
|
self.chunker = LateChunker(
|
||||||
|
embedding_model=self.embedding_model,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
rules=RecursiveRules(),
|
||||||
|
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||||
|
)
|
||||||
|
elif chunker_config.chunker_strategy == "NeuralChunker":
|
||||||
|
self.chunker = NeuralChunker(
|
||||||
|
model=self.embedding_model,
|
||||||
|
min_characters_per_chunk=self.min_characters_per_chunk,
|
||||||
|
)
|
||||||
|
elif chunker_config.chunker_strategy == "LLMChunker":
|
||||||
|
if not llm_client:
|
||||||
|
raise ValueError("LLMChunker requires an LLM client")
|
||||||
|
self.chunker = LLMChunker(llm_client, self.chunk_size)
|
||||||
|
elif chunker_config.chunker_strategy == "HybridChunker":
|
||||||
|
self.chunker = HybridChunker(
|
||||||
|
semantic_threshold=self.threshold,
|
||||||
|
base_chunk_size=self.chunk_size,
|
||||||
|
)
|
||||||
|
elif chunker_config.chunker_strategy == "SentenceChunker":
|
||||||
|
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
|
||||||
|
# 为了兼容不同版本,这里仅传递广泛支持的参数
|
||||||
|
self.chunker = SentenceChunker(
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
chunk_overlap=self.chunk_overlap,
|
||||||
|
min_sentences_per_chunk=self.min_sentences_per_chunk,
|
||||||
|
min_characters_per_sentence=self.min_characters_per_sentence,
|
||||||
|
delim=self.delim,
|
||||||
|
include_delim=self.include_delim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}")
|
||||||
|
|
||||||
|
async def generate_chunks(self, dialogue: DialogData):
|
||||||
|
"""
|
||||||
|
生成分块,支持异步操作
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 预处理文本:确保对话标记格式统一
|
||||||
|
content = dialogue.content
|
||||||
|
content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号
|
||||||
|
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
|
||||||
|
|
||||||
|
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
|
||||||
|
# 同步分块器
|
||||||
|
chunks = self.chunker(content)
|
||||||
|
else:
|
||||||
|
# 异步分块器(如LLMChunker)
|
||||||
|
chunks = await self.chunker(content)
|
||||||
|
|
||||||
|
# 过滤空块和过小的块
|
||||||
|
valid_chunks = []
|
||||||
|
for c in chunks:
|
||||||
|
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
|
||||||
|
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
|
||||||
|
valid_chunks.append(c)
|
||||||
|
|
||||||
|
dialogue.chunks = [
|
||||||
|
Chunk(
|
||||||
|
content=c.text if hasattr(c, 'text') else str(c),
|
||||||
|
metadata={
|
||||||
|
"start_index": getattr(c, "start_index", None),
|
||||||
|
"end_index": getattr(c, "end_index", None),
|
||||||
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for c in valid_chunks
|
||||||
|
]
|
||||||
|
return dialogue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"分块失败: {e}")
|
||||||
|
|
||||||
|
# 改进的后备方案:尝试按对话回合分割
|
||||||
|
try:
|
||||||
|
# 简单的按对话分割
|
||||||
|
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
|
||||||
|
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
|
||||||
|
|
||||||
|
class SimpleChunk:
|
||||||
|
def __init__(self, text, start_index, end_index):
|
||||||
|
self.text = text
|
||||||
|
self.start_index = start_index
|
||||||
|
self.end_index = end_index
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
current_start = 0
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
speaker, ct = match[0], match[1].strip()
|
||||||
|
turn_text = f"{speaker} {ct}"
|
||||||
|
|
||||||
|
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||||
|
current_chunk = turn_text
|
||||||
|
current_start = dialogue.content.find(turn_text, current_start)
|
||||||
|
else:
|
||||||
|
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||||
|
|
||||||
|
dialogue.chunks = [
|
||||||
|
Chunk(
|
||||||
|
content=c.text,
|
||||||
|
metadata={
|
||||||
|
"start_index": c.start_index,
|
||||||
|
"end_index": c.end_index,
|
||||||
|
"chunker_strategy": "DialogueTurnFallback",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for c in chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# 最后的手段:单一大块
|
||||||
|
dialogue.chunks = [Chunk(
|
||||||
|
content=dialogue.content,
|
||||||
|
metadata={"chunker_strategy": "SingleChunkFallback"},
|
||||||
|
)]
|
||||||
|
|
||||||
|
return dialogue
|
||||||
|
|
||||||
|
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||||
|
"""
|
||||||
|
评估分块质量
|
||||||
|
"""
|
||||||
|
if not getattr(dialogue, 'chunks', None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
chunks = dialogue.chunks
|
||||||
|
total_chars = sum(len(chunk.content) for chunk in chunks)
|
||||||
|
avg_chunk_size = total_chars / len(chunks)
|
||||||
|
|
||||||
|
# 计算各种指标
|
||||||
|
chunk_sizes = [len(chunk.content) for chunk in chunks]
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"strategy": self.chunker_config.chunker_strategy,
|
||||||
|
"num_chunks": len(chunks),
|
||||||
|
"total_characters": total_chars,
|
||||||
|
"avg_chunk_size": avg_chunk_size,
|
||||||
|
"min_chunk_size": min(chunk_sizes),
|
||||||
|
"max_chunk_size": max(chunk_sizes),
|
||||||
|
"chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0,
|
||||||
|
"coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
||||||
|
"""
|
||||||
|
保存分块结果到文件,文件名包含策略名称
|
||||||
|
"""
|
||||||
|
strategy_name = self.chunker_config.chunker_strategy
|
||||||
|
# 在文件名中添加策略名称
|
||||||
|
base_name, ext = os.path.splitext(output_path)
|
||||||
|
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
||||||
|
|
||||||
|
with open(strategy_output_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(f"=== Chunking Strategy: {strategy_name} ===\n")
|
||||||
|
f.write(f"Total chunks: {len(dialogue.chunks)}\n")
|
||||||
|
f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n")
|
||||||
|
f.write("=" * 60 + "\n\n")
|
||||||
|
|
||||||
|
for i, chunk in enumerate(dialogue.chunks):
|
||||||
|
f.write(f"Chunk {i+1}:\n")
|
||||||
|
f.write(f"Size: {len(chunk.content)} characters\n")
|
||||||
|
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
||||||
|
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
||||||
|
f.write(f"Content: {chunk.content}\n")
|
||||||
|
f.write("-" * 40 + "\n\n")
|
||||||
|
|
||||||
|
print(f"Chunking results saved to: {strategy_output_path}")
|
||||||
|
return strategy_output_path
|
||||||
176
app/core/memory/llm_tools/embedder_client.py
Normal file
176
app/core/memory/llm_tools/embedder_client.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""
|
||||||
|
Embedder 客户端抽象基类
|
||||||
|
|
||||||
|
提供统一的嵌入向量生成接口,支持重试机制和错误处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
retry_if_exception_type,
|
||||||
|
before_sleep_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedderClientException(BusinessException):
|
||||||
|
"""Embedder 客户端异常"""
|
||||||
|
def __init__(self, message: str, code: str = BizCode.EMBEDDING_ERROR):
|
||||||
|
super().__init__(message, code=code)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedderClient(ABC):
|
||||||
|
"""
|
||||||
|
Embedder 客户端抽象基类
|
||||||
|
|
||||||
|
提供统一的嵌入向量生成接口,包括:
|
||||||
|
- 批量文本嵌入(response)
|
||||||
|
- 自动重试机制
|
||||||
|
- 错误处理
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_config: RedBearModelConfig):
|
||||||
|
"""
|
||||||
|
初始化 Embedder 客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config: 模型配置,包含模型名称、提供商、API密钥等信息
|
||||||
|
"""
|
||||||
|
self.config = model_config
|
||||||
|
self.model_name = model_config.model_name
|
||||||
|
self.provider = model_config.provider
|
||||||
|
self.api_key = model_config.api_key
|
||||||
|
self.base_url = model_config.base_url
|
||||||
|
self.max_retries = model_config.max_retries
|
||||||
|
self.timeout = model_config.timeout
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"初始化 Embedder 客户端: provider={self.provider}, "
|
||||||
|
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def response(
|
||||||
|
self,
|
||||||
|
messages: List[str],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
生成嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 文本列表
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量列表,每个向量是一个浮点数列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbedderClientException: 嵌入向量生成失败
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_retry_decorator(self):
|
||||||
|
"""
|
||||||
|
创建重试装饰器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的 tenacity retry 装饰器
|
||||||
|
"""
|
||||||
|
return retry(
|
||||||
|
stop=stop_after_attempt(self.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=2, max=10),
|
||||||
|
retry=retry_if_exception_type((
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
ConnectionError,
|
||||||
|
Exception, # 可以根据需要细化异常类型
|
||||||
|
)),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def response_with_retry(
|
||||||
|
self,
|
||||||
|
messages: List[str],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
带重试机制的嵌入向量生成接口
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 文本列表
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbedderClientException: 重试失败后抛出
|
||||||
|
"""
|
||||||
|
retry_decorator = self._create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _response_with_retry():
|
||||||
|
try:
|
||||||
|
return await self.response(messages, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"嵌入向量生成失败: {e}")
|
||||||
|
raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e
|
||||||
|
|
||||||
|
return await _response_with_retry()
|
||||||
|
|
||||||
|
async def embed_single(self, text: str, **kwargs) -> List[float]:
|
||||||
|
"""
|
||||||
|
为单个文本生成嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 单个文本
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量(浮点数列表)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbedderClientException: 嵌入向量生成失败
|
||||||
|
"""
|
||||||
|
embeddings = await self.response_with_retry([text], **kwargs)
|
||||||
|
return embeddings[0] if embeddings else []
|
||||||
|
|
||||||
|
async def embed_batch(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
batch_size: int = 100,
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
批量生成嵌入向量(支持大批量文本)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
batch_size: 每批处理的文本数量
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbedderClientException: 嵌入向量生成失败
|
||||||
|
"""
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i + batch_size]
|
||||||
|
batch_embeddings = await self.response_with_retry(batch, **kwargs)
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user