Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

39
.gitignore vendored Normal file
View File

@@ -0,0 +1,39 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
logs/
res/
*.egg-info
# Virtual environments
.venv
docs/
examples/
# Environment variables
.env
.kiro
.vscode/settings.json
.idea
# Temporary outputs
app/core/memory/agent/.DS_Store
app/core/memory/src/utils/.DS_Store
time.log
celerybeat-schedule.db
search_results.json
*.txt
*.json
migrations/versions
tmp
files
# Exclude dep files
huggingface.co/
nltk_data/
tika-server*.jar*
cl100k_base.tiktoken
libssl*.deb

97
Dockerfile Normal file
View File

@@ -0,0 +1,97 @@
FROM python:3.12-slim
USER root
SHELL ["/bin/bash", "-c"]
ARG NEED_MIRROR=1
WORKDIR /code
# 1. Download dependencies through download_deps.py: python download_deps.py --china-mirrors
# 2. Copy models
COPY huggingface.co/InfiniFlow/deepdoc/ /code/res/deepdoc/
COPY huggingface.co/InfiniFlow/text_concat_xgb_v1.0/ /code/res/text_concat_xgb_v1.0/
COPY huggingface.co/InfiniFlow/huqie/huqie.txt.trie /code/res/
# https://github.com/chrismattmann/tika-python
# 3. This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache.
COPY nltk_data/ /root/nltk_data/
COPY tika-server-standard-3.1.0.jar /tmp/tika-server.jar
COPY tika-server-standard-3.1.0.jar.md5 /tmp/tika-server.jar.md5
COPY cl100k_base.tiktoken /code/res/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
ENV TIKA_SERVER_JAR="file:///tmp/tika-server.jar"
ENV DEBIAN_FRONTEND=noninteractive
# 4. Setup apt
# Python package and implicit dependencies:
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
# python-pptx: default-jdk tika-server-standard-3.0.0.jar
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
apt install -y libicu-dev && \
if [ "$NEED_MIRROR" == "1" ]; then \
rm -f /etc/apt/sources.list.d/debian.sources && \
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-backports main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list; \
fi; \
rm -f /etc/apt/apt.conf.d/docker-clean && \
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
chmod 1777 /tmp && \
apt update && \
apt --no-install-recommends install -y ca-certificates && \
apt update && \
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
apt install -y pkg-config libgdiplus && \
apt install -y default-jdk && \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
pip3 config set global.trusted-host pypi.tuna.tsinghua.edu.cn; \
mkdir -p /etc/uv && \
echo "[[index]]" > /etc/uv/uv.toml && \
echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \
echo "default = true" >> /etc/uv/uv.toml; \
fi; \
pipx install uv
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
ENV PATH=/root/.local/bin:$PATH
# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13
# 5. aspose-slides on linux/arm64 is unavailable
COPY libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb /tmp/
RUN if [ "$(uname -m)" = "x86_64" ]; then \
dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \
elif [ "$(uname -m)" = "aarch64" ]; then \
dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_arm64.deb; \
fi && \
rm -f /tmp/libssl1.1_*.deb
# 6. install dependencies from uv.lock file
COPY ./pyproject.toml /code/pyproject.toml
COPY ./uv.lock /code/uv.lock
COPY ./app /code/app
# https://github.com/astral-sh/uv/issues/10462
# uv records index url into uv.lock but doesn't failover among multiple indexes
RUN --mount=type=cache,id=mem_uv,target=/root/.cache/uv,sharing=locked \
if [ "$NEED_MIRROR" == "1" ]; then \
sed -i 's|pypi.org|pypi.tuna.tsinghua.edu.cn|g' uv.lock; \
else \
sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \
fi; \
uv lock && \
uv sync --locked --no-dev
ENV PATH=/code/.venv/bin:$PATH

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
sources, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but not
limited to compiled object code, generated documentation, and
conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2025] [SuanmoSuanyangTechnology]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

137
README.md Normal file
View File

@@ -0,0 +1,137 @@
# MemoryBear
## 项目简介
MemoryBear 是一个面向智能体的记忆系统与知识管理服务。它支持从对话与文档中萃取结构化知识、生成嵌入向量、构建图谱,提供关键词与语义的混合搜索,并内置遗忘机制与自我反思流程,以维持长期记忆的有效性与可用性。
## 核心特性
- 知识萃取:陈述句、三元组、时间信息与摘要生成
- 图谱存储:对接 Neo4j 管理实体与关系
- 混合搜索:关键词检索 + 语义向量检索
- 遗忘机制:按记忆强度与时效做逐步衰减
- 自我反思:定期回顾并优化已有记忆
- FastAPI 服务:统一暴露管理端与服务端 API
## 架构总览
- 萃取引擎Extraction Engine预处理、去重、结构化提取
- 遗忘引擎Forgetting Engine记忆强度模型与衰减策略
- 自我反思引擎Reflection Engine评价与重写记忆
- 检索服务:关键词、语义与混合检索
- Agent 与 MCP提供多工具协作的智能体能力
## 快速开始
### 环境要求
- Python 3.12
- PostgreSQL 13+
- Neo4j 4.4+
- Redis 6.0+
### 安装依赖
```bash
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
# 方式一:基于 pyproject 安装
pip install .
# 方式二:使用 requirements.txt
pip install -r requirements.txt
```
### 配置环境变量
创建 `.env` 文件(示例):
```env
# Postgres
DB_HOST=127.0.0.1
DB_PORT=5432
DB_USER=postgres
DB_PASSWORD=your-password
DB_NAME=redbear-mem
DB_AUTO_UPGRADE=false
# Neo4j
NEO4J_URI=bolt://localhost:7687
NEO4J_USERNAME=neo4j
NEO4J_PASSWORD=your-password
# Redis
REDIS_HOST=127.0.0.1
REDIS_PORT=6379
REDIS_DB=1
# LLM / API Keys按需
OPENAI_API_KEY=your-openai-key
DASHSCOPE_API_KEY=your-dashscope-key
# 其他
WEB_URL=http://localhost:3000
LOG_LEVEL=INFO
```
### 初始化与启动
```bash
# 如需自动迁移数据库:设置 DB_AUTO_UPGRADE=true 或手动执行
alembic upgrade head
# 启动开发服务
uvicorn app.main:app --reload --port 8000
# 打开交互文档
# http://localhost:8000/docs
```
## 项目结构
```
app/
├── main.py # FastAPI 入口
├── controllers/ # 控制器与路由
├── core/ # 核心:配置、异常、日志等
│ └── memory/ # 记忆模块
│ ├── storage_services/ # 萃取/遗忘/反思/检索
│ ├── agent/ # Agent + MCP 服务
│ ├── utils/ # 工具与提示词
│ └── models/ # 领域模型
└── rag/ # RAG 能力与文档解析
logs/ # 日志与输出
LICENSE # 许可协议Apache-2.0
README.md # 项目说明
```
## API 与路由
- 管理端:`/api`JWT 认证)
- 服务端:`/v1`API Key 认证)
- 根路由健康检查:`GET /` 返回运行状态
- Swagger 文档:`/docs`
## 部署建议
- 使用 `gunicorn` + `uvicorn.workers.UvicornWorker` 作为生产入口
- 配置 `LOG_LEVEL=WARNING` 并启用文件日志
- 数据库与缓存请使用托管服务或独立实例
示例:
```bash
gunicorn app.main:app -w 4 -k uvicorn.workers.UvicornWorker
```
## 许可证
本项目采用 Apache License 2.0 开源协议,详情见 `LICENSE`
## 致谢与交流
- 问题反馈与讨论:请提交 Issue 到代码仓库
- 欢迎贡献:提交 PR 前请先创建功能分支并遵循常规提交信息格式

116
alembic.ini Normal file
View File

@@ -0,0 +1,116 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = postgresql://user:password@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

201
app/aioRedis.py Normal file
View File

@@ -0,0 +1,201 @@
import os
import asyncio
import json
import logging
from typing import Dict, Any, Optional
import redis.asyncio as redis
from redis.asyncio import ConnectionPool
from app.core.config import settings
# 设置日志记录器
logger = logging.getLogger(__name__)
# 创建连接池
pool = ConnectionPool.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=30
)
aio_redis = redis.StrictRedis(connection_pool=pool)
async def get_redis_connection():
"""获取Redis连接"""
try:
return redis.StrictRedis(connection_pool=pool)
except Exception as e:
logger.error(f"Redis连接失败: {str(e)}")
return None
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
"""设置Redis键值
Args:
key: Redis键
val: 要存储的值(字符串或字典)
expire: 过期时间(秒)None表示永不过期
"""
try:
if isinstance(val, dict):
val = json.dumps(val, ensure_ascii=False)
if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire)
else:
# 设置永久键值
await aio_redis.set(key, val)
except Exception as e:
logger.error(f"Redis set错误: {str(e)}")
async def aio_redis_get(key: str):
"""获取Redis键值"""
try:
return await aio_redis.get(key)
except Exception as e:
logger.error(f"Redis get错误: {str(e)}")
return None
async def aio_redis_delete(key: str):
"""删除Redis键"""
try:
return await aio_redis.delete(key)
except Exception as e:
logger.error(f"Redis delete错误: {str(e)}")
return None
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
"""发布消息到Redis频道"""
try:
conn = await get_redis_connection()
if not conn:
return False
await conn.publish(channel, json.dumps(message, ensure_ascii=False))
return True
except Exception as e:
logger.error(f"Redis发布错误: {str(e)}")
return False
class RedisSubscriber:
"""Redis订阅器"""
def __init__(self, channel: str):
self.channel = channel
self.conn = None
self.pubsub = None
self.is_closed = False
self._queue = asyncio.Queue()
self._task = None
async def start(self):
"""开始订阅"""
if self.is_closed or self._task:
return
self._task = asyncio.create_task(self._receive_messages())
logger.info(f"开始订阅: {self.channel}")
async def _receive_messages(self):
"""接收消息"""
try:
self.conn = await get_redis_connection()
if not self.conn:
return
self.pubsub = self.conn.pubsub()
await self.pubsub.subscribe(self.channel)
while not self.is_closed:
try:
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
if message and isinstance(message.get("data"), str):
try:
await self._queue.put(json.loads(message["data"]))
except json.JSONDecodeError:
logger.warning(f"消息解析失败: {message['data']}")
await asyncio.sleep(0.01)
except Exception as e:
if "closed" in str(e).lower():
break
logger.warning(f"接收消息错误: {str(e)}")
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"订阅错误: {str(e)}")
await self._queue.put({"type": "error", "data": {"message": str(e), "status": "error"}})
finally:
await self._queue.put(None)
await self._cleanup()
async def _cleanup(self):
"""清理资源"""
if self.pubsub:
try:
await self.pubsub.unsubscribe(self.channel)
await self.pubsub.close()
except Exception:
pass
if self.conn:
try:
await self.conn.close()
except Exception:
pass
async def get_message(self) -> Optional[Dict[str, Any]]:
"""获取消息"""
if self.is_closed:
return None
if not self._task:
await self.start()
try:
return await self._queue.get()
except Exception as e:
logger.error(f"获取消息错误: {str(e)}")
return None
async def close(self):
"""关闭订阅器"""
if self.is_closed:
return
self.is_closed = True
if self._task:
self._task.cancel()
await self._cleanup()
class RedisPubSubManager:
"""Redis发布订阅管理器"""
def __init__(self):
self.subscribers = {}
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
return await aio_redis_publish(channel, message)
def get_subscriber(self, channel: str) -> RedisSubscriber:
if channel in self.subscribers:
subscriber = self.subscribers[channel]
if not subscriber.is_closed:
return subscriber
subscriber = RedisSubscriber(channel)
self.subscribers[channel] = subscriber
return subscriber
def cancel_subscription(self, channel: str) -> bool:
if channel in self.subscribers:
asyncio.create_task(self.subscribers[channel].close())
del self.subscribers[channel]
return True
return False
def cancel_all_subscriptions(self) -> int:
count = len(self.subscribers)
for subscriber in self.subscribers.values():
asyncio.create_task(subscriber.close())
self.subscribers.clear()
return count
# 全局实例
pubsub_manager = RedisPubSubManager()

109
app/celery_app.py Normal file
View File

@@ -0,0 +1,109 @@
import os
from datetime import timedelta
from urllib.parse import quote
from celery import Celery
from app.core.config import settings
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
celery_app = Celery(
"redbear_tasks",
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置
celery_app.conf.update(
# 序列化
task_serializer='json',
accept_content=['json'],
result_serializer='json',
# 时区
timezone='Asia/Shanghai',
enable_utc=True,
# 任务追踪
task_track_started=True,
task_ignore_result=False,
# 超时设置
task_time_limit=30 * 60, # 30 分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时
# Worker 设置 - 针对 macOS 优化
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# 结果过期时间
result_expires=3600, # 结果保存 1 小时
# 任务确认设置
task_acks_late=True, # 任务完成后才确认,避免任务丢失
worker_disable_rate_limits=True, # 禁用速率限制
# 任务路由(可选,用于不同队列)
# task_routes={
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
# 'tasks.process_item': {'queue': 'default'},
# },
)
# 自动发现任务模块
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
# 构建定时任务配置
beat_schedule_config = {
"run-reflection-engine": {
"task": "app.core.memory.agent.reflection.timer",
"schedule": reflection_schedule,
"args": (),
},
"check-read-service": {
"task": "app.core.memory.agent.health.check_read_service",
"schedule": health_schedule,
"args": (),
},
}
# 如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule,
"kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
},
}
celery_app.conf.beat_schedule = beat_schedule_config

10
app/celery_worker.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Celery Worker 入口点
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
"""
from app.celery_app import celery_app
# 导入任务模块以注册任务
import app.tasks
__all__ = ['celery_app']

View File

@@ -0,0 +1,60 @@
"""管理端接口 - 基于 JWT 认证
路由前缀: /
认证方式: JWT Token
"""
from fastapi import APIRouter
from . import (
model_controller,
task_controller,
test_controller,
user_controller,
auth_controller,
workspace_controller,
setup_controller,
file_controller,
document_controller,
knowledge_controller,
chunk_controller,
knowledgeshare_controller,
app_controller,
upload_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_storage_controller,
memory_dashboard_controller,
api_key_controller,
release_share_controller,
public_share_controller,
multi_agent_controller,
)
# 创建管理端 API 路由器
manager_router = APIRouter()
# 注册所有管理端路由
manager_router.include_router(task_controller.router)
manager_router.include_router(user_controller.router)
manager_router.include_router(auth_controller.router)
manager_router.include_router(workspace_controller.router)
manager_router.include_router(workspace_controller.public_router) # 公开路由(无需认证)
manager_router.include_router(setup_controller.router)
manager_router.include_router(model_controller.router)
manager_router.include_router(file_controller.router)
manager_router.include_router(document_controller.router)
manager_router.include_router(knowledge_controller.router)
manager_router.include_router(chunk_controller.router)
manager_router.include_router(test_controller.router)
manager_router.include_router(knowledgeshare_controller.router)
manager_router.include_router(app_controller.router)
manager_router.include_router(upload_controller.router)
manager_router.include_router(memory_agent_controller.router)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(memory_storage_controller.router)
manager_router.include_router(api_key_controller.router)
manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
__all__ = ["manager_router"]

View File

@@ -0,0 +1,151 @@
"""API Key 管理接口 - 基于 JWT 认证"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
import uuid
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.user_model import User
from app.core.response_utils import success
from app.schemas import api_key_schema
from app.schemas.response_schema import ApiResponse
from app.services.api_key_service import ApiKeyService
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/apikeys", tags=["API Keys"])
logger = get_business_logger()
@router.post("", response_model=ApiResponse)
@cur_workspace_access_guard()
def create_api_key(
data: api_key_schema.ApiKeyCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建 API Key
- 创建后返回明文 API Key仅此一次
- 支持设置权限范围、速率限制、配额等
"""
workspace_id = current_user.current_workspace_id
api_key_obj, api_key = ApiKeyService.create_api_key(
db,
workspace_id=workspace_id,
user_id=current_user.id,
data=data
)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
return success(data=response_data, msg="API Key 创建成功")
@router.get("", response_model=ApiResponse)
@cur_workspace_access_guard()
def list_api_keys(
type: api_key_schema.ApiKeyType = Query(None),
is_active: bool = Query(None),
resource_id: uuid.UUID = Query(None),
page: int = Query(1, ge=1),
pagesize: int = Query(10, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""列出 API Keys"""
workspace_id = current_user.current_workspace_id
query = api_key_schema.ApiKeyQuery(
type=type,
is_active=is_active,
resource_id=resource_id,
page=page,
pagesize=pagesize
)
result = ApiKeyService.list_api_keys(db, workspace_id, query)
return success(data=result)
@router.get("/{api_key_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取 API Key 详情"""
workspace_id = current_user.current_workspace_id
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
return success(data=api_key_schema.ApiKey.model_validate(api_key))
@router.put("/{api_key_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def update_api_key(
api_key_id: uuid.UUID,
data: api_key_schema.ApiKeyUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新 API Key"""
workspace_id = current_user.current_workspace_id
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
@router.delete("/{api_key_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def delete_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除 API Key"""
workspace_id = current_user.current_workspace_id
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
return success(msg="API Key 删除成功")
@router.post("/{api_key_id}/regenerate", response_model=ApiResponse)
@cur_workspace_access_guard()
def regenerate_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""重新生成 API Key
- 生成新的 API Key 并返回明文(仅此一次)
- 旧的 API Key 立即失效
"""
workspace_id = current_user.current_workspace_id
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
return success(data=response_data, msg="API Key 重新生成成功")
@router.get("/{api_key_id}/stats", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_api_key_stats(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取 API Key 使用统计"""
workspace_id = current_user.current_workspace_id
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
return success(data=stats)

View File

@@ -0,0 +1,716 @@
import uuid
from typing import Optional
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.models import User
from app.repositories import knowledge_repository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services import app_service, workspace_service
from app.services.app_service import AppService
from app.services.agent_config_helper import enrich_agent_config
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
from fastapi.responses import StreamingResponse
from app.models.app_model import AppType
from app.core.error_codes import BizCode
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@router.post("", summary="创建应用(可选创建 Agent 配置)")
@cur_workspace_access_guard()
def create_app(
payload: app_schema.AppCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
return success(data=app_schema.App.model_validate(app))
@router.get("", summary="应用列表(分页)")
@cur_workspace_access_guard()
def list_apps(
type: str | None = None,
visibility: str | None = None,
status: str | None = None,
search: str | None = None,
include_shared: bool = True,
page: int = 1,
pagesize: int = 10,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出应用
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
"""
workspace_id = current_user.current_workspace_id
items_orm, total = app_service.list_apps(
db,
workspace_id=workspace_id,
type=type,
visibility=visibility,
status=status,
search=search,
include_shared=include_shared,
page=page,
pagesize=pagesize,
)
# 使用 AppService 的转换方法来设置 is_shared 字段
service = app_service.AppService(db)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items))
@router.get("/{app_id}", summary="获取应用详情")
@cur_workspace_access_guard()
def get_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取应用详细信息
- 支持获取本工作空间的应用
- 支持获取分享给本工作空间的应用
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
app = service.get_app(app_id, workspace_id)
# 转换为 Schema 并设置 is_shared 字段
app_schema_obj = service._convert_to_schema(app, workspace_id)
return success(data=app_schema_obj)
@router.put("/{app_id}", summary="更新应用基本信息")
@cur_workspace_access_guard()
def update_app(
app_id: uuid.UUID,
payload: app_schema.AppUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=app_schema.App.model_validate(app))
@router.delete("/{app_id}", summary="删除应用")
@cur_workspace_access_guard()
def delete_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""删除应用
会级联删除:
- Agent 配置
- 发布版本
- 会话和消息
"""
workspace_id = current_user.current_workspace_id
logger.info(
f"用户请求删除应用",
extra={
"app_id": str(app_id),
"user_id": str(current_user.id),
"workspace_id": str(workspace_id)
}
)
app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id)
return success(msg="应用删除成功")
@router.post("/{app_id}/copy", summary="复制应用")
@cur_workspace_access_guard()
def copy_app(
app_id: uuid.UUID,
new_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""复制应用(包括基础信息和配置)
- 复制应用的基础信息(名称、描述、图标等)
- 复制 Agent 配置(如果是 agent 类型)
- 新应用默认为草稿状态
- 不影响原应用
"""
workspace_id = current_user.current_workspace_id
logger.info(
f"用户请求复制应用",
extra={
"source_app_id": str(app_id),
"user_id": str(current_user.id),
"workspace_id": str(workspace_id),
"new_name": new_name
}
)
service = AppService(db)
new_app = service.copy_app(
app_id=app_id,
user_id=current_user.id,
workspace_id=workspace_id,
new_name=new_name
)
return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功")
@router.put("/{app_id}/config", summary="更新 Agent 配置")
@cur_workspace_access_guard()
def update_agent_config(
app_id: uuid.UUID,
payload: app_schema.AgentConfigUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
cfg = enrich_agent_config(cfg)
return success(data=app_schema.AgentConfig.model_validate(cfg))
@router.get("/{app_id}/config", summary="获取 Agent 配置")
@cur_workspace_access_guard()
def get_agent_config(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
# 配置总是存在(不存在时返回默认模板)
cfg = enrich_agent_config(cfg)
return success(data=app_schema.AgentConfig.model_validate(cfg))
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
@cur_workspace_access_guard()
def publish_app(
app_id: uuid.UUID,
payload: app_schema.PublishRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.publish(
db,
app_id=app_id,
publisher_id=current_user.id,
workspace_id=workspace_id,
version_name = payload.version_name,
release_notes=payload.release_notes
)
return success(data=app_schema.AppRelease.model_validate(release))
@router.get("/{app_id}/release", summary="获取当前发布版本")
@cur_workspace_access_guard()
def get_current_release(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
if not release:
return success(data=None)
return success(data=app_schema.AppRelease.model_validate(release))
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
@cur_workspace_access_guard()
def list_releases(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
data = [app_schema.AppRelease.model_validate(r) for r in releases]
return success(data=data)
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
@cur_workspace_access_guard()
def rollback(
app_id: uuid.UUID,
version: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
return success(data=app_schema.AppRelease.model_validate(release))
@router.post("/{app_id}/share", summary="分享应用到其他工作空间")
@cur_workspace_access_guard()
def share_app(
app_id: uuid.UUID,
payload: app_schema.AppShareCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""分享应用到其他工作空间
- 只能分享自己工作空间的应用
- 不能分享到自己的工作空间
- 同一个应用不能重复分享到同一个工作空间
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.share_app(
app_id=app_id,
target_workspace_ids=payload.target_workspace_ids,
user_id=current_user.id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间")
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
@cur_workspace_access_guard()
def unshare_app(
app_id: uuid.UUID,
target_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""取消应用分享
- 只能取消自己工作空间应用的分享
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
service.unshare_app(
app_id=app_id,
target_workspace_id=target_workspace_id,
workspace_id=workspace_id
)
return success(msg="应用分享已取消")
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
@cur_workspace_access_guard()
def list_app_shares(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出应用的所有分享记录
- 只能查看自己工作空间应用的分享记录
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.list_app_shares(
app_id=app_id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data)
@router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置")
@cur_workspace_access_guard()
async def draft_run(
app_id: uuid.UUID,
payload: app_schema.DraftRunRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
试运行 Agent使用当前的草稿配置未发布的配置
- 不需要发布应用即可测试
- 使用当前的 AgentConfig 配置
- 支持流式和非流式返回
"""
workspace_id = current_user.current_workspace_id
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
service = AppService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
service._validate_app_accessible(app, workspace_id)
if app.type == AppType.AGENT:
service._check_agent_config(app_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
logger.debug(
f"开始非流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id),
"has_variables": bool(payload.variables)
}
)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
logger.debug(
f"试运行返回结果",
extra={
"result_type": str(type(result)),
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
}
)
# 验证结果
try:
validated_result = app_schema.DraftRunResponse.model_validate(result)
logger.debug(f"结果验证成功")
return success(data=validated_result)
except Exception as e:
logger.error(
f"结果验证失败",
extra={
"error": str(e),
"error_type": str(type(e)),
"result": str(result)[:200]
}
)
raise
elif app.type == AppType.MULTI_AGENT:
# 1. 检查多智能体配置完整性
service._check_multi_agent_config(app_id)
# 2. 构建多智能体运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=payload.message,
conversation_id=payload.conversation_id,
user_id=payload.user_id,
variables=payload.variables or {},
use_llm_routing=True # 默认启用 LLM 路由
)
# 3. 流式返回
if payload.stream:
logger.debug(
f"开始多智能体流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 4. 非流式返回
logger.debug(
f"开始多智能体非流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
multiservice = MultiAgentService(db)
result = await multiservice.run(app_id, multi_agent_request)
logger.debug(
f"多智能体试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="多 Agent 任务执行成功"
)
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
@cur_workspace_access_guard()
async def draft_run_compare(
app_id: uuid.UUID,
payload: app_schema.DraftRunCompareRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
多模型对比试运行
- 支持对比 1-5 个模型
- 可以是不同的模型,也可以是同一模型的不同参数配置
- 通过 model_parameters 覆盖默认参数
- 支持并行或串行执行(非流式)
- 支持流式返回(串行执行)
- 返回每个模型的运行结果和性能对比
使用场景:
1. 对比不同模型的效果GPT-4 vs Claude vs Gemini
2. 调优模型参数(不同 temperature 的效果对比)
3. 性能和成本分析
"""
workspace_id = current_user.current_workspace_id
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
logger.info(
f"多模型对比试运行",
extra={
"app_id": str(app_id),
"model_count": len(payload.models),
"parallel": payload.parallel,
"stream": payload.stream
}
)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.models import ModelConfig
service = AppService(db)
# 1. 验证应用和权限
app = service._get_app_or_404(app_id)
if app.type != "agent":
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
from sqlalchemy import select
from app.models import AgentConfig
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 验证所有模型配置
model_configs = []
for model_item in payload.models:
model_config = db.get(ModelConfig, model_item.model_config_id)
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id,
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
})
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
)
logger.info(
f"多模型对比完成",
extra={
"app_id": str(app_id),
"successful": result["successful_count"],
"failed": result["failed_count"]
}
)
return success(data=app_schema.DraftRunCompareResponse(**result))

View File

@@ -0,0 +1,195 @@
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.response_utils import success
from app.db import get_db
from app.schemas.response_schema import ApiResponse
from app.schemas.token_schema import Token, RefreshTokenRequest, TokenRequest
from app.schemas.workspace_schema import InviteAcceptRequest
from app.services import auth_service, user_service, workspace_service
from app.core import security
from app.core.config import settings
from app.services.session_service import SessionService
from app.core.logging_config import get_auth_logger, get_security_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.dependencies import get_current_user, oauth2_scheme
from app.models.user_model import User
# 获取专用日志器
auth_logger = get_auth_logger()
security_logger = get_security_logger()
router = APIRouter(tags=["Authentication"])
@router.post("/token", response_model=ApiResponse)
async def login_for_access_token(
form_data: TokenRequest,
db: Session = Depends(get_db)
):
"""用户登录获取token"""
auth_logger.info(f"用户登录请求: {form_data.email}")
# 验证邀请码(如果提供)
invite_info = None
# 验证用户凭据或注册新用户
user = None
if form_data.invite:
auth_logger.info(f"检测到邀请码: {form_data.invite[:8]}...")
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
if not invite_info.is_valid:
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
if invite_info.email != form_data.email:
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
try:
# 尝试认证用户
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
if form_data.invite:
auth_service.bind_workspace_with_invite(db=db,
user=user,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id)
except BusinessException as e:
# 用户不存在且有邀请码,尝试注册
if e.code == BizCode.USER_NOT_FOUND:
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
user = auth_service.register_user_with_invite(
db=db,
email=form_data.email,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
else:
# 其他认证失败情况,直接抛出
raise
else:
try:
# 尝试认证用户
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
except BusinessException as e:
# 其他认证失败情况,直接抛出
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
# 创建 tokens
access_token, access_token_id = security.create_access_token(subject=user.id)
refresh_token, refresh_token_id = security.create_refresh_token(subject=user.id)
# 计算过期时间
access_expires_at = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
# 单点登录会话管理
if settings.ENABLE_SINGLE_SESSION:
await SessionService.invalidate_old_session(user.id, access_token_id)
await SessionService.set_user_active_session(user.id, access_token_id, access_expires_at)
# 更新最后登录时间
user_service.update_last_login_time(db, user.id)
auth_logger.info(f"用户 {user.username} 登录成功")
return success(
data=Token(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg="登录成功"
)
@router.post("/refresh", response_model=ApiResponse)
async def refresh_token(
refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db)
):
"""刷新token"""
auth_logger.info("收到token刷新请求")
# 验证 refresh token
userId = security.verify_token(refresh_request.refresh_token, "refresh")
if not userId:
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
# 检查用户是否存在
user = auth_service.get_user_by_id(db, userId)
if not user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 检查 refresh token 黑名单
if settings.ENABLE_SINGLE_SESSION:
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
# 生成新 tokens
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
new_refresh_token, new_refresh_token_id = security.create_refresh_token(subject=user.id)
# 计算过期时间
access_expires_at = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_expires_at = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
# 单点登录会话管理
if settings.ENABLE_SINGLE_SESSION:
# 将旧 refresh token 加入黑名单
old_refresh_token_id = security.get_token_id(refresh_request.refresh_token)
if old_refresh_token_id:
await SessionService.blacklist_token(old_refresh_token_id)
# 更新会话
await SessionService.invalidate_old_session(user.id, new_access_token_id)
await SessionService.set_user_active_session(user.id, new_access_token_id, access_expires_at)
auth_logger.info(f"用户 {user.id} token刷新成功")
return success(
data=Token(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer",
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg="token刷新成功"
)
@router.post("/logout", response_model=ApiResponse)
async def logout(
token: str = Depends(oauth2_scheme),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""登出当前用户加入token黑名单并清理会话"""
auth_logger.info(f"用户 {current_user.username} 请求登出")
token_id = security.get_token_id(token)
if not token_id:
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
# 加入黑名单
await SessionService.blacklist_token(token_id)
# 清理会话
if settings.ENABLE_SINGLE_SESSION:
await SessionService.clear_user_session(current_user.username)
auth_logger.info(f"用户 {current_user.username} 登出成功")
return success(msg="登出成功")

View File

@@ -0,0 +1,447 @@
import os
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.core.config import settings
from app.db import get_db
from app.core.rag.llm.cv_model import QWenCV
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models.document_model import Document
from app.models import knowledge_model, knowledgeshare_model
from app.core.rag.models.chunk import DocumentChunk
from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/chunks",
tags=["chunks"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
async def get_preview_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query document block preview list
- Support filtering by document_id
- Support keyword search for segmented content
- Return paging metadata + file list
"""
api_logger.info(f"Paged query document block preview list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 3. Check if the document exists
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
# 4. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
if not db_file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 5. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 6. Check if the file exists
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
# 7. Document parsing & segmentation
def progress_callback(prog=None, msg=None):
print(f"prog: {prog} msg: {msg}\n")
# Prepare to configure vision_model information
vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese", # Default to Chinese
base_url=db_knowledge.image2text.api_keys[0].api_base
)
from app.core.rag.app.naive import chunk
res = chunk(filename=file_path,
from_page=0,
to_page=5,
callback=progress_callback,
vision_model=vision_model,
parser_config=db_document.parser_config,
is_root=False)
start_index = (page - 1) * pagesize
end_index = start_index + pagesize
# Use slicing to obtain the data of the current page
paginated_chunk_str_list = res[start_index:end_index]
chunks = []
for idx, item in enumerate(paginated_chunk_str_list):
metadata = {
"doc_id": uuid.uuid4().hex,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": idx,
"status": 1,
}
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
# 8. Return structured response
total = len(res)
result = {
"items": chunks,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
return success(data=result, msg="Querying the document block preview list succeeded")
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
async def get_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query document chunk list
- Support filtering by document_id
- Support keyword search for segmented content
- Return paging metadata + file list
"""
api_logger.info(f"Paged query document chunk list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 3. Execute paged query
try:
api_logger.debug(f"Start executing document chunk query")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True)
api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Document chunk query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of document chunk list succeeded")
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
async def create_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
create_data: chunk_schema.ChunkCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create chunk
"""
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
# 1. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 1. Obtain document information
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 2. Get the sort ID
sort_id = 0
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
sort_id = sort_id + 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(document_id),
"knowledge_id": str(kb_id),
"sort_id": sort_id,
"status": 1,
}
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
# 3. Segmented vector storage
vector_service.add_chunks([chunk])
# 4.update chunk_num
db_document.chunk_num += 1
db.commit()
return success(data=chunk, msg="Document chunk creation successful")
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def get_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve document chunk information based on doc_id
"""
api_logger.info(f"Obtain document chunk information: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
# 1. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.get_by_segment(doc_id=doc_id)
if total:
return success(data=items[0], msg="Document chunk query successful")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document chunk does not exist or you do not have access"
)
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def update_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
update_data: chunk_schema.ChunkUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Update document chunk content
"""
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={update_data.content}, username: {current_user.username}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.get_by_segment(doc_id=doc_id)
if total:
chunk = items[0]
chunk.page_content = update_data.content
vector_service.update_by_segment(chunk)
return success(data=chunk, msg="The document chunk has been successfully updated")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document chunk does not exist or you do not have access to it"
)
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def delete_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
delete document chunk
"""
api_logger.info(f"Request to delete document chunk: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
if vector_service.text_exists(doc_id):
vector_service.delete_by_ids([doc_id])
# 更新 chunk_num
db_document = db.query(Document).filter(Document.id == document_id).first()
db_document.chunk_num -= 1
db.commit()
return success(msg="The document chunk has been successfully deleted")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document chunk does not exist or you do not have access to it"
)
@router.get("/retrieve_type", response_model=ApiResponse)
def get_retrieve_types():
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
async def retrieve_chunks(
retrieve_data: chunk_schema.ChunkRetrieve,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
retrieve chunk
"""
api_logger.info(f"retrieve chunk: query={retrieve_data.query}, username: {current_user.username}")
filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
]
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters,
current_user=current_user
)
existing_ids.extend(items)
if not existing_ids:
return success(data=[], msg="retrieval successful")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
return success(data=rs, msg="retrieval successful")

View File

@@ -0,0 +1,341 @@
import os
from typing import Optional
import datetime
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import document_model
from app.schemas import document_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import document_service, file_service, knowledge_service
from app.controllers import file_controller
from app.celery_app import celery_app
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/documents",
tags=["documents"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
async def get_documents(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
document_ids: Optional[str] = Query(None, description="document ids, separated by commas"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query document list
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query document list: kb_id={kb_id}, page={page}, pagesize={pagesize}, keywords={keywords}, document_ids={document_ids}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
document_model.Document.kb_id == kb_id,
document_model.Document.status == 1
]
if parent_id:
files = file_service.get_files_by_parent_id(db=db, parent_id=parent_id, current_user=current_user)
files_ids = [item.id for item in files]
filters.append(document_model.Document.file_id.in_(files_ids))
# Keyword search (fuzzy matching of file name)
if keywords:
api_logger.debug(f"Add keyword search criteria: {keywords}")
filters.append(document_model.Document.file_name.ilike(f"%{keywords}%"))
# document ids
if document_ids:
filters.append(document_model.Document.id.in_(document_ids.split(',')))
# 3. Execute paged query
try:
api_logger.debug(f"Start executing document paging query")
total, items = document_service.get_documents_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"Document query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Document query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of document list succeeded")
@router.post("/document", response_model=ApiResponse)
async def create_document(
create_data: document_schema.DocumentCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create document
"""
api_logger.info(f"Create document request: file_name={create_data.file_name}, kb_id={create_data.kb_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating a document: {create_data.file_name}")
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful")
except Exception as e:
api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}")
raise
@router.get("/{document_id}", response_model=ApiResponse)
async def get_document(
document_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve document information based on document_id
"""
api_logger.info(f"Obtain document information: document_id={document_id}, username: {current_user.username}")
try:
# 1. Query document information from the database
api_logger.debug(f"query documentation: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have access: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have access"
)
api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Document query failed: document_id={document_id} - {str(e)}")
raise
@router.put("/{document_id}", response_model=ApiResponse)
async def update_document(
document_id: uuid.UUID,
update_data: document_schema.DocumentUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Update document information
"""
# 1. Check if the document exists
api_logger.debug(f"Query the document to be updated: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
# 2. If updating the status, synchronize the document status switch to whether it can be retrieved from the vector database
update_dict = update_data.dict(exclude_unset=True)
if "status" in update_dict:
new_status = update_dict["status"]
if new_status != db_document.status:
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
vector_service.change_status_by_document_id(document_id=str(document_id), status=new_status)
# 3. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the document fields: {document_id}")
updated_fields = []
for field, value in update_dict.items():
if hasattr(db_document, field):
old_value = getattr(db_document, field)
if old_value != value:
# update value
setattr(db_document, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
db_document.updated_at = datetime.datetime.now()
# 4. Save to database
try:
db.commit()
db.refresh(db_document)
api_logger.info(f"The document has been successfully updated: {db_document.file_name} (ID: {db_document.id})")
except Exception as e:
db.rollback()
api_logger.error(f"Document update failed: document_id={document_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Document update failed: {str(e)}"
)
# 5. Return the updated document
return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully")
@router.delete("/{document_id}", response_model=ApiResponse)
async def delete_document(
document_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Delete document
"""
api_logger.info(f"Request to delete document: document_id={document_id}, username: {current_user.username}")
try:
# 1. Check if the document exists
api_logger.debug(f"Check whether the document exists: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
file_id = db_document.file_id
# 2. Delete document
api_logger.debug(f"Perform document delete: {db_document.file_name} (ID: {document_id})")
db.delete(db_document)
db.commit()
# 3. Delete file
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
# 4. Delete vector index
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
api_logger.info(f"The document has been successfully deleted: {db_document.file_name} (ID: {document_id})")
return success(msg="The document has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the document: document_id={document_id} - {str(e)}")
raise
@router.post("/{document_id}/chunks", response_model=ApiResponse)
async def parse_documents(
document_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
parse document
"""
api_logger.info(f"Request to parse document: document_id={document_id}, username: {current_user.username}")
try:
# 1. Check if the document exists
api_logger.debug(f"Check whether the document exists: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
# 2. Check if the file exists
api_logger.debug(f"Check whether the file exists: {db_document.file_id}")
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={db_document.file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 3. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 4. Check if the file exists
if not os.path.exists(file_path):
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
# 5. Obtain knowledge base information
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 6. Task: Document parsing, vectorization, and storage
# from app.tasks import parse_document
# parse_document(file_path, document_id)
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
result = {
"task_id": task.id
}
return success(data=result, msg="Task accepted. The document is being processed in the background.")
except Exception as e:
api_logger.error(f"Failed to parse document: document_id={document_id} - {str(e)}")
raise

View File

@@ -0,0 +1,453 @@
import os
from typing import Any, Optional
from pathlib import Path
import shutil
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import file_model
from app.schemas import file_schema, document_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import file_service, document_service
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/files",
tags=["files"]
)
@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse)
async def get_files(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query file list
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
file_model.File.kb_id == kb_id
]
if parent_id:
filters.append(file_model.File.parent_id == parent_id)
# Keyword search (fuzzy matching of file name)
if keywords:
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
# 3. Execute paged query
try:
api_logger.debug(f"Start executing file paging query")
total, items = file_service.get_files_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of file list succeeded")
@router.post("/folder", response_model=ApiResponse)
def create_folder(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
folder_name: str = '/',
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a new folder
"""
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating a folder: {folder_name}")
create_folder = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=folder_name,
file_ext='folder',
file_size=0,
)
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful")
except Exception as e:
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
raise
@router.post("/file", response_model=ApiResponse)
async def upload_file(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
upload file
"""
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
# Read the contents of the file
contents = await file.read()
# Check file size
file_size = len(contents)
print(f"file size: {file_size} byte")
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
# Extract the extension using `os.path.splitext`
_, file_extension = os.path.splitext(file.filename)
upload_file = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=file.filename,
file_ext=file_extension.lower(),
file_size=file_size,
)
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}{file_extension}")
# Save file
with open(save_path, "wb") as f:
f.write(contents)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_data = document_schema.DocumentCreate(
kb_id=kb_id,
created_by=current_user.id,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful")
@router.post("/customtext", response_model=ApiResponse)
async def custom_text(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
custom text
"""
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
# Check file content size
# 将内容编码为字节UTF-8
content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes)
print(f"file size: {file_size} byte")
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The content is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
upload_file = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=f"{create_data.title}.txt",
file_ext=".txt",
file_size=file_size,
)
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
# Save file
with open(save_path, "wb") as f:
f.write(content_bytes)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_document_data = document_schema.DocumentCreate(
kb_id=kb_id,
created_by=current_user.id,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
@router.get("/{file_id}", response_model=Any)
async def get_file(
file_id: uuid.UUID,
db: Session = Depends(get_db)
) -> Any:
"""
Download the file based on the file_id
- Query file information from the database
- Construct the file path and check if it exists
- Return a FileResponse to download the file
"""
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
# 1. Query file information from the database
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Check if the file exists
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
# 4.Return FileResponse (automatically handle download)
return FileResponse(
path=file_path,
filename=db_file.file_name, # Use original file name
media_type="application/octet-stream" # Universal binary stream type
)
@router.put("/{file_id}", response_model=ApiResponse)
async def update_file(
file_id: uuid.UUID,
update_data: file_schema.FileUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Update file information (such as file name)
- Only specified fields such as file_name are allowed to be modified
"""
api_logger.debug(f"Query the file to be updated: {file_id}")
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the file fields: {file_id}")
updated_fields = []
for field, value in update_data.items():
if hasattr(db_file, field):
old_value = getattr(db_file, field)
if old_value != value:
# update value
setattr(db_file, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database
try:
db.commit()
db.refresh(db_file)
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
except Exception as e:
db.rollback()
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File update failed: {str(e)}"
)
# 4. Return the updated file
return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully")
@router.delete("/{file_id}", response_model=ApiResponse)
async def delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Delete a file or folder
"""
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
await _delete_file(db=db, file_id=file_id, current_user=current_user)
return success(msg="File deleted successfully")
async def _delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> None:
"""
Delete a file or folder
"""
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct physical path
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.id)
) if db_file.file_ext == 'folder' else Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Delete physical files/folders
try:
if file_path.exists():
if db_file.file_ext == 'folder':
shutil.rmtree(file_path) # Recursively delete folders
else:
file_path.unlink() # Delete a single file
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete physical file/folder: {str(e)}"
)
# 4.Delete db_file
if db_file.file_ext == 'folder':
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
db.delete(db_file)
db.commit()

View File

@@ -0,0 +1,305 @@
from typing import Optional
import datetime
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import knowledge_model, document_model, file_model
from app.schemas import knowledge_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/knowledges",
tags=["knowledges"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/knowledgetype", response_model=ApiResponse)
def get_knowledge_types():
return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType))
@router.get("/permissiontype", response_model=ApiResponse)
def get_permission_types():
return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType))
@router.get("/parsertype", response_model=ApiResponse)
def get_parser_types():
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
@router.get("/knowledges", response_model=ApiResponse)
async def get_knowledges(
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"),
kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the knowledge base list in pages
- Support filtering by parent_id
- Support keyword search for knowledge base names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query knowledge base list: workspace_id={current_user.current_workspace_id}, page={page}, pagesize={pagesize}, keywords={keywords}, kb_ids={kb_ids}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
knowledge_model.Knowledge.workspace_id == current_user.current_workspace_id
]
if parent_id:
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
# Keyword search (fuzzy matching of knowledge base name)
if keywords:
api_logger.debug(f"Add keyword search criteria: {keywords}")
filters.append(
or_(
knowledge_model.Knowledge.name.ilike(f"%{keywords}%"),
knowledge_model.Knowledge.description.ilike(f"%{keywords}%")
)
)
# Knowledge base ids
if kb_ids:
filters.append(knowledge_model.Knowledge.id.in_(kb_ids.split(',')))
else:
filters.append(knowledge_model.Knowledge.status != 2)
# 3. Execute paged query
try:
api_logger.debug(f"Start executing knowledge base paging query")
total, items = knowledge_service.get_knowledges_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"Knowledge base query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Knowledge base query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page*pagesize < total else False
}
}
return success(data=result, msg="Query of knowledge base list successful")
@router.post("/knowledge", response_model=ApiResponse)
async def create_knowledge(
create_data: knowledge_schema.KnowledgeCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create knowledge
"""
api_logger.info(f"Request to create a knowledge base: name={create_data.name}, workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating the knowledge base: {create_data.name}")
# 1. Check if the knowledge base name already exists
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=create_data.name, current_user=current_user)
if db_knowledge_exist:
api_logger.warning(f"The knowledge base name already exists: {create_data.name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The knowledge base name already exists: {create_data.name}"
)
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user)
api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})")
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}")
raise
@router.get("/{knowledge_id}", response_model=ApiResponse)
async def get_knowledge(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve knowledge base information based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Query knowledge base information from the database
api_logger.debug(f"Query knowledge base: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})")
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Knowledge base query failed: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.put("/{knowledge_id}", response_model=ApiResponse)
async def update_knowledge(
knowledge_id: uuid.UUID,
update_data: knowledge_schema.KnowledgeUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}")
db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user)
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated")
async def _update_knowledge(
knowledge_id: uuid.UUID,
update_data: knowledge_schema.KnowledgeUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> knowledge_schema.Knowledge:
"""
Update knowledge base information
"""
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Query the knowledge base to be updated: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. If updating the embedding_id, delete the knowledge base vector index, reset all document parsing progress to 0, and set chunk_num to 0
update_dict = update_data.dict(exclude_unset=True)
if "name" in update_dict:
name = update_dict["name"]
if name != db_knowledge.name:
# Check if the knowledge base name already exists
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=name, current_user=current_user)
if db_knowledge_exist:
api_logger.warning(f"The knowledge base name already exists: {name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The knowledge base name already exists: {name}"
)
if "embedding_id" in update_dict:
embedding_id = update_dict["embedding_id"]
if embedding_id != db_knowledge.embedding_id:
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
vector_service.delete()
document_service.reset_documents_progress_by_kb_id(db, kb_id=db_knowledge.id, current_user=current_user)
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the knowledge base fields: {knowledge_id}")
updated_fields = []
for field, value in update_data.dict(exclude_unset=True).items():
if hasattr(db_knowledge, field):
old_value = getattr(db_knowledge, field)
if old_value != value:
# update value
setattr(db_knowledge, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
db_knowledge.updated_at = datetime.datetime.now()
# 3. Save to database
db.commit()
db.refresh(db_knowledge)
api_logger.info(f"The knowledge base has been successfully updated: {db_knowledge.name} (ID: {db_knowledge.id})")
# 4. Return the updated knowledge base
return db_knowledge
except HTTPException:
raise
except Exception as e:
db.rollback()
api_logger.error(f"Knowledge base update failed: knowledge_id={knowledge_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Knowledge base update failed: {str(e)}"
)
@router.delete("/{knowledge_id}", response_model=ApiResponse)
async def delete_knowledge(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Soft-delete knowledge base
"""
api_logger.info(f"Request to delete knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. Soft-delete knowledge base
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
db_knowledge.status = 2
db.commit()
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
return success(msg="The knowledge base has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the knowledge base: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -0,0 +1,199 @@
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import knowledgeshare_model, knowledge_model
from app.schemas import knowledgeshare_schema, knowledge_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledgeshare_service, knowledge_service
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/knowledgeshares",
tags=["knowledgeshares"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/{kb_id}/knowledgeshares", response_model=ApiResponse)
async def get_knowledgeshares(
kb_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query knowledge base sharing list
- Support filtering by kb_id
- Support dynamic sorting
- Return paging metadata + share list
"""
api_logger.info(
f"Query knowledge base sharing list: workspace_id={current_user.current_workspace_id}, kb_id={kb_id}, page={page}, pagesize={pagesize}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
knowledgeshare_model.KnowledgeShare.source_workspace_id == current_user.current_workspace_id,
knowledgeshare_model.KnowledgeShare.source_kb_id == kb_id
]
# 3. Execute paged query
try:
api_logger.debug(f"Start executing knowledge base sharing and paging query")
total, items = knowledgeshare_service.get_knowledgeshares_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"Knowledge base sharing query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Knowledge base sharing query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of knowledge base sharing list successful")
@router.post("/knowledgeshare", response_model=ApiResponse)
async def create_knowledgeshare(
create_data: knowledgeshare_schema.KnowledgeShareCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create knowledgeshare
"""
api_logger.info(
f"Create a knowledge base sharing request: source_kb_id={create_data.source_kb_id}, source_workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
try:
# 1.Create a knowledge base with permission_id=knowledge_model.PermissionType.Share
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=create_data.source_kb_id, current_user=current_user)
knowledge = knowledge_schema.KnowledgeCreate(
workspace_id=create_data.target_workspace_id,
created_by=current_user.id,
parent_id=create_data.target_workspace_id,
name=db_knowledge.name,
description=db_knowledge.description,
avatar=db_knowledge.avatar,
type=db_knowledge.type,
permission_id=knowledge_model.PermissionType.Share,
embedding_id=db_knowledge.embedding_id,
reranker_id=db_knowledge.reranker_id,
llm_id=db_knowledge.llm_id,
image2text_id=db_knowledge.image2text_id,
doc_num=db_knowledge.doc_num,
chunk_num=db_knowledge.chunk_num,
parser_id=db_knowledge.parser_id,
parser_config=db_knowledge.parser_config
)
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=knowledge, current_user=current_user)
# 2. Create a knowledge base for sharing
api_logger.debug(f"Start creating the knowledge base sharing: {db_knowledge.name}")
create_data.target_kb_id = db_knowledge.id
db_knowledgeshare = knowledgeshare_service.create_knowledgeshare(db=db, knowledgeshare=create_data, current_user=current_user)
api_logger.info(f"The knowledge base sharing has been successfully created: (ID: {db_knowledgeshare.id})")
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="The knowledge base sharing has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the knowledge base sharing failed: {str(e)}")
raise
@router.get("/{knowledgeshare_id}", response_model=ApiResponse)
async def get_knowledgeshare(
knowledgeshare_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve knowledge base sharing information based on knowledgeshare_id
"""
api_logger.info(f"Obtain details of the knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
try:
# 1. Query knowledge base sharing information from the database
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
if not db_knowledgeshare:
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base sharing does not exist or access is denied"
)
api_logger.info(f"Knowledge base sharing query successful: (ID: {db_knowledgeshare.id})")
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="Successfully obtained knowledge base sharing information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Knowledge base sharing query failed: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
raise
@router.delete("/{knowledgeshare_id}", response_model=ApiResponse)
async def delete_knowledgeshare(
knowledgeshare_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Delete knowledge base sharing
"""
api_logger.info(f"Delete knowledge base sharing request: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
try:
# 1. Query knowledge base sharing information from the database
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
if not db_knowledgeshare:
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base sharing does not exist or access is denied"
)
# 2. Deleting shared knowledge base
knowledge_service.delete_knowledge_by_id(db, knowledge_id=db_knowledgeshare.target_kb_id ,current_user=current_user)
# 3. Delete knowledge base sharing
api_logger.debug(f"perform knowledge base sharing delete: (ID: {knowledgeshare_id})")
knowledgeshare_service.delete_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
api_logger.info(f"The knowledge base sharing has been successfully deleted: (ID: {knowledgeshare_id})")
return success(msg="The knowledge base sharing has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
raise

View File

@@ -0,0 +1,802 @@
import json
import time
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.db import get_db
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.rag.llm.cv_model import QWenCV
from app.models import ModelApiKey, Knowledge
from app.services.memory_agent_service import MemoryAgentService
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services import task_service, workspace_service
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends, File, UploadFile, Form
from app.repositories import knowledge_repository
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
import os
# 加载.env文件
load_dotenv()
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_agent_service = MemoryAgentService()
router = APIRouter(
prefix="/memory",
tags=["Memory"],
)
def validate_config_id(config_id: int, db: Session) -> int:
"""
Validate and ensure config_id is available, valid, and exists in database.
Args:
config_id: Configuration ID to validate
db: Database session for checking existence
Returns:
int: Validated config_id
Raises:
ValueError: If config_id is None, invalid, or doesn't exist in database
"""
if config_id is None:
api_logger.info(f"config_id is required but was not provided")
config_id = os.getenv('config_id')
if config_id is None:
raise ValueError("config_id is required but was not provided")
# Check if config exists in database
try:
from app.models.data_config_model import DataConfig
from app.models.models_model import ModelConfig
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config is None:
error_msg = f"Configuration with config_id={config_id} does not exist in database"
api_logger.error(error_msg)
raise ValueError(error_msg)
# Validate llm_id exists and is usable
if config.llm_id:
try:
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
if llm_config is None:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not llm_config.is_active:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating LLM model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no llm_id set")
raise ValueError(f"Config {config_id} has no llm_id set")
# Validate embedding_id exists and is usable
if config.embedding_id:
try:
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
if embedding_config is None:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not embedding_config.is_active:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating embedding model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no embedding_id set")
raise ValueError(f"Config {config_id} has no embedding_id set")
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
return config_id
except ValueError:
# Re-raise ValueError from above
raise
except Exception as e:
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
):
"""
Get latest health status written by Celery periodic task
Returns health status information from Redis cache
"""
api_logger.info("Health status check requested")
try:
result = await memory_agent_service.get_health_status()
return success(data=result["status"])
except Exception as e:
api_logger.error(f"Health status check failed: {str(e)}")
return fail(BizCode.SERVICE_UNAVAILABLE, "健康状态查询失败", str(e))
@router.get("/download_log")
async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
):
"""
Download or stream agent service log file
log_type: str = Query("file", regex="^(file|transmission)$",
description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
Args:
log_type: Log retrieval mode
- "file": Returns complete log file content in single response (default)
- "transmission": Real-time streaming of log content using Server-Sent Events
Returns:
- file mode: ApiResponse with log content
- transmission mode: StreamingResponse with SSE
"""
api_logger.info(f"Log download requested with log_type={log_type}")
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
if log_type not in ["file", "transmission"]:
api_logger.warning(f"Invalid log_type parameter: {log_type}")
return fail(
BizCode.BAD_REQUEST,
"无效的log_type参数",
"log_type必须是'file''transmission'"
)
# Route to appropriate mode
if log_type == "file":
# File mode: Return complete log file content
try:
log_content = memory_agent_service.get_log_content()
return success(data=log_content)
except ValueError as e:
api_logger.warning(f"Log content issue: {str(e)}")
return fail(BizCode.FILE_NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Log reading failed: {str(e)}")
return fail(BizCode.FILE_READ_ERROR, "日志读取失败", str(e))
else: # log_type == "transmission"
# Transmission mode: Stream log content using SSE
try:
api_logger.info("Starting SSE log streaming")
return StreamingResponse(
memory_agent_service.stream_log_content(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable nginx buffering
}
)
except Exception as e:
api_logger.error(f"Failed to start log streaming: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
@router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and group_id
Returns:
Response with write operation status
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning(f"workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
result = await memory_agent_service.write_memory(
user_input.group_id,
user_input.message,
config_id,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except Exception as e:
api_logger.error(f"Write operation error: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server_async(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and group_id
Returns:
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Read service endpoint - processes read operations synchronously
search_switch values:
- "0": Requires verification
- "1": No verification, direct split
- "2": Direct answer based on context
Args:
user_input: Read request with message, history, search_switch, and group_id
Returns:
Response with query answer
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.group_id,
user_input.message,
user_input.history,
user_input.search_switch,
config_id,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="回复对话消息成功")
except Exception as e:
api_logger.error(f"Read operation error: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@router.post("/file", response_model=ApiResponse)
async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"),
model_id:str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user)
):
"""
文件上传接口 - 支持图片识别
Args:
files: 上传的文件列表
metadata: 文件元数据(可选)
current_user: 当前用户
Returns:
文件处理结果
"""
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0]
file_content = []
try:
for file in files:
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
content = await file.read()
if file.content_type and file.content_type.startswith("image/"):
vision_model = QWenCV(
key=apiConfig.api_key,
model_name=apiConfig.model_name,
lang="Chinese",
base_url=apiConfig.api_base
)
description, token_count = vision_model.describe(content)
file_content.append(description)
api_logger.info(f"Image processed: {file.filename}, tokens: {token_count}")
else:
api_logger.warning(f"Unsupported file type: {file.content_type}")
file_content.append(f"[不支持的文件类型: {file.content_type}]")
result_text = ';'.join(file_content)
api_logger.info(f"File processing completed, result length: {len(result_text)}")
return success(data=result_text, msg="转换文本成功")
except Exception as e:
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
@router.post("/read_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server_async(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
return success(data={"task_id": task.id}, msg="查询任务已提交")
except Exception as e:
api_logger.error(f"Async read operation failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async read task
Args:
task_id: Celery task ID returned from /read_service_async
Returns:
Task status and result if completed
Response format:
- PENDING: Task is waiting to be executed
- STARTED: Task has started
- SUCCESS: Task completed successfully, returns result
- FAILURE: Task failed, returns error message
"""
api_logger.info(f"Read task status check requested for task {task_id}")
try:
result = task_service.get_task_memory_read_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
if isinstance(task_result, dict):
# 新格式:包含详细信息
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
msg="查询任务已完成"
)
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="查询任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
if isinstance(error_info, dict):
error_msg = error_info.get("error", str(error_info))
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
data={
"status": status,
"task_id": task_id,
"message": "任务处理中,请稍后查询"
},
msg="查询任务处理中"
)
else:
# 未知状态
return success(
data={
"status": status,
"task_id": task_id
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async write task
Args:
task_id: Celery task ID returned from /writer_service_async
Returns:
Task status and result if completed
Response format:
- PENDING: Task is waiting to be executed
- STARTED: Task has started
- SUCCESS: Task completed successfully, returns result
- FAILURE: Task failed, returns error message
"""
api_logger.info(f"Write task status check requested for task {task_id}")
try:
result = task_service.get_task_memory_write_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
if isinstance(task_result, dict):
# 新格式:包含详细信息
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
msg="写入任务已完成"
)
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="写入任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
if isinstance(error_info, dict):
error_msg = error_info.get("error", str(error_info))
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
data={
"status": status,
"task_id": task_id,
"message": "任务处理中,请稍后查询"
},
msg="写入任务处理中"
)
else:
# 未知状态
return success(
data={
"status": status,
"task_id": task_id
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
current_user: User = Depends(get_current_user)
):
"""
Determine the type of user message (read or write)
Args:
user_input: Request containing user message and group_id
Returns:
Type classification result
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
result = await memory_agent_service.classify_message_type(user_input.message)
return success(data=result)
except Exception as e:
api_logger.error(f"Message type classification failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "类型判断失败", str(e))
# ==================== 新增的三个接口路由 ====================
@router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user)
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
only_active=only_active,
current_workspace_id=current_user.current_workspace_id,
db=db
)
return success(data=result, msg="获取知识库类型统计成功")
except Exception as e:
api_logger.error(f"Knowledge type stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
limit: int = Query(20, description="返回标签数量限制"),
current_user: User = Depends(get_current_user)
):
"""
获取指定用户的热门记忆标签
返回格式:
[
{"name": "标签名", "frequency": 频次},
...
]
"""
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
try:
result = await memory_agent_service.get_hot_memory_tags_by_user(
end_user_id=end_user_id,
limit=limit
)
return success(data=result, msg="获取热门记忆标签成功")
except Exception as e:
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
current_user: User = Depends(get_current_user)
):
"""
获取用户详情,包含:
- name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签
返回格式:
{
"name": "用户名",
"tags": ["产品设计师", "旅行爱好者", "摄影发烧友"],
"hot_tags": [
{"name": "标签1", "frequency": 10},
{"name": "标签2", "frequency": 8},
...
]
}
"""
api_logger.info(f"User profile requested: end_user_id={end_user_id}, current_user={current_user.id}")
try:
result = await memory_agent_service.get_user_profile(
end_user_id=end_user_id,
current_user_id=str(current_user.id)
)
return success(data=result, msg="获取用户详情成功")
except Exception as e:
api_logger.error(f"User profile failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取用户详情失败", str(e))
# @router.get("/docs/api", response_model=ApiResponse)
# async def get_api_docs_api(
# file_path: Optional[str] = Query(None, description="API文档文件路径不传则使用默认路径")
# ):
# """
# Get parsed API documentation (Public endpoint - no authentication required)
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Parsed API documentation including title, meta info, and sections
# """
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
# try:
# result = await memory_agent_service.get_api_docs(file_path)
# if result.get("success"):
# return success(msg=result["msg"], data=result["data"])
# else:
# return fail(
# code=BizCode.BAD_REQUEST,
# msg=result["msg"],
# error=result.get("data", {}).get("error", result.get("error_code", ""))
# )
# except Exception as e:
# api_logger.error(f"API docs retrieval failed: {str(e)}")
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))

View File

@@ -0,0 +1,516 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import App as AppSchema
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/dashboard",
tags=["Dashboard"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/total_end_users", response_model=ApiResponse)
def get_workspace_total_end_users(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取用户列表的总用户数
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
total_end_users = memory_dashboard_service.get_workspace_total_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
return success(data=total_end_users, msg="用户数量获取成功")
@router.get("/end_users", response_model=ApiResponse)
async def get_workspace_end_users(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的宿主列表
返回格式与原 memory_list 接口中的 end_users 字段相同
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
end_users = memory_dashboard_service.get_workspace_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
result = []
for end_user in end_users:
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
memory_num = await memory_storage_service.search_all(str(end_user.id))
result.append(
{
'end_user':end_user,
'memory_num':memory_num
}
)
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功")
@router.get("/memory_increment", response_model=ApiResponse)
def get_workspace_memory_increment(
limit: int = Query(7, description="返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作空间的记忆增量"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆增量")
memory_increment = memory_dashboard_service.get_workspace_memory_increment(
db=db,
workspace_id=workspace_id,
current_user=current_user,
limit=limit
)
api_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
return success(data=memory_increment, msg="记忆增量获取成功")
@router.get("/api_increment", response_model=ApiResponse)
def get_workspace_api_increment(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取API调用趋势"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的API调用增量")
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
api_logger.info(f"成功获取 {api_increment} API调用增量")
return success(data=api_increment, msg="API调用增量获取成功")
@router.post("/total_memory", response_model=ApiResponse)
def write_workspace_total_memory(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""工作空间记忆总量的写入(异步任务)"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求写入工作空间 {workspace_id} 的记忆总量")
# 触发 Celery 异步任务
from app.celery_app import celery_app
task = celery_app.send_task(
"app.controllers.memory_storage_controller.search_all",
kwargs={"workspace_id": str(workspace_id)}
)
api_logger.info(f"已触发记忆总量统计任务task_id: {task.id}")
return success(
data={"task_id": task.id, "workspace_id": str(workspace_id)},
msg="记忆总量统计任务已启动"
)
@router.get("/task_status/{task_id}", response_model=ApiResponse)
def get_task_status(
task_id: str,
current_user: User = Depends(get_current_user),
):
"""查询异步任务的执行状态和结果"""
api_logger.info(f"用户 {current_user.username} 查询任务状态: task_id={task_id}")
from app.celery_app import celery_app
from celery.result import AsyncResult
# 获取任务结果
task_result = AsyncResult(task_id, app=celery_app)
response_data = {
"task_id": task_id,
"status": task_result.state, # PENDING, STARTED, SUCCESS, FAILURE, RETRY, REVOKED
}
# 如果任务完成,返回结果
if task_result.ready():
if task_result.successful():
response_data["result"] = task_result.result
api_logger.info(f"任务 {task_id} 执行成功")
return success(data=response_data, msg="任务执行成功")
else:
# 任务失败
response_data["error"] = str(task_result.result)
api_logger.error(f"任务 {task_id} 执行失败: {task_result.result}")
return success(data=response_data, msg="任务执行失败")
else:
# 任务还在执行中
api_logger.info(f"任务 {task_id} 状态: {task_result.state}")
return success(data=response_data, msg=f"任务状态: {task_result.state}")
@router.get("/memory_list", response_model=ApiResponse)
def get_workspace_memory_list(
limit: int = Query(7, description="记忆增量返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
用户记忆列表整合接口
整合以下三个接口的数据:
1. total_memory - 工作空间记忆总量
2. memory_increment - 工作空间记忆增量
3. hosts - 工作空间宿主列表
返回格式:
{
"total_memory": float,
"memory_increment": [
{"date": "2024-01-01", "count": 100},
...
],
"hosts": [
{"id": "uuid", "name": "宿主名", ...},
...
]
}
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆列表")
memory_list = memory_dashboard_service.get_workspace_memory_list(
db=db,
workspace_id=workspace_id,
current_user=current_user,
limit=limit
)
api_logger.info(f"成功获取记忆列表")
return success(data=memory_list, msg="记忆列表获取成功")
@router.get("/total_memory_count", response_model=ApiResponse)
async def get_workspace_total_memory_count(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的记忆总量通过聚合所有host的记忆数
逻辑:
1. 从 memory_list 获取所有 host_id
2. 对每个 host_id 调用 search_all 获取 total
3. 将所有 total 求和返回
返回格式:
{
"total_memory_count": int,
"host_count": int,
"details": [
{"host_id": "uuid", "count": 100},
...
]
}
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆总量")
total_memory_count = await memory_dashboard_service.get_workspace_total_memory_count(
db=db,
workspace_id=workspace_id,
current_user=current_user,
end_user_id=end_user_id
)
api_logger.info(f"成功获取记忆总量: {total_memory_count.get('total_memory_count', 0)}")
return success(data=total_memory_count, msg="记忆总量获取成功")
# ======== RAG 数据统计 ========
@router.get("/total_rag_count", response_model=ApiResponse)
def get_workspace_total_rag_count(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取 rag 的总文档数、总chunk数、总知识库数量、总api调用数量
"""
total_documents = memory_dashboard_service.get_rag_total_doc(db, current_user)
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
data = {
'total_documents':total_documents,
'total_chunk':total_chunk,
'total_kb':total_kb,
'total_api':1024
}
return success(data=data, msg="RAG相关数据获取成功")
@router.get("/current_user_rag_total_num", response_model=ApiResponse)
def get_current_user_rag_total_num(
end_user_id: str = Query(..., description="宿主ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取当前宿主的 RAG 的总chunk数量
"""
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
@router.get("/rag_content", response_model=ApiResponse)
def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取当前宿主知识库中的chunk内容
"""
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
return success(data=data, msg="宿主RAGchunk数据获取成功")
@router.get("/chunk_summary_tag", response_model=ApiResponse)
async def get_chunk_summary_tag(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
max_tags: int = Query(10, description="最大标签数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取chunk总结、提取的标签和人物形象
返回格式:
{
"summary": "chunk内容的总结",
"tags": [
{"tag": "标签1", "frequency": 5},
{"tag": "标签2", "frequency": 3},
...
],
"personas": [
"产品设计师",
"旅行爱好者",
"摄影发烧友",
...
]
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
data = await memory_dashboard_service.get_chunk_summary_and_tags(
end_user_id=end_user_id,
limit=limit,
max_tags=max_tags,
db=db,
current_user=current_user
)
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
@router.get("/chunk_insight", response_model=ApiResponse)
async def get_chunk_insight(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取chunk的洞察内容
返回格式:
{
"insight": "对chunk内容的深度洞察分析"
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
data = await memory_dashboard_service.get_chunk_insight(
end_user_id=end_user_id,
limit=limit,
db=db,
current_user=current_user
)
api_logger.info(f"成功获取chunk洞察")
return success(data=data, msg="chunk洞察获取成功")
@router.get("/dashboard_data", response_model=ApiResponse)
async def dashboard_data(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
整合dashboard数据接口
整合以下接口的数据:
1. /dashboard/total_memory_count - 记忆总量
2. /dashboard/api_increment - API调用增量
3. /memory/stats/types - 知识库类型统计只要total数据
4. /dashboard/total_rag_count - RAG相关数据
根据 storage_type 判断调用不同的接口
返回格式:
{
"storage_type": str,
"neo4j_data": {
"total_memory": int,
"total_app": int,
"total_knowledge": int,
"total_api_call": int
} | null,
"rag_data": {
"total_memory": int,
"total_app": int,
"total_knowledge": int,
"total_api_call": int
} | null
}
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = None
# 根据 storage_type 决定返回哪个数据对象
# 如果是 'rag'neo4j_data 为 null否则 rag_data 为 null
result = {
"storage_type": storage_type,
"neo4j_data": None,
"rag_data": None
}
try:
# 如果 storage_type 为 'neo4j' 或空,获取 neo4j_data
if storage_type == 'neo4j':
neo4j_data = {
"total_memory": None,
"total_app": None,
"total_knowledge": None,
"total_api_call": None
}
# 1. 获取记忆总量total_memory
try:
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
db=db,
workspace_id=workspace_id,
current_user=current_user,
end_user_id=end_user_id
)
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
# total_app: 统计当前空间下的所有app数量
from app.repositories import app_repository
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
neo4j_data["total_app"] = len(apps_orm)
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
except Exception as e:
api_logger.warning(f"获取记忆总量失败: {str(e)}")
# 2. 获取知识库类型统计total_knowledge
try:
from app.services.memory_agent_service import MemoryAgentService
memory_agent_service = MemoryAgentService()
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
only_active=True,
current_workspace_id=workspace_id,
db=db
)
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
except Exception as e:
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 3. 获取API调用增量total_api_call转换为整数
try:
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
neo4j_data["total_api_call"] = api_increment
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取API调用增量失败: {str(e)}")
result["neo4j_data"] = neo4j_data
api_logger.info(f"成功获取neo4j_data")
# 如果 storage_type 为 'rag',获取 rag_data
elif storage_type == 'rag':
rag_data = {
"total_memory": None,
"total_app": None,
"total_knowledge": None,
"total_api_call": None
}
# 获取RAG相关数据
try:
# total_memory: 使用 total_chunk总chunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量
from app.repositories import app_repository
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
rag_data["total_app"] = len(apps_orm)
# total_knowledge: 使用 total_kb总知识库数
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 固定值
rag_data["total_api_call"] = 1024
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
result["rag_data"] = rag_data
api_logger.info(f"成功获取rag_data")
api_logger.info(f"成功获取dashboard整合数据")
return success(data=result, msg="Dashboard数据获取成功")
except Exception as e:
api_logger.error(f"获取dashboard整合数据失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取dashboard整合数据失败: {str(e)}"
)

View File

@@ -0,0 +1,542 @@
from typing import Optional
import os
import uuid
from fastapi import APIRouter, Depends
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services.memory_storage_service import (
MemoryStorageService,
DataConfigService,
kb_type_distribution,
search_dialogue,
search_chunk,
search_statement,
search_entity,
search_all,
search_detials,
search_edges,
search_entity_graph,
analytics_hot_memory_tags,
analytics_memory_insight_report,
analytics_recent_activity_stats,
analytics_user_summary,
)
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import (
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
ConfigPilotRun,
)
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from app.dependencies import get_current_user
from app.models.user_model import User
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_storage_service = MemoryStorageService()
router = APIRouter(
prefix="/memory-storage",
tags=["Memory Storage"],
)
@router.get("/info", response_model=ApiResponse)
async def get_storage_info(
storage_id: str,
current_user: User = Depends(get_current_user)
):
"""
Example wrapper endpoint - retrieves storage information
Args:
storage_id: Storage identifier
Returns:
Storage information
"""
api_logger.info(f"Storage info requested ")
try:
result = await memory_storage_service.get_storage_info()
return success(data=result)
except Exception as e:
api_logger.error(f"Storage info retrieval failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
# --- DB connection dependency ---
_CONN: Optional[object] = None
"""PostgreSQL 连接生成与管理(使用 psycopg2"""
# 这个可以转移,可能是已经有的
# PostgreSQL 数据库连接
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
host = os.getenv("DB_HOST")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
database = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host or "localhost",
port=port,
user=user,
password=password,
dbname=database,
)
# 设置自动提交,避免显式事务管理
conn.autocommit = True
# 设置会话时区为中国标准时间Asia/Shanghai便于直接以本地时区展示
try:
cur = conn.cursor()
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
cur.close()
except Exception:
# 时区设置失败不影响连接,仅记录但不抛出
pass
return conn
except Exception as e:
try:
print(f"[PostgreSQL] 连接失败: {e}")
except Exception:
pass
return None
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
global _CONN
if _CONN is None:
_CONN = _make_pgsql_conn()
return _CONN
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
"""Close and recreate the global DB connection."""
global _CONN
try:
if _CONN:
try:
_CONN.close()
except Exception:
pass
_CONN = _make_pgsql_conn()
return _CONN is not None
except Exception:
_CONN = None
return False
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config(
payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
try:
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
payload.workspace_id = workspace_id
svc = DataConfigService(get_db_conn())
result = svc.create(payload)
return success(data=result, msg="创建成功")
except Exception as e:
api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.delete(ConfigParamsDelete(config_id=config_id))
return success(data=result, msg="删除成功")
except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config(
payload: ConfigUpdate,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.update(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
def update_config_extracted(
payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.update_extracted(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
# --- Forget config params ---
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
def update_config_forget(
payload: ConfigUpdateForget,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.update_forget(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
config_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.get_extracted(ConfigKey(config_id=config_id))
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_forget(
config_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.get_forget(ConfigKey(config_id=config_id))
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config(
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
try:
svc = DataConfigService(get_db_conn())
# 传递 workspace_id 进行过滤(保持为 UUID 类型)
result = svc.get_all(workspace_id=workspace_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read all config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
async def pilot_run(
payload: ConfigPilotRun,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
# 先尝试从数据库加载配置
try:
config_loaded = reload_configuration_from_database(str(payload.config_id))
if not config_loaded:
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
except Exception as e:
api_logger.error(f"Exception while loading configuration: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
try:
svc = DataConfigService(get_db_conn())
result = await svc.pilot_run(payload)
return success(data=result, msg="试运行完成")
except ValueError as e:
# 捕获参数验证错误
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
except Exception as e:
api_logger.error(f"Pilot run failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
"""
以下为搜索与分析接口,直接挂载到同一 router统一响应为 ApiResponse。
"""
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
try:
result = await kb_type_distribution(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"KB type distribution failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
@router.get("/search/dialogue", response_model=ApiResponse)
async def search_dialogues_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
try:
result = await search_dialogue(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search dialogue failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "对话查询失败", str(e))
@router.get("/search/chunk", response_model=ApiResponse)
async def search_chunks_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
try:
result = await search_chunk(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search chunk failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "分块查询失败", str(e))
@router.get("/search/statement", response_model=ApiResponse)
async def search_statements_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
try:
result = await search_statement(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search statement failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "语句查询失败", str(e))
@router.get("/search/entity", response_model=ApiResponse)
async def search_entities_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
try:
result = await search_entity(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体查询失败", str(e))
@router.get("/search", response_model=ApiResponse)
async def search_all_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try:
result = await search_all(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search all failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "全部查询失败", str(e))
@router.get("/search/detials", response_model=ApiResponse)
async def search_entities_detials(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
try:
result = await search_detials(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search details failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "详情查询失败", str(e))
@router.get("/search/edges", response_model=ApiResponse)
async def search_entity_edges(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
try:
result = await search_edges(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search edges failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/search/entity_graph", response_model=ApiResponse)
async def search_for_entity_graph(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
"""
搜索所有实体之间的关系网络
"""
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
try:
result = await search_entity_graph(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity graph failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api(
end_user_id: Optional[str] = None,
limit: int = 10,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Hot memory tags requested for end_user_id: {end_user_id}")
try:
result = await analytics_hot_memory_tags(end_user_id, limit)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}")
try:
result = await analytics_memory_insight_report(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Memory insight report failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e))
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info("Recent activity stats requested")
try:
result = await analytics_recent_activity_stats()
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"User summary requested for end_user_id: {end_user_id}")
try:
result = await analytics_user_summary(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"User summary failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
from app.core.memory.utils.self_reflexion_utils import self_reflexion
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
Args:
None
Returns:
自我反思结果。
"""
return await self_reflexion(host_id)

View File

@@ -0,0 +1,332 @@
from fastapi import APIRouter, Depends, status, Query
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType
from app.models.user_model import User
from app.schemas import model_schema
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/models",
tags=["Models"],
)
@router.get("/type", response_model=ApiResponse)
def get_model_types():
return success(msg="获取模型类型成功", data=list(ModelType))
@router.get("/provider", response_model=ApiResponse)
def get_model_providers():
return success(msg="获取模型提供商成功", data=list(ModelProvider))
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db)
):
"""
获取模型配置列表
支持多个 type 参数:
- 单个:?type=LLM
- 多个:?type=LLM&type=EMBEDDING
"""
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}")
try:
query = model_schema.ModelConfigQuery(
type=type,
provider=provider,
is_active=is_active,
is_public=is_public,
search=search,
page=page,
pagesize=pagesize
)
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
result_orm = ModelConfigService.get_model_list(db=db, query=query)
result = PageData.model_validate(result_orm)
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
return success(data=result, msg="模型配置列表获取成功")
except Exception as e:
api_logger.error(f"获取模型配置列表失败: {str(e)}")
raise
@router.get("/{model_id}", response_model=ApiResponse)
def get_model_by_id(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
"""
根据ID获取模型配置
"""
api_logger.info(f"获取模型配置请求: model_id={model_id}")
try:
api_logger.debug(f"开始获取模型配置: model_id={model_id}")
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
api_logger.info(f"模型配置获取成功: {result_orm.name}")
# 将ORM对象转换为Pydantic模型
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result_pydantic, msg="模型配置获取成功")
except Exception as e:
api_logger.error(f"获取模型配置失败: model_id={model_id} - {str(e)}")
raise
@router.post("", response_model=ApiResponse)
async def create_model(
model_data: model_schema.ModelConfigCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
创建模型配置
- 创建模型配置基础信息
- 如果包含 API Key会先验证配置有效性然后创建
- 验证失败时会抛出异常,不会创建配置
- 可通过 skip_validation=true 跳过验证
"""
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始创建模型配置: {model_data.name}")
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data)
api_logger.info(f"模型配置创建成功: {result_orm.name} (ID: {result_orm.id})")
# 将ORM对象转换为Pydantic模型
result = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result, msg="模型配置创建成功")
except Exception as e:
api_logger.error(f"创建模型配置失败: {model_data.name} - {str(e)}")
raise
@router.put("/{model_id}", response_model=ApiResponse)
def update_model(
model_id: uuid.UUID,
model_data: model_schema.ModelConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
更新模型配置
"""
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data)
api_logger.info(f"模型配置更新成功: {result_orm.name} (ID: {model_id})")
# 将ORM对象转换为Pydantic模型
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result_pydantic, msg="模型配置更新成功")
except Exception as e:
api_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}")
raise
@router.delete("/{model_id}", response_model=ApiResponse)
def delete_model(
model_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
删除模型配置
"""
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始删除模型配置: model_id={model_id}")
ModelConfigService.delete_model(db=db, model_id=model_id)
api_logger.info(f"模型配置删除成功: model_id={model_id}")
return success(msg="模型配置删除成功")
except Exception as e:
api_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}")
raise
# API Key 相关接口
@router.get("/{model_id}/apikeys", response_model=ApiResponse)
def get_model_api_keys(
model_id: uuid.UUID,
is_active: bool = Query(True, description="是否只获取活跃的API Key"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取模型的API Key列表
"""
api_logger.info(f"获取模型API Key列表请求: model_id={model_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始获取模型API Key列表: model_id={model_id}")
result_orm = ModelApiKeyService.get_api_keys_by_model(
db=db, model_config_id=model_id, is_active=is_active
)
# 将ORM对象列表转换为Pydantic模型列表
result_pydantic = [model_schema.ModelApiKey.model_validate(item) for item in result_orm]
api_logger.info(f"模型API Key列表获取成功: 数量={len(result_pydantic)}")
return success(data=result_pydantic, msg="模型API Key列表获取成功")
except Exception as e:
api_logger.error(f"获取模型API Key列表失败: model_id={model_id} - {str(e)}")
raise
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
async def create_model_api_key(
model_id: uuid.UUID,
api_key_data: model_schema.ModelApiKeyCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
为模型创建API Key
"""
api_logger.info(f"创建模型API Key请求: model_id={model_id}, model_name={api_key_data.model_name}, 用户: {current_user.username}")
try:
# 设置模型配置ID
api_key_data.model_config_id = model_id
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
return success(data=result, msg="模型API Key创建成功")
except Exception as e:
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
raise
@router.get("/apikeys/{api_key_id}", response_model=ApiResponse)
def get_api_key_by_id(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
根据ID获取API Key
"""
api_logger.info(f"获取API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始获取API Key: api_key_id={api_key_id}")
result = ModelApiKeyService.get_api_key_by_id(db=db, api_key_id=api_key_id)
api_logger.info(f"API Key获取成功: {result.model_name}")
return success(data=result, msg="API Key获取成功")
except Exception as e:
api_logger.error(f"获取API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@router.put("/apikeys/{api_key_id}", response_model=ApiResponse)
async def update_api_key(
api_key_id: uuid.UUID,
api_key_data: model_schema.ModelApiKeyUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
更新API Key
"""
api_logger.info(f"更新API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始更新API Key: api_key_id={api_key_id}")
result = await ModelApiKeyService.update_api_key(db=db, api_key_id=api_key_id, api_key_data=api_key_data)
api_logger.info(f"API Key更新成功: {result.model_name} (ID: {api_key_id})")
result_pydantic = model_schema.ModelApiKey.model_validate(result)
return success(data=result_pydantic, msg="API Key更新成功")
except Exception as e:
api_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@router.delete("/apikeys/{api_key_id}", response_model=ApiResponse)
def delete_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
删除API Key
"""
api_logger.info(f"删除API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始删除API Key: api_key_id={api_key_id}")
ModelApiKeyService.delete_api_key(db=db, api_key_id=api_key_id)
api_logger.info(f"API Key删除成功: api_key_id={api_key_id}")
return success(msg="API Key删除成功")
except Exception as e:
api_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@router.post("/validate", response_model=ApiResponse)
async def validate_model_config(
validate_data: model_schema.ModelValidateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
验证模型配置是否有效
支持验证不同类型的模型:
- llm: 大语言模型
- chat: 对话模型
- embedding: 向量模型
- rerank: 重排序模型
"""
api_logger.info(f"验证模型配置请求: {validate_data.model_name} ({validate_data.model_type}), 用户: {current_user.username}")
result = await ModelConfigService.validate_model_config(
db=db,
model_name=validate_data.model_name,
provider=validate_data.provider,
api_key=validate_data.api_key,
api_base=validate_data.api_base,
model_type=validate_data.model_type,
test_message=validate_data.test_message
)
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")

View File

@@ -0,0 +1,404 @@
"""多 Agent 控制器"""
import uuid
from fastapi import APIRouter, Depends, Query, Path
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.schemas import multi_agent_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.multi_agent_service import MultiAgentService
from app.models import User
router = APIRouter(prefix="/apps", tags=["Multi-Agent"])
logger = get_business_logger()
# ==================== 多 Agent 配置管理 ====================
@router.post(
"/{app_id}/multi-agent",
summary="创建多 Agent 配置"
)
def create_multi_agent_config(
app_id: uuid.UUID = Path(..., description="应用 ID"),
data: multi_agent_schema.MultiAgentConfigCreate = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""创建多 Agent 配置
支持四种编排模式:
- sequential: 顺序执行
- parallel: 并行执行
- conditional: 条件路由
- loop: 循环执行
"""
service = MultiAgentService(db)
config = service.create_config(
app_id=app_id,
data=data,
created_by=current_user.id
)
return success(
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
msg="多 Agent 配置创建成功"
)
@router.get(
"/{app_id}/multi-agent",
summary="获取当前应用的最新有效多 Agent 配置"
)
def get_multi_agent_configs(
app_id: uuid.UUID = Path(..., description="应用 ID"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取指定应用的最新有效多 Agent 配置,如果不存在则返回默认模板"""
service = MultiAgentService(db)
# 通过 app_id 获取最新有效配置(已转换 agent_id 为 app_id
config = service.get_multi_agent_configs(app_id)
if not config:
# 返回默认模板
default_template = {
"app_id": str(app_id),
"master_agent_id": None,
"master_agent_name": None,
"orchestration_mode": "conditional",
"sub_agents": [],
"routing_rules": [],
"execution_config": {
"max_iterations": 10,
"timeout": 300,
"enable_parallel": False,
"error_handling": "stop"
},
"aggregation_strategy": "merge",
}
return success(
data=default_template,
msg="该应用暂无配置,返回默认模板"
)
# config 已经是字典格式,直接返回
return success(data=config)
@router.put(
"/{app_id}/multi-agent",
summary="更新多 Agent 配置"
)
def update_multi_agent_config(
app_id: uuid.UUID = Path(..., description="应用 ID"),
data: multi_agent_schema.MultiAgentConfigUpdate = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""更新多 Agent 配置"""
service = MultiAgentService(db)
config = service.update_config(app_id, data)
return success(
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
msg="多 Agent 配置更新成功"
)
@router.delete(
"/{app_id}/multi-agent",
summary="删除多 Agent 配置"
)
def delete_multi_agent_config(
app_id: uuid.UUID = Path(..., description="应用 ID"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除多 Agent 配置"""
service = MultiAgentService(db)
service.delete_config(app_id)
return success(msg="多 Agent 配置删除成功")
# ==================== 多 Agent 运行 ====================
@router.post(
"/{app_id}/multi-agent/run",
summary="运行多 Agent 任务"
)
async def run_multi_agent(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: multi_agent_schema.MultiAgentRunRequest = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""运行多 Agent 任务
根据配置的编排模式执行多个 Agent
- sequential: 按优先级顺序执行
- parallel: 并行执行所有 Agent
- conditional: 根据条件选择 Agent
- loop: 循环执行直到满足条件
"""
service = MultiAgentService(db)
result = await service.run(app_id, request)
return success(
data=multi_agent_schema.MultiAgentRunResponse(**result),
msg="多 Agent 任务执行成功"
)
# ==================== 智能路由测试 ====================
@router.post(
"/{app_id}/multi-agent/test-routing",
summary="测试智能路由"
)
async def test_routing(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: multi_agent_schema.RoutingTestRequest = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""测试智能路由功能
支持三种路由模式:
- keyword: 仅使用关键词路由
- llm: 使用 LLM 路由(需要提供 routing_model_id
- hybrid: 混合路由(关键词 + LLM
参数:
- message: 测试消息
- conversation_id: 会话 ID可选
- routing_model_id: 路由模型 ID可选用于 LLM 路由)
- use_llm: 是否启用 LLM默认 False
- keyword_threshold: 关键词置信度阈值(默认 0.8
"""
from app.services.conversation_state_manager import ConversationStateManager
from app.services.llm_router import LLMRouter
from app.models import ModelConfig
# 1. 获取多 Agent 配置
service = MultiAgentService(db)
config = service.get_config(app_id)
if not config:
return success(
data=None,
msg="应用未配置多 Agent无法测试路由"
)
# 2. 准备子 Agent 信息
sub_agents = {}
for sub_agent_info in config.sub_agents:
agent_id = sub_agent_info["agent_id"]
sub_agents[agent_id] = {
"name": sub_agent_info.get("name", agent_id),
"role": sub_agent_info.get("role", "")
}
# 3. 获取路由模型(如果指定)
routing_model = None
if request.routing_model_id:
routing_model = db.get(ModelConfig, request.routing_model_id)
if not routing_model:
return success(
data=None,
msg=f"路由模型不存在: {request.routing_model_id}"
)
# 4. 初始化路由器
state_manager = ConversationStateManager()
router = LLMRouter(
db=db,
state_manager=state_manager,
routing_rules=config.routing_rules or [],
sub_agents=sub_agents,
routing_model_config=routing_model,
use_llm=request.use_llm and routing_model is not None
)
# 5. 设置阈值
if request.keyword_threshold:
router.keyword_high_confidence_threshold = request.keyword_threshold
# 6. 执行路由
try:
routing_result = await router.route(
message=request.message,
conversation_id=str(request.conversation_id) if request.conversation_id else None,
force_new=request.force_new
)
# 7. 获取 Agent 信息
agent_id = routing_result["agent_id"]
agent_info = sub_agents.get(agent_id, {})
# 8. 构建响应
response_data = {
"message": request.message,
"routing_result": {
"agent_id": agent_id,
"agent_name": agent_info.get("name", agent_id),
"agent_role": agent_info.get("role", ""),
"confidence": routing_result["confidence"],
"strategy": routing_result["strategy"],
"topic": routing_result["topic"],
"topic_changed": routing_result["topic_changed"],
"reason": routing_result["reason"],
"routing_method": routing_result["routing_method"]
},
"cmulti-agent/batch-test-routingonfig_info": {
"use_llm": request.use_llm and routing_model is not None,
"routing_model": routing_model.name if routing_model else None,
"keyword_threshold": router.keyword_high_confidence_threshold,
"total_sub_agents": len(sub_agents)
}
}
return success(
data=response_data,
msg="路由测试成功"
)
except Exception as e:
logger.error(f"路由测试失败: {str(e)}")
return success(
data=None,
msg=f"路由测试失败: {str(e)}"
)
@router.post(
"/{app_id}/",
summary="批量测试智能路由"
)
async def batch_test_routing(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: multi_agent_schema.BatchRoutingTestRequest = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""批量测试智能路由功能
用于测试多条消息的路由效果,并统计准确率
参数:
- test_cases: 测试用例列表
- routing_model_id: 路由模型 ID可选
- use_llm: 是否启用 LLM
- keyword_threshold: 关键词置信度阈值
"""
from app.services.conversation_state_manager import ConversationStateManager
from app.services.llm_router import LLMRouter
from app.models import ModelConfig
# 1. 获取多 Agent 配置
service = MultiAgentService(db)
config = service.get_config(app_id)
if not config:
return success(
data=None,
msg="应用未配置多 Agent无法测试路由"
)
# 2. 准备子 Agent 信息
sub_agents = {}
for sub_agent_info in config.sub_agents:
agent_id = sub_agent_info["agent_id"]
sub_agents[agent_id] = {
"name": sub_agent_info.get("name", agent_id),
"role": sub_agent_info.get("role", "")
}
# 3. 获取路由模型
routing_model = None
if request.routing_model_id:
routing_model = db.get(ModelConfig, request.routing_model_id)
# 4. 初始化路由器
state_manager = ConversationStateManager()
router = LLMRouter(
db=db,
state_manager=state_manager,
routing_rules=config.routing_rules or [],
sub_agents=sub_agents,
routing_model_config=routing_model,
use_llm=request.use_llm and routing_model is not None
)
if request.keyword_threshold:
router.keyword_high_confidence_threshold = request.keyword_threshold
# 5. 批量测试
results = []
correct_count = 0
total_count = len(request.test_cases)
for test_case in request.test_cases:
try:
routing_result = await router.route(
message=test_case.message,
conversation_id=str(uuid.uuid4()) # 每个测试用例使用独立会话
)
agent_id = routing_result["agent_id"]
agent_info = sub_agents.get(agent_id, {})
# 判断是否正确
is_correct = None
if test_case.expected_agent_id:
is_correct = (agent_id == str(test_case.expected_agent_id))
if is_correct:
correct_count += 1
results.append({
"message": test_case.message,
"description": test_case.description,
"routed_agent_id": agent_id,
"routed_agent_name": agent_info.get("name"),
"expected_agent_id": str(test_case.expected_agent_id) if test_case.expected_agent_id else None,
"is_correct": is_correct,
"confidence": routing_result["confidence"],
"routing_method": routing_result["routing_method"],
"strategy": routing_result["strategy"]
})
except Exception as e:
logger.error(f"测试用例失败: {test_case.message}, 错误: {str(e)}")
results.append({
"message": test_case.message,
"description": test_case.description,
"error": str(e)
})
# 6. 统计
accuracy = None
if correct_count > 0:
total_with_expected = sum(1 for r in results if r.get("expected_agent_id"))
if total_with_expected > 0:
accuracy = correct_count / total_with_expected * 100
response_data = {
"total_count": total_count,
"correct_count": correct_count,
"accuracy": accuracy,
"results": results,
"config_info": {
"use_llm": request.use_llm and routing_model is not None,
"routing_model": routing_model.name if routing_model else None,
"keyword_threshold": router.keyword_high_confidence_threshold
}
}
return success(
data=response_data,
msg=f"批量测试完成,准确率: {accuracy:.1f}%" if accuracy else "批量测试完成"
)

View File

@@ -0,0 +1,437 @@
from fastapi import APIRouter, Depends, Query, Request, Header
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
import uuid
import hashlib
import time
import jwt
from typing import Optional, Dict
from functools import wraps
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.config import settings
from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.conversation_service import ConversationService
from app.services.auth_service import create_access_token
from app.dependencies import get_share_user_id, ShareTokenData
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
def get_base_url(request: Request) -> str:
"""从请求中获取基础 URL"""
return f"{request.url.scheme}://{request.url.netloc}"
def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
"""获取或生成用户 ID
优先级:
1. 使用前端传递的 user_id
2. 基于 IP + User-Agent 生成唯一 ID
Args:
payload_user_id: 前端传递的 user_id
request: FastAPI Request 对象
Returns:
用户 ID
"""
if payload_user_id:
return payload_user_id
# 获取客户端 IP
client_ip = request.client.host if request.client else "unknown"
# 获取 User-Agent
user_agent = request.headers.get("user-agent", "unknown")
# 生成唯一 ID基于 IP + User-Agent 的哈希
unique_string = f"{client_ip}_{user_agent}"
hash_value = hashlib.md5(unique_string.encode()).hexdigest()[:16]
return f"guest_{hash_value}"
@router.post(
"/{share_token}/token",
summary="获取访问 token"
)
def get_access_token(
share_token: str,
payload: release_share_schema.TokenRequest,
request: Request,
db: Session = Depends(get_db),
):
"""获取访问 token
- 用户通过 user_id + share_token 换取访问 token
- 后续请求需要携带此 token
"""
# 获取或生成 user_id
user_id = get_or_generate_user_id(payload.user_id, request)
# 验证分享链接(可选:验证密码)
service = ReleaseShareService(db)
try:
service.get_shared_release_info(
share_token=share_token,
password=payload.password
)
except Exception as e:
logger.error(f"获取分享信息失败: {str(e)}")
raise
# 生成 token
access_token = create_access_token(user_id, share_token)
logger.info(
f"生成访问 token",
extra={
"share_token": share_token,
"user_id": user_id
}
)
return success(data={
"access_token": access_token,
"token_type": "Bearer",
"user_id": user_id
})
@router.get(
"",
summary="获取公开分享的应用信息",
response_model=None
)
def get_shared_release(
password: str = Query(None, description="访问密码(如果需要)"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取公开分享的发布版本信息
- 无需认证即可访问
- 如果设置了密码保护,需要提供正确的密码
- 如果密码错误或未提供密码,返回基本信息(不含配置详情)
"""
service = ReleaseShareService(db)
info = service.get_shared_release_info(
share_token=share_data.share_token,
password=password
)
return success(data=info)
@router.post(
"/verify",
summary="验证访问密码"
)
def verify_password(
payload: release_share_schema.PasswordVerifyRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""验证分享的访问密码
- 用于前端先验证密码,再获取完整信息
"""
service = ReleaseShareService(db)
is_valid = service.verify_password(
share_token=share_data.share_token,
password=payload.password
)
return success(data={"valid": is_valid})
@router.get(
"/embed",
summary="获取嵌入代码"
)
def get_embed_code(
width: str = Query("100%", description="iframe 宽度"),
height: str = Query("600px", description="iframe 高度"),
request: Request = None,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取嵌入代码
- 返回 iframe 嵌入代码
- 可以自定义宽度和高度
"""
base_url = get_base_url(request) if request else None
service = ReleaseShareService(db)
embed_code = service.get_embed_code(
share_token=share_data.share_token,
width=width,
height=height,
base_url=base_url
)
return success(data=embed_code)
# ---------- 会话管理接口 ----------
@router.get(
"/conversations",
summary="获取会话列表"
)
def list_conversations(
password: str = Query(None, description="访问密码"),
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取分享应用的会话列表
- 可以按 user_id 筛选
- 支持分页
"""
logger.debug(f"share_data:{share_data.user_id}")
other_id = share_data.user_id
service = SharedChatService(db)
share, release = service._get_release_by_share_token(share_data.share_token, password)
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
other_id=other_id
)
logger.debug(new_end_user.id)
service = SharedChatService(db)
conversations, total = service.list_conversations(
share_token=share_data.share_token,
user_id=str(new_end_user.id),
password=password,
page=page,
pagesize=pagesize
)
items = [conversation_schema.Conversation.model_validate(c) for c in conversations]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items))
@router.get(
"/conversations/{conversation_id}",
summary="获取会话详情(含消息)"
)
def get_conversation(
conversation_id: uuid.UUID,
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取会话详情和消息历史"""
chat_service = SharedChatService(db)
conversation = chat_service.get_conversation_messages(
share_token=share_data.share_token,
conversation_id=conversation_id,
password=password
)
# 获取消息
conv_service = ConversationService(db)
messages = conv_service.get_messages(conversation_id)
# 构建响应
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
conv_dict["messages"] = [
conversation_schema.Message.model_validate(m) for m in messages
]
return success(data=conv_dict)
# ---------- 聊天接口 ----------
@router.post(
"/chat",
summary="发送消息(支持流式和非流式)"
)
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""发送消息并获取回复
使用 Bearer token 认证:
- Header: Authorization: Bearer {token}
- user_id 和 share_token 从 token 中解码
- 支持多轮对话(提供 conversation_id
- 支持流式返回(设置 stream=true
- 如果不提供 conversation_id会自动创建新会话
"""
service = SharedChatService(db)
# 从依赖中获取 user_id 和 share_token
user_id = share_data.user_id
share_token = share_data.share_token
password = None # Token 认证不需要密码
# end_user_id = user_id
other_id = user_id
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
try:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.app_service import AppService
# 验证分享链接和密码
share, release = service._get_release_by_share_token(share_token, password)
# # Create end_user_id by concatenating app_id with user_id
# end_user_id = f"{share.app_id}_{user_id}"
# Store end_user_id in database with original user_id
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
other_id=other_id,
original_user_id=user_id # Save original user_id to other_id
)
# 获取应用类型
app_type = release.app.type if release.app else None
# 根据应用类型验证配置
if app_type == "agent":
# Agent 类型:验证模型配置
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app_type == "multi_agent":
# Multi-Agent 类型:验证多 Agent 配置
config = release.config or {}
if not config.get("sub_agents"):
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
else:
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
# 获取或创建会话(提前验证)
conversation = service.create_or_get_conversation(
share_token=share_data.share_token,
conversation_id=payload.conversation_id,
user_id=str(new_end_user.id), # 转换为字符串
password=password
)
logger.debug(
f"参数验证完成",
extra={
"share_token": share_token,
"app_type": app_type,
"conversation_id": str(conversation.id),
"stream": payload.stream
}
)
except Exception as e:
# 验证失败,直接抛出异常(会被 FastAPI 的异常处理器捕获)
logger.error(f"参数验证失败: {str(e)}")
raise
if app_type == AppType.AGENT:
# 流式返回
if payload.stream:
async def event_generator():
async for event in service.chat_stream(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
result = await service.chat(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
)
return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
if payload.stream:
async def event_generator():
async for event in service.multi_agent_chat_stream(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 多 Agent 非流式返回
result = await service.multi_agent_chat(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
)
return success(data=conversation_schema.ChatResponse(**result))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -0,0 +1,170 @@
import uuid
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.schemas import release_share_schema
from app.services.release_share_service import ReleaseShareService
from app.dependencies import get_current_user, cur_workspace_access_guard
router = APIRouter(tags=["Release Share"])
logger = get_business_logger()
def get_base_url(request: Request) -> str:
"""从请求中获取基础 URL"""
return f"{request.url.scheme}://{request.url.netloc}"
@router.post(
"/apps/{app_id}/releases/{release_id}/share",
summary="创建/启用分享配置"
)
@cur_workspace_access_guard()
def create_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
payload: release_share_schema.ReleaseShareCreate,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""创建或更新发布版本的分享配置
- 如果已存在分享配置,则更新
- 自动生成唯一的分享 token
- 返回完整的分享 URL
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.create_or_update_share(
release_id=release_id,
user_id=current_user.id,
workspace_id=workspace_id,
data=payload,
base_url=base_url
)
share_schema = service._convert_to_schema(share, base_url)
return success(data=share_schema, msg="分享配置已创建")
@router.put(
"/apps/{app_id}/releases/{release_id}/share",
summary="更新分享配置"
)
@cur_workspace_access_guard()
def update_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
payload: release_share_schema.ReleaseShareUpdate,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""更新分享配置
- 可以更新启用状态、密码、嵌入设置等
- 不会改变 share_token
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.update_share(
release_id=release_id,
workspace_id=workspace_id,
data=payload
)
share_schema = service._convert_to_schema(share, base_url)
return success(data=share_schema, msg="分享配置已更新")
@router.get(
"/apps/{app_id}/releases/{release_id}/share",
summary="获取分享配置"
)
@cur_workspace_access_guard()
def get_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取发布版本的分享配置
- 如果不存在分享配置,返回 null
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.get_share(
release_id=release_id,
workspace_id=workspace_id,
base_url=base_url
)
return success(data=share)
@router.delete(
"/apps/{app_id}/releases/{release_id}/share",
summary="删除分享配置"
)
@cur_workspace_access_guard()
def delete_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""删除分享配置
- 删除后,公开访问链接将失效
"""
workspace_id = current_user.current_workspace_id
service = ReleaseShareService(db)
service.delete_share(
release_id=release_id,
workspace_id=workspace_id
)
return success(msg="分享配置已删除")
@router.post(
"/apps/{app_id}/releases/{release_id}/share/regenerate-token",
summary="重新生成分享链接"
)
@cur_workspace_access_guard()
def regenerate_token(
app_id: uuid.UUID,
release_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""重新生成分享 token
- 旧的分享链接将失效
- 生成新的唯一 token
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.regenerate_token(
release_id=release_id,
workspace_id=workspace_id
)
share_schema = service._convert_to_schema(share, base_url)
return success(data=share_schema, msg="分享链接已重新生成")

View File

@@ -0,0 +1,17 @@
"""Service API Controllers - 基于 API Key 认证的服务接口
路由前缀: /v1
认证方式: API Key
"""
from fastapi import APIRouter
from . import app_api_controller, rag_api_controller, memory_api_controller
# 创建 V1 API 路由器
service_router = APIRouter()
# 注册子路由
service_router.include_router(app_api_controller.router)
service_router.include_router(rag_api_controller.router)
service_router.include_router(memory_api_controller.router)
__all__ = ["service_router"]

View File

@@ -0,0 +1,16 @@
"""App 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/v1/apps", tags=["V1 - App API"])
logger = get_business_logger()
@router.get("")
async def list_apps():
"""列出可访问的应用(占位)"""
return success(data=[], msg="App API - Coming Soon")

View File

@@ -0,0 +1,16 @@
"""Memory 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
logger = get_business_logger()
@router.get("")
async def get_memory_info():
"""获取记忆服务信息(占位)"""
return success(data={}, msg="Memory API - Coming Soon")

View File

@@ -0,0 +1,16 @@
"""RAG 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"])
logger = get_business_logger()
@router.get("")
async def list_knowledge():
"""列出可访问的知识库(占位)"""
return success(data=[], msg="RAG API - Coming Soon")

View File

@@ -0,0 +1,23 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.response_utils import success
from app.db import get_db
from app.schemas.response_schema import ApiResponse
from app.services import user_service
router = APIRouter(
prefix="/setup",
tags=["Setup"],
)
@router.post("", summary="Create the first superuser", response_model=ApiResponse)
def setup_initial_user(db: Session = Depends(get_db)):
"""
Create the initial superuser. This can only be run once.
Reads credentials from environment variables.
"""
user = user_service.create_initial_superuser(db)
if not user:
return success(msg="Superuser already exists.")
return success(msg="Superuser created successfully.")

View File

@@ -0,0 +1,25 @@
from fastapi import APIRouter, status
from app.schemas.item_schema import Item
from app.services import task_service
router = APIRouter(
prefix="/tasks",
tags=["Tasks"],
)
@router.post("/process_item", status_code=status.HTTP_202_ACCEPTED)
def process_item_task(item: Item):
"""
This endpoint receives an item, and instead of processing it directly,
it sends a task to the Celery queue via the task service.
"""
task_id = task_service.create_processing_task(item.dict())
return {"message": "Task accepted. The item is being processed in the background.", "task_id": task_id}
@router.get("/result/{task_id}")
def get_task_result_controller(task_id: str):
"""
This endpoint allows clients to check the status and result of a
previously submitted task using its ID, by calling the task service.
"""
return task_service.get_task_result(task_id)

View File

@@ -0,0 +1,126 @@
from fastapi import APIRouter, Depends, status, Query, HTTPException
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
from app.models.user_model import User
from app.schemas import model_schema
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/test",
tags=["test"],
)
@router.get(f"/llm/{{model_id}}", response_model=ApiResponse)
def test_llm(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
api_logger.error(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
try:
apiConfig: ModelApiKey = config.api_keys[0]
llm = RedBearLLM(RedBearModelConfig(
model_name=apiConfig.model_name,
provider=apiConfig.provider,
api_key=apiConfig.api_key,
base_url=apiConfig.api_base
), type=config.type)
print(llm.dict())
template = """Question: {question}
Answer: Let's think step by step."""
# ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm
answer = chain.invoke({"question": "What is LangChain?"})
print("Answer:", answer)
return success(msg="测试LLM成功", data={"question": "What is LangChain?", "answer": answer})
except Exception as e:
api_logger.error(f"测试LLM失败: {str(e)}")
raise
@router.get(f"/embedding/{{model_id}}", response_model=ApiResponse)
def test_embedding(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
api_logger.error(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model = RedBearEmbeddings(RedBearModelConfig(
model_name=apiConfig.model_name,
provider=apiConfig.provider,
api_key=apiConfig.api_key,
base_url=apiConfig.api_base
))
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
embeddings = model.embed_documents(data)
print(embeddings)
query = "我想找一个适合学习的地方。"
query_embedding = model.embed_query(query)
print(query_embedding)
return success(msg="测试LLM成功")
@router.get(f"/rerank/{{model_id}}", response_model=ApiResponse)
def test_rerank(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
api_logger.error(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model = RedBearRerank(RedBearModelConfig(
model_name=apiConfig.model_name,
provider=apiConfig.provider,
api_key=apiConfig.api_key,
base_url=apiConfig.api_base
))
query = "最近哪家咖啡店评价最好?"
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
scores = model.rerank(query=query, documents=data, top_n=3)
print(scores)
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})

View File

@@ -0,0 +1,376 @@
"""
Upload Controller for Generic File Upload System
Handles HTTP requests for file upload, download, deletion, and metadata updates.
"""
import os
import json
from typing import List, Optional, Any
from pathlib import Path
from fastapi import APIRouter, Depends, File, UploadFile, Form
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.generic_file_schema import (
GenericFileResponse,
FileMetadataUpdate,
UploadResultSchema,
BatchUploadResponse
)
from app.core.response_utils import success, fail
from app.core.upload_enums import UploadContext
from app.services.upload_service import UploadService
from app.core.logging_config import get_logger
from app.core.exceptions import (
ValidationException,
ResourceNotFoundException,
FileUploadException,
BusinessException
)
# Get logger
logger = get_logger(__name__)
# Create router
router = APIRouter(
prefix="/api",
tags=["upload"],
dependencies=[Depends(get_current_user)]
)
# Initialize upload service
upload_service = UploadService()
@router.post("/upload", response_model=ApiResponse)
async def upload_file(
file: UploadFile = File(..., description="要上传的文件"),
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
单文件上传接口
- **file**: 要上传的文件
- **context**: 上传上下文,决定文件存储位置和验证规则
- **metadata**: 可选的文件元数据JSON格式字符串
返回上传成功的文件信息
"""
logger.info(f"Upload request: filename={file.filename}, context={context}, user={current_user.id}")
try:
# Validate and parse context
try:
upload_context = UploadContext(context)
except ValueError:
logger.warning(f"Invalid upload context: {context}")
raise ValidationException(
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
field="context"
)
# Parse metadata if provided
file_metadata = {}
if metadata:
try:
file_metadata = json.loads(metadata)
except json.JSONDecodeError:
logger.warning(f"Invalid metadata JSON: {metadata}")
raise ValidationException(
"元数据必须是有效的JSON格式",
field="metadata"
)
# Upload file
db_file = upload_service.upload_file(
file=file,
context=upload_context,
metadata=file_metadata,
current_user=current_user,
db=db
)
# Convert to response schema
file_response = GenericFileResponse.model_validate(db_file)
logger.info(f"Upload successful: {file.filename} (ID: {db_file.id})")
return success(data=file_response.dict(), msg="文件上传成功")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Upload failed: {str(e)}")
# Wrap unknown exceptions as FileUploadException
raise FileUploadException(
f"文件上传失败: {str(e)}",
cause=e
)
@router.post("/upload/batch", response_model=ApiResponse)
async def upload_files_batch(
files: List[UploadFile] = File(..., description="要上传的文件列表"),
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
批量文件上传接口
- **files**: 要上传的文件列表最多20个
- **context**: 上传上下文,决定文件存储位置和验证规则
- **metadata**: 可选的文件元数据JSON格式字符串应用于所有文件
返回每个文件的上传结果
"""
logger.info(f"Batch upload request: {len(files)} files, context={context}, user={current_user.id}")
try:
# Validate and parse context
try:
upload_context = UploadContext(context)
except ValueError:
logger.warning(f"Invalid upload context: {context}")
raise ValidationException(
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
field="context"
)
# Parse metadata if provided
file_metadata = {}
if metadata:
try:
file_metadata = json.loads(metadata)
except json.JSONDecodeError:
logger.warning(f"Invalid metadata JSON: {metadata}")
raise ValidationException(
"元数据必须是有效的JSON格式",
field="metadata"
)
# Upload files in batch
upload_results = upload_service.upload_files_batch(
files=files,
context=upload_context,
metadata=file_metadata,
current_user=current_user,
db=db
)
# Convert results to response schemas
result_schemas = []
for result in upload_results:
result_schema = UploadResultSchema(
success=result.success,
file_id=result.file_id,
file_name=result.file_name,
error=result.error,
file_info=None
)
# If upload was successful, get file info
if result.success and result.file_id:
try:
db_file = upload_service.get_file(result.file_id, current_user, db)
result_schema.file_info = GenericFileResponse.model_validate(db_file)
except Exception as e:
logger.warning(f"Failed to get file info for {result.file_id}: {str(e)}")
result_schemas.append(result_schema)
# Create batch response
batch_response = BatchUploadResponse(
total=len(files),
success_count=sum(1 for r in upload_results if r.success),
failed_count=sum(1 for r in upload_results if not r.success),
results=result_schemas
)
logger.info(f"Batch upload completed: {batch_response.success_count}/{batch_response.total} successful")
return success(data=batch_response.dict(), msg="批量上传完成")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Batch upload failed: {str(e)}")
# Wrap unknown exceptions as FileUploadException
raise FileUploadException(
f"批量上传失败: {str(e)}",
cause=e
)
@router.get("/files/{file_id}", response_model=Any)
async def download_file(
file_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Any:
"""
文件下载接口
- **file_id**: 文件ID
返回文件内容供下载
"""
logger.info(f"Download request: file_id={file_id}, user={current_user.id}")
try:
# Parse file_id
import uuid
try:
file_uuid = uuid.UUID(file_id)
except ValueError:
logger.warning(f"Invalid file ID format: {file_id}")
raise ValidationException(
"无效的文件ID格式",
field="file_id"
)
# Get file from database
db_file = upload_service.get_file(file_uuid, current_user, db)
# Check if physical file exists
storage_path = Path(db_file.storage_path)
if not storage_path.exists():
logger.error(f"Physical file not found: {storage_path}")
raise ResourceNotFoundException(
"文件",
str(file_uuid),
context={"detail": "文件未找到(可能已被删除)"}
)
# Return file response
logger.info(f"Download successful: {db_file.file_name} (ID: {file_id})")
return FileResponse(
path=str(storage_path),
filename=db_file.file_name,
media_type=db_file.mime_type or "application/octet-stream"
)
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Download failed: {str(e)}")
# Wrap unknown exceptions
raise FileUploadException(
f"文件下载失败: {str(e)}",
cause=e
)
@router.delete("/files/{file_id}", response_model=ApiResponse)
async def delete_file(
file_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
文件删除接口
- **file_id**: 文件ID
删除文件(包括物理文件和数据库记录)
"""
logger.info(f"Delete request: file_id={file_id}, user={current_user.id}")
try:
# Parse file_id
import uuid
try:
file_uuid = uuid.UUID(file_id)
except ValueError:
logger.warning(f"Invalid file ID format: {file_id}")
raise ValidationException(
"无效的文件ID格式",
field="file_id"
)
# Delete file
upload_service.delete_file(file_uuid, current_user, db)
logger.info(f"Delete successful: file_id={file_id}")
return success(msg="文件删除成功")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Delete failed: {str(e)}")
# Wrap unknown exceptions
raise FileUploadException(
f"文件删除失败: {str(e)}",
cause=e
)
@router.put("/files/{file_id}", response_model=ApiResponse)
async def update_file_metadata(
file_id: str,
update_data: FileMetadataUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
文件元数据更新接口
- **file_id**: 文件ID
- **update_data**: 要更新的元数据
更新文件的元数据(文件名、自定义元数据、公开状态)
"""
logger.info(f"Update metadata request: file_id={file_id}, user={current_user.id}")
try:
# Parse file_id
import uuid
try:
file_uuid = uuid.UUID(file_id)
except ValueError:
logger.warning(f"Invalid file ID format: {file_id}")
raise ValidationException(
"无效的文件ID格式",
field="file_id"
)
# Convert update data to dict, excluding unset fields
update_dict = update_data.dict(exclude_unset=True)
if not update_dict:
logger.warning(f"No fields to update for file: {file_id}")
raise ValidationException(
"没有提供要更新的字段",
field="update_data"
)
# Update file metadata
updated_file = upload_service.update_file_metadata(
file_uuid, update_dict, current_user, db
)
# Convert to response schema
file_response = GenericFileResponse.model_validate(updated_file)
logger.info(f"Update metadata successful: file_id={file_id}")
return success(data=file_response.dict(), msg="文件元数据更新成功")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Update metadata failed: {str(e)}")
# Wrap unknown exceptions
raise FileUploadException(
f"文件元数据更新失败: {str(e)}",
cause=e
)

View File

@@ -0,0 +1,183 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
import uuid
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.models.user_model import User
from app.schemas import user_schema
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
from app.schemas.response_schema import ApiResponse
from app.services import user_service
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/users",
tags=["Users"],
)
@router.post("/superuser", response_model=ApiResponse)
def create_superuser(
user: user_schema.UserCreate,
db: Session = Depends(get_db),
current_superuser: User = Depends(get_current_superuser)
):
"""创建超级管理员(仅超级管理员可访问)"""
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
result = user_service.create_superuser(db=db, user=user, current_user=current_superuser)
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="超级管理员创建成功")
@router.delete("/{user_id}", response_model=ApiResponse)
def delete_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""停用用户(软删除)"""
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
result = user_service.deactivate_user(
db=db, user_id_to_deactivate=user_id, current_user=current_user
)
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
return success(msg="用户停用成功")
@router.post("/{user_id}/activate", response_model=ApiResponse)
def activate_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""激活用户"""
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
result = user_service.activate_user(
db=db, user_id_to_activate=user_id, current_user=current_user
)
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户激活成功")
@router.get("", response_model=ApiResponse)
def get_current_user_info(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前用户信息"""
api_logger.info(f"当前用户信息请求: {current_user.username}")
result = user_service.get_user(
db=db, user_id=current_user.id, current_user=current_user
)
result_schema = user_schema.User.model_validate(result)
# 设置当前工作空间的角色和名称
if current_user.current_workspace_id:
from app.repositories.workspace_repository import WorkspaceRepository
workspace_repo = WorkspaceRepository(db)
current_workspace = workspace_repo.get_workspace_by_id(current_user.current_workspace_id)
if current_workspace:
result_schema.current_workspace_name = current_workspace.name
for ws in result.workspaces:
if ws.workspace_id == current_user.current_workspace_id:
result_schema.role = ws.role
break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
return success(data=result_schema, msg="用户信息获取成功")
@router.get("/superusers", response_model=ApiResponse)
def get_tenant_superusers(
include_inactive: bool = False,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
superusers = user_service.get_tenant_superusers(
db=db,
current_user=current_user,
include_inactive=include_inactive
)
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
return success(data=superusers_schema, msg="租户超管列表获取成功")
@router.get("/{user_id}", response_model=ApiResponse)
def get_user_info_by_id(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""根据用户ID获取用户信息"""
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
result = user_service.get_user(
db=db, user_id=user_id, current_user=current_user
)
api_logger.info(f"用户信息获取成功: {result.username}")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户信息获取成功")
@router.put("/change-password", response_model=ApiResponse)
async def change_password(
request: ChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""修改当前用户密码"""
api_logger.info(f"用户密码修改请求: {current_user.username}")
await user_service.change_password(
db=db,
user_id=current_user.id,
old_password=request.old_password,
new_password=request.new_password,
current_user=current_user
)
api_logger.info(f"用户密码修改成功: {current_user.username}")
return success(msg="密码修改成功")
@router.put("/admin/change-password", response_model=ApiResponse)
async def admin_change_password(
request: AdminChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""超级管理员修改指定用户的密码"""
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
user, generated_password = await user_service.admin_change_password(
db=db,
target_user_id=request.user_id,
new_password=request.new_password,
current_user=current_user
)
# 根据是否生成了随机密码来构造响应
if request.new_password:
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
return success(msg="密码修改成功")
else:
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
return success(data=generated_password, msg="密码重置成功")

View File

@@ -0,0 +1,342 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from app.models.user_model import User
from app.models.tenant_model import Tenants
from app.models.workspace_model import Workspace, InviteStatus
from app.schemas.response_schema import ApiResponse
from app.schemas.workspace_schema import (
WorkspaceCreate, WorkspaceUpdate, WorkspaceResponse,
WorkspaceInviteCreate, WorkspaceInviteResponse,
InviteValidateResponse, InviteAcceptRequest,
WorkspaceMemberUpdate, WorkspaceMemberItem
)
from app.schemas import knowledge_schema
from app.services import workspace_service
from app.core.logging_config import get_api_logger
from app.services import knowledge_service, document_service
# 获取API专用日志器
api_logger = get_api_logger()
# 需要认证的路由器
router = APIRouter(
prefix="/workspaces",
tags=["Workspaces"],
dependencies=[Depends(get_current_user)]
)
# 公开路由器(不需要认证)
public_router = APIRouter(
prefix="/workspaces",
tags=["Workspaces"]
)
def _convert_members_to_table_items(members):
"""将工作空间成员列表转换为表格项"""
return [
WorkspaceMemberItem(
id=m.id,
username=m.user.username,
account=m.user.email,
role=m.role,
last_login_at=m.user.last_login_at
)
for m in members
]
@router.get("", response_model=ApiResponse)
def get_workspaces(
include_current: bool = Query(True, description="是否包含当前工作空间"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
current_tenant: Tenants = Depends(get_current_tenant)
):
"""获取当前租户下用户参与的所有工作空间
Args:
include_current: 是否包含当前工作空间(默认 True
"""
api_logger.info(
f"用户 {current_user.username} 在租户 {current_tenant.name} 中请求获取工作空间列表",
extra={"include_current": include_current}
)
workspaces = workspace_service.get_user_workspaces(db, current_user)
# 如果不包含当前工作空间,则过滤掉
if not include_current and current_user.current_workspace_id:
workspaces = [w for w in workspaces if w.id != current_user.current_workspace_id]
api_logger.debug(
f"过滤掉当前工作空间",
extra={"current_workspace_id": str(current_user.current_workspace_id)}
)
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
return success(data=workspaces_schema, msg="工作空间列表获取成功")
@router.post("", response_model=ApiResponse)
def create_workspace(
workspace: WorkspaceCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""创建新的工作空间"""
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
result = workspace_service.create_workspace(
db=db, workspace=workspace, user=current_user)
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功")
@router.put("", response_model=ApiResponse)
@cur_workspace_access_guard()
def update_workspace(
workspace: WorkspaceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新工作空间"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 ID: {workspace_id}")
result = workspace_service.update_workspace(
db=db,
workspace_id=workspace_id,
workspace_in=workspace,
user=current_user,
)
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间更新成功")
@router.get("/members", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_cur_workspace_members(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作空间成员列表(关系序列化)"""
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
members = workspace_service.get_workspace_members(
db=db,
workspace_id=current_user.current_workspace_id,
user=current_user,
)
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
table_items = _convert_members_to_table_items(members)
return success(data=table_items, msg="工作空间成员列表获取成功")
@router.put("/members", response_model=ApiResponse)
@cur_workspace_access_guard()
def update_workspace_members(
updates: List[WorkspaceMemberUpdate],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
members = workspace_service.update_workspace_member_roles(
db=db,
workspace_id=workspace_id,
updates=updates,
user=current_user,
)
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
return success(msg="成员角色更新成功")
@router.delete("/members/{member_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def delete_workspace_member(
member_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
workspace_service.delete_workspace_member(
db=db,
workspace_id=workspace_id,
member_id=member_id,
user=current_user,
)
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
return success(msg="成员删除成功")
# 创建空间协作邀请
@router.post("/invites", response_model=ApiResponse)
@cur_workspace_access_guard()
def create_workspace_invite(
invite_data: WorkspaceInviteCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""创建工作空间邀请"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求为工作空间 {workspace_id} 创建邀请: {invite_data.email}")
result = workspace_service.create_workspace_invite(
db=db,
workspace_id=workspace_id,
invite_data=invite_data,
user=current_user
)
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
return success(data=result, msg="邀请创建成功")
@router.get("/invites", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_workspace_invites(
status_filter: Optional[InviteStatus] = Query(None, alias="status"),
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作空间邀请列表"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的邀请列表")
invites = workspace_service.get_workspace_invites(
db=db,
workspace_id=workspace_id,
user=current_user,
status=status_filter,
limit=limit,
offset=offset
)
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
return success(data=invites, msg="邀请列表获取成功")
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
def get_workspace_invite_info(
token: str,
db: Session = Depends(get_db),
):
"""获取工作空间邀请用户信息(无需认证)"""
result = workspace_service.validate_invite_token(db=db, token=token)
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
return success(data=result, msg="邀请验证成功")
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def revoke_workspace_invite(
invite_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""撤销工作空间邀请"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求撤销工作空间 {workspace_id} 的邀请 {invite_id}")
result = workspace_service.revoke_workspace_invite(
db=db,
workspace_id=workspace_id,
invite_id=invite_id,
user=current_user
)
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
return success(data=result, msg="邀请撤销成功")
# ==================== 公开邀请接口(无需认证) ====================
# # 创建一个新的路由器用于公开接口
# public_router = APIRouter(
# prefix="/invites",
# tags=["Public Invites"]
# )
# @public_router.get("/validate", response_model=ApiResponse)
# def validate_invite_token(
# token: str = Query(..., description="邀请令牌"),
# db: Session = Depends(get_db),
# ):
# """验证邀请令牌(公开接口)"""
# api_logger.info(f"验证邀请令牌请求")
@router.put("/{workspace_id}/switch", response_model=ApiResponse)
@workspace_access_guard()
def switch_workspace(
workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""切换工作空间"""
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
workspace_service.switch_workspace(
db=db,
workspace_id=workspace_id,
user=current_user,
)
api_logger.info(f"成功切换工作空间为 {workspace_id}")
return success(msg="工作空间切换成功")
@router.get("/storage", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_workspace_storage_type(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前工作空间的存储类型"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的存储类型")
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
@router.get("/workspace_models", response_model=ApiResponse)
@cur_workspace_access_guard()
def workspace_models_configs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的模型配置")
configs = workspace_service.get_workspace_models_configs(
db=db,
workspace_id=workspace_id,
user=current_user
)
if configs is None:
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作空间不存在或无权访问"
)
api_logger.info(
f"成功获取工作空间 {workspace_id} 的模型配置: "
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
)
return success(data=configs, msg="模型配置获取成功")

View File

View File

@@ -0,0 +1,35 @@
from pydantic import BaseModel
from app.core.agent.agent_chat import Agent_chat
from app.core.logging_config import get_business_logger
from fastapi import APIRouter, Depends, HTTPException
from app.dependencies import workspace_access_guard
from app.services.agent_server import config,ChatRequest
router = APIRouter(prefix="/Test", tags=["Apps"])
logger = get_business_logger()
class CombinedRequest(BaseModel):
config_base: config
agent_config: ChatRequest
@router.post("", summary="uuid")
async def agent_chat(
config_base: CombinedRequest
):
chat_config=config_base.agent_config
chat_base=config_base.config_base
request = ChatRequest(
end_user_id=chat_config.end_user_id,
message=chat_config.message,
search_switch=chat_config.search_switch,
kb_ids=chat_config.kb_ids,
similarity_threshold=chat_config.similarity_threshold,
vector_similarity_weight=chat_config.vector_similarity_weight,
top_k=chat_config.top_k,
hybrid=chat_config.hybrid,
token=chat_config.token
)
chat_result=await Agent_chat(chat_base).chat(request)
return chat_result

View File

@@ -0,0 +1,109 @@
import asyncio
import os
import time
from typing import Dict, Any, List
from app.core.logging_config import get_business_logger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.api_resquests_server import messages_type, write_messages
from app.services.agent_server import ChatRequest, tool_memory, create_dynamic_agent, tool_Retrieval
logger = get_business_logger()
class Agent_chat:
def __init__(self,config_data: dict):
self.prompt_message = render_prompt_message(
config_data.template_str,
PromptMessageRole.USER,
config_data.params
)
self.prompt = self.prompt_message.get_text_content()
self.model_configs = config_data.model_configs
self.history_memory = config_data.history_memory
self.knowledge_base = config_data.knowledge_base
logger.info(f"渲染结果:{self.prompt_message.get_text_content()}" )
async def run_agent(self,agent, end_user_id:str, user_prompt:str, model_name:str):
response = agent.invoke(
{
"messages": [
{
"role": "user",
"content": user_prompt
}
]
},
{"configurable": {"thread_id": f'{model_name}_{end_user_id}'}},
)
outputs = []
for msg in response["messages"]:
if hasattr(msg, "tool_calls") and msg.tool_calls:
outputs.append({
"role": "assistant",
"tool_calls": [
{"name": t["name"], "arguments": t["args"]}
for t in msg.tool_calls
]
})
elif hasattr(msg, "content") and msg.content:
outputs.append({
"role": msg.__class__.__name__.lower().replace("message", ""),
"content": msg.content
})
ai_messages=[msg['content'] for msg in outputs if msg["role"] == "ai"]
return {"model_name": model_name, "end_user_id": end_user_id, "response": ai_messages}
async def chat(self,req: ChatRequest) -> Dict[str, Any]:
end_user_id = req.end_user_id # 用 user_id 作为对话线程标识
start=time.time()
user_prompt = req.message
'''判断是都写入redis数据库'''
messags_type = await messages_type(req.message,end_user_id)
messags_type=messags_type['data']
if messags_type=='question':
writer_result=await write_messages(f'{end_user_id}', req.message)
logger.info(f'判断类型写入耗时:{time.time() - start},{writer_result}')
'''history_memory'''
if self.history_memory==True:
tool_result =await tool_memory(req)
if tool_result!='' :tool_result=tool_result['data']
if tool_result!='' :self.prompt=self.prompt+f''',历史消息:{tool_result},结合历史消息'''
logger.info(f"记忆科学消耗时间:{time.time()-start},工具调用结果:{tool_result}")
'''baidu'''
'''knowledge_base'''
if self.knowledge_base == True:
retrieval_result=await tool_Retrieval(req)
retrieval_knowledge = [i['page_content'] for i in retrieval_result['data']]
retrieval_knowledge=','.join(retrieval_knowledge)
logger.info(f"检索消耗时间:{time.time()-start},{retrieval_knowledge}")
if retrieval_knowledge!='' :self.prompt=self.prompt+f",知识库检索内容:{retrieval_knowledge},结合检索结果"
self.prompt=self.prompt+f'给出最合适的答案,确保答案的完整性,只保留用户的问题的回答,不额外输出提示语'
logger.info(f"用户输入:{user_prompt}")
logger.info(f"系统prompt{self.prompt}")
AGENTS = {
cfg["name"]: await create_dynamic_agent(cfg["name"], cfg["moder_id"], self.prompt, req.token)
for cfg in self.model_configs
}
tasks=[
self.run_agent(agent, end_user_id, user_prompt, model_name)
for model_name, agent in AGENTS.items()
]
# 并行运行
results = await asyncio.gather(*tasks)
result=[]
for i in results:
result.append(i)
chat_result=(f"最终耗时:{time.time()-start},{result}")
return chat_result

View File

@@ -0,0 +1,347 @@
"""
LangChain Agent 封装
使用 LangChain 1.x 标准方式
- 使用 create_agent 创建 agent graph
- 支持工具调用循环
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
import asyncio
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False
):
"""初始化 LangChain Agent
Args:
model_name: 模型名称
api_key: API Key
provider: 提供商openai, xinference, gpustack, ollama, dashscope
api_base: API 基础 URL
temperature: 温度参数
max_tokens: 最大 token 数
system_prompt: 系统提示词
tools: 工具列表(可选,框架自动走 ReAct 循环)
streaming: 是否启用流式输出(默认 True
"""
self.model_name = model_name
self.provider = provider
self.system_prompt = system_prompt or "你是一个专业的AI助手"
self.tools = tools or []
self.streaming = streaming
# 创建 RedBearLLM支持多提供商
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": streaming # 使用参数控制流式
}
)
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
# 获取底层模型用于真正的流式调用
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
# 确保底层模型也启用流式
if streaming and hasattr(self._underlying_llm, 'streaming'):
self._underlying_llm.streaming = True
# 使用 create_agent 创建 agent graphLangChain 1.x 标准方式)
# 无论是否有工具,都使用 agent 统一处理
self.agent = create_agent(
model=self.llm,
tools=self.tools if self.tools else None,
system_prompt=self.system_prompt
)
logger.info(
f"LangChain Agent 初始化完成",
extra={
"model": model_name,
"provider": provider,
"has_api_base": bool(api_base),
"temperature": temperature,
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
}
)
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None
) -> List[BaseMessage]:
"""准备消息列表
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
Returns:
List[BaseMessage]: 消息列表
"""
messages = []
# 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息
if history:
for msg in history:
if msg["role"] == "user":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
messages.append(AIMessage(content=msg["content"]))
# 添加当前用户消息
user_content = message
if context:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
messages.append(HumanMessage(content=user_content))
return messages
async def chat(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""执行对话
Args:
message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果)
Returns:
Dict: 包含 content 和元数据的字典
"""
start_time = time.time()
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
if config_id==None:
actual_config_id = os.getenv("config_id")
else:actual_config_id=config_id
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备调用 LangChain Agent",
extra={
"has_context": bool(context),
"has_history": bool(history),
"has_tools": bool(self.tools),
"message_count": len(messages)
}
)
# 统一使用 agent.invoke 调用
result = await self.agent.ainvoke({"messages": messages})
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
content = msg.content
break
elapsed_time = time.time() - start_time
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
response = {
"content": content,
"model": self.model_name,
"elapsed_time": elapsed_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
logger.debug(
f"Agent 调用完成",
extra={
"elapsed_time": elapsed_time,
"content_length": len(response["content"])
}
)
return response
except Exception as e:
logger.error(f"Agent 调用失败", extra={"error": str(e)})
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""执行流式对话
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
Yields:
str: 消息内容块
"""
logger.info("=" * 80)
logger.info(f" chat_stream 方法开始执行")
logger.info(f" Message: {message[:100]}")
logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
start_time = time.time()
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
else:
if config_id==None:
actual_config_id = os.getenv("config_id")
else:actual_config_id=config_id
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
try:
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
except Exception as e:
logger.error(f"Agent 记忆用户输入出错", extra={"error": str(e)})
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备流式调用has_tools={bool(self.tools)}, message_count={len(messages)}"
)
chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2"
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
if chunk:
if hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif isinstance(chunk, str):
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
except Exception as e:
logger.error("=" * 80)
logger.error(f"chat_stream 异常: {str(e)}")
logger.error("=" * 80, exc_info=True)
raise
finally:
logger.info("=" * 80)
logger.info(f"chat_stream 方法执行结束")
logger.info("=" * 80)

56
app/core/api_key_utils.py Normal file
View File

@@ -0,0 +1,56 @@
"""API Key 工具函数"""
import secrets
import hashlib
from app.models.api_key_model import ApiKeyType
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
"""生成 API Key
Args:
key_type: API Key 类型
Returns:
tuple: (api_key, key_hash, key_prefix)
"""
# 前缀映射
prefix_map = {
ApiKeyType.APP: "sk-app-",
ApiKeyType.RAG: "sk-rag-",
ApiKeyType.MEMORY: "sk-mem-",
ApiKeyType.GENERAL: "sk-gen-",
}
prefix = prefix_map[key_type]
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
api_key = f"{prefix}{random_string}"
# 生成哈希值存储
key_hash = hash_api_key(api_key)
return api_key, key_hash, prefix
def hash_api_key(api_key: str) -> str:
"""对 API Key 进行哈希
Args:
api_key: API Key 明文
Returns:
str: 哈希值
"""
return hashlib.sha256(api_key.encode()).hexdigest()
def verify_api_key(api_key: str, key_hash: str) -> bool:
"""验证 API Key
Args:
api_key: API Key 明文
key_hash: 存储的哈希值
Returns:
bool: 是否匹配
"""
return hash_api_key(api_key) == key_hash

47
app/core/compensation.py Normal file
View File

@@ -0,0 +1,47 @@
"""
Compensation Transaction Handler
Handles operations that cannot be rolled back (like file system operations).
"""
from typing import List, Callable
from app.core.logging_config import get_logger
logger = get_logger(__name__)
class CompensationHandler:
"""补偿事务处理器,用于处理无法回滚的操作"""
def __init__(self):
self._compensations: List[Callable] = []
def register(self, compensation: Callable):
"""
注册补偿操作
Args:
compensation: 补偿操作的可调用对象
"""
self._compensations.append(compensation)
logger.debug(f"Registered compensation operation: {compensation.__name__ if hasattr(compensation, '__name__') else 'lambda'}")
def execute(self):
"""执行所有补偿操作(按注册的逆序执行)"""
if not self._compensations:
logger.debug("No compensation operations to execute")
return
logger.info(f"Executing {len(self._compensations)} compensation operations")
for compensation in reversed(self._compensations):
try:
compensation()
logger.debug(f"Compensation operation executed successfully")
except Exception as e:
logger.error(f"补偿操作失败: {e}", exc_info=True)
def clear(self):
"""清空补偿操作"""
count = len(self._compensations)
self._compensations.clear()
if count > 0:
logger.debug(f"Cleared {count} compensation operations")

237
app/core/config.py Normal file
View File

@@ -0,0 +1,237 @@
import os
import json
from pathlib import Path
from typing import Dict, Any, Optional
from dotenv import load_dotenv
load_dotenv()
class Settings:
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
# Neo4j Configuration (记忆系统数据库)
NEO4J_URI: str = os.getenv("NEO4J_URI", "bolt://1.94.111.67:7687")
NEO4J_USERNAME: str = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD: str = os.getenv("NEO4J_PASSWORD", "")
# Database configuration (Postgres)
DB_HOST: str = os.getenv("DB_HOST", "127.0.0.1")
DB_PORT: int = int(os.getenv("DB_PORT", "5432"))
DB_USER: str = os.getenv("DB_USER", "postgres")
DB_PASSWORD: str = os.getenv("DB_PASSWORD", "password")
DB_NAME: str = os.getenv("DB_NAME", "redbear-mem")
DB_AUTO_UPGRADE = os.getenv("DB_AUTO_UPGRADE", "false").lower() == "true"
# Redis configuration
REDIS_HOST: str = os.getenv("REDIS_HOST", "127.0.0.1")
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
# ElasticSearch configuration
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
ELASTICSEARCH_PORT: int = int(os.getenv("ELASTICSEARCH_PORT", "9200"))
ELASTICSEARCH_USERNAME: str = os.getenv("ELASTICSEARCH_USERNAME", "elastic")
ELASTICSEARCH_PASSWORD: str = os.getenv("ELASTICSEARCH_PASSWORD", "")
ELASTICSEARCH_VERIFY_CERTS: bool = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "False").lower() == "true"
ELASTICSEARCH_CA_CERTS: str = os.getenv("ELASTICSEARCH_CA_CERTS", "")
ELASTICSEARCH_REQUEST_TIMEOUT: int = int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", "100000"))
ELASTICSEARCH_RETRY_ON_TIMEOUT: bool = os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", "True").lower() == "true"
ELASTICSEARCH_MAX_RETRIES: int = int(os.getenv("ELASTICSEARCH_MAX_RETRIES", "10"))
# Xinference configuration
XINFERENCE_URL: str = os.getenv("XINFERENCE_URL", "http://127.0.0.1")
# LangSmith configuration
LANGCHAIN_TRACING_V2: bool = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
LANGCHAIN_TRACING: bool = os.getenv("LANGCHAIN_TRACING", "false").lower() == "true"
LANGCHAIN_API_KEY: str = os.getenv("LANGCHAIN_API_KEY", "")
LANGCHAIN_ENDPOINT: str = os.getenv("LANGCHAIN_ENDPOINT", "")
# LLM Request Configuration
LLM_TIMEOUT: float = float(os.getenv("LLM_TIMEOUT", "120.0"))
LLM_MAX_RETRIES: int = int(os.getenv("LLM_MAX_RETRIES", "2"))
# JWT Token Configuration
SECRET_KEY: str = os.getenv("SECRET_KEY", "a_default_secret_key_that_is_long_and_random")
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
REFRESH_TOKEN_EXPIRE_DAYS: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
# VOLC ASR settings
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
VOLC_ACCESS_KEY: str = os.getenv("VOLC_ACCESS_KEY", "")
VOLC_SUBMIT_URL: str = os.getenv("VOLC_SUBMIT_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/submit")
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
# Langfuse configuration
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
# ========================================================================
# Superuser settings (internal defaults)
FIRST_SUPERUSER_EMAIL: str = os.getenv("FIRST_SUPERUSER_EMAIL", "admin@example.com")
FIRST_SUPERUSER_USERNAME: str = os.getenv("FIRST_SUPERUSER_USERNAME", "admin")
FIRST_SUPERUSER_PASSWORD: str = os.getenv("FIRST_SUPERUSER_PASSWORD", "admin_password")
# Generic File Upload (internal)
GENERIC_FILE_PATH: str = os.getenv("GENERIC_FILE_PATH", "/uploads")
ENABLE_FILE_COMPRESSION: bool = os.getenv("ENABLE_FILE_COMPRESSION", "false").lower() == "true"
ENABLE_VIRUS_SCAN: bool = os.getenv("ENABLE_VIRUS_SCAN", "false").lower() == "true"
FILE_ACCESS_URL_PREFIX: str = os.getenv("FILE_ACCESS_URL_PREFIX", "http://localhost:8000/api/files")
# Frontend URL for workspace invitations (internal)
WEB_URL: str = os.getenv("WEB_URL", "http://localhost:3000")
# CORS configuration (internal)
CORS_ORIGINS: list[str] = [
origin.strip()
for origin in os.getenv("CORS_ORIGINS", "").split(",")
if origin.strip()
]
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
LOG_FILE_PATH: str = os.getenv("LOG_FILE_PATH", "logs/app.log")
LOG_MAX_SIZE: int = int(os.getenv("LOG_MAX_SIZE", "10485760")) # 10MB
LOG_BACKUP_COUNT: int = int(os.getenv("LOG_BACKUP_COUNT", "5"))
LOG_TO_CONSOLE: bool = os.getenv("LOG_TO_CONSOLE", "true").lower() == "true"
LOG_TO_FILE: bool = os.getenv("LOG_TO_FILE", "true").lower() == "true"
# Sensitive Data Filtering
ENABLE_SENSITIVE_DATA_FILTER: bool = os.getenv("ENABLE_SENSITIVE_DATA_FILTER", "true").lower() == "true"
# Memory Module Logging
PROMPT_LOG_LEVEL: str = os.getenv("PROMPT_LOG_LEVEL", "INFO")
ENABLE_TEMPLATE_LOGGING: bool = os.getenv("ENABLE_TEMPLATE_LOGGING", "false").lower() == "true"
TIMING_LOG_FILE: str = os.getenv("TIMING_LOG_FILE", "logs/time.log")
TIMING_LOG_TO_CONSOLE: bool = os.getenv("TIMING_LOG_TO_CONSOLE", "true").lower() == "true"
AGENT_LOG_FILE: str = os.getenv("AGENT_LOG_FILE", "logs/agent_service.log")
AGENT_LOG_MAX_SIZE: int = int(os.getenv("AGENT_LOG_MAX_SIZE", "5242880")) # 5MB
AGENT_LOG_BACKUP_COUNT: int = int(os.getenv("AGENT_LOG_BACKUP_COUNT", "20"))
# Log Streaming Configuration
LOG_STREAM_KEEPALIVE_INTERVAL: int = int(os.getenv("LOG_STREAM_KEEPALIVE_INTERVAL", "300")) # 5 minutes
LOG_STREAM_MAX_CONNECTIONS: int = int(os.getenv("LOG_STREAM_MAX_CONNECTIONS", "10"))
LOG_STREAM_BUFFER_SIZE: int = int(os.getenv("LOG_STREAM_BUFFER_SIZE", "8192")) # 8KB
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.
Args:
filename: Optional filename to append to the output directory
Returns:
Full path to the output file or directory
"""
base_path = Path(self.MEMORY_OUTPUT_DIR)
if filename:
return str(base_path / filename)
return str(base_path)
def get_memory_config_path(self, config_file: str = "") -> str:
"""
Get the full path for memory module configuration files.
Args:
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
Returns:
Full path to the config file
"""
if not config_file:
config_file = self.MEMORY_CONFIG_FILE
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
def load_memory_config(self) -> Dict[str, Any]:
"""
Load memory module configuration from config.json.
Returns:
Dictionary containing memory configuration
"""
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
try:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
return {}
def load_memory_runtime_config(self) -> Dict[str, Any]:
"""
Load memory module runtime configuration from runtime.json.
Returns:
Dictionary containing runtime configuration
"""
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
try:
with open(runtime_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
return {"selections": {}}
def load_memory_dbrun_config(self) -> Dict[str, Any]:
"""
Load memory module database run configuration from dbrun.json.
Returns:
Dictionary containing dbrun configuration
"""
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
return {"selections": {}}
def ensure_memory_output_dir(self) -> None:
"""
Ensure the memory output directory exists.
Creates the directory if it doesn't exist.
"""
output_dir = Path(self.MEMORY_OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
settings = Settings()

130
app/core/error_codes.py Normal file
View File

@@ -0,0 +1,130 @@
from enum import IntEnum
class BizCode(IntEnum):
# 通用1xxx
OK = 0
BAD_REQUEST = 1000
VALIDATION_FAILED = 1001
MISSING_PARAMETER = 1002
INVALID_PARAMETER = 1003
# 认证/鉴权2xxx/3xxx
UNAUTHORIZED = 2001
TOKEN_INVALID = 2002
TOKEN_EXPIRED = 2003
TOKEN_BLACKLISTED = 2004
PASSWORD_ERROR = 2005
LOGIN_FAILED = 2006
FORBIDDEN = 3001
TENANT_NOT_FOUND = 3002
WORKSPACE_NO_ACCESS = 3003
WORKSPACE_INVITE_NOT_FOUND = 3004
# 资源4xxx
NOT_FOUND = 4000
USER_NOT_FOUND = 4001
WORKSPACE_NOT_FOUND = 4002
MODEL_NOT_FOUND = 4003
KNOWLEDGE_NOT_FOUND = 4004
DOCUMENT_NOT_FOUND = 4005
FILE_NOT_FOUND = 4006
APP_NOT_FOUND = 4007
RELEASE_NOT_FOUND = 4008
# 冲突/状态5xxx
DUPLICATE_NAME = 5001
RESOURCE_ALREADY_EXISTS = 5002
VERSION_ALREADY_EXISTS = 5003
STATE_CONFLICT = 5004
# 应用发布6xxx
PUBLISH_FAILED = 6001
NO_DRAFT_TO_PUBLISH = 6002
ROLLBACK_TARGET_NOT_FOUND = 6003
APP_TYPE_NOT_SUPPORTED = 6004
AGENT_CONFIG_MISSING = 6005
SHARE_DISABLED = 6006
INVALID_PASSWORD = 6007
PASSWORD_REQUIRED = 6008
EMBED_NOT_ALLOWED = 6009
PERMISSION_DENIED = 6010
INVALID_CONVERSATION = 6011
# 模型7xxx
MODEL_CONFIG_INVALID = 7001
API_KEY_MISSING = 7002
PROVIDER_NOT_SUPPORTED = 7003
LLM_ERROR = 7004
EMBEDDING_ERROR = 7005
# 文件/解析8xxx
FILE_READ_ERROR = 8001
PARSER_NOT_SUPPORTED = 8002
CHUNKING_FAILED = 8003
# RAG/知识9xxx
INDEX_BUILD_FAILED = 9001
EMBEDDING_FAILED = 9002
SEARCH_FAILED = 9003
# 系统100xx
INTERNAL_ERROR = 10001
DB_ERROR = 10002
SERVICE_UNAVAILABLE = 10003
RATE_LIMITED = 10004
# 建议的HTTP状态映射如需在异常处理器中使用
HTTP_MAPPING = {
BizCode.OK: 200,
BizCode.LOGIN_FAILED: 200,
BizCode.BAD_REQUEST: 400,
BizCode.VALIDATION_FAILED: 400,
BizCode.MISSING_PARAMETER: 400,
BizCode.INVALID_PARAMETER: 400,
BizCode.UNAUTHORIZED: 401,
BizCode.TOKEN_INVALID: 401,
BizCode.TOKEN_EXPIRED: 401,
BizCode.TOKEN_BLACKLISTED: 401,
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 404,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.NOT_FOUND: 404,
BizCode.USER_NOT_FOUND: 200,
BizCode.WORKSPACE_NOT_FOUND: 404,
BizCode.MODEL_NOT_FOUND: 404,
BizCode.KNOWLEDGE_NOT_FOUND: 404,
BizCode.DOCUMENT_NOT_FOUND: 404,
BizCode.FILE_NOT_FOUND: 404,
BizCode.APP_NOT_FOUND: 404,
BizCode.RELEASE_NOT_FOUND: 404,
BizCode.DUPLICATE_NAME: 409,
BizCode.RESOURCE_ALREADY_EXISTS: 409,
BizCode.VERSION_ALREADY_EXISTS: 409,
BizCode.STATE_CONFLICT: 409,
BizCode.PUBLISH_FAILED: 500,
BizCode.NO_DRAFT_TO_PUBLISH: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 404,
BizCode.APP_TYPE_NOT_SUPPORTED: 400,
BizCode.AGENT_CONFIG_MISSING: 400,
BizCode.SHARE_DISABLED: 403,
BizCode.INVALID_PASSWORD: 401,
BizCode.PASSWORD_REQUIRED: 401,
BizCode.EMBED_NOT_ALLOWED: 403,
BizCode.PERMISSION_DENIED: 403,
BizCode.INVALID_CONVERSATION: 400,
BizCode.MODEL_CONFIG_INVALID: 400,
BizCode.API_KEY_MISSING: 400,
BizCode.PROVIDER_NOT_SUPPORTED: 400,
BizCode.LLM_ERROR: 500,
BizCode.EMBEDDING_ERROR: 500,
BizCode.FILE_READ_ERROR: 500,
BizCode.PARSER_NOT_SUPPORTED: 400,
BizCode.CHUNKING_FAILED: 500,
BizCode.INDEX_BUILD_FAILED: 500,
BizCode.EMBEDDING_FAILED: 500,
BizCode.SEARCH_FAILED: 500,
BizCode.INTERNAL_ERROR: 500,
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,
BizCode.RATE_LIMITED: 429,
}

86
app/core/exceptions.py Normal file
View File

@@ -0,0 +1,86 @@
"""
业务异常定义
"""
from typing import Any, Dict, Optional
from app.core.error_codes import BizCode
class BusinessException(Exception):
"""业务逻辑异常基类"""
def __init__(
self,
message: str,
code: BizCode | int | None = None,
context: Optional[Dict[str, Any]] = None,
cause: Optional[Exception] = None
):
self.message = message
self.code = code if code is not None else BizCode.BAD_REQUEST
# Make a copy of context to avoid modifying the original dict
self.context = dict(context) if context else {}
self.cause = cause
super().__init__(self.message)
def __str__(self) -> str:
ctx = f", context={self.context}" if self.context else ""
code_name = self.code.name if isinstance(self.code, BizCode) else str(self.code)
return f"{code_name}: {self.message}{ctx}"
class ValidationException(BusinessException):
"""数据验证异常"""
def __init__(self, message: str, field: str = None, **kwargs):
context = {"field": field} if field else {}
if "context" in kwargs:
context.update(kwargs.pop("context"))
super().__init__(message, BizCode.VALIDATION_FAILED, context, **kwargs)
class AuthenticationException(BusinessException):
"""认证异常"""
def __init__(self, message: str = "认证失败", **kwargs):
super().__init__(message, BizCode.UNAUTHORIZED, **kwargs)
class AuthorizationException(BusinessException):
"""授权异常"""
def __init__(self, message: str = "权限不足", **kwargs):
super().__init__(message, BizCode.FORBIDDEN, **kwargs)
class ResourceNotFoundException(BusinessException):
"""资源未找到异常"""
def __init__(self, resource_type: str, resource_id: str = None, **kwargs):
message = f"{resource_type} 不存在"
context = {"resource_type": resource_type}
if resource_id:
context["resource_id"] = resource_id
if "context" in kwargs:
context.update(kwargs.pop("context"))
super().__init__(message, BizCode.FILE_NOT_FOUND, context, **kwargs)
class DuplicateResourceException(BusinessException):
"""资源重复异常"""
def __init__(self, message: str = "资源已存在", **kwargs):
super().__init__(message, BizCode.DUPLICATE_NAME, **kwargs)
class FileUploadException(BusinessException):
"""文件上传异常"""
def __init__(self, message: str, **kwargs):
super().__init__(message, BizCode.FILE_READ_ERROR, **kwargs)
class PermissionDeniedException(BusinessException):
"""权限拒绝异常"""
def __init__(self, message: str = "权限不足", **kwargs):
super().__init__(message, BizCode.FORBIDDEN, **kwargs)

633
app/core/logging_config.py Normal file
View File

@@ -0,0 +1,633 @@
import logging
import logging.handlers
import os
from pathlib import Path
from typing import Optional
from app.core.config import settings
from app.core.sensitive_filter import SensitiveDataFilter
class SensitiveDataLoggingFilter(logging.Filter):
"""日志过滤器:自动过滤敏感信息"""
def filter(self, record: logging.LogRecord) -> bool:
"""
过滤日志记录中的敏感信息
Args:
record: 日志记录
Returns:
True表示允许记录False表示拒绝
"""
# 过滤消息中的敏感信息
if hasattr(record, 'msg') and isinstance(record.msg, str):
record.msg = SensitiveDataFilter.filter_string(record.msg)
# 过滤参数中的敏感信息
if hasattr(record, 'args') and record.args:
if isinstance(record.args, dict):
record.args = SensitiveDataFilter.filter_dict(record.args)
elif isinstance(record.args, (list, tuple)):
record.args = tuple(
SensitiveDataFilter.filter_string(str(arg)) if isinstance(arg, str) else arg
for arg in record.args
)
return True
class LoggingConfig:
"""全局日志配置类"""
_initialized = False
_memory_loggers_initialized = False
_prompt_logger = None
_template_logger = None
_timing_logger = None
_agent_loggers = {}
@classmethod
def setup_logging(cls) -> None:
"""初始化全局日志配置"""
if cls._initialized:
return
# 创建日志目录
log_dir = Path(settings.LOG_FILE_PATH).parent
log_dir.mkdir(parents=True, exist_ok=True)
# 配置根日志器
root_logger = logging.getLogger()
root_logger.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
# 清除现有处理器
root_logger.handlers.clear()
# 创建格式化器
formatter = logging.Formatter(
fmt=settings.LOG_FORMAT,
datefmt='%Y-%m-%d %H:%M:%S'
)
# 创建敏感信息过滤器
sensitive_filter = SensitiveDataLoggingFilter()
# 控制台处理器
if settings.LOG_TO_CONSOLE:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
console_handler.addFilter(sensitive_filter)
root_logger.addHandler(console_handler)
# 文件处理器(带轮转)
if settings.LOG_TO_FILE:
file_handler = logging.handlers.RotatingFileHandler(
filename=settings.LOG_FILE_PATH,
maxBytes=settings.LOG_MAX_SIZE,
backupCount=5,
encoding='utf-8'
)
file_handler.setFormatter(formatter)
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
file_handler.addFilter(sensitive_filter)
root_logger.addHandler(file_handler)
cls._initialized = True
# Initialize memory module logging
cls.setup_memory_logging()
# 记录初始化完成
logger = logging.getLogger(__name__)
logger.info("全局日志系统初始化完成")
@classmethod
def setup_memory_logging(cls) -> None:
"""Initialize memory module specific loggers.
Called automatically by setup_logging() or can be called independently.
Sets up:
- Prompt logger with timestamped files
- Template logger with conditional file output
- Timing logger with dual output (file + console)
- Agent logger factory with concurrent handlers
"""
if cls._memory_loggers_initialized:
return
# Create logs directory if it doesn't exist
log_dir = Path("logs")
try:
log_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
print(f"Warning: Could not create log directory: {e}")
# Continue with console-only logging
# Initialize memory-specific loggers
# These will be created lazily when first requested via factory functions
# This method just marks the system as ready for memory logging
cls._memory_loggers_initialized = True
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""获取日志器实例
Args:
name: 日志器名称,默认为调用模块名
Returns:
配置好的日志器实例
"""
return logging.getLogger(name)
def get_auth_logger() -> logging.Logger:
"""获取认证专用日志器"""
return logging.getLogger("auth")
def get_security_logger() -> logging.Logger:
"""获取安全专用日志器"""
return logging.getLogger("security")
def get_api_logger() -> logging.Logger:
"""获取API专用日志器"""
return logging.getLogger("api")
def get_db_logger() -> logging.Logger:
"""获取数据库专用日志器"""
return logging.getLogger("database")
def get_business_logger() -> logging.Logger:
"""获取业务逻辑专用日志器"""
return logging.getLogger("business")
def get_prompt_logger() -> logging.Logger:
"""Get the prompt logger for memory module.
Returns a logger configured for prompt rendering output with:
- Logger name: memory.prompts
- Output: logs/prompt_logs-{timestamp}.log
- Level: Configurable via PROMPT_LOG_LEVEL setting (default: INFO)
- Handler: FileHandler (no console output)
The logger is cached after first creation for performance.
Returns:
Logger configured for prompt rendering output
Example:
>>> logger = get_prompt_logger()
>>> logger.info("=== RENDERED EXTRACTION PROMPT ===\\n%s", prompt_content)
"""
# Return cached logger if already initialized
if LoggingConfig._prompt_logger is not None:
return LoggingConfig._prompt_logger
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create prompt logger
logger = logging.getLogger("memory.prompts")
logger.setLevel(getattr(logging, settings.PROMPT_LOG_LEVEL.upper()))
logger.propagate = False # Don't propagate to root logger (no console output)
# Create timestamped log file
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
log_file = Path("logs/prompts/") / f"prompt_logs-{timestamp}.log"
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
# Create file handler
file_handler = logging.FileHandler(
filename=str(log_file),
encoding='utf-8'
)
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
# Add handler to logger
logger.addHandler(file_handler)
# Cache the logger
LoggingConfig._prompt_logger = logger
return logger
def get_template_logger() -> logging.Logger:
"""Get the template logger for memory module.
Returns a logger configured for template rendering information with:
- Logger name: memory.templates
- Output: logs/prompt_templates.log (only when ENABLE_TEMPLATE_LOGGING is True)
- Level: INFO
- Handler: FileHandler when enabled, NullHandler when disabled
The logger is cached after first creation for performance.
Returns:
Logger configured for template rendering info
Example:
>>> logger = get_template_logger()
>>> logger.info("Rendering template: %s with context keys: %s",
... template_name, list(context.keys()))
"""
# Return cached logger if already initialized
if LoggingConfig._template_logger is not None:
return LoggingConfig._template_logger
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create template logger
logger = logging.getLogger("memory.templates")
logger.setLevel(logging.INFO)
logger.propagate = False # Don't propagate to root logger
# Add appropriate handler based on configuration
if settings.ENABLE_TEMPLATE_LOGGING:
# Create log file path
log_file = Path("logs") / "prompt_templates.log"
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
# Create file handler
file_handler = logging.FileHandler(
filename=str(log_file),
encoding='utf-8'
)
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
# Add handler to logger
logger.addHandler(file_handler)
else:
# Use NullHandler when template logging is disabled
null_handler = logging.NullHandler()
logger.addHandler(null_handler)
# Cache the logger
LoggingConfig._template_logger = logger
return logger
def log_prompt_rendering(prompt_type: str, content: str) -> None:
"""Log rendered prompt content.
Logs the rendered prompt with a formatted header and separator for easy
identification in log files. This is useful for debugging LLM interactions
and understanding what prompts are being sent.
Args:
prompt_type: Type of prompt (e.g., 'statement_extraction', 'triplet_extraction')
content: The rendered prompt text
Example:
>>> log_prompt_rendering("extraction", "Extract entities from: Hello world")
# Logs:
# === RENDERED EXTRACTION PROMPT ===
# Extract entities from: Hello world
# =====================================
"""
logger = get_prompt_logger()
# Format the log entry with header and separator
separator = "=" * 50
header = f"=== RENDERED {prompt_type.upper()} PROMPT ==="
log_message = f"\n{header}\n{content}\n{separator}\n"
logger.info(log_message)
def log_template_rendering(template_name: str, context: dict | None = None) -> None:
"""Log template rendering information.
Logs the template name and context keys for debugging template rendering.
This function is wrapped in try-except to ensure it never breaks application
flow, even if logging fails.
Args:
template_name: Name of the Jinja2 template being rendered
context: Optional context dictionary with template variables
Example:
>>> log_template_rendering("extract_triplet.jinja2", {"text": "...", "ontology": "..."})
# Logs: Rendering template: extract_triplet.jinja2 with context keys: ['text', 'ontology']
>>> log_template_rendering("system.jinja2")
# Logs: Rendering template: system.jinja2 with no context
"""
try:
logger = get_template_logger()
if context is not None:
context_keys = list(context.keys())
logger.info(f"Rendering template: {template_name} with context keys: {context_keys}")
else:
logger.info(f"Rendering template: {template_name} with no context")
except Exception:
# Never break application flow due to logging issues
# Silently ignore any logging errors
pass
def get_timing_logger() -> logging.Logger:
"""Get the timing logger for memory module.
Returns a logger configured for performance timing with:
- Logger name: memory.timing
- Output: Configurable via TIMING_LOG_FILE setting (default: logs/time.log)
- Level: INFO
- Handlers: FileHandler + optional StreamHandler for console output
- Console output: Controlled by TIMING_LOG_TO_CONSOLE setting (default: True)
The logger is cached after first creation for performance.
Returns:
Logger configured for performance timing
Example:
>>> logger = get_timing_logger()
>>> logger.info("[2025-11-18 10:30:45] Extraction: 2.34 seconds")
"""
# Return cached logger if already initialized
if LoggingConfig._timing_logger is not None:
return LoggingConfig._timing_logger
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create timing logger
logger = logging.getLogger("memory.timing")
logger.setLevel(logging.INFO)
logger.propagate = False # Don't propagate to root logger
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Add file handler
log_file = Path(settings.TIMING_LOG_FILE)
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(
filename=str(log_file),
encoding='utf-8'
)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Add console handler if enabled
if settings.TIMING_LOG_TO_CONSOLE:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Cache the logger
LoggingConfig._timing_logger = logger
return logger
def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -> None:
"""Log timing information for performance tracking.
Logs timing information to both file and console (console output is always shown
for backward compatibility). The file output includes a timestamp and full details,
while console output shows a concise checkmark format.
Args:
step_name: Name of the operation being timed
duration: Duration in seconds
log_file: Optional custom log file path (default: logs/time.log)
Example:
>>> log_time("Knowledge Extraction", 2.34)
# File logs: [2025-11-18 10:30:45] Knowledge Extraction: 2.34 seconds
# Console prints: ✓ Knowledge Extraction: 2.34s
>>> log_time("Database Query", 0.15, "logs/custom_time.log")
# Logs to custom file and console
"""
from datetime import datetime
# Format timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Format timing entry for file
log_entry = f"[{timestamp}] {step_name}: {duration:.2f} seconds\n"
# Write to file with error handling
try:
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
with open(log_path, "a", encoding="utf-8") as f:
f.write(log_entry)
except IOError as e:
# Fallback to console only if file write fails
print(f"Warning: Could not write to timing log: {e}")
# Always print to console (backward compatible behavior)
print(f"{step_name}: {duration:.2f}s")
def get_agent_logger(name: str = "agent_service",
console_level: str = "INFO",
file_level: str = "DEBUG") -> logging.Logger:
"""Get an agent logger with concurrent file handling.
Returns a logger configured for agent operations with:
- Logger name: memory.agent.{name}
- Output: Configurable via AGENT_LOG_FILE setting (default: logs/agent_service.log)
- Console level: Configurable (default: INFO)
- File level: Configurable (default: DEBUG)
- Handler: ConcurrentRotatingFileHandler for multi-process support
- Rotation: Configurable via AGENT_LOG_MAX_SIZE (default: 5MB) and
AGENT_LOG_BACKUP_COUNT (default: 20)
The logger is cached by name after first creation for performance.
Supports concurrent writes from multiple processes.
Args:
name: Logger name for namespacing (default: "agent_service")
console_level: Log level for console output (default: "INFO")
file_level: Log level for file output (default: "DEBUG")
Returns:
Logger configured for agent operations
Example:
>>> logger = get_agent_logger("my_agent")
>>> logger.info("Agent operation started")
>>> logger.debug("Detailed agent state information")
>>> logger = get_agent_logger("custom_agent", console_level="WARNING", file_level="INFO")
>>> logger.warning("This appears in console and file")
>>> logger.info("This only appears in file")
"""
# Return cached logger if already initialized
if name in LoggingConfig._agent_loggers:
return LoggingConfig._agent_loggers[name]
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Create agent logger with namespaced name
logger_name = f"memory.agent.{name}"
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG) # Set to DEBUG to allow both handlers to filter
logger.propagate = False # Don't propagate to root logger
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(getattr(logging, console_level.upper()))
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Add concurrent rotating file handler
try:
from concurrent_log_handler import ConcurrentRotatingFileHandler
except ImportError:
# Fall back to standard RotatingFileHandler if concurrent handler not available
from logging.handlers import RotatingFileHandler as ConcurrentRotatingFileHandler
print("Warning: concurrent-log-handler not available, using standard RotatingFileHandler")
# Create log file path
log_file = Path(settings.AGENT_LOG_FILE)
# Ensure log directory exists
log_file.parent.mkdir(parents=True, exist_ok=True)
# Create file handler with rotation
file_handler = ConcurrentRotatingFileHandler(
filename=str(log_file),
maxBytes=settings.AGENT_LOG_MAX_SIZE,
backupCount=settings.AGENT_LOG_BACKUP_COUNT,
encoding='utf-8'
)
file_handler.setLevel(getattr(logging, file_level.upper()))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# Cache the logger
LoggingConfig._agent_loggers[name] = logger
return logger
def get_named_logger(name: str) -> logging.Logger:
"""Backward compatible alias for get_agent_logger.
This function maintains backward compatibility with existing code that uses
the get_named_logger pattern from the agent logger module.
Args:
name: Logger name for namespacing
Returns:
Logger configured for agent operations
Example:
>>> logger = get_named_logger("my_agent")
>>> logger.info("Agent operation started")
"""
return get_agent_logger(name)
def get_memory_logger(name: Optional[str] = None) -> logging.Logger:
"""Get a standard logger for memory module components.
Returns a logger configured for memory module components that inherits
the root logger's configuration (handlers, formatters, and level). This
provides consistent logging behavior across the memory module while
maintaining the ability to filter and identify memory-specific logs.
The logger uses the 'memory' namespace:
- If name is provided: logger name is 'memory.{module_name}'
- If name is None: logger name is 'memory'
The logger inherits all handlers and formatters from the root logger,
ensuring consistent output format and destinations (console, file, etc.).
Args:
name: Optional logger name, typically __name__ from the calling module.
If provided, creates a namespaced logger under 'memory.{name}'.
If None, returns the base 'memory' logger.
Returns:
Logger configured for memory module operations with root logger inheritance
Example:
>>> # In app/core/memory/src/search.py
>>> logger = get_memory_logger(__name__)
>>> logger.info("Starting search operation")
# Logs: [timestamp] - memory.app.core.memory.src.search - INFO - Starting search operation
>>> # Get base memory logger
>>> logger = get_memory_logger()
>>> logger.debug("Memory module initialized")
# Logs: [timestamp] - memory - DEBUG - Memory module initialized
>>> # In app/core/memory/src/knowledge_extraction/triplet_extraction.py
>>> logger = get_memory_logger(__name__)
>>> logger.error("Extraction failed", exc_info=True)
# Logs error with full traceback
"""
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Construct logger name with memory namespace
if name is not None:
logger_name = f"memory.{name}"
else:
logger_name = "memory"
# Get logger - it will inherit from root logger configuration
logger = logging.getLogger(logger_name)
# The logger automatically inherits handlers, formatters, and level from root logger
# through Python's logging hierarchy, so no additional configuration is needed
return logger

View File

View File

View File

@@ -0,0 +1,16 @@
"""
LangGraph Graph package for memory agent.
This package provides the LangGraph workflow orchestrator with modular
node implementations, routing logic, and state management.
Package structure:
- read_graph: Main graph factory for read operations
- write_graph: Main graph factory for write operations
- nodes: LangGraph node implementations
- routing: State routing logic
- state: State management utilities
"""
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
__all__ = ['make_read_graph']

View File

@@ -0,0 +1,10 @@
"""
LangGraph node implementations.
This module contains custom node implementations for the LangGraph workflow.
"""
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
__all__ = ["ToolExecutionNode", "create_input_message"]

View File

@@ -0,0 +1,144 @@
"""
Input node for LangGraph workflow entry point.
This module provides the create_input_message function which processes initial
user input with multimodal support and creates the first tool call message.
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Dict, Any
from langchain_core.messages import AIMessage
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
logger = logging.getLogger(__name__)
async def create_input_message(
state: Dict[str, Any],
tool_name: str,
session_id: str,
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
This function:
1. Extracts the last message content from state
2. Processes multimodal inputs (images/audio) using the multimodal processor
3. Generates a unique message ID
4. Extracts namespace from session_id
5. Handles verified_data extraction for backward compatibility
6. Returns AIMessage with complete tool_calls structure
Args:
state: LangGraph state dictionary containing messages
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
session_id: Session identifier (format: "call_id_{namespace}")
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
Returns:
State update with AIMessage containing tool_call
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
"""
messages = state.get("messages", [])
# Extract last message content
if messages:
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
else:
logger.warning("[create_input_message] No messages in state, using empty string")
last_message = ""
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
# Process multimodal input (images/audio)
try:
processed_content = await multimodal_processor.process_input(last_message)
if processed_content != last_message:
logger.info(
f"[create_input_message] Multimodal processing converted input "
f"from {len(last_message)} to {len(processed_content)} chars"
)
last_message = processed_content
except Exception as e:
logger.error(
f"[create_input_message] Multimodal processing failed: {e}",
exc_info=True
)
# Continue with original content
# Generate unique message ID
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Extract namespace from session_id
# Expected format: "call_id_{namespace}" or similar
try:
namespace = str(session_id).split('_id_')[1]
except (IndexError, AttributeError):
logger.warning(
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
)
namespace = "unknown"
# Handle verified_data extraction (backward compatibility)
# This regex-based extraction is kept for compatibility with existing data formats
if 'verified_data' in str(last_message):
try:
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
query_match = re.findall(r'"query": "(.*?)",', messages_last)
if query_match:
last_message = query_match[0]
logger.debug(
f"[create_input_message] Extracted query from verified_data: {last_message}"
)
except Exception as e:
logger.warning(
f"[create_input_message] Failed to extract query from verified_data: {e}"
)
# Construct tool call message
tool_call_id = f"{session_id}_{uuid_str}"
logger.info(
f"[create_input_message] Creating tool call for '{tool_name}' "
f"with ID: {tool_call_id}"
)
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id
},
"id": tool_call_id
}]
)
]
}

View File

@@ -0,0 +1,199 @@
"""
Tool execution node for LangGraph workflow.
This module provides the ToolExecutionNode class which wraps tool execution
with parameter transformation logic using the ParameterBuilder service.
"""
import logging
import time
from typing import Any, Callable, Dict
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_tool_call_id,
extract_content_payload
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
logger = logging.getLogger(__name__)
class ToolExecutionNode:
"""
Custom LangGraph node that wraps tool execution with parameter transformation.
This node extracts content from previous tool results, transforms parameters
based on tool type using ParameterBuilder, and invokes the tool with the
correct argument structure.
Attributes:
tool_node: LangGraph ToolNode wrapping the actual tool
id: Node identifier for message IDs
tool_name: Name of the tool being executed
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
"""
def __init__(
self,
tool: Callable,
node_id: str,
namespace: str,
search_switch: str,
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type:str,
user_rag_memory_id:str
):
"""
Initialize the tool execution node.
Args:
tool: The tool function to execute
node_id: Identifier for this node (used in message IDs)
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
"""
self.tool_node = ToolNode([tool])
self.id = node_id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type=storage_type
self.user_rag_memory_id=user_rag_memory_id
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the tool with transformed parameters.
This method:
1. Extracts the last message from state
2. Extracts tool call ID using state extractors
3. Extracts content payload using state extractors
4. Builds tool arguments using parameter builder
5. Constructs AIMessage with tool_calls
6. Invokes the tool and returns the result
Args:
state: LangGraph state dictionary
Returns:
Updated state with tool result in messages
"""
messages = state.get("messages", [])
logger.debug( self.tool_name)
if not messages:
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
return {"messages": [AIMessage(content="Error: No messages in state")]}
last_message = messages[-1]
logger.debug(
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
)
try:
# Extract tool call ID using state extractors
tool_call_id = extract_tool_call_id(last_message)
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
except ValueError as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
)
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
try:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
exc_info=True
)
content = {}
try:
# Build tool arguments using parameter builder
tool_args = self.parameter_builder.build_tool_args(
tool_name=self.tool_name,
content=content,
tool_call_id=tool_call_id,
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
exc_info=True
)
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
# Construct tool input message
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": f"{self.id}_{tool_call_id}",
}]
)
]
}
try:
# Invoke the tool
result = await self.tool_node.ainvoke(tool_input)
logger.debug(
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Return the result directly - it already contains the messages list
return result
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Return error as ToolMessage to maintain message chain consistency
from langchain_core.messages import ToolMessage
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
]
}

View File

@@ -0,0 +1,508 @@
import asyncio
import io
import json
import logging
import os
import re
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Literal
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from functools import partial
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
from langgraph.checkpoint.memory import InMemorySaver
from app.core.memory.agent.utils.redis_tool import store
from app.core.logging_config import get_agent_logger
# Import new modular components
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
logger = get_agent_logger(__name__)
warnings.filterwarnings("ignore", category=RuntimeWarning)
load_dotenv()
redishost=os.getenv("REDISHOST")
redisport=os.getenv('REDISPORT')
redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# 在工作流中添加循环计数更新
async def update_loop_count(state):
"""更新循环计数器"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
# 添加边界检查
if not messages:
return END
counter.add(1) # 累加 1
loop_count = counter.get_total()
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"Status tools: {status_tools}")
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # 最大循环次数 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# 添加默认返回值,避免返回 None
counter.reset()
return "Summary" # 或根据业务需求选择合适的默认值
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# 添加默认返回值,避免返回 None
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
# 在 input_sentence 函数中修改参数名称
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
messages = state["messages"]
last_message = messages[-1].content if messages else ""
if last_message.endswith('.jpg') or last_message.endswith('.png'):
last_message=await picture_model_requests(last_message)
if any(last_message.endswith(ext) for ext in audio_extensions):
last_message=await Vico_recognition([last_message]).run()
logger.debug(f"Audio recognition result: {last_message}")
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
namespace = str(id).split('_id_')[1]
if 'verified_data' in str(last_message):
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": name,
"args": {
"sentence": last_message,
'sessionid': id,
'messages_id': str(uuid_str),
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
"apply_id":apply_id,
"group_id":group_id
},
"id": id + f'_{uuid_str}'
}]
)
]
}
class ProblemExtensionNode:
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
self.tool_node = ToolNode([tool])
self.id = id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
if self.tool_name=='Input_Summary':
tool_call =re.findall(f"'id': '(.*?)'",str(last_message))[0]
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# try:
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
# except:
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# 捕获 ToolMessage 的 content 字段(支持单/双引号),并避免贪婪匹配
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# 回退:直接尝试使用原始字符串
extracted_payload = raw_msg
# 优先尝试将内容解析为 JSON
try:
content = json.loads(extracted_payload)
except Exception:
# 尝试从文本中提取 JSON 片段再解析
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
try:
parsed = json.loads(cand)
break
except Exception:
continue
# 如果仍然失败,则以原始字符串作为内容
content = parsed if parsed is not None else extracted_payload
# 根据工具名称构建正确的参数
tool_args = {}
if self.tool_name == "Verify":
# Verify工具需要context和usermessages参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve工具需要context和usermessages参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary工具需要字符串类型的context参数
if isinstance(content, dict):
# 将字典转换为JSON字符串
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary工具需要字符串类型的context参数
if isinstance(content, dict):
# 将字典转换为JSON字符串
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name=='Input_Summary':
tool_args["context"] =str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name=='Retrieve_Summary' :
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
try:
parsed_content = json.loads(content)
# Check if it has a "context" key
if isinstance(parsed_content, dict) and "context" in parsed_content:
tool_args["context"] = parsed_content["context"]
else:
tool_args["context"] = parsed_content
except json.JSONDecodeError:
# If parsing fails, wrap the string
tool_args["context"] = {"content": content}
elif isinstance(content, dict):
# Check if content has a "context" key that needs unwrapping
if "context" in content:
tool_args["context"] = content["context"]
else:
tool_args["context"] = content
else:
tool_args["context"] = {"content": str(content)}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# 其他工具使用context参数
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": self.id + f"{tool_call}",
}]
)
]
}
result = await self.tool_node.ainvoke(tool_input)
result_text = str(result)
return {"messages": [AIMessage(content=result_text)]}
@asynccontextmanager
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
memory = InMemorySaver()
tool=[i.name for i in tools ]
logger.info(f"Initializing read graph with tools: {tool}")
if config_id:
logger.info(f"使用配置 ID: {config_id}")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
Verify_ = next((t for t in tools if t.name == "Verify"), None)
Summary_ = next((t for t in tools if t.name == "Summary"), None)
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
# Instantiate services
parameter_builder = ParameterBuilder()
multimodal_processor = MultimodalProcessor()
# Create nodes using new modular components
Split_The_Problem_node = ToolNode([Split_The_Problem_])
Problem_Extension_node = ToolExecutionNode(
tool=Problem_Extension_,
node_id="Problem_Extension_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Summary_node = ToolExecutionNode(
tool=Summary_,
node_id="Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
return await create_input_message(
state=state,
tool_name=tool_name,
session_id=f"{session_prefix}_{namespace}",
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor
)
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
workflow.add_node("Problem_Extension", Problem_Extension_node)
workflow.add_node("Retrieve", Retrieve_node)
workflow.add_node("Verify", Verify_node)
workflow.add_node("Summary", Summary_node)
workflow.add_node("Summary_fails", Summary_fails_node)
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
workflow.add_node("Input_Summary", Input_Summary_node)
# Add edges using imported routers
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
graph = workflow.compile(checkpointer=memory)
yield graph
# 添加到文件末尾或创建新的执行脚本
# 在 memory_agent_service.py 文件中添加以下函数

View File

@@ -0,0 +1,13 @@
"""LangGraph routing logic."""
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue,
)
__all__ = [
"Verify_continue",
"Retrieve_continue",
"Split_continue",
]

View File

@@ -0,0 +1,123 @@
"""
Routing functions for LangGraph conditional edges.
This module provides routing functions that determine the next node to execute
based on state values. All functions return Literal types for type safety.
"""
import logging
import re
from typing import Literal
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = logging.getLogger(__name__)
# Global counter for Verify routing
counter = COUNTState(limit=3)
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
"""
Determine routing after Verify node based on verification result.
This function checks the verification result in the last message and routes to:
- Summary: if verification succeeded
- content_input: if verification failed and retry limit not reached
- Summary_fails: if verification failed and retry limit reached
Args:
state: LangGraph state containing messages
Returns:
Next node name as Literal type
"""
messages = state.get("messages", [])
# Boundary check
if not messages:
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
counter.reset()
return "Summary"
# Increment counter
counter.add(1)
loop_count = counter.get_total()
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
# Extract verification result from last message
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
# Route based on verification result
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Max retry count is 2
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Default to Summary if status is unclear
counter.reset()
return "Summary"
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing after Retrieve node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '0': Route to Verify (verification needed)
- search_switch == '1': Route to Retrieve_Summary (direct summary)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
search_switch = extract_search_switch(state)
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Default to Retrieve_Summary
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
return 'Retrieve_Summary'
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing after content_input node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '2': Route to Input_Summary (direct input summary)
- Otherwise: Route to Split_The_Problem (problem decomposition)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
logger.debug(f"[Split_continue] state keys: {state.keys()}")
search_switch = extract_search_switch(state)
logger.debug(f"[Split_continue] search_switch: {search_switch}")
if search_switch == '2':
return 'Input_Summary'
# Default to Split_The_Problem
return 'Split_The_Problem'

View File

@@ -0,0 +1,13 @@
"""LangGraph state management utilities."""
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_search_switch,
extract_tool_call_id,
extract_content_payload,
)
__all__ = [
"extract_search_switch",
"extract_tool_call_id",
"extract_content_payload",
]

View File

@@ -0,0 +1,164 @@
"""
State extraction utilities for type-safe access to LangGraph state values.
This module provides utility functions for extracting values from LangGraph state
dictionaries with proper error handling and sensible defaults.
"""
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
def extract_search_switch(state: dict) -> Optional[str]:
"""
Extract search_switch from state or messages.
"""
search_switch = state.get("search_switch")
if search_switch is not None:
return str(search_switch)
# Try to extract from messages
messages = state.get("messages", [])
if not messages:
return None
# 从最新的消息开始查找
for message in reversed(messages):
# 尝试从 tool_calls 中提取
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
# 从 tool_call 的 args 中提取
if "args" in tool_call and isinstance(tool_call["args"], dict):
search_switch = tool_call["args"].get("search_switch")
if search_switch is not None:
return str(search_switch)
# 直接从 tool_call 中提取
search_switch = tool_call.get("search_switch")
if search_switch is not None:
return str(search_switch)
# 尝试从 content 中提取(如果是 JSON 格式)
if hasattr(message, "content"):
try:
import json
if isinstance(message.content, str):
content_data = json.loads(message.content)
if isinstance(content_data, dict):
search_switch = content_data.get("search_switch")
if search_switch is not None:
return str(search_switch)
except (json.JSONDecodeError, ValueError):
pass
return None
def extract_tool_call_id(message: Any) -> str:
"""
Extract tool call ID from message using structured attributes.
This function extracts the tool call ID from a message object, handling both
direct attribute access and tool_calls list structures.
Args:
message: Message object (typically ToolMessage or AIMessage)
Returns:
Tool call ID as string
Raises:
ValueError: If tool call ID cannot be extracted
Examples:
>>> message = ToolMessage(content="...", tool_call_id="call_123")
>>> extract_tool_call_id(message)
'call_123'
"""
# Try direct attribute access for ToolMessage
if hasattr(message, "tool_call_id"):
tool_call_id = message.tool_call_id
if tool_call_id:
return str(tool_call_id)
# Try extracting from tool_calls list for AIMessage
if hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "id" in tool_call:
return str(tool_call["id"])
# Try extracting from id attribute
if hasattr(message, "id"):
message_id = message.id
if message_id:
return str(message_id)
# If all else fails, raise an error
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
def extract_content_payload(message: Any) -> Any:
"""
Extract content payload from ToolMessage, parsing JSON if needed.
This function extracts the content from a message and attempts to parse it as JSON
if it appears to be a JSON string. It handles various message formats and provides
sensible fallbacks.
Args:
message: Message object (typically ToolMessage)
Returns:
Parsed content (dict, list, or str)
Examples:
>>> message = ToolMessage(content='{"key": "value"}')
>>> extract_content_payload(message)
{'key': 'value'}
>>> message = ToolMessage(content='plain text')
>>> extract_content_payload(message)
'plain text'
"""
# Extract raw content
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "args" in tool_call:
return tool_call["args"]
else:
raw_content = str(message)
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
return json.loads(raw_content)
except (json.JSONDecodeError, ValueError):
pass
# If that fails, try to extract JSON from the string
# This handles cases where the content is embedded in a larger string
import re
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
return json.loads(candidate)
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
return raw_content

View File

@@ -0,0 +1,78 @@
import asyncio
import json
from contextlib import asynccontextmanager
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.utils.llm_tools import WriteState
import warnings
import sys
from langchain_core.messages import AIMessage
from app.core.logging_config import get_agent_logger
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
if config_id:
logger.info(f"使用配置 ID: {config_id}")
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_type_tool or not data_write_tool:
logger.error('不存在数据存储工具', exc_info=True)
raise ValueError('不存在数据存储工具')
# ToolNode
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
result = await data_type_tool.ainvoke({
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
})
result=json.loads( result)
# 调用 Data_write传递 config_id
write_params = {
"content": result["context"],
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id
}
# 如果提供了 config_id添加到参数中
if config_id:
write_params["config_id"] = config_id
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
content = write_result.get("data", str(write_result))
else:
content = str(write_result)
logger.info("写入内容: %s", content)
return {"messages": [AIMessage(content=content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph

View File

@@ -0,0 +1,285 @@
"""
Log Streamer Module
Manages streaming of log file content with file watching and real-time transmission.
"""
import os
import re
import time
import asyncio
from typing import AsyncGenerator, Optional
from pathlib import Path
from app.core.logging_config import get_logger
logger = get_logger(__name__)
class LogStreamer:
"""Manages log file streaming with file watching and content transmission"""
def __init__(self, log_path: str, keepalive_interval: int = 300):
"""
Initialize LogStreamer
Args:
log_path: Path to the log file to stream
keepalive_interval: Interval in seconds for sending keepalive messages (default: 300)
"""
self.log_path = log_path
self.keepalive_interval = keepalive_interval
self.last_position = 0
# Pattern to match and remove timestamp and log level prefix
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
# This pattern is comprehensive to handle various log formats
self.pattern = re.compile(
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
)
logger.info(f"LogStreamer initialized for {log_path}")
@staticmethod
def clean_log_line(line: str) -> str:
"""
Static method to clean log entry by removing timestamp and log level prefix.
This is the canonical log cleaning method used by both file mode and transmission mode.
Args:
line: Raw log line
Returns:
Cleaned log line without timestamp and log level prefix
"""
# Pattern to match and remove timestamp and log level prefix
# Matches: "YYYY-MM-DD HH:MM:SS,mmm - [LEVEL] - module_name - "
pattern = re.compile(
r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3} - \[(?:INFO|DEBUG|WARNING|ERROR|CRITICAL)\] - \S+ - '
)
cleaned = re.sub(pattern, '', line)
return cleaned
def clean_log_entry(self, line: str) -> str:
"""
Clean log entry by removing timestamp and log level prefix.
This instance method delegates to the static method for consistency.
Args:
line: Raw log line
Returns:
Cleaned log line without timestamp and log level prefix
"""
return LogStreamer.clean_log_line(line)
async def send_keepalive(self) -> dict:
"""
Generate keepalive message
Returns:
Keepalive message dict with timestamp
"""
return {
"event": "keepalive",
"data": {
"timestamp": int(time.time())
}
}
async def read_existing_and_stream(self) -> AsyncGenerator[dict, None]:
"""
Read existing log content first, then watch for new content
This method reads all existing content in the file first,
then continues to watch for new content as it's written.
Yields:
Dict messages with event type and data:
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
- done events: {"event": "done", "data": {"message": "..."}}
"""
logger.info(f"Starting log stream (read existing) for {self.log_path}")
# Check if file exists
if not os.path.exists(self.log_path):
logger.error(f"Log file not found: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件不存在",
"error": f"File not found: {self.log_path}"
}
}
return
try:
with open(self.log_path, 'r', encoding='utf-8') as f:
# First, read all existing content
for line in f:
if line.strip(): # Skip empty lines
cleaned_line = self.clean_log_entry(line)
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
# Now watch for new content
self.last_position = f.tell()
last_keepalive = time.time()
while True:
line = f.readline()
if line:
cleaned_line = self.clean_log_entry(line)
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
last_keepalive = time.time()
else:
# No new content, check if we need to send keepalive
current_time = time.time()
if current_time - last_keepalive >= self.keepalive_interval:
keepalive_msg = await self.send_keepalive()
yield keepalive_msg
last_keepalive = current_time
# Sleep briefly before checking again
await asyncio.sleep(0.1)
except FileNotFoundError:
logger.error(f"Log file disappeared during streaming: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件在流式传输期间变得不可用",
"error": "File not found during streaming"
}
}
except Exception as e:
logger.error(f"Error during log streaming: {e}", exc_info=True)
yield {
"event": "error",
"data": {
"code": 8001,
"message": "流式传输期间发生错误",
"error": str(e)
}
}
finally:
logger.info(f"Log stream ended for {self.log_path}")
yield {
"event": "done",
"data": {
"message": "流式传输完成"
}
}
async def watch_and_stream(self) -> AsyncGenerator[dict, None]:
"""
Watch log file and stream only new content as it's written
This method starts from the end of the file and only streams
new content that is written after the stream starts.
Yields:
Dict messages with event type and data:
- log events: {"event": "log", "data": {"content": "...", "timestamp": ...}}
- keepalive events: {"event": "keepalive", "data": {"timestamp": ...}}
- error events: {"event": "error", "data": {"code": ..., "message": "...", "error": "..."}}
- done events: {"event": "done", "data": {"message": "..."}}
"""
logger.info(f"Starting log stream (new content only) for {self.log_path}")
# Check if file exists
if not os.path.exists(self.log_path):
logger.error(f"Log file not found: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件不存在",
"error": f"File not found: {self.log_path}"
}
}
return
try:
# Open file and seek to end to start streaming new content
with open(self.log_path, 'r', encoding='utf-8') as f:
# Move to end of file
f.seek(0, os.SEEK_END)
self.last_position = f.tell()
last_keepalive = time.time()
while True:
# Check if file has new content
current_position = f.tell()
# Read new lines if available
line = f.readline()
if line:
# Clean the log entry
cleaned_line = self.clean_log_entry(line)
# Yield log event
yield {
"event": "log",
"data": {
"content": cleaned_line.rstrip('\n'),
"timestamp": int(time.time())
}
}
# Update last keepalive time since we sent data
last_keepalive = time.time()
else:
# No new content, check if we need to send keepalive
current_time = time.time()
if current_time - last_keepalive >= self.keepalive_interval:
keepalive_msg = await self.send_keepalive()
yield keepalive_msg
last_keepalive = current_time
# Sleep briefly before checking again
await asyncio.sleep(0.1)
except FileNotFoundError:
logger.error(f"Log file disappeared during streaming: {self.log_path}")
yield {
"event": "error",
"data": {
"code": 4006,
"message": "日志文件在流式传输期间变得不可用",
"error": "File not found during streaming"
}
}
except Exception as e:
logger.error(f"Error during log streaming: {e}", exc_info=True)
yield {
"event": "error",
"data": {
"code": 8001,
"message": "流式传输期间发生错误",
"error": str(e)
}
}
finally:
logger.info(f"Log stream ended for {self.log_path}")
yield {
"event": "done",
"data": {
"message": "流式传输完成"
}
}

View File

@@ -0,0 +1,32 @@
"""
Agent logger module for backward compatibility.
This module maintains the get_named_logger() function for backward compatibility
while delegating to the centralized logging configuration.
All new code should import directly from app.core.logging_config instead.
"""
__version__ = "0.1.0"
__author__ = "RED_BEAR"
from app.core.logging_config import get_agent_logger
def get_named_logger(name):
"""Get a named logger for agent operations.
This function maintains backward compatibility with existing code.
It delegates to the centralized get_agent_logger() function.
Args:
name: Logger name for namespacing
Returns:
Logger configured for agent operations
Example:
>>> logger = get_named_logger("my_agent")
>>> logger.info("Agent operation started")
"""
return get_agent_logger(name)

View File

@@ -0,0 +1,28 @@
"""
MCP Server package for memory agent.
This package provides the FastMCP server implementation with context-based
dependency injection for tool functions.
Package structure:
- server: FastMCP server initialization and context setup
- tools: MCP tool implementations
- models: Pydantic response models
- services: Business logic services
"""
from app.core.memory.agent.mcp_server.server import (
mcp,
initialize_context,
main,
get_context_resource
)
# Import tools to register them (but don't export them)
from app.core.memory.agent.mcp_server import tools
__all__ = [
'mcp',
'initialize_context',
'main',
'get_context_resource',
]

View File

@@ -0,0 +1,11 @@
"""
MCP Server Instance
This module contains the FastMCP server instance that is shared across all modules.
It's in a separate file to avoid circular import issues.
"""
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server instance
# This instance is shared across all tool modules
mcp = FastMCP('data_flow')

View File

@@ -0,0 +1,30 @@
"""Pydantic models for MCP server responses."""
from .problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse,
)
from .summary_models import (
SummaryData,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse,
)
from .verification_models import VerificationResult
from .retrieval_models import RetrievalResult, DistinguishTypeResponse
__all__ = [
"ProblemBreakdownItem",
"ProblemBreakdownResponse",
"ExtendedQuestionItem",
"ProblemExtensionResponse",
"SummaryData",
"SummaryResponse",
"RetrieveSummaryData",
"RetrieveSummaryResponse",
"VerificationResult",
"RetrievalResult",
"DistinguishTypeResponse",
]

View File

@@ -0,0 +1,34 @@
"""Pydantic models for problem breakdown and extension operations."""
from typing import List, Optional
from pydantic import BaseModel, Field, RootModel
class ProblemBreakdownItem(BaseModel):
"""Individual item in problem breakdown response."""
id: str
question: str
type: str
reason: Optional[str] = None
class ProblemBreakdownResponse(RootModel[List[ProblemBreakdownItem]]):
"""Response model for problem breakdown containing list of breakdown items."""
pass
class ExtendedQuestionItem(BaseModel):
"""Individual extended question item with reasoning."""
original_question: str = Field(..., description="原始初步问题")
extended_question: str = Field(..., description="扩展后的问题")
type: str = Field(..., description="类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)")
reason: str = Field(..., description="生成该扩展问题的理由")
class ProblemExtensionResponse(RootModel[List[ExtendedQuestionItem]]):
"""Response model for problem extension containing list of extended questions."""
pass

View File

@@ -0,0 +1,17 @@
"""Pydantic models for retrieval operations."""
from typing import List, Dict, Any
from pydantic import BaseModel
class RetrievalResult(BaseModel):
"""Result model for retrieval operation."""
Query: str
Expansion_issue: List[Dict[str, Any]]
class DistinguishTypeResponse(BaseModel):
"""Response model for data type differentiation."""
type: str

View File

@@ -0,0 +1,31 @@
"""Pydantic models for summary operations."""
from typing import List
from pydantic import BaseModel, Field
class SummaryData(BaseModel):
"""Data structure for summary input."""
query: str
history: List[str] = Field(default_factory=list)
retrieve_info: List[str] = Field(default_factory=list)
class SummaryResponse(BaseModel):
"""Response model for summary operation."""
data: SummaryData
query_answer: str
class RetrieveSummaryData(BaseModel):
"""Data structure for retrieve summary response."""
query_answer: str = Field(default="")
class RetrieveSummaryResponse(BaseModel):
"""Response model for retrieve summary operation."""
data: RetrieveSummaryData

View File

@@ -0,0 +1,14 @@
"""Pydantic models for verification operations."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str
expansion_issue: List[Dict[str, Any]]
split_result: str
reason: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)

View File

@@ -0,0 +1,161 @@
"""
MCP Server initialization with FastMCP context setup.
This module initializes the FastMCP server and registers shared resources
in the context for dependency injection into tool functions.
"""
import os
import sys
from mcp.server.fastmcp import FastMCP
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.mcp_instance import mcp
logger = get_agent_logger(__name__)
def get_context_resource(ctx, resource_name: str):
"""
Helper function to retrieve a resource from the FastMCP context.
Args:
ctx: FastMCP Context object (passed to tool functions)
resource_name: Name of the resource to retrieve
Returns:
The requested resource
Raises:
AttributeError: If the resource doesn't exist
Example:
@mcp.tool()
async def my_tool(ctx: Context):
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
"""
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
raise RuntimeError("Context does not have fastmcp attribute")
if not hasattr(ctx.fastmcp, resource_name):
raise AttributeError(
f"Resource '{resource_name}' not found in context. "
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
)
return getattr(ctx.fastmcp, resource_name)
def initialize_context():
"""
Initialize and register shared resources in FastMCP context.
This function sets up all shared resources that will be available
to tool functions via dependency injection through the context parameter.
Resources are stored as attributes on the FastMCP instance and can be
accessed via ctx.fastmcp in tool functions.
Resources registered:
- session_store: RedisSessionStore for session management
- llm_client: LLM client for structured API calls
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
- template_service: Service for template rendering
- search_service: Service for hybrid search
- session_service: Service for session operations
"""
try:
# Register Redis session store
logger.info("Registering session_store in context")
mcp.session_store = store
# Register LLM client
try:
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
llm_client = get_llm_client(SELECTED_LLM_ID)
mcp.llm_client = llm_client
logger.info("llm_client registered successfully")
except Exception as e:
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
# 注册一个 None 值,避免工具调用时找不到资源
mcp.llm_client = None
logger.warning("llm_client set to None due to initialization failure")
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
mcp.app_settings = settings
# Register template service
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
# logger.info(f"Registering template_service in context with root: {template_root}")
template_service = TemplateService(template_root)
mcp.template_service = template_service
# Register search service
# logger.info("Registering search_service in context")
search_service = SearchService()
mcp.search_service = search_service
# Register session service
# logger.info("Registering session_service in context")
session_service = SessionService(store)
mcp.session_service = session_service
# logger.info("All context resources registered successfully")
except Exception as e:
logger.error(f"Failed to initialize context: {e}", exc_info=True)
raise
def main():
"""
Main entry point for the MCP server.
Initializes context and starts the server with SSE transport.
"""
try:
# logger.info("Starting MCP server initialization")
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
# Initialize context resources
initialize_context()
# Import and register tools
# logger.info("Importing MCP tools")
from app.core.memory.agent.mcp_server.tools import (
problem_tools,
retrieval_tools,
verification_tools,
summary_tools,
data_tools
)
# logger.info("All MCP tools imported and registered")
# Log registered tools for debugging
import asyncio
tools_list = asyncio.run(mcp.list_tools())
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:8081 with SSE transport")
# Run the server with SSE transport for HTTP connections
# The server will be available at http://127.0.0.1:8081
import uvicorn
app = mcp.sse_app()
uvicorn.run(app, host=settings.SERVER_IP, port=8081, log_level="info")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,23 @@
"""
MCP Server Services
This module provides business logic services for the MCP server:
- TemplateService: Template loading and rendering
- SearchService: Search result processing
- SessionService: Session and history management
- ParameterBuilder: Tool parameter construction
"""
from .template_service import TemplateService, TemplateRenderError
from .search_service import SearchService
from .session_service import SessionService
from .parameter_builder import ParameterBuilder
__all__ = [
"TemplateService",
"TemplateRenderError",
"SearchService",
"SessionService",
"ParameterBuilder",
]

View File

@@ -0,0 +1,157 @@
"""
Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
import json
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
content: Any,
tool_call_id: str,
search_switch: str,
apply_id: str,
group_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
Different tools expect different argument formats:
- Verify: dict context
- Retrieve: dict context + search_switch
- Summary/Summary_fails: JSON string context
- Retrieve_Summary: unwrap nested context structures
- Input_Summary: raw message string
Args:
tool_name: Name of the tool being invoked
content: Parsed content from previous tool result
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
Returns:
Dictionary of tool arguments ready for invocation
"""
# Base arguments common to most tools
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name == "Verify":
# Verify expects dict context
return {
"context": content if isinstance(content, dict) else {},
**base_args
}
elif tool_name == "Retrieve":
# Retrieve expects dict context + search_switch
return {
"context": content if isinstance(content, dict) else {},
"search_switch": search_switch,
**base_args
}
elif tool_name in ["Summary", "Summary_fails"]:
# Summary tools expect JSON string context
if isinstance(content, dict):
context_str = json.dumps(content, ensure_ascii=False)
elif isinstance(content, str):
context_str = content
else:
context_str = json.dumps({"data": content}, ensure_ascii=False)
return {
"context": context_str,
**base_args
}
elif tool_name == "Retrieve_Summary":
# Retrieve_Summary needs to unwrap nested context structures
# Handle both 'content' and 'context' keys
context_dict = content
if isinstance(content, dict):
# Check for nested 'content' wrapper
if "content" in content:
inner = content["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
except json.JSONDecodeError:
logger.warning(
f"Failed to parse JSON content for {tool_name}: {inner[:100]}"
)
context_dict = {"Query": "", "Expansion_issue": []}
elif isinstance(inner, dict):
context_dict = inner
# Check for 'context' wrapper
elif "context" in content:
context_dict = content["context"] if isinstance(content["context"], dict) else content
return {
"context": context_dict,
**base_args
}
elif tool_name == "Input_Summary":
# Input_Summary expects raw message string + search_switch
# Content should be the raw message string
if isinstance(content, dict):
# Try to extract message from dict
message_str = content.get("sentence", str(content))
else:
message_str = str(content)
return {
"context": message_str,
"search_switch": search_switch,
**base_args
}
else:
# Default: pass content as context
logger.warning(
f"Unknown tool name '{tool_name}', using default argument structure"
)
return {
"context": content,
**base_args
}

View File

@@ -0,0 +1,193 @@
"""
Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import List, Tuple, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
"""
Extract only meaningful content from search results, dropping all metadata.
Extraction rules by node type:
- Statements: extract 'statement' field
- Entities: extract 'name' and 'fact_summary' fields
- Summaries: extract 'content' field
- Chunks: extract 'content' field
Args:
result: Search result dictionary
Returns:
Clean content string without metadata
"""
if not isinstance(result, dict):
return str(result)
content_parts = []
# Statements: extract statement field
if 'statement' in result and result['statement']:
content_parts.append(result['statement'])
# Summaries/Chunks: extract content field
if 'content' in result and result['content']:
content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original)
# if 'name' in result and result['name']:
# content_parts.append(result['name'])
# if result.get('fact_summary'):
# content_parts.append(result['fact_summary'])
# Return concatenated content or empty string
return '\n'.join(content_parts) if content_parts else ""
def clean_query(self, query: str) -> str:
"""
Clean and escape query text for Lucene.
- Removes wrapping quotes
- Removes newlines and carriage returns
- Applies Lucene escaping
Args:
query: Raw query string
Returns:
Cleaned and escaped query string
"""
q = str(query).strip()
# Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"')
):
q = q[1:-1]
# Remove newlines and carriage returns
q = q.replace('\r', ' ').replace('\n', ' ').strip()
# Apply Lucene escaping
q = escape_lucene_query(q)
return q
async def execute_hybrid_search(
self,
group_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
Args:
group_id: Group identifier for filtering results
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
group_id=group_id,
limit=limit,
include=include,
output_path=output_path,
rerank_alpha=rerank_alpha
)
# Extract results based on search type and include parameter
# Prioritize summaries as they contain synthesized contextual information
answer_list = []
# For hybrid search, use reranked_results
if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in reranked_results:
category_results = reranked_results[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
else:
# For keyword or embedding search, results are directly in answer dict
# Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in answer:
category_results = answer[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
# Extract clean content from all results
content_list = [
self.extract_content_from_result(ans)
for ans in answer_list
]
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])
# Log first 200 chars
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
# Return raw results if requested
if return_raw_results:
return clean_content, cleaned_query, answer
else:
return clean_content, cleaned_query, None
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}",
exc_info=True
)
# Return empty results on failure
if return_raw_results:
return "", cleaned_query, {}
else:
return "", cleaned_query, None

View File

@@ -0,0 +1,169 @@
"""
Session Service for managing user sessions and conversation history.
This service provides clean Redis interactions with error handling and
session management utilities.
"""
from typing import List, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
logger = get_agent_logger(__name__)
class SessionService:
"""Service for managing user sessions and conversation history."""
def __init__(self, store: RedisSessionStore):
"""
Initialize the session service.
Args:
store: Redis session store instance
"""
self.store = store
logger.info("SessionService initialized")
def resolve_user_id(self, session_string: str) -> str:
"""
Extract user ID from session string.
Handles formats like:
- 'call_id_user123' -> 'user123'
- 'prefix_id_user456_suffix' -> 'user456_suffix'
Args:
session_string: Session identifier string
Returns:
Extracted user ID
"""
try:
# Split by '_id_' and take everything after it
parts = session_string.split('_id_')
if len(parts) > 1:
return parts[1]
# Fallback: return original string
return session_string
except Exception as e:
logger.warning(
f"Failed to parse user ID from session string '{session_string}': {e}"
)
return session_string
async def get_history(
self,
user_id: str,
apply_id: str,
group_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
Args:
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
)
return []
return history
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
return []
async def save_session(
self,
user_id: str,
query: str,
apply_id: str,
group_id: str,
ai_response: str
) -> Optional[str]:
"""
Save conversation turn to Redis.
Args:
user_id: User identifier
query: User query/message
apply_id: Application identifier
group_id: Group identifier
ai_response: AI response/answer
Returns:
Session ID if successful, None on error
"""
try:
# Validate required fields
if not user_id:
logger.warning("Cannot save session: user_id is empty")
return None
if not query:
logger.warning("Cannot save session: query is empty")
return None
# Save session
session_id = self.store.save_session(
userid=user_id,
messages=query,
apply_id=apply_id,
group_id=group_id,
aimessages=ai_response
)
logger.info(f"Session saved successfully: {session_id}")
return session_id
except Exception as e:
logger.error(
f"Failed to save session for user {user_id}: {e}",
exc_info=True
)
return None
async def cleanup_duplicates(self) -> int:
"""
Remove duplicate session entries.
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- group_id
- messages
- aimessages
Returns:
Number of duplicate sessions deleted
"""
try:
deleted_count = self.store.delete_duplicate_sessions()
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
return deleted_count
except Exception as e:
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
return 0

View File

@@ -0,0 +1,116 @@
"""
Template Service for loading and rendering Jinja2 templates.
This service provides centralized template management with caching and error handling.
"""
import os
from functools import lru_cache
from typing import Optional
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
from app.core.logging_config import get_agent_logger, log_prompt_rendering
logger = get_agent_logger(__name__)
class TemplateRenderError(Exception):
"""Exception raised when template rendering fails."""
def __init__(self, template_name: str, error: Exception, variables: dict):
self.template_name = template_name
self.error = error
self.variables = variables
super().__init__(
f"Failed to render template '{template_name}': {str(error)}"
)
class TemplateService:
"""Service for loading and rendering Jinja2 templates with caching."""
def __init__(self, template_root: str):
"""
Initialize the template service.
Args:
template_root: Root directory containing template files
"""
self.template_root = template_root
self.env = Environment(
loader=FileSystemLoader(template_root),
autoescape=False # Disable autoescape for prompt templates
)
logger.info(f"TemplateService initialized with root: {template_root}")
@lru_cache(maxsize=128)
def _load_template(self, template_name: str) -> Template:
"""
Load a template from disk with caching.
Args:
template_name: Relative path to template file
Returns:
Loaded Jinja2 Template object
Raises:
TemplateNotFound: If template file doesn't exist
"""
try:
return self.env.get_template(template_name)
except TemplateNotFound as e:
expected_path = os.path.join(self.template_root, template_name)
logger.error(
f"Template not found: {template_name}. "
f"Expected path: {expected_path}"
)
raise
async def render_template(
self,
template_name: str,
operation_name: str,
**variables
) -> str:
"""
Load and render a Jinja2 template.
Args:
template_name: Relative path to template file
operation_name: Name for logging (e.g., "split_the_problem")
**variables: Template variables to render
Returns:
Rendered template string
Raises:
TemplateRenderError: If template loading or rendering fails
"""
try:
# Load template (cached)
template = self._load_template(template_name)
# Render template
rendered = template.render(**variables)
# Log rendered prompt
log_prompt_rendering(operation_name, rendered)
return rendered
except TemplateNotFound as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): Template not found",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)
except Exception as e:
logger.error(
f"Template rendering failed for {operation_name} "
f"({template_name}): {e}",
exc_info=True
)
raise TemplateRenderError(template_name, e, variables)

View File

@@ -0,0 +1,27 @@
"""
MCP Tools module.
This module contains all MCP tool implementations organized by functionality.
Tools are organized into the following modules:
- problem_tools: Question segmentation and extension
- retrieval_tools: Database and context retrieval
- verification_tools: Data verification
- summary_tools: Summarization and summary retrieval
- data_tools: Data type differentiation and writing
"""
# Import all tool modules to register them with the MCP server
from . import problem_tools
from . import retrieval_tools
from . import verification_tools
from . import summary_tools
from . import data_tools
__all__ = [
'problem_tools',
'retrieval_tools',
'verification_tools',
'summary_tools',
'data_tools',
]

View File

@@ -0,0 +1,149 @@
"""
Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
from app.core.memory.agent.utils.write_tools import write
logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str
) -> dict:
"""
Distinguish the type of data (read or write).
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
Returns:
dict: Contains 'context' with the original text and 'type' field
"""
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Render template
try:
system_prompt = await template_service.render_template(
template_name='distinguish_types_prompt.jinja2',
operation_name='status_typle',
user_query=context
)
except Exception as e:
logger.error(
f"Template rendering failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
result = structured.model_dump()
# Add context to result
result["context"] = context
return result
except Exception as e:
logger.error(
f"LLM call failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": f"LLM call failed: {str(e)}"
}
except Exception as e:
logger.error(
f"Data_type_differentiation failed: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": str(e)
}
@mcp.tool()
async def Data_write(
ctx: Context,
content: str,
user_id: str,
apply_id: str,
group_id: str,
config_id: str
) -> dict:
"""
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
config_id: Configuration ID for processing (optional, integer)
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
try:
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data using utility function
try:
await write(content, user_id, apply_id, group_id, config_id=config_id)
logger.info(f"写入成功Config ID: {config_id if config_id else 'None'}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": config_id
}
except Exception as e:
logger.error(f"写入失败: {e}", exc_info=True)
return {
"status": "error",
"message": str(e)
}
except Exception as e:
logger.error(
f"Data_write failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}

View File

@@ -0,0 +1,293 @@
"""
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
"""
import json
import time
from typing import List
from pydantic import BaseModel, Field, RootModel
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse
)
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
logger = get_agent_logger(__name__)
@mcp.tool()
async def Split_The_Problem(
ctx: Context,
sentence: str,
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
Args:
ctx: FastMCP context for dependency injection
sentence: Original sentence to split
sessionid: Session identifier
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
# Get conversation history
history = await session_service.get_history(user_id, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Render template
try:
system_prompt = await template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=sentence
)
except Exception as e:
logger.error(
f"Template rendering failed for Split_The_Problem: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemBreakdownResponse
)
# Handle RootModel response with .root attribute access
if structured is None:
# LLM returned None, use empty list as fallback
split_result = json.dumps([], ensure_ascii=False)
elif hasattr(structured, 'root') and structured.root is not None:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
elif isinstance(structured, list):
# Fallback: treat structured itself as the list
split_result = json.dumps(
[item.model_dump() for item in structured],
ensure_ascii=False
)
else:
# Last resort: use empty list
split_result = json.dumps([], ensure_ascii=False)
except Exception as e:
logger.error(
f"LLM call failed for Split_The_Problem: {e}",
exc_info=True
)
split_result = json.dumps([], ensure_ascii=False)
logger.info(f"问题拆分")
logger.info(f"问题拆分结果==>>:{split_result}")
# Emit intermediate output for frontend
result = {
"context": split_result,
"original": sentence,
"_intermediate": {
"type": "problem_split",
"data": json.loads(split_result) if split_result else [],
"original_query": sentence
}
}
return result
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('问题拆分', duration)
@mcp.tool()
async def Problem_Extension(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Extend the problem with additional sub-questions.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing split problem results
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Process context to extract questions
extent_quest, original = await Problem_Extension_messages_deal(context)
# Format questions for template rendering
questions_formatted = []
for msg in extent_quest:
if msg.get("role") == "user":
questions_formatted.append(msg.get("content", ""))
# Render template
try:
system_prompt = await template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=questions_formatted
)
except Exception as e:
logger.error(
f"Template rendering failed for Problem_Extension: {e}",
exc_info=True
)
return {
"context": {},
"original": original,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
response_content = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemExtensionResponse
)
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
aggregated_dict = {}
logger.info(f"问题扩展")
logger.info(f"问题扩展==>>:{aggregated_dict}")
# Emit intermediate output for frontend
result = {
"context": aggregated_dict,
"original": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"data": aggregated_dict,
"original_query": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return result
except Exception as e:
logger.error(
f"Problem_Extension failed: {e}",
exc_info=True
)
return {
"context": {},
"original": context.get("original", ""),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('问题扩展', duration)

View File

@@ -0,0 +1,282 @@
"""
Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
from dotenv import load_dotenv
import os
from app.core.rag.nlp.search import knowledge_retrieval
# 加载.env文件
load_dotenv()
import time
from typing import List
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
logger = get_agent_logger(__name__)
@mcp.tool()
async def Retrieve(
ctx: Context,
context,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Retrieve data from the database using hybrid search.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary or string containing query information
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'context' with Query and Expansion_issue results
"""
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# Extract services from context
search_service = get_context_resource(ctx, 'search_service')
databases_anser = []
# Handle both dict and string context
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
content, original = await Retriev_messages_deal(context)
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
for key, values in content.items():
if isinstance(values, list):
all_items.extend(values)
elif isinstance(values, str):
all_items.append(values)
elif values is not None:
# Fallback: convert non-empty non-list values to string
all_items.append(str(values))
# Execute search for each question
for idx, question in enumerate(all_items):
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": question,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query=question
raw_results=clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
databases_anser.append({
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(all_items)
}
})
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Continue with empty result for this question
databases_anser.append({
"Query_small": question,
"Result_small": ""
})
# Build initial database data structure
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val):
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
else:
# Handle string context (simple query)
query = str(context).strip()
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = query
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
"Expansion_issue": [{
"Query_small": cleaned_query,
"Answer_Small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": 1,
"total": 1
}
}]
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for query '{query}': {e}",
exc_info=True
)
# Return empty results on failure
dup_databases = {
"Query": query,
"Expansion_issue": []
}
logger.info(
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
# Build result with intermediate outputs
result = {
"context": dup_databases,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
# Add intermediate outputs list if they exist
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
if intermediate_outputs:
result['_intermediates'] = intermediate_outputs
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
else:
logger.warning("No intermediate outputs found in dup_databases")
return result
except Exception as e:
logger.error(
f"Retrieve failed: {e}",
exc_info=True
)
return {
"context": {
"Query": "",
"Expansion_issue": []
},
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)

View File

@@ -0,0 +1,647 @@
"""
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
"""
import json
import re
import time
from typing import List
from pydantic import BaseModel, Field
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.summary_models import (
SummaryData,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse
)
from app.core.memory.agent.utils.messages_tool import (
Summary_messages_deal,
Resolve_username
)
from app.core.rag.nlp.search import knowledge_retrieval
from dotenv import load_dotenv
import os
# 加载.env文件
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Summary(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Summarize the verified data.
Args:
ctx: FastMCP context for dependency injection
context: JSON string containing verified data
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Process context to extract answer and query
answer_small, query = await Summary_messages_deal(context)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
# Prepare data for template
data = {
"query": query,
"history": history,
"retrieve_info": answer_small
}
except Exception as e:
logger.error(
f"Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='summary_prompt.jinja2',
operation_name='summary',
data=data,
query=query
)
except Exception as e:
logger.error(
f"Template rendering failed for Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=SummaryResponse
)
aimessages = structured.query_answer or ""
except Exception as e:
logger.error(
f"LLM call failed for Summary: {e}",
exc_info=True
)
aimessages = ""
try:
# Save session
if aimessages != "":
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"验证之后的总结==>>:{aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('总结', duration)
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
@mcp.tool()
async def Retrieve_Summary(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Summarize data directly from retrieval results.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing Query and Expansion_issue from Retrieve
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
logger.info(f"Retrieve_Summary: successfully parsed JSON")
except json.JSONDecodeError:
# Try unescaping first
try:
unescaped = inner.encode('utf-8').decode('unicode_escape')
parsed = json.loads(unescaped)
logger.info(f"Retrieve_Summary: parsed after unescaping")
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error(
f"Retrieve_Summary: parsing failed even after unescape: {e}"
)
context_dict = {"Query": "", "Expansion_issue": []}
parsed = None
if parsed:
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
elif isinstance(inner, dict):
context_dict = inner
else:
context_dict = {"Query": "", "Expansion_issue": []}
elif "context" in context:
context_dict = context["context"] if isinstance(context["context"], dict) else context
else:
context_dict = context
else:
context_dict = {"Query": "", "Expansion_issue": []}
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
elif "Answer_Samll" in item:
answer = item["Answer_Samll"]
if answer is not None:
# Handle both string and list formats
if isinstance(answer, list):
# Join list of characters/strings into a single string
retrieve_info.append(''.join(str(x) for x in answer))
elif isinstance(answer, str):
retrieve_info.append(answer)
else:
retrieve_info.append(str(answer))
# Join all retrieve_info into a single string
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
except Exception as e:
logger.error(
f"Retrieve_Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
query=query,
history=history,
retrieve_info=retrieve_info_str
)
except Exception as e:
logger.error(
f"Template rendering failed for Retrieve_Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
# Handle case where structured response might be None or incomplete
if structured and hasattr(structured, 'data') and structured.data:
aimessages = structured.data.query_answer or ""
else:
logger.warning("Structured response is None or incomplete, using default message")
aimessages = "信息不足,无法回答"
# Check for insufficient information response
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
# Save session
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"Retrieve_Summary: LLM call failed: {e}",
exc_info=True
)
aimessages = ""
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"检索之后的总结==>>:{aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索总结', duration)
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"summary": aimessages,
"query": query,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
@mcp.tool()
async def Input_Summary(
ctx: Context,
context: str,
usermessages: str,
search_switch: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Generate a quick summary for direct input without verification.
Args:
ctx: FastMCP context for dependency injection
context: String containing the input sentence
usermessages: User messages identifier
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'query_answer' with the summary result
"""
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# Initialize variables to avoid UnboundLocalError
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
search_service = get_context_resource(ctx, 'search_service')
# Check if llm_client is None
if llm_client is None:
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
logger.error(error_msg)
return error_msg
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
# Get conversation history
history = await session_service.get_history(
str(sessionid),
str(apply_id),
str(group_id)
)
# Override with empty list for now (as in original)
# Log the raw context for debugging
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
# Extract sentence from context
# Context can be a string or might contain the sentence in various formats
try:
# Try to parse as JSON first
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
try:
import json
context_dict = json.loads(context)
if isinstance(context_dict, dict):
query = context_dict.get('sentence', context_dict.get('content', context))
else:
query = context
except json.JSONDecodeError:
# Not valid JSON, try regex
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
query = match.group(1) if match else context
else:
query = context
except Exception as e:
logger.warning(f"Failed to extract query from context: {e}")
query = context
# Clean query
query = str(query).strip().strip("\"'")
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
# Execute search based on search_switch and storage_type
try:
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
'''检索'''
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
raw_results = []
retrieve_info = ""
kb_config={
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id":os.getenv('reranker_id'),
"reranker_top_k": 10
}
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
retrieve_info = '\n\n'.join(retrieval_knowledge)
raw_results=[retrieve_info]
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
except:
retrieve_info=''
raw_results=['']
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
logger.info(f"Input_Summary: 使用 summary 进行检索")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
except Exception as e:
logger.error(
f"Input_Summary: hybrid_search failed, using empty results: {e}",
exc_info=True
)
retrieve_info, question, raw_results = "", query, []
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='input_summary',
query=query,
history=history,
retrieve_info=retrieve_info
)
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
aimessages = structured.data.query_answer or "信息不足,无法回答"
except Exception as e:
logger.error(
f"Input_Summary: response_structured failed, using default answer: {e}",
exc_info=True
)
aimessages = "信息不足,无法回答"
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": query,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Input_Summary failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
@mcp.tool()
async def Summary_fails(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Handle workflow failure when summary cannot be generated.
Args:
ctx: FastMCP context for dependency injection
context: Failure context string
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'query_answer' with failure message
"""
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Parse session ID from usermessages
usermessages_parts = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages_parts[:-1])
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
logger.info(f"没有相关数据")
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
return {
"status": "success",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
except Exception as e:
logger.error(
f"Summary_fails failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}

View File

@@ -0,0 +1,169 @@
"""
Verification Tools for data verification.
This module contains MCP tools for verifying retrieved data.
"""
import time
from jinja2 import Template
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.core.memory.agent.utils.messages_tool import (
Verify_messages_deal,
Retrieve_verify_tool_messages_deal,
Resolve_username
)
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
logger = get_agent_logger(__name__)
@mcp.tool()
async def Verify(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Verify the retrieved data.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing query and expansion issues
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'verified_data' with verification results
"""
start = time.time()
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Load verification prompt template
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
# Read template file directly (VerifyTool expects raw template content)
from app.core.memory.agent.utils.messages_tool import read_template_file
system_prompt = await read_template_file(file_path)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
template = Template(system_prompt)
system_prompt = template.render(history=history, sentence=context)
# Process context to extract query and results
Query_small, Result_small, query = await Verify_messages_deal(context)
# Build query list for verification
query_list = []
for query_small, anser in zip(Query_small, Result_small):
query_list.append({
'Query_small': query_small,
'Answer_Small': anser
})
messages = {
"Query": query,
"Expansion_issue": query_list
}
# Call verification workflow
verify_tool = VerifyTool(system_prompt, messages)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
try:
messages_deal = await Retrieve_verify_tool_messages_deal(
verify_result,
history,
query
)
except Exception as e:
logger.error(
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
exc_info=True
)
# Fallback to avoid 500 errors
messages_deal = {
"data": {
"query": query,
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": history,
}
logger.info(f"验证==>>:{messages_deal}")
# Emit intermediate output for frontend
return {
"status": "success",
"verified_data": messages_deal,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "数据验证",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
"verified_count": len(query_list),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Verify failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"verified_data": {
"data": {
"query": "",
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": [],
}
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('验证', duration)

View File

@@ -0,0 +1,7 @@
"""Agent utilities."""
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
__all__ = [
"MultimodalProcessor",
]

View File

@@ -0,0 +1,70 @@
import os
import json
from typing import List
from datetime import datetime
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
content: str = "这是用户的输入",
ref_id: str = "wyl_20251027",
config_id: str = None
) -> List[DialogData]:
"""Generate chunks from all test data entries using the specified chunker strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
content: Dialog content
ref_id: Reference identifier
config_id: Configuration ID for processing
Returns:
List of DialogData objects with generated chunks for each test entry
"""
dialog_data_list = []
messages = []
messages.append(ConversationMessage(role="用户", msg=content))
# Create DialogData
conversation_context = ConversationContext(msgs=messages)
# Create DialogData with group_id based on the entry's id for uniqueness
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
config_id=config_id
)
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
dialog_data_list.append(dialog_data)
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list

View File

@@ -0,0 +1,204 @@
import asyncio
import json
from collections import defaultdict
from typing import TypedDict, Annotated
import os
import logging
from jinja2 import Template
from langchain_core.messages import AnyMessage
from dotenv import load_dotenv
from langgraph.graph import add_messages
from openai import OpenAI
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
from app.core.models.base import RedBearModelConfig
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
load_dotenv()
#TODO: Refactor entire picture/voice
# async def LLM_model_request(context,data,query):
# '''
# Agent model request
# Args:
# context:Input request
# data: template parameters
# query:request content
# Returns:
# '''
# template = Template(context)
# system_prompt = template.render(**data)
# llm_client = get_llm_client(SELECTED_LLM_ID)
# result = await llm_client.chat(
# messages=[{"role": "system", "content": system_prompt}] + [{"role": "user", "content": query}]
# )
# return result
async def picture_model_requests(image_url):
'''
Args:
image_url:
Returns:
'''
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
system_prompt = await read_template_file(file_path)
result = await Picture_recognize(image_url,system_prompt)
return (result)
class WriteState(TypedDict):
'''
Langgrapg Writing TypedDict
'''
messages: Annotated[list[AnyMessage], add_messages]
user_id:str
apply_id:str
group_id:str
class ReadState(TypedDict):
'''
Langgrapg READING TypedDict
name:
id:user id
loop_count:Traverse times
search_switchtype
config_id: configuration id for filtering results
'''
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
name: str
id: str
loop_count:int
search_switch: str
user_id: str
apply_id: str
group_id: str
config_id: str
class COUNTState:
'''
The number of times the workflow dialogue retrieval content has no correct message recall traversal
'''
def __init__(self, limit: int = 5):
self.total: int = 0 # 当前累加值
self.limit: int = limit # 最大上限
def add(self, value: int = 1):
"""累加数字,如果达到上限就保持最大值"""
self.total += value
print(f"[COUNTState] 当前值: {self.total}")
if self.total >= self.limit:
print(f"[COUNTState] 达到上限 {self.limit}")
self.total = self.limit # 达到上限不再增加
def get_total(self) -> int:
"""获取当前累加值"""
return self.total
def reset(self):
"""手动重置累加值"""
self.total = 0
print(f"[COUNTState] 已重置为 0")
# def embed(texts: list[str]) -> list[list[float]]:
# # 这里可以换成 LangChain Embeddings
# return [[float(len(t) % 5), float(len(t) % 3)] for t in texts]
# def export_store_to_json(store, namespace):
# """Export the entire storage content to a JSON file"""
# # 搜索所有存储项
# all_items = store.search(namespace)
# # 整理数据
# export_data = {}
# for item in all_items:
# if hasattr(item, 'key') and hasattr(item, 'value'):
# export_data[item.key] = item.value
# # 保存到文件
# os.makedirs("memory_logs", exist_ok=True)
# with open("memory_logs/full_memory_export.json", "w", encoding="utf-8") as f:
# json.dump(export_data, f, ensure_ascii=False, indent=2)
# print(f"{len(export_data)} 条记忆到 JSON 文件")
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
grouped[item[query_key]].append(item[result_key])
return [{key: values} for key, values in grouped.items()]
def deduplicate_entries(entries):
seen = set()
deduped = []
for entry in entries:
key = (entry['Query_small'], entry['Result_small'])
if key not in seen:
seen.add(key)
deduped.append(entry)
return deduped
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
try:
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base=model_config['api_base']
logger.info(f"model_name: {backend_model_name}")
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
logger.info(f"base_url: {model_config['api_base']}")
client = OpenAI(
api_key=api_key, base_url=api_base,
)
completion = client.chat.completions.create(
model=backend_model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url":image_path,
},
{"type": "text",
"text": PROMPT_TICKET_EXTRACTION}
]
}
])
picture_text = completion.choices[0].message.content
picture_text = picture_text.replace('```json', '').replace('```', '')
picture_text = json.loads(picture_text)
return (picture_text['statement'])
async def Voice_recognize():
try:
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base = model_config['api_base']
return api_key,backend_model_name,api_base

View File

@@ -0,0 +1,15 @@
from app.core.config import settings
def get_mcp_server_config():
"""
Get the MCP server configuration
"""
mcp_server_config = {
"data_flow": {
"url": f"http://{settings.SERVER_IP}:8081/sse", # 你前面的 FastMCP(weather) 服务端口
"transport": "sse",
"timeout": 15000,
"sse_read_timeout": 15000,
}
}
return mcp_server_config

View File

@@ -0,0 +1,239 @@
import json
import logging
import re
from typing import List, Any
from langchain_core.messages import AnyMessage
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
out = []
for m in msgs:
if hasattr(m, "content"):
out.append({"role": "user", "content": getattr(m, "content", "")})
elif isinstance(m, dict) and "role" in m and "content" in m:
out.append(m)
else:
out.append({"role": "user", "content": str(m)})
return out
def _extract_content(resp: Any) -> str:
"""Extract LLM content and sanitize to raw JSON/text.
- Supports both object and dict response shapes.
- Removes leading role labels (e.g., "Assistant:").
- Strips Markdown code fences like ```json ... ```.
- Attempts to isolate the first valid JSON array/object block when extra text is present.
"""
def _to_text(r: Any) -> str:
try:
# 对象形式: resp.choices[0].message.content
if hasattr(r, "choices") and getattr(r, "choices", None):
msg = r.choices[0].message
if hasattr(msg, "content"):
return msg.content
if isinstance(msg, dict) and "content" in msg:
return msg["content"]
# 字典形式: resp["choices"][0]["message"]["content"]
if isinstance(r, dict):
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
except Exception:
pass
return str(r)
def _clean_text(text: str) -> str:
s = str(text).strip()
# 移除可能的角色前缀
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
# 提取 ```json ... ``` 代码块
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
if m:
s = m.group(1).strip()
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
if not (s.startswith("{") or s.startswith("[")):
left = s.find("[")
right = s.rfind("]")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
else:
left = s.find("{")
right = s.rfind("}")
if left != -1 and right != -1 and right > left:
s = s[left:right + 1].strip()
return s
raw = _to_text(resp)
return _clean_text(raw)
def Resolve_username(usermessages):
'''
Extract username
Args:
usermessages: user name
Returns:
'''
usermessages = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages[:-1])
return sessionid
# TODO: USE app.core.memory.src.utils.render_template instead
async def read_template_file(template_path: str) -> str:
"""
读取模板文件
Args:
template_path: 模板文件路径
Returns:
模板内容字符串
Note:
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
"""
try:
with open(template_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
logger.error(f"模板文件未找到: {template_path}")
raise
except IOError as e:
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
raise
async def Problem_Extension_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
extent_quest = []
original = context.get('original', '')
messages = context.get('context', '')
messages = json.loads(messages)
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
return extent_quest, original
async def Retriev_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
if isinstance(context, dict):
if 'context' in context or 'original' in context:
return context.get('context', {}), context.get('original', '')
return content, original_value
async def Verify_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
query = context['context']['Query']
Query_small_list = context['context']['Expansion_issue']
Result_small = []
Query_small = []
for i in Query_small_list:
Result_small.append(i['Answer_Small'][0])
Query_small.append(i['Query_small'])
return Query_small, Result_small, query
async def Summary_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
query = re.findall(r'"query": (.*?),', messages)[0]
query = query.replace('[', '').replace(']', '').strip()
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
answer_small_texts = []
for m in matches:
try:
parsed = json.loads(m)
for item in parsed:
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
except Exception:
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
return answer_small_texts, query
async def VerifyTool_messages_deal(context):
'''
Extract data
Args:
context:
Returns:
'''
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
content_messages = messages.split('"context":')[1].replace('""', '"')
messages = str(content_messages).split("name='Retrieve'")[0]
query = re.findall(f'"Query": "(.*?)"', messages)[0]
Query_small = re.findall(f'"Query_small": "(.*?)"', messages)
Result_small = re.findall(f'"Result_small": "(.*?)"', messages)
return Query_small, Result_small, query
async def Retrieve_Summary_messages_deal(context):
pass
async def Retrieve_verify_tool_messages_deal(context, history, query):
'''
Extract data
Args:
context:
Returns:
'''
results = []
# 统一转为字符串,避免 None 或非字符串导致正则报错
text = str(context)
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
for block in blocks:
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
results.append({
"query_small": query_small.group(1) if query_small else None,
"answer_small": answer_small.group(1) if answer_small else None,
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
"status": status.group(1) if status else "",
"query_answer": query_answer.group(1) if query_answer else None
})
result = []
for r in results:
# 统一按字符串判定状态,兼容大小写和缺失情况
status_str = str(r.get('status', '')).strip().lower()
if status_str == 'false':
continue
else:
result.append(r)
split_result = 'failed' if not result else 'success'
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
"history": history}
return result

View File

@@ -0,0 +1,38 @@
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# sys.path.insert(0, project_root)
# load_dotenv()
# async def llm_client_chat(messages: List[dict]) -> str:
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
# try:
# cfg = get_model_config(SELECTED_LLM_ID)
# rb_config = RedBearModelConfig(
# model_name=cfg["model_name"],
# provider=cfg["provider"],
# api_key=cfg["api_key"],
# base_url=cfg["base_url"],
# )
# client = OpenAIClient(model_config=rb_config, type_="chat")
# except Exception as e:
# logger.error(f"获取模型配置失败:{e}")
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
# return err
# try:
# response = await client.chat(messages)
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
# return _extract_content(response)
# # return _extract_content(result)
# except Exception as e:
# logger.error(f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
# return f"LLM调用失败{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
# async def main(image_url):
# await llm_client_chat(image_url)
#
# # 运行主函数
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
#

View File

@@ -0,0 +1,131 @@
"""
Multimodal input processor for handling image and audio content.
This module provides utilities for detecting and processing multimodal inputs
(images and audio files) by converting them to text using appropriate models.
"""
import logging
from typing import List
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
from app.core.memory.agent.utils.llm_tools import picture_model_requests
logger = logging.getLogger(__name__)
class MultimodalProcessor:
"""
Processor for handling multimodal inputs (images and audio).
This class detects image and audio file paths in input content and converts
them to text using appropriate recognition models.
"""
# Supported file extensions
IMAGE_EXTENSIONS = ['.jpg', '.png']
AUDIO_EXTENSIONS = [
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
]
def __init__(self):
"""Initialize the multimodal processor."""
pass
def is_image(self, content: str) -> bool:
"""
Check if content is an image file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported image extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_image("photo.jpg")
True
>>> processor.is_image("document.pdf")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
def is_audio(self, content: str) -> bool:
"""
Check if content is an audio file path.
Args:
content: Input string to check
Returns:
True if content ends with a supported audio extension
Examples:
>>> processor = MultimodalProcessor()
>>> processor.is_audio("recording.mp3")
True
>>> processor.is_audio("video.mp4")
True
>>> processor.is_audio("document.txt")
False
"""
if not isinstance(content, str):
return False
content_lower = content.lower()
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
async def process_input(self, content: str) -> str:
"""
Process input content, converting images/audio to text if needed.
This method detects if the input is an image or audio file and converts
it to text using the appropriate recognition model. If processing fails
or the content is not multimodal, it returns the original content.
Args:
content: Input string (may be file path or regular text)
Returns:
Text content (original or converted from image/audio)
Examples:
>>> processor = MultimodalProcessor()
>>> await processor.process_input("photo.jpg")
"Recognized text from image..."
>>> await processor.process_input("Hello world")
"Hello world"
"""
if not isinstance(content, str):
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
return str(content)
try:
# Check for image input
if self.is_image(content):
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
result = await picture_model_requests(content)
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
return result
# Check for audio input
if self.is_audio(content):
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
result = await Vico_recognition([content]).run()
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
return result
except Exception as e:
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
logger.info(f"[MultimodalProcessor] Falling back to original content")
return content
# Return original content if not multimodal
return content

View File

@@ -0,0 +1,81 @@
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
角色:
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
- 如果历史信息或上下文与当前问题无关,可忽略。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{% if questions is string %}
{{ questions }}
{% else %}
{% for question in questions %}
- {{ question }}
{% endfor %}
{% endif %}
需求:
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
- 子问题数量不超过4个。
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
输出要求:
- 仅输出 JSON 数组,不要包含任何解释或代码块。
- 每个元素包含:
- `original_question`: 原始问题
- `extended_question`: 扩展后的问题
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
- `reason`: 生成该扩展问题的简短理由
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
示例:
输入:
[
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
]
输出:
[
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
"type": "多跳",
"reason": "输出原问题的关键要素"
},
{
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
"type": "多跳",
"reason": "输出原问题的关键要素"
}
]
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,37 @@
# 角色
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
# 任务
根据提供的上下文信息回答用户的问题。
# 输入信息
- 历史对话:{{history}}
- 检索信息:{{retrieve_info}}
## User Query
{{query}}
# 回答指南
1. 仔细分析用户的问题
2. 优先使用检索信息中的相关内容回答
3. 结合历史对话提供连贯的回复
4. 如果信息不足:
- 对于简单问候或日常对话,给出自然简短的回复
- 对于复杂问题,诚实说明信息不足
5. 保持回答简洁、相关、自然
6. 使用与问题相同的语言回答
**Output format**
- 直接回答问题,像人类对话一样自然流畅
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
- 不要解释推理过程或评论信息来源
- 如果只能部分回答问题,先回答能回答的部分,然后说明哪些方面信息不足
- 如果完全无法回答,简洁地说明:"信息不足,无法回答。"
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,29 @@
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
你是一个智能问答助手,任务如下
## 目标:
1. 接收一个字典,格式为 {'问题': [答案列表]}。
2. 接收一个问题(字典中的 key
3. 找到与问题匹配的答案列表。
4. 将答案列表合并成一句自然流畅的话:
- 如果答案有两条使用“是”连接例如“A是B”。
- 如果答案有三条或以上,使用“,并且”“另外”等自然连词,保证句子流畅。
5. 输出内容时只输出合并后的答案,不输出关键点或其他文字。
6. 如果问题未在字典中找到对应答案,请输出:
对不起,我没有找到相关信息。
输出要求:
- 文本形式
---
字典示例:
{
'今天的天气怎么样': ['今天天气很好', '今天是晴天']
}
问题示例:
今天的天气怎么样
输出要求:
今天天气很好,是晴天

View File

@@ -0,0 +1,10 @@
请提图像内的文本
返回数据格式以json方式输出,
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求{"statement":识别出的文本内容}
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,34 @@
你是一个输入分类助手,负责判断用户输入的意图类型。
## User Input
{{ user_query }}
请你根据以下规则判断:
1. 如果输入是在寻求信息、提问、请求解释、或疑问句(包括隐含的问题),则分类为 "question"。
2. 如果输入是命令、陈述、描述、感叹、或其他类型,不在寻求答案,则分类为 "other"。
只输出:
{
"type": "question"
}
{
"type": "other"
}
示例:
输入:"Python怎么读取文件"
输出:{"type": "question"}
输入:"帮我写个读取文件的函数"
输出:{"type": "other"}
输入:"今天是星期几?"
输出:{"type": "question"}
返回数据格式以json方式输出,
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求{"statement":识别出的文本内容}
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,160 @@
# 角色:{#InputSlot placeholder="角色名称" mode="input"#}{#/InputSlot#}
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
## 目标:
你需要根据以下类型对输入数据进行分类,并生成相应的拆分策略和示例。
---
### 历史信息参考
在生成扩展问题时,你可以参考以下历史数据(如果提供):
- 历史对话或任务的主题;
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
- 历史中已解答的问题(避免重复);
- 历史推理链(保持逻辑一致性)。
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
输入历史信息内容:{{history}}
## User Input
{{ sentence }}
## 需求:
1:首先判断类型(单跳、多跳、开放域、时间)。
2:根据类型进行拆分。
3:拆分后的内容需保证信息完整且可独立处理。
4:对每个拆分条目,可附加示例或说明。
5:拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
## 指令:
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
单跳Single-hop
描述:问题或数据只需要通过一步即可得到答案或完成拆分,不依赖其他信息。
拆分策略:直接识别核心信息或关键字段,生成可独立处理的片段。
示例:
输入数据:"请列出今年诺贝尔物理学奖的得主"
拆分结果:[
{
"id": "Q1",
"question": "今年诺贝尔物理学奖得主是谁",
"type": "单跳’",
}
]
注意: 当遇到上下文依赖问题时明确指出缺失的信息类型并且question可填写输入问题
多跳Multi-hop:
描述:问题或数据需要通过多步推理或跨多个信息源才能得到答案。
拆分策略:将问题拆解为多个子问题,每个子问题对应一个独立处理步骤,需要具备推理链条与逻辑连接数量。
示例:
输入数据:"今年诺贝尔物理学奖得主的研究领域及代表性成果"
拆分结果:
[
{
"id": "Q1",
"question": 今年诺贝尔物理学奖得主是谁?",
"type": "多跳’",
},
{
"id": "Q2",
"question": "该得主的研究领域是什么?",
"type": "多跳’",
},
{
"id": "Q3",
"question": "该得主的代表性成果有哪些?",
"type": "多跳’"
}
]
开放域Open-domain:
描述:问题或数据不局限于特定知识库,需要从大范围信息中检索和生成答案,而不是从一个已知的小范围数据源中查找。。
拆分策略:根据主题或关键实体拆分,同时保留上下文以便检索外部知识,问题涉及一般性、常识性、跨学科内容,可能是开放式回答(描述性、推理性、综合性)
需要外部知识检索或推理才能确定,比如:“为什么人类需要睡眠?”、“量子计算与经典计算的主要区别是什么?”。
示例:
输入数据:"介绍量子计算的最新研究进展"
拆分结果:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
时间Temporal:
描述:问题或数据涉及时间维度,需要按时间顺序或时间点拆分。
拆分策略:根据事件时间或时间段拆分为独立条目或问题。
示例:
输入数据:"列出苹果公司过去五年的重大事件"
拆分结果:
[
{
"id": "Q1",
"question": 苹果公司2019年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q2",
"question": "苹果公司2020年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2021年的重大事件有哪些",
"type": "时间’",
},
{
"id": "Q3",
"question": "苹果公司2022年的重大事件有哪些",
"type": "时间’",
}
,
{
"id": "Q4",
"question": "苹果公司2023年的重大事件有哪些",
"type": "时间’",
}
]
输出要求:
- 每个子问题包括:
- `id`: 子问题编号Q1, Q2...
- `question`: 子问题内容
- `type`: 类型(事实检索 / 澄清 / 定义 / 比较 / 行动建议等)
- `reason`: 拆分的理由(为什么要这样拆)
- 格式案例:
[
{
"id": "Q1",
"question": 量子计算的基本概念是什么?",
"type": "开放域’",
},
{
"id": "Q2",
"question": "当前量子计算的主要研究方向有哪些?",
"type": "开放域’",
},
{
"id": "Q3",
"question": "近期在量子计算领域有哪些重大进展?",
"type": "开放域’",
}
]
- 必须通过json.loads()的格式支持的形式输出
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子“statement”“Zhang Xinhua said\”我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```

View File

@@ -0,0 +1,60 @@
# 角色
你是验证专家
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析是不是回答Query_Samll这个字段的问题
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
## 工作步骤
1. 获取所有的Query_Samll字段和Answer_Samll字段
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
4. 如果是True保留否则不要相对应的问题和回答
5. 输出,需要严格按照模版
输入:{{history}}
历史消息:{"history":{{sentence}}}
### 第一步 获取用户的输入
获取用户的输入提取对应的Query_Samll和Answer_Samll
### 第二步 分析验证
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容如果有关系不是答非所问
## 核心验证标准
在评估子问题拆分时必须严格遵循以下标准且验证过程中完全不依赖于子问题的相关信息Answer_Samll
1. 合理性标准(必须全部满足)
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
- 最小化每个不同的子问题数量应尽可能少通常不超过原问题关键要素数量的2倍建议2-4个避免冗余和不必要拆分。
- 相关性:每个不同的子问题必须直接服务于原问题的解答,不引入无关内容或扩展原问题未提及的主题。
- 可操作性:每个不同的子问题应能在有限资源(如标准工具或合理时间)内独立解答,且难度适中。
- 逻辑性:每个不同的子问题间应有清晰的逻辑关系(如并列、递进、因果),共同构成原问题的解答路径。
2. 不合理拆分的特征(出现任一特征即为不合理):
- 不同的子问题数量超过5个或明显多于必要数量。
- 引入原问题未提及的新主题、人物、细节或个人看法。
- 拆分过于细碎,失去实用价值,无法高效合成原问题答案。
3. 特殊情况说明:
- 每个不同的子问题与原问题相同,需进一步判断:
- 每个不同的子问题不可进一步拆分 → success合理最小化拆分
- 每个不同的子问题能够进一步拆分为更小、更合理的问题 → failed不合理拆分没有最小化
- 每个不同的子问题数量=原问题核心要素数量 → success理想情况
- 每个不同的子问题数量=核心要素数量+1 → success通常合理
### 第三步 添加状态
如果有相关性并且比较高给一个状态TRUE否则给一个FLASE的状态
### 第四步 判断
如果状态是TRUE保留这条数据否则需不需要这条数据
### 第五步 输出格式
按照json的形式输出
{"data":"Query":原来Query的字段"history":原来的history字段
"expansion_issue":以为列表的形式存储验证之后的数据比如[
{"query_small": query_small,
"answer_small": answer_small,,
"status": 回答的结果是否符合query_small填写状态,
"query_answer": answer_small},
{
"query_small": "张曼婷生日是什么时候?",
"answer_small": "张曼婷喜欢绘画。",
"status": "True",
"query_answer": "张曼 婷喜欢绘画。"
},{}......]
,
"split_result":如果expansion_issue是空的列表返回failed不是空列表返回success,
"reason": 为以上分析完之后的结果给一个说明
}

View File

@@ -0,0 +1,57 @@
{# 角色定义 #}
你是专业的问题解答专家,负责根据上下文信息和检索到的所有信息准确回答用户的问题。
{# 输入数据展示 #}
{% if data %}
## 输入数据
上下文信息:
{% for item in data.history %}
- {{ item }}
{% endfor %}
检索到的所有信息:
{% for item in data.retrieve_info %}
- {{ item }}
{% endfor %}
{% endif %}
## User Query
{{ query }}
{# 问题回答标准 #}
## 问题回答核心标准
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。注意,若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
- 若能根据已有信息回答用户的问题,应根据上下文信息和检索到的所有信息提供简明扼要的答案。
- 若不能根据已有信息回答用户的问题,应直接回复“信息不足,无法回答。”,不能自己编造答案。
{# 重要提醒 #}
再次提醒,给出问题的答案时,仅根据已有的信息进行回答,不能自己编造答案。
{# 输出格式模板 #}
## 输出格式
严格按照以下JSON格式输出不添加任何其他内容
{
"data": {
"query": "{{ query }}",
"history": [
{% for item in data.history %}
"{{ item | replace('"', '\\"') }}"
{% if not loop.last %},{% endif %}
{% endfor %}
],
"retrieve_info": [
{% for item in data.retrieve_info %}
"{{ item | replace('"', '\\"') }}"
{% if not loop.last %},{% endif %}
{% endfor %}
]
},
"query_answer": "{% if not data.history and not data.retrieve_info %}信息不足,无法回答。{% endif %}"
}
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.{{ json_schema }}

View File

@@ -0,0 +1,203 @@
import redis
import uuid
from datetime import datetime
from app.core.config import settings
class RedisSessionStore:
def __init__(self, host='localhost', port=6379, db=0, password=None,session_id=''):
self.r = redis.Redis(host=host, port=port, db=db, password=password)
self.uudi=session_id
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id):
"""
写入一条会话数据,返回 session_id
"""
try:
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
# 使用 Hash 存储结构化数据
result = self.r.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
"apply_id": apply_id,
"group_id": group_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
})
print(f"保存结果: {result}, session_id: {session_id}")
return session_id # 返回新生成的 session_id
except Exception as e:
print(f"保存会话失败: {e}")
raise e
# ---------------- 读取 ----------------
def get_session(self, session_id):
"""
读取一条会话数据
"""
key = f"session:{session_id}"
data = self.r.hgetall(key)
if data:
return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
return None
def get_session_apply_group(self, sessionid, apply_id, group_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 解码数据
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
# 检查三个条件是否都匹配
if (decoded_data.get('sessionid') == sessionid and
decoded_data.get('apply_id') == apply_id and
decoded_data.get('group_id') == group_id):
result_items.append(decoded_data)
return result_items
def get_all_sessions(self):
"""
获取所有会话数据
"""
sessions = {}
for key in self.r.keys('session:*'):
sid = key.decode('utf-8').split(':')[1]
sessions[sid] = self.get_session(sid)
return sessions
# ---------------- 更新 ----------------
def update_session(self, session_id, field, value):
"""
更新单个字段
"""
key = f"session:{session_id}"
if self.r.exists(key):
self.r.hset(key, field, value)
return True
return False
# ---------------- 删除 ----------------
def delete_session(self, session_id):
"""
删除单条会话
"""
key = f"session:{session_id}"
return self.r.delete(key)
def delete_all_sessions(self):
"""
删除所有会话
"""
keys = self.r.keys('session:*')
if keys:
return self.r.delete(*keys)
return 0
def delete_duplicate_sessions(self):
"""
删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
"""
seen = set() # 用来记录已出现的唯一组合
deleted_count = 0
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 获取五个字段的值并解码
sessionid = data.get(b'sessionid', b'').decode('utf-8')
user_id = data.get(b'id', b'').decode('utf-8') # 对应user_id
group_id = data.get(b'group_id', b'').decode('utf-8')
messages = data.get(b'messages', b'').decode('utf-8')
aimessages = data.get(b'aimessages', b'').decode('utf-8')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages)
if identifier in seen:
# 重复,删除该 key
self.r.delete(key)
deleted_count += 1
else:
# 第一次出现,加入 seen
seen.add(identifier)
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}")
return deleted_count
def find_user_session(self,sessionid):
user_id = sessionid
result_items = []
for key, values in store.get_all_sessions().items():
history = {}
if user_id == str(values['sessionid']):
history["Query"] = values['messages']
history["Answer"] = values['aimessages']
result_items.append(history)
if len(result_items) <= 1:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id):
"""
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key_bytes in self.r.keys('session:*'):
key = key_bytes.decode('utf-8')
data = self.r.hgetall(key)
if not data:
continue
# 解码数据
decoded_data = {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}
# 检查三个条件是否都匹配
if (decoded_data.get('sessionid') == sessionid and
decoded_data.get('apply_id') == apply_id and
decoded_data.get('group_id') == group_id):
history = {
"Query": decoded_data.get('messages'),
"Answer": decoded_data.get('aimessages')
}
result_items.append(history)
# 如果结果少于等于1条返回空列表
if len(result_items) <= 1:
result_items = []
return result_items
store = RedisSessionStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)

View File

@@ -0,0 +1,59 @@
"""
Type classification utility for distinguishing read/write operations.
"""
from jinja2 import Template
from pydantic import BaseModel
from app.core.logging_config import get_agent_logger, log_prompt_rendering
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.config import settings
logger = get_agent_logger(__name__)
class DistinguishTypeResponse(BaseModel):
"""Response model for type classification"""
type: str
async def status_typle(messages: str) -> dict:
"""
Classify message type as read or write operation.
Args:
messages: User message to classify
Returns:
dict: Contains 'type' field with classification result
"""
try:
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/distinguish_types_prompt.jinja2'
template_content = await read_template_file(file_path)
template = Template(template_content)
system_prompt = template.render(user_query=messages)
log_prompt_rendering("status_typle", system_prompt)
except Exception as e:
logger.error(f"Template rendering failed for status_typle: {e}", exc_info=True)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
return structured.model_dump()
except Exception as e:
logger.error(f"LLM call failed for status_typle: {e}", exc_info=True)
return {
"type": "error",
"message": f"LLM call failed: {str(e)}"
}

View File

@@ -0,0 +1,76 @@
from typing import TypedDict, Annotated, List, Any
from langchain_core.messages import AnyMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph, add_messages
import asyncio
import json
from dotenv import load_dotenv, find_dotenv
import os
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from langchain_core.messages import HumanMessage
from jinja2 import Environment, FileSystemLoader
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
from app.core.logging_config import get_agent_logger
load_dotenv(find_dotenv())
logger = get_agent_logger(__name__)
def keep_last(_, right):
return right
class State(TypedDict):
user_input: Annotated[dict, keep_last]
messages: Annotated[List[AnyMessage], add_messages]
agent1_response: str
agent2_response: str
agent3_response: str
final_response: str
status: Annotated[str, keep_last]
class VerifyTool:
def __init__(self, system_prompt: str="", verify_data: Any=None):
self.system_prompt = system_prompt
if isinstance(verify_data, str):
self.verify_data = verify_data
else:
try:
self.verify_data = json.dumps(verify_data, ensure_ascii=False)
except Exception:
self.verify_data = str(verify_data)
async def model_1(self, state: State) -> State:
llm_client = get_llm_client(SELECTED_LLM_ID)
response_content = await llm_client.chat(
messages=[{"role": "system", "content": self.system_prompt}] + _to_openai_messages(state["messages"])
)
return {
"agent1_response": response_content,
"status": "processed",
}
def get_graph(self):
graph = StateGraph(State)
graph.add_node("model_1", self.model_1)
graph.add_edge(START, "model_1")
graph.add_edge("model_1", END)
compiled_graph = graph.compile()
return compiled_graph
async def verify(self):
graph = self.get_graph()
initial_state = {
"user_input": self.verify_data,
"messages": [HumanMessage(content=self.verify_data)],
"final_response": "",
"status": ""
}
final_state = await graph.ainvoke(initial_state)
# return final_state["final_response"]
return final_state["agent1_response"]

View File

@@ -0,0 +1,49 @@
import os
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy.orm import Session
import logging
import json
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
logger = logging.getLogger(__name__)
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
"""
将数据写入数据库
:param host_id: 宿主 ID
:param data: 要写入的数据
:return: 写入数据库的结果
"""
# 从数据库会话中获取会话
db: Session = next(get_db())
try:
if isinstance(data, (dict, list)):
serialized = json.dumps(data, ensure_ascii=False)
elif isinstance(data, str):
serialized = data
else:
serialized = str(data)
new_retrieval_info = RetrievalInfo(
# host_id=host_id,
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
retrieve_info=serialized,
created_at=datetime.now()
)
db.add(new_retrieval_info)
db.commit()
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
return "success to write data to database"
except Exception as e:
db.rollback()
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
raise e
finally:
try:
db.close()
except Exception:
pass

View File

@@ -0,0 +1,183 @@
import asyncio
from dotenv import load_dotenv
import time
from datetime import datetime
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation_all,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.log.logging_utils import log_time
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
load_dotenv()
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
"""
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator
Args:
content: 对话内容
user_id: 用户ID
apply_id: 应用ID
group_id: 组ID
ref_id: 参考ID默认为 "wyl20251027"
config_id: 配置ID用于标记数据处理配置
"""
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
logger.info(f"Config ID: {config_id if config_id else 'None'}")
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
# Initialize timing log
log_file = "logs/time.log"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
pipeline_start = time.time()
# 初始化客户端
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.src.llm_tools.openai_embedder import OpenAIEmbedderClient
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
# Step 1: 加载和分块数据
step_start = time.time()
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
content=content,
ref_id=ref_id,
config_id=config_id,
)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# Step 2: 初始化并运行 ExtractionOrchestrator
step_start = time.time()
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
)
# 运行完整的提取流水线
# orchestrator.run returns a flat tuple of 7 values after deduplication
(
all_dialogue_nodes,
all_chunk_nodes,
all_statement_nodes,
all_entity_nodes,
all_statement_chunk_edges,
all_statement_entity_edges,
all_entity_entity_edges,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# Step 8: Save all data to Neo4j database using graph models
step_start = time.time()
# 运行索引创建
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close()
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
step_start = time.time()
try:
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
)
# Save memory summaries to Neo4j as nodes
try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector)
# Link summaries to statements via chunks for summary→entity queries
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
try:
await ms_connector.close()
except Exception:
pass
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
finally:
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
# Log total pipeline time
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# Add completion marker to log
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
if __name__ == "__main__":
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
asyncio.run(write(content, ref_id="wyl20251027"))

View File

@@ -0,0 +1,19 @@
"""
LLM 工具模块
提供 LLM 和 Embedder 客户端的抽象基类和具体实现。
"""
from app.core.memory.llm_tools.llm_client import LLMClient
from app.core.memory.llm_tools.embedder_client import EmbedderClient
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.llm_tools.chunker_client import ChunkerClient
__all__ = [
"LLMClient",
"EmbedderClient",
"OpenAIClient",
"OpenAIEmbedderClient",
"ChunkerClient",
]

View File

@@ -0,0 +1,330 @@
from typing import Any, List
import re
import os
import asyncio
import json
import numpy as np
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from chonkie import (
SemanticChunker,
RecursiveChunker,
RecursiveRules,
LateChunker,
NeuralChunker,
SentenceChunker,
TokenChunker,
)
from app.core.memory.models.config_models import ChunkerConfig
from app.core.memory.models.message_models import DialogData, Chunk
try:
from app.core.memory.llm_tools.openai_client import OpenAIClient
except Exception:
# 在测试或无可用依赖(如 langfuse环境下允许惰性导入
OpenAIClient = Any
class LLMChunker:
"""基于LLM的智能分块策略"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client
self.chunk_size = chunk_size
async def __call__(self, text: str) -> List[Any]:
# 使用LLM分析文本结构并进行智能分块
prompt = f"""
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
请以JSON格式返回结果包含chunks数组每个chunk有text字段。
文本内容:
{text[:5000]}
"""
messages = [
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
{"role": "user", "content": prompt}
]
try:
# 使用异步的 achat 方法
if hasattr(self.llm_client, 'achat'):
response = await self.llm_client.achat(messages)
else:
# 如果没有异步方法,使用同步方法并转换为异步
response = await asyncio.to_thread(self.llm_client.chat, messages)
# 检查响应格式并提取内容
if hasattr(response, 'choices') and len(response.choices) > 0:
content = response.choices[0].message.content
elif hasattr(response, 'content'):
content = response.content
else:
content = str(response)
# 解析LLM响应
if "```json" in content:
json_str = content.split("```json")[1].split("```")[0].strip()
elif "```" in content:
json_str = content.split("```")[1].split("```")[0].strip()
else:
json_str = content
result = json.loads(json_str)
class SimpleChunk:
def __init__(self, text, index):
self.text = text
self.start_index = index * 100 # 近似位置
self.end_index = (index + 1) * 100
return [SimpleChunk(chunk["text"], i) for i, chunk in enumerate(result.get("chunks", []))]
except Exception as e:
print(f"LLM分块失败: {e}")
# 失败时返回空列表,外层会处理回退方案
return []
class HybridChunker:
"""混合分块策略:先按结构分块,再按语义合并"""
def __init__(self, semantic_threshold: float = 0.8, base_chunk_size: int = 300):
self.semantic_threshold = semantic_threshold
self.base_chunk_size = base_chunk_size
self.base_chunker = TokenChunker(tokenizer="character", chunk_size=base_chunk_size)
self.semantic_chunker = SemanticChunker(threshold=semantic_threshold)
def __call__(self, text: str) -> List[Any]:
# 先用基础分块
base_chunks = self.base_chunker(text)
# 如果文本不长,直接返回基础分块
if len(base_chunks) <= 3:
return base_chunks
# 对基础分块进行语义合并
combined_text = " ".join([chunk.text for chunk in base_chunks])
return self.semantic_chunker(combined_text)
class ChunkerClient:
def __init__(self, chunker_config: ChunkerConfig, llm_client: OpenAIClient = None):
self.chunker_config = chunker_config
self.embedding_model = chunker_config.embedding_model
self.chunk_size = chunker_config.chunk_size
self.threshold = chunker_config.threshold
self.language = chunker_config.language
self.skip_window = chunker_config.skip_window
self.min_sentences = chunker_config.min_sentences
self.min_characters_per_chunk = chunker_config.min_characters_per_chunk
self.llm_client = llm_client
# 可选参数(从配置中安全获取,提供默认值)
self.chunk_overlap = getattr(chunker_config, 'chunk_overlap', 0)
self.min_sentences_per_chunk = getattr(chunker_config, 'min_sentences_per_chunk', 1)
self.min_characters_per_sentence = getattr(chunker_config, 'min_characters_per_sentence', 12)
self.delim = getattr(chunker_config, 'delim', [".", "!", "?", "\n"])
self.include_delim = getattr(chunker_config, 'include_delim', "prev")
self.tokenizer_or_token_counter = getattr(chunker_config, 'tokenizer_or_token_counter', "character")
# 初始化具体分块器策略
if chunker_config.chunker_strategy == "TokenChunker":
self.chunker = TokenChunker(
tokenizer=self.tokenizer_or_token_counter,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
elif chunker_config.chunker_strategy == "SemanticChunker":
self.chunker = SemanticChunker(
embedding_model=self.embedding_model,
threshold=self.threshold,
chunk_size=self.chunk_size,
min_sentences=self.min_sentences,
)
elif chunker_config.chunker_strategy == "RecursiveChunker":
self.chunker = RecursiveChunker(
rules=RecursiveRules(),
min_characters_per_chunk=self.min_characters_per_chunk or 50,
chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "LateChunker":
self.chunker = LateChunker(
embedding_model=self.embedding_model,
chunk_size=self.chunk_size,
rules=RecursiveRules(),
min_characters_per_chunk=self.min_characters_per_chunk,
)
elif chunker_config.chunker_strategy == "NeuralChunker":
self.chunker = NeuralChunker(
model=self.embedding_model,
min_characters_per_chunk=self.min_characters_per_chunk,
)
elif chunker_config.chunker_strategy == "LLMChunker":
if not llm_client:
raise ValueError("LLMChunker requires an LLM client")
self.chunker = LLMChunker(llm_client, self.chunk_size)
elif chunker_config.chunker_strategy == "HybridChunker":
self.chunker = HybridChunker(
semantic_threshold=self.threshold,
base_chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "SentenceChunker":
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
# 为了兼容不同版本,这里仅传递广泛支持的参数
self.chunker = SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
min_sentences_per_chunk=self.min_sentences_per_chunk,
min_characters_per_sentence=self.min_characters_per_sentence,
delim=self.delim,
include_delim=self.include_delim,
)
else:
raise ValueError(f"Unknown chunker strategy: {chunker_config.chunker_strategy}")
async def generate_chunks(self, dialogue: DialogData):
"""
生成分块,支持异步操作
"""
try:
# 预处理文本:确保对话标记格式统一
content = dialogue.content
content = content.replace('AI', 'AI:').replace('用户:', '用户:') # 统一冒号
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
# 同步分块器
chunks = self.chunker(content)
else:
# 异步分块器如LLMChunker
chunks = await self.chunker(content)
# 过滤空块和过小的块
valid_chunks = []
for c in chunks:
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
valid_chunks.append(c)
dialogue.chunks = [
Chunk(
content=c.text if hasattr(c, 'text') else str(c),
metadata={
"start_index": getattr(c, "start_index", None),
"end_index": getattr(c, "end_index", None),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
for c in valid_chunks
]
return dialogue
except Exception as e:
print(f"分块失败: {e}")
# 改进的后备方案:尝试按对话回合分割
try:
# 简单的按对话分割
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
class SimpleChunk:
def __init__(self, text, start_index, end_index):
self.text = text
self.start_index = start_index
self.end_index = end_index
chunks = []
current_chunk = ""
current_start = 0
for match in matches:
speaker, ct = match[0], match[1].strip()
turn_text = f"{speaker} {ct}"
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
current_chunk = turn_text
current_start = dialogue.content.find(turn_text, current_start)
else:
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
dialogue.chunks = [
Chunk(
content=c.text,
metadata={
"start_index": c.start_index,
"end_index": c.end_index,
"chunker_strategy": "DialogueTurnFallback",
},
)
for c in chunks
]
except Exception:
# 最后的手段:单一大块
dialogue.chunks = [Chunk(
content=dialogue.content,
metadata={"chunker_strategy": "SingleChunkFallback"},
)]
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:
"""
评估分块质量
"""
if not getattr(dialogue, 'chunks', None):
return {}
chunks = dialogue.chunks
total_chars = sum(len(chunk.content) for chunk in chunks)
avg_chunk_size = total_chars / len(chunks)
# 计算各种指标
chunk_sizes = [len(chunk.content) for chunk in chunks]
metrics = {
"strategy": self.chunker_config.chunker_strategy,
"num_chunks": len(chunks),
"total_characters": total_chars,
"avg_chunk_size": avg_chunk_size,
"min_chunk_size": min(chunk_sizes),
"max_chunk_size": max(chunk_sizes),
"chunk_size_std": np.std(chunk_sizes) if len(chunk_sizes) > 1 else 0,
"coverage_ratio": total_chars / len(dialogue.content) if dialogue.content else 0,
}
return metrics
def save_chunking_results(self, dialogue: DialogData, output_path: str):
"""
保存分块结果到文件,文件名包含策略名称
"""
strategy_name = self.chunker_config.chunker_strategy
# 在文件名中添加策略名称
base_name, ext = os.path.splitext(output_path)
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
with open(strategy_output_path, 'w', encoding='utf-8') as f:
f.write(f"=== Chunking Strategy: {strategy_name} ===\n")
f.write(f"Total chunks: {len(dialogue.chunks)}\n")
f.write(f"Total characters: {sum(len(chunk.content) for chunk in dialogue.chunks)}\n")
f.write("=" * 60 + "\n\n")
for i, chunk in enumerate(dialogue.chunks):
f.write(f"Chunk {i+1}:\n")
f.write(f"Size: {len(chunk.content)} characters\n")
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
f.write(f"Content: {chunk.content}\n")
f.write("-" * 40 + "\n\n")
print(f"Chunking results saved to: {strategy_output_path}")
return strategy_output_path

View File

@@ -0,0 +1,176 @@
"""
Embedder 客户端抽象基类
提供统一的嵌入向量生成接口,支持重试机制和错误处理。
"""
from abc import ABC, abstractmethod
from typing import List, Optional
import asyncio
import logging
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
before_sleep_log,
)
from app.core.models.base import RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
class EmbedderClientException(BusinessException):
"""Embedder 客户端异常"""
def __init__(self, message: str, code: str = BizCode.EMBEDDING_ERROR):
super().__init__(message, code=code)
class EmbedderClient(ABC):
"""
Embedder 客户端抽象基类
提供统一的嵌入向量生成接口,包括:
- 批量文本嵌入response
- 自动重试机制
- 错误处理
"""
def __init__(self, model_config: RedBearModelConfig):
"""
初始化 Embedder 客户端
Args:
model_config: 模型配置包含模型名称、提供商、API密钥等信息
"""
self.config = model_config
self.model_name = model_config.model_name
self.provider = model_config.provider
self.api_key = model_config.api_key
self.base_url = model_config.base_url
self.max_retries = model_config.max_retries
self.timeout = model_config.timeout
logger.info(
f"初始化 Embedder 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}"
)
@abstractmethod
async def response(
self,
messages: List[str],
**kwargs
) -> List[List[float]]:
"""
生成嵌入向量
Args:
messages: 文本列表
**kwargs: 额外参数
Returns:
嵌入向量列表,每个向量是一个浮点数列表
Raises:
EmbedderClientException: 嵌入向量生成失败
"""
pass
def _create_retry_decorator(self):
"""
创建重试装饰器
Returns:
配置好的 tenacity retry 装饰器
"""
return retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((
asyncio.TimeoutError,
ConnectionError,
Exception, # 可以根据需要细化异常类型
)),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True,
)
async def response_with_retry(
self,
messages: List[str],
**kwargs
) -> List[List[float]]:
"""
带重试机制的嵌入向量生成接口
Args:
messages: 文本列表
**kwargs: 额外参数
Returns:
嵌入向量列表
Raises:
EmbedderClientException: 重试失败后抛出
"""
retry_decorator = self._create_retry_decorator()
@retry_decorator
async def _response_with_retry():
try:
return await self.response(messages, **kwargs)
except Exception as e:
logger.error(f"嵌入向量生成失败: {e}")
raise EmbedderClientException(f"嵌入向量生成失败: {e}") from e
return await _response_with_retry()
async def embed_single(self, text: str, **kwargs) -> List[float]:
"""
为单个文本生成嵌入向量
Args:
text: 单个文本
**kwargs: 额外参数
Returns:
嵌入向量(浮点数列表)
Raises:
EmbedderClientException: 嵌入向量生成失败
"""
embeddings = await self.response_with_retry([text], **kwargs)
return embeddings[0] if embeddings else []
async def embed_batch(
self,
texts: List[str],
batch_size: int = 100,
**kwargs
) -> List[List[float]]:
"""
批量生成嵌入向量(支持大批量文本)
Args:
texts: 文本列表
batch_size: 每批处理的文本数量
**kwargs: 额外参数
Returns:
嵌入向量列表
Raises:
EmbedderClientException: 嵌入向量生成失败
"""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_embeddings = await self.response_with_retry(batch, **kwargs)
all_embeddings.extend(batch_embeddings)
return all_embeddings

Some files were not shown because too many files have changed in this diff Show More