Merge remote-tracking branch 'origin/release/v0.2.9' into develop
This commit is contained in:
@@ -151,11 +151,6 @@ async def write(
|
||||
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
@@ -279,5 +274,21 @@ async def write(
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
# Close LLM/Embedder underlying httpx clients to prevent
|
||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||
for client_obj in (llm_client, embedder_client):
|
||||
try:
|
||||
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||
if underlying is None:
|
||||
continue
|
||||
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||
inner = getattr(underlying, '_model', underlying)
|
||||
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||
http_client = getattr(inner, 'async_client', None)
|
||||
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
||||
type=type_
|
||||
)
|
||||
|
||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def message_has_files(message: "ConversationMessage") -> bool:
|
||||
"""检查消息是否包含文件。
|
||||
|
||||
Args:
|
||||
message: 待检查的消息对象
|
||||
|
||||
Returns:
|
||||
bool: 如果消息包含文件则返回 True,否则返回 False
|
||||
"""
|
||||
return message.files and len(message.files) > 0
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
|
||||
@@ -128,7 +140,7 @@ class SemanticPruner:
|
||||
1. 空消息
|
||||
2. 场景特定填充词库精确匹配
|
||||
3. 常见寒暄精确匹配
|
||||
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
5. 纯表情/标点
|
||||
"""
|
||||
t = message.msg.strip()
|
||||
@@ -482,6 +494,11 @@ class SemanticPruner:
|
||||
"""
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||
if self._is_filler_message(m):
|
||||
to_delete_ids.add(id(m))
|
||||
@@ -549,6 +566,11 @@ class SemanticPruner:
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||
if self._is_filler_message(m):
|
||||
@@ -801,6 +823,12 @@ class SemanticPruner:
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与分类
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
|
||||
llm_protected_msgs.append((idx, m)) # 放入保护列表
|
||||
continue
|
||||
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
llm_protected_msgs.append((idx, m))
|
||||
|
||||
@@ -182,7 +182,7 @@ class ExtractionOrchestrator:
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[PerceptualEdge],
|
||||
dict
|
||||
list[DialogData]
|
||||
]:
|
||||
"""
|
||||
运行完整的知识提取流水线(优化版:并行执行)
|
||||
@@ -295,6 +295,7 @@ class ExtractionOrchestrator:
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
) = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -306,6 +307,11 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||
if not is_pilot_run:
|
||||
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return (
|
||||
dialogue_nodes,
|
||||
@@ -1399,7 +1405,8 @@ class ExtractionOrchestrator:
|
||||
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
if first_alias:
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
@@ -1415,29 +1422,33 @@ class ExtractionOrchestrator:
|
||||
|
||||
|
||||
|
||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||
|
||||
这个方法直接返回 LLM 提取的别名列表,不做任何修改。
|
||||
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||
第一个别名将被用作 other_name。
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
|
||||
Returns:
|
||||
别名列表(保持 LLM 提取的原始顺序)
|
||||
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||
"""
|
||||
USER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
for entity in entity_nodes:
|
||||
if getattr(entity, 'name', '').strip() in USER_NAMES:
|
||||
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
aliases = getattr(entity, 'aliases', []) or []
|
||||
logger.debug(f"提取到用户别名(原始顺序): {aliases}")
|
||||
return aliases
|
||||
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||
return filtered
|
||||
return []
|
||||
|
||||
|
||||
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表"""
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||
@@ -1451,7 +1462,10 @@ class ExtractionOrchestrator:
|
||||
aliases = result[0].get('aliases') or []
|
||||
if not aliases:
|
||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||
return aliases
|
||||
return []
|
||||
# 过滤掉占位名称,防止历史脏数据传播
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
return filtered
|
||||
|
||||
def _resolve_other_name(
|
||||
self,
|
||||
@@ -1463,14 +1477,25 @@ class ExtractionOrchestrator:
|
||||
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
|
||||
|
||||
决策规则:
|
||||
- 为空 → 用本次对话第一个别名
|
||||
- 为空或为占位名称 → 用本次对话第一个别名
|
||||
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
|
||||
- 否则 → 保持不变(返回 None)
|
||||
|
||||
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||
"""
|
||||
if not current or not current.strip():
|
||||
return current_aliases[0].strip() if current_aliases else None
|
||||
# 当前值为空或为占位名称时,需要更新
|
||||
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
candidate = current_aliases[0].strip() if current_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
if current not in neo4j_aliases:
|
||||
return neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
@@ -1492,6 +1517,7 @@ class ExtractionOrchestrator:
|
||||
list[StatementChunkEdge],
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[DialogData],
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
@@ -1555,6 +1581,8 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edges,
|
||||
dedup_statement_entity_edges,
|
||||
dedup_entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
)
|
||||
|
||||
final_entity_nodes = dedup_entity_nodes
|
||||
@@ -1562,7 +1590,16 @@ class ExtractionOrchestrator:
|
||||
final_entity_entity_edges = dedup_entity_entity_edges
|
||||
else:
|
||||
# 正式模式:执行完整的两阶段去重
|
||||
result_tuple = await dedup_layers_and_merge_and_return(
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
final_entity_nodes,
|
||||
statement_chunk_edges,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = await dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
@@ -1576,21 +1613,21 @@ class ExtractionOrchestrator:
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 解包返回值
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
final_entity_nodes,
|
||||
_,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = result_tuple
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||
|
||||
result_tuple = (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
final_entity_nodes,
|
||||
statement_chunk_edges,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
||||
|
||||
@@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement.
|
||||
{% if language == "zh" %}
|
||||
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
||||
- 用户的真实姓名:放入 aliases
|
||||
- **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等**
|
||||
- 示例:
|
||||
* "我叫李明" → name="用户", aliases=["李明"]
|
||||
* ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases)
|
||||
* ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases)
|
||||
{% else %}
|
||||
- User entity name field: use "User" or "I"
|
||||
- User's real name: put in aliases
|
||||
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
|
||||
- Examples:
|
||||
* "I'm John" → name="User", aliases=["John"]
|
||||
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
|
||||
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
|
||||
{% endif %}
|
||||
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ class OSSStorage(StorageBackend):
|
||||
access_key_id: str,
|
||||
access_key_secret: str,
|
||||
bucket_name: str,
|
||||
connect_timeout: int = 30,
|
||||
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
|
||||
):
|
||||
"""
|
||||
Initialize the OSSStorage backend.
|
||||
@@ -53,6 +55,8 @@ class OSSStorage(StorageBackend):
|
||||
access_key_id: The Aliyun access key ID.
|
||||
access_key_secret: The Aliyun access key secret.
|
||||
bucket_name: The name of the OSS bucket.
|
||||
connect_timeout: Connection timeout in seconds (default: 30).
|
||||
multipart_threshold: File size threshold for multipart upload (default: 10MB).
|
||||
|
||||
Raises:
|
||||
StorageConfigError: If any required configuration is missing.
|
||||
@@ -69,10 +73,17 @@ class OSSStorage(StorageBackend):
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.bucket_name = bucket_name
|
||||
self.multipart_threshold = multipart_threshold
|
||||
|
||||
try:
|
||||
auth = oss2.Auth(access_key_id, access_key_secret)
|
||||
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
||||
# 设置超时和重试
|
||||
self.bucket = oss2.Bucket(
|
||||
auth,
|
||||
endpoint,
|
||||
bucket_name,
|
||||
connect_timeout=connect_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
|
||||
)
|
||||
@@ -108,21 +119,38 @@ class OSSStorage(StorageBackend):
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||
# 大文件使用分片上传
|
||||
if len(content) > self.multipart_threshold:
|
||||
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
|
||||
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
|
||||
parts = []
|
||||
part_size = 5 * 1024 * 1024 # 5MB per part
|
||||
part_num = 1
|
||||
|
||||
for offset in range(0, len(content), part_size):
|
||||
chunk = content[offset:offset + part_size]
|
||||
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||
part_num += 1
|
||||
|
||||
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||
else:
|
||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||
|
||||
logger.info(f"File uploaded to OSS successfully: {file_key}")
|
||||
return file_key
|
||||
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to upload file to OSS: {e.message}",
|
||||
message=f"Failed to upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to upload file to OSS: {e}",
|
||||
message=f"Failed to upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
@@ -135,28 +163,73 @@ class OSSStorage(StorageBackend):
|
||||
) -> int:
|
||||
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||
buf = io.BytesIO()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
upload_id = None
|
||||
|
||||
try:
|
||||
# 收集流数据
|
||||
total_size = 0
|
||||
async for chunk in stream:
|
||||
if not chunk:
|
||||
continue
|
||||
buf.write(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
content = buf.getvalue()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
||||
return len(content)
|
||||
|
||||
if not content:
|
||||
raise StorageUploadError(
|
||||
message="Empty stream content",
|
||||
file_key=file_key,
|
||||
)
|
||||
|
||||
# 大文件使用分片上传
|
||||
if len(content) > self.multipart_threshold:
|
||||
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
|
||||
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
|
||||
parts = []
|
||||
part_size = 5 * 1024 * 1024 # 5MB
|
||||
part_num = 1
|
||||
|
||||
for offset in range(0, len(content), part_size):
|
||||
chunk = content[offset:offset + part_size]
|
||||
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||
part_num += 1
|
||||
|
||||
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||
else:
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
|
||||
return total_size
|
||||
|
||||
except OssError as e:
|
||||
if upload_id:
|
||||
try:
|
||||
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e.message}",
|
||||
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
if upload_id:
|
||||
try:
|
||||
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e}",
|
||||
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
finally:
|
||||
buf.close()
|
||||
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
@@ -182,14 +255,14 @@ class OSSStorage(StorageBackend):
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error downloading file {file_key}: {e}")
|
||||
raise StorageDownloadError(
|
||||
message=f"Failed to download file from OSS: {e.message}",
|
||||
message=f"Failed to download file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file from OSS {file_key}: {e}")
|
||||
raise StorageDownloadError(
|
||||
message=f"Failed to download file from OSS: {e}",
|
||||
message=f"Failed to download file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
@@ -215,14 +288,14 @@ class OSSStorage(StorageBackend):
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error deleting file {file_key}: {e}")
|
||||
raise StorageDeleteError(
|
||||
message=f"Failed to delete file from OSS: {e.message}",
|
||||
message=f"Failed to delete file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
|
||||
raise StorageDeleteError(
|
||||
message=f"Failed to delete file from OSS: {e}",
|
||||
message=f"Failed to delete file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ class SimpleMCPClient:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(self.server_url)
|
||||
|
||||
if response.status != 200:
|
||||
if response.status not in (200, 202):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -190,7 +190,9 @@ class SimpleMCPClient:
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
if response.status != 200:
|
||||
# MCP SSE 协议:POST 请求返回 200 或 202 均为正常
|
||||
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
|
||||
if response.status not in (200, 202):
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
@@ -205,7 +207,7 @@ class SimpleMCPClient:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status != 200:
|
||||
if response.status not in (200, 202):
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
|
||||
Reference in New Issue
Block a user