feat(multimodal): support document image extraction and inline vision processing
Added document image extraction capability for PDF and DOCX files, including page/index metadata and storage integration. Extended `process_files` with `document_image_recognition` flag to conditionally enable vision-based image processing when model supports it. Updated knowledge repository and workflow node logic to enforce status=1 checks. Added PyMuPDF dependency.
This commit is contained in:
@@ -16,7 +16,7 @@ from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.app_schema import FileInput, FileType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -165,8 +165,27 @@ class AppChatService:
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
fu_config = features_config.get("file_upload", {})
|
||||
if hasattr(fu_config, "model_dump"):
|
||||
fu_config = fu_config.model_dump()
|
||||
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
|
||||
processed_files = await multimodal_service.process_files(
|
||||
files, document_image_recognition=doc_img_recognition
|
||||
)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
|
||||
f.type == FileType.DOCUMENT for f in files
|
||||
):
|
||||
from langchain.agents import create_agent
|
||||
agent.system_prompt += (
|
||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||
)
|
||||
agent.agent = create_agent(
|
||||
model=agent.llm,
|
||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||
system_prompt=agent.system_prompt
|
||||
)
|
||||
# 为需要运行时上下文的工具注入上下文
|
||||
for t in tools:
|
||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||
@@ -438,8 +457,27 @@ class AppChatService:
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
fu_config = features_config.get("file_upload", {})
|
||||
if hasattr(fu_config, "model_dump"):
|
||||
fu_config = fu_config.model_dump()
|
||||
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
|
||||
processed_files = await multimodal_service.process_files(
|
||||
files, document_image_recognition=doc_img_recognition
|
||||
)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
|
||||
f.type == FileType.DOCUMENT for f in files
|
||||
):
|
||||
from langchain.agents import create_agent
|
||||
agent.system_prompt += (
|
||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||
)
|
||||
agent.agent = create_agent(
|
||||
model=agent.llm,
|
||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||
system_prompt=agent.system_prompt
|
||||
)
|
||||
|
||||
# 为需要运行时上下文的工具注入上下文
|
||||
for t in tools:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Tuple, Union
|
||||
import jwt
|
||||
@@ -130,7 +132,7 @@ def register_user_with_invite(
|
||||
email: str,
|
||||
password: str,
|
||||
invite_token: str,
|
||||
workspace_id: str,
|
||||
workspace_id: uuid.UUID,
|
||||
username: Optional[str] = None,
|
||||
) -> User:
|
||||
"""
|
||||
@@ -147,6 +149,7 @@ def register_user_with_invite(
|
||||
from app.schemas.user_schema import UserCreate
|
||||
from app.schemas.workspace_schema import InviteAcceptRequest
|
||||
from app.services import user_service, workspace_service
|
||||
from app.repositories import workspace_repository as ws_repo
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -159,7 +162,8 @@ def register_user_with_invite(
|
||||
password=password,
|
||||
username=email.split('@')[0] if not username else username
|
||||
)
|
||||
user = user_service.create_user(db=db, user=user_create)
|
||||
workspace = ws_repo.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
user = user_service.create_user(db=db, user=user_create, workspace=workspace)
|
||||
logger.info(f"用户创建成功: {user.email} (ID: {user.id})")
|
||||
|
||||
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit)
|
||||
|
||||
@@ -10,6 +10,7 @@ import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@@ -27,7 +28,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput, Citation
|
||||
from app.schemas.app_schema import FileInput, Citation, FileType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -637,12 +638,35 @@ class AgentRunService:
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
has_doc_with_images = False
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
fu_config = features_config.get("file_upload", {})
|
||||
if hasattr(fu_config, "model_dump"):
|
||||
fu_config = fu_config.model_dump()
|
||||
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
|
||||
processed_files = await multimodal_service.process_files(
|
||||
files, document_image_recognition=doc_img_recognition
|
||||
)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
capability = api_key_config.get("capability", [])
|
||||
has_doc_with_images = (
|
||||
doc_img_recognition
|
||||
and "vision" in capability
|
||||
and any(f.type == FileType.DOCUMENT for f in files)
|
||||
)
|
||||
if has_doc_with_images:
|
||||
agent.system_prompt += (
|
||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||
)
|
||||
# 重建 agent graph 以使新 system_prompt 生效
|
||||
agent.agent = create_agent(
|
||||
model=agent.llm,
|
||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||
system_prompt=agent.system_prompt
|
||||
)
|
||||
# 为需要运行时上下文的工具注入上下文
|
||||
for t in tools:
|
||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||
@@ -895,12 +919,34 @@ class AgentRunService:
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
has_doc_with_images = False
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
fu_config = features_config.get("file_upload", {})
|
||||
if hasattr(fu_config, "model_dump"):
|
||||
fu_config = fu_config.model_dump()
|
||||
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
|
||||
processed_files = await multimodal_service.process_files(
|
||||
files, document_image_recognition=doc_img_recognition
|
||||
)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
capability = api_key_config.get("capability", [])
|
||||
has_doc_with_images = (
|
||||
doc_img_recognition
|
||||
and "vision" in capability
|
||||
and any(f.type == FileType.DOCUMENT for f in files)
|
||||
)
|
||||
if has_doc_with_images:
|
||||
agent.system_prompt += (
|
||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||
)
|
||||
agent.agent = create_agent(
|
||||
model=agent.llm,
|
||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||
system_prompt=agent.system_prompt
|
||||
)
|
||||
# 为需要运行时上下文的工具注入上下文
|
||||
for t in tools:
|
||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||
|
||||
@@ -821,7 +821,7 @@ def get_rag_content(
|
||||
for document in documents:
|
||||
try:
|
||||
kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id)
|
||||
if not kb:
|
||||
if not (kb and kb.status == 1):
|
||||
business_logger.warning(f"知识库不存在: kb_id={document.kb_id}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -344,6 +344,7 @@ class MultimodalService:
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]],
|
||||
document_image_recognition: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
@@ -379,6 +380,31 @@ class MultimodalService:
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
is_support, content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
# 仅当开关开启且模型支持视觉时,才提取文档内嵌图片
|
||||
if document_image_recognition and "vision" in self.capability:
|
||||
img_infos = await self.extract_document_images(file)
|
||||
for img_info in img_infos:
|
||||
page = img_info["page"]
|
||||
index = img_info["index"]
|
||||
ext = img_info.get("ext", "png")
|
||||
try:
|
||||
_, img_url = await self._save_doc_image_to_storage(img_info["bytes"], ext)
|
||||
placeholder = f"第{page}页 第{index + 1}张图片" if page > 0 else f"第{index + 1}张图片"
|
||||
# 在文本内容中追加图片位置标记
|
||||
if result and result[-1].get("type") in ("text", "document"):
|
||||
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
||||
result[-1][key] = result[-1].get(key, "") + f"\n[{placeholder}]: {img_url}"
|
||||
# 将图片以视觉格式追加到消息内容中
|
||||
img_file = FileInput(
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=TransferMethod.REMOTE_URL,
|
||||
url=img_url,
|
||||
file_type="image/png",
|
||||
)
|
||||
_, img_content = await self._process_image(img_file, strategy_class(img_file))
|
||||
result.append(img_content)
|
||||
except Exception as img_err:
|
||||
logger.warning(f"文档图片处理失败: {img_err}")
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
is_support, content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
@@ -431,12 +457,8 @@ class MultimodalService:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
|
||||
Args:
|
||||
file: 文档文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
仅返回文本内容(图片通过 process_files 中的额外步骤追加)
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return True, {
|
||||
@@ -444,19 +466,63 @@ class MultimodalService:
|
||||
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self.extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
file_name = file_metadata.file_name if file_metadata else "unknown"
|
||||
|
||||
# 使用策略格式化文档
|
||||
return await strategy.format_document(file_name, text)
|
||||
|
||||
async def _save_doc_image_to_storage(
|
||||
self,
|
||||
img_bytes: bytes,
|
||||
ext: str,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
将文档内嵌图片保存到存储后端,写入 FileMetadata。
|
||||
tenant_id / workspace_id 从 api_config 所在的 FileMetadata 上下文获取,
|
||||
无法获取时使用占位 UUID(图片仍可通过 permanent URL 访问)。
|
||||
|
||||
Returns:
|
||||
(file_id_str, permanent_url)
|
||||
"""
|
||||
import uuid as _uuid
|
||||
from app.services.file_storage_service import FileStorageService, generate_file_key
|
||||
from app.db import get_db_context
|
||||
|
||||
file_id = _uuid.uuid4()
|
||||
file_ext = f".{ext}" if not ext.startswith(".") else ext
|
||||
content_type = f"image/{ext}"
|
||||
|
||||
# tenant_id / workspace_id 尽量从已有 FileMetadata 推断,否则用占位值
|
||||
placeholder = _uuid.UUID(int=0)
|
||||
tenant_id = placeholder
|
||||
workspace_id = placeholder
|
||||
|
||||
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||||
storage_svc = FileStorageService()
|
||||
await storage_svc.storage.upload(file_key, img_bytes, content_type)
|
||||
|
||||
with get_db_context() as db:
|
||||
meta = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=f"doc_image_{file_id}{file_ext}",
|
||||
file_ext=file_ext,
|
||||
file_size=len(img_bytes),
|
||||
content_type=content_type,
|
||||
status="completed",
|
||||
)
|
||||
db.add(meta)
|
||||
db.commit()
|
||||
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||
return str(file_id), url
|
||||
|
||||
async def _process_audio(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
处理音频文件
|
||||
@@ -582,6 +648,84 @@ class MultimodalService:
|
||||
logger.error(f"Failed to load file. - {e}")
|
||||
return "[Failed to load file.]"
|
||||
|
||||
async def extract_document_images(self, file: FileInput) -> list[dict]:
|
||||
"""
|
||||
提取文档中的内嵌图片(支持 PDF 和 DOCX),附带位置信息。
|
||||
|
||||
Returns:
|
||||
list[dict]: 每项包含:
|
||||
- bytes: 图片二进制
|
||||
- page: 所在页码(PDF 从 1 开始,DOCX 为 0)
|
||||
- index: 该页/文档内的图片序号(从 0 开始)
|
||||
- ext: 图片扩展名(如 png、jpeg)
|
||||
"""
|
||||
try:
|
||||
file_content = file.get_content()
|
||||
if not file_content:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
file.set_content(file_content)
|
||||
|
||||
file_mime_type = magic.from_buffer(file_content, mime=True)
|
||||
if file_mime_type in PDF_MIME:
|
||||
return self._extract_pdf_images(file_content)
|
||||
elif self._is_word_file(file_content, file_mime_type):
|
||||
return self._extract_docx_images(file_content)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"提取文档图片失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _extract_pdf_images(file_content: bytes) -> list[dict]:
|
||||
"""从 PDF 提取内嵌图片,附带页码和序号"""
|
||||
images = []
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
doc = fitz.open(stream=file_content, filetype="pdf")
|
||||
for page_num, page in enumerate(doc, start=1):
|
||||
for idx, img in enumerate(page.get_images(full=True)):
|
||||
xref = img[0]
|
||||
base_image = doc.extract_image(xref)
|
||||
images.append({
|
||||
"bytes": base_image["image"],
|
||||
"ext": base_image.get("ext", "png"),
|
||||
"page": page_num,
|
||||
"index": idx,
|
||||
})
|
||||
doc.close()
|
||||
except ImportError:
|
||||
logger.warning("PyMuPDF 未安装,无法提取 PDF 图片,请执行: uv add pymupdf")
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 图片失败: {e}")
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def _extract_docx_images(file_content: bytes) -> list[dict]:
|
||||
"""从 DOCX 提取内嵌图片,附带序号(DOCX 无页码概念,page 固定为 0)"""
|
||||
images = []
|
||||
try:
|
||||
if file_content[:2] != b'PK':
|
||||
return []
|
||||
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||
media_files = sorted(
|
||||
name for name in zf.namelist()
|
||||
if name.startswith("word/media/") and not name.endswith("/")
|
||||
)
|
||||
for idx, name in enumerate(media_files):
|
||||
ext = name.rsplit(".", 1)[-1].lower() if "." in name else "png"
|
||||
images.append({
|
||||
"bytes": zf.read(name),
|
||||
"ext": ext,
|
||||
"page": 0,
|
||||
"index": idx,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"提取 DOCX 图片失败: {e}")
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
async def _extract_pdf_text(file_content: bytes) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.aioRedis import aio_redis_set, aio_redis_get, aio_redis_delete
|
||||
from app.models import Workspace
|
||||
from app.models.user_model import User
|
||||
from app.repositories import user_repository
|
||||
from app.schemas.user_schema import UserCreate
|
||||
@@ -74,7 +75,7 @@ def create_initial_superuser(db: Session):
|
||||
)
|
||||
|
||||
|
||||
def create_user(db: Session, user: UserCreate) -> User:
|
||||
def create_user(db: Session, user: UserCreate, workspace: Workspace) -> User:
|
||||
business_logger.info(f"创建用户: {user.username}, email: {user.email}")
|
||||
|
||||
try:
|
||||
@@ -93,24 +94,9 @@ def create_user(db: Session, user: UserCreate) -> User:
|
||||
business_logger.debug(f"开始创建用户: {user.username}")
|
||||
hashed_password = get_password_hash(user.password)
|
||||
|
||||
# 获取默认租户(第一个活跃租户)
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
tenant_repo = TenantRepository(db)
|
||||
tenants = tenant_repo.get_tenants(skip=0, limit=1, is_active=True)
|
||||
|
||||
if not tenants:
|
||||
business_logger.error("系统中没有可用的租户")
|
||||
raise BusinessException(
|
||||
"系统配置错误:没有可用的租户",
|
||||
code=BizCode.TENANT_NOT_FOUND,
|
||||
context={"username": user.username, "email": user.email}
|
||||
)
|
||||
|
||||
default_tenant = tenants[0]
|
||||
|
||||
new_user = user_repository.create_user(
|
||||
db=db, user=user, hashed_password=hashed_password,
|
||||
tenant_id=default_tenant.id, is_superuser=False
|
||||
tenant_id=workspace.tenant_id, is_superuser=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -694,7 +694,8 @@ class WorkflowService:
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
"execution_config": config.execution_config,
|
||||
"features": feature_configs
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -894,7 +895,8 @@ class WorkflowService:
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
"execution_config": config.execution_config,
|
||||
"features": feature_configs
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user