From ba30161559ea2c15164934826689c38521d0ae49 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 15 Jan 2026 14:58:54 +0800 Subject: [PATCH 01/12] fix(web): stream api support refresh token --- web/src/utils/stream.ts | 131 ++++++++++++++++++++++++++-------------- 1 file changed, 87 insertions(+), 44 deletions(-) diff --git a/web/src/utils/stream.ts b/web/src/utils/stream.ts index 7688cdd5..e4179e25 100644 --- a/web/src/utils/stream.ts +++ b/web/src/utils/stream.ts @@ -1,8 +1,47 @@ import { message } from 'antd'; import i18n from '@/i18n' import { cookieUtils } from './request' +import { refreshToken } from '@/api/user' +import { clearAuthData } from './auth' const API_PREFIX = '/api' +// Token refresh state +let isRefreshing = false; +let refreshPromise: Promise | null = null; + +// Refresh token function for SSE +const refreshTokenForSSE = async (): Promise => { + if (isRefreshing && refreshPromise) { + return refreshPromise; + } + + isRefreshing = true; + refreshPromise = (async () => { + try { + const refresh_token = cookieUtils.get('refreshToken'); + if (!refresh_token) { + throw new Error(i18n.t('common.refreshTokenNotExist')); + } + const response: any = await refreshToken(); + const newToken = response.access_token; + cookieUtils.set('authToken', newToken); + return newToken; + } catch (error) { + clearAuthData(); + message.warning(i18n.t('common.loginExpired')); + if (!window.location.hash.includes('#/login')) { + window.location.href = `/#/login`; + } + throw error; + } finally { + isRefreshing = false; + refreshPromise = null; + } + })(); + + return refreshPromise; +}; + export interface SSEMessage { event?: string data?: string | object @@ -66,62 +105,66 @@ function parseDataContent(dataContent: string): string | object { } } +const makeSSERequest = async (url: string, data: any, token: string, config = { headers: {} }) => { + return fetch(`${API_PREFIX}${url}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}`, + ...config.headers, + }, + body: JSON.stringify(data) + }); +}; export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMessage[]) => void, config = { headers: {} }) => { try { - const token = cookieUtils.get('authToken'); - const response = await fetch(`${API_PREFIX}${url}`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${token}`, - ...config.headers, - }, - body: JSON.stringify(data) - }); + let token = cookieUtils.get('authToken'); + let response = await makeSSERequest(url, data, token || '', config); - const { status } = response - - switch(status) { + switch (response.status) { case 401: if (url?.includes('/public')) { return message.warning(i18n.t('common.publicApiCannotRefreshToken')); } - window.location.href = `/#/login`; - break; - default: - if (!response.body) throw new Error('No response body'); - - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let buffer = ''; // 添加缓冲区来处理不完整的消息 - - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - const chunk = decoder.decode(value, { stream: true }); - buffer += chunk; - - // 处理完整的事件 - const events = buffer.split('\n\n'); - buffer = events.pop() || ''; // 保留最后一个可能不完整的事件 - - for (const event of events) { - if (event.trim() && onMessage) { - onMessage(parseSSEToJSON(event) ?? {}); - } - } - } - - // 处理剩余的缓冲区内容 - if (buffer.trim() && onMessage) { - onMessage(parseSSEToJSON(buffer) ?? {}); + try { + const newToken = await refreshTokenForSSE(); + response = await makeSSERequest(url, data, newToken, config); + } catch (refreshError) { + return; } break; } + if (!response.body) throw new Error('No response body'); + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; // 添加缓冲区来处理不完整的消息 + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + // 处理完整的事件 + const events = buffer.split('\n\n'); + buffer = events.pop() || ''; // 保留最后一个可能不完整的事件 + + for (const event of events) { + if (event.trim() && onMessage) { + onMessage(parseSSEToJSON(event) ?? {}); + } + } + } + + // 处理剩余的缓冲区内容 + if (buffer.trim() && onMessage) { + onMessage(parseSSEToJSON(buffer) ?? {}); + } } catch (error) { console.error('Request failed:', error); throw error; } -} \ No newline at end of file +}; \ No newline at end of file From 67d0b196b8202215e5178baa8877d58268eda33e Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 16 Jan 2026 13:56:36 +0800 Subject: [PATCH 02/12] =?UTF-8?q?fix(web):=20loop=E3=80=81iteration=20sub?= =?UTF-8?q?=20node=20move=20bugfix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../views/Workflow/hooks/useWorkflowGraph.ts | 90 +++++++++++++++++-- 1 file changed, 85 insertions(+), 5 deletions(-) diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 77ea56ca..75bd2517 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -666,6 +666,77 @@ export const useWorkflowGraph = ({ graphRef.current.resize(containerRef.current.offsetWidth, containerRef.current.offsetHeight); } }; + + const nodeChangePosition = ({ node, options }: { node: Node; options: { skipParentHandler?: boolean } }) => { + const embedPadding = 50; // Define the embed padding constant + if (options.skipParentHandler) { + return + } + + const children = node.getChildren() + if (children && children.length) { + node.prop('originPosition', node.getPosition()) + } + + const parent = node.getParent() + if (parent && parent.isNode()) { + let originSize = parent.prop('originSize') + if (originSize == null) { + originSize = parent.getSize() + parent.prop('originSize', originSize) + } + + let originPosition = parent.prop('originPosition') + if (originPosition == null) { + originPosition = parent.getPosition() + parent.prop('originPosition', originPosition) + } + + let x = originPosition.x + let y = originPosition.y + let cornerX = originPosition.x + originSize.width + let cornerY = originPosition.y + originSize.height + let hasChange = false + + const children = parent.getChildren() + if (children) { + children.forEach((child) => { + const bbox = child.getBBox().inflate(embedPadding) + const corner = bbox.getCorner() + + if (bbox.x < x) { + x = bbox.x + hasChange = true + } + + if (bbox.y < y) { + y = bbox.y + hasChange = true + } + + if (corner.x > cornerX) { + cornerX = corner.x + hasChange = true + } + + if (corner.y > cornerY) { + cornerY = corner.y + hasChange = true + } + }) + } + + if (hasChange) { + parent.prop( + { + position: { x, y }, + size: { width: cornerX - x, height: cornerY - y }, + }, + { skipParentHandler: true }, + ) + } + } + } // 初始化 const init = () => { @@ -764,10 +835,7 @@ export const useWorkflowGraph = ({ }, }, embedding: { - enabled: true, - validate (this) { - return false - } + enabled: false, }, translating: { restrict(view) { @@ -783,6 +851,17 @@ export const useWorkflowGraph = ({ return null }, }, + highlighting: { + embedding: { + name: 'stroke', + args: { + padding: -1, + attrs: { + stroke: '#73d13d', + }, + }, + }, + }, }); // 使用插件 setupPlugins(); @@ -824,7 +903,8 @@ export const useWorkflowGraph = ({ // 监听缩放事件 graphRef.current.on('scale', scaleEvent); // 监听节点移动事件 - graphRef.current.on('node:moved', nodeMoved); + // graphRef.current.on('node:moved', nodeMoved); + graphRef.current.on('node:change:position', nodeChangePosition); // 监听画布变化事件 const events = [ From a6a18b73046a66a16690260db42f4635b77df7af Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 16 Jan 2026 13:57:46 +0800 Subject: [PATCH 03/12] feat(web): menu order adjustment --- web/src/store/menu.json | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/web/src/store/menu.json b/web/src/store/menu.json index b49788a8..62f6c13c 100644 --- a/web/src/store/menu.json +++ b/web/src/store/menu.json @@ -332,21 +332,6 @@ } ] }, - { - "id": 19, - "parent": 0, - "code": "member", - "label": "成员管理", - "i18nKey": "menu.memberManagement", - "path": "/member", - "enable": true, - "display": true, - "level": 1, - "sort": 0, - "icon": null, - "iconActive": null, - "subs": null - }, { "id": 10, "parent": 0, @@ -377,6 +362,21 @@ "iconActive": null, "subs": null }, + { + "id": 19, + "parent": 0, + "code": "member", + "label": "成员管理", + "i18nKey": "menu.memberManagement", + "path": "/member", + "enable": true, + "display": true, + "level": 1, + "sort": 0, + "icon": null, + "iconActive": null, + "subs": null + }, { "id": 12, "parent": 0, From c2c2b306a28eade36e8391d53d6aff3d43a5902c Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 16 Jan 2026 15:48:02 +0800 Subject: [PATCH 04/12] refactor: agent config refactor --- web/src/i18n/zh.ts | 2 +- web/src/views/ApplicationConfig/Agent.tsx | 209 ++++++------------ .../ApplicationConfig/components/Card.tsx | 3 + .../components/ChatVariableConfigModal.tsx | 101 +++++++++ .../components/{ => Knowledge}/Knowledge.tsx | 105 +++++---- .../{ => Knowledge}/KnowledgeConfigModal.tsx | 39 ++-- .../KnowledgeGlobalConfigModal.tsx | 19 +- .../{ => Knowledge}/KnowledgeListModal.tsx | 17 +- .../components/Knowledge/types.ts | 30 +++ .../components/{ => ToolList}/ToolList.tsx | 19 +- .../components/{ => ToolList}/ToolModal.tsx | 0 .../components/ToolList/types.ts | 26 +++ .../components/VariableList.tsx | 131 ----------- .../{ => VariableList}/ApiExtensionModal.tsx | 2 +- .../{ => VariableList}/VariableEditModal.tsx | 11 +- .../components/VariableList/VariableList.tsx | 110 +++++++++ .../components/VariableList/types.ts | 28 +++ web/src/views/ApplicationConfig/types.ts | 92 +------- 18 files changed, 501 insertions(+), 443 deletions(-) create mode 100644 web/src/views/ApplicationConfig/components/ChatVariableConfigModal.tsx rename web/src/views/ApplicationConfig/components/{ => Knowledge}/Knowledge.tsx (54%) rename web/src/views/ApplicationConfig/components/{ => Knowledge}/KnowledgeConfigModal.tsx (76%) rename web/src/views/ApplicationConfig/components/{ => Knowledge}/KnowledgeGlobalConfigModal.tsx (86%) rename web/src/views/ApplicationConfig/components/{ => Knowledge}/KnowledgeListModal.tsx (88%) create mode 100644 web/src/views/ApplicationConfig/components/Knowledge/types.ts rename web/src/views/ApplicationConfig/components/{ => ToolList}/ToolList.tsx (93%) rename web/src/views/ApplicationConfig/components/{ => ToolList}/ToolModal.tsx (100%) create mode 100644 web/src/views/ApplicationConfig/components/ToolList/types.ts delete mode 100644 web/src/views/ApplicationConfig/components/VariableList.tsx rename web/src/views/ApplicationConfig/components/{ => VariableList}/ApiExtensionModal.tsx (99%) rename web/src/views/ApplicationConfig/components/{ => VariableList}/VariableEditModal.tsx (96%) create mode 100644 web/src/views/ApplicationConfig/components/VariableList/VariableList.tsx create mode 100644 web/src/views/ApplicationConfig/components/VariableList/types.ts diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 028202d1..6b46084f 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -658,8 +658,8 @@ export const zh = { priority: '结构化整合', addTool: '添加工具', tool: '工具', + variableConfig: '配置变量' }, - // 角色管理相关翻译 role: { roleManagement: '角色管理', roleId: '角色ID', diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 92170d55..9aab1110 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -13,26 +13,25 @@ import type { Config, ModelConfig, AgentRef, - KnowledgeBase, - KnowledgeConfig, - Variable, MemoryConfig, AiPromptModalRef, Source, - ToolOption + ChatVariableConfigModalRef } from './types' +import type { Variable } from './components/VariableList/types' +import type { KnowledgeConfig } from './components/Knowledge/types' import type { Model } from '@/views/ModelManagement/types' import { getModelList } from '@/api/models'; import { saveAgentConfig } from '@/api/application' -import Knowledge from './components/Knowledge' -import VariableList from './components/VariableList' +import Knowledge from './components/Knowledge/Knowledge' +import VariableList from './components/VariableList/VariableList' import { getApplicationConfig } from '@/api/application' -import { getKnowledgeBaseList } from '@/api/knowledgeBase' import { memoryConfigListUrl } from '@/api/memory' import CustomSelect from '@/components/CustomSelect' import aiPrompt from '@/assets/images/application/aiPrompt.png' import AiPromptModal from './components/AiPromptModal' -import ToolList from './components/ToolList' +import ToolList from './components/ToolList/ToolList' +import ChatVariableConfigModal from './components/ChatVariableConfigModal'; const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className}) => { return ( @@ -66,7 +65,7 @@ const SwitchWrapper: FC<{ title: string, desc?: string, name: string | string[]; ) } -const SelectWrapper: FC<{ title: string, desc: string, name: string, url: string }> = ({ title, desc, name, url }) => { +const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], url: string }> = ({ title, desc, name, url }) => { const { t } = useTranslation(); return ( <> @@ -77,6 +76,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string, url: string className="rb:mb-0!" > ((_props, ref) => { const [modelList, setModelList] = useState([]) const [defaultModel, setDefaultModel] = useState(null) const [chatList, setChatList] = useState([]) - const [formData, setFormData] = useState<{ - default_model_config_id?: string, - model_parameters?: Config['model_parameters'], - tools: ToolOption[], - } | null>(null) - const values = Form.useWatch<{ - memoryEnabled: boolean; - memory_content?: string | number; - } & Config>([], form) - - const [knowledgeConfig, setKnowledgeConfig] = useState({ knowledge_bases: [] }) - const [variableList, setVariableList] = useState([]) + const values = Form.useWatch([], form) const [isSave, setIsSave] = useState(false) const initialized = useRef(false) - const [toolList, setToolList] = useState([]) // 初始化完成标记 useEffect(() => { - if (data && values && formData) { + if (data) { initialized.current = true } - }, [data, values, formData]) + }, [data]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [knowledgeConfig]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [variableList]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [formData]) useEffect(() => { if (!initialized.current) return if (isSave) return setIsSave(true) }, [values]) - useEffect(() => { - if (!initialized.current) return - if (isSave) return - setIsSave(true) - }, [toolList]) useEffect(() => { getModels() @@ -157,68 +125,19 @@ const Agent = forwardRef((_props, ref) => { setLoading(true) getApplicationConfig(id as string).then(res => { const response = res as Config - setData({ - ...response, - tools: Array.isArray(response.tools) ? response.tools : [] - }) - const { memory, tools } = response + let allTools = Array.isArray(response.tools) ? response.tools : [] form.setFieldsValue({ ...response, - memoryEnabled: memory?.enabled || false, - memory_content: memory?.memory_content ? Number(memory?.memory_content) : undefined, - tools: Array.isArray(tools) ? tools : [] + tools: allTools }) - setFormData({ - default_model_config_id: response.default_model_config_id, - model_parameters: response.model_parameters || {}, - tools: Array.isArray(tools) ? tools : [] + setData({ + ...response, + tools: allTools }) - if (response?.knowledge_retrieval?.knowledge_bases?.length) { - getDefaultKnowledgeList(response) - } - if (response?.tools?.length) { - setToolList(response?.tools) - } }).finally(() => { setLoading(false) }) } - const getDefaultKnowledgeList = (data: Config) => { - if (!data || !data.knowledge_retrieval || !data.knowledge_retrieval?.knowledge_bases?.length) { - return - } - const initialList = [...(data?.knowledge_retrieval?.knowledge_bases || [])] - getKnowledgeBaseList(undefined, { - kb_ids: initialList.map(vo => vo.kb_id).join(','), - page: 1, - pagesize: 100, - }) - .then(res => { - const list = res.items || [] - const knowledge_bases: KnowledgeBase[] = list.map(item => { - const filterItem = initialList.find(vo => vo.kb_id === item.id) - return { - ...item, - ...filterItem - } - }) - setKnowledgeConfig(prev => ({ - ...prev, - knowledge_bases: [...knowledge_bases] - })) - setData((prev) => { - prev = prev as Config - const knowledge_retrieval: KnowledgeConfig = { - ...(prev?.knowledge_retrieval || {}), - knowledge_bases: [...knowledge_bases] - } - return { - ...(prev || {}), - knowledge_retrieval - } - }) - }) - } const refresh = (vo: ModelConfig, type: Source) => { if (type === 'model') { @@ -227,15 +146,7 @@ const Agent = forwardRef((_props, ref) => { default_model_config_id, model_parameters: {...rest} }) - setFormData((prevState) => { - const prev = prevState as Config - return { - ...(prev || {}), - default_model_config_id, - model_parameters: {...rest} - }; - }) - if (default_model_config_id === formData?.default_model_config_id) { + if (default_model_config_id === values?.default_model_config_id) { setChatList([{ label: vo.label || '', model_config_id: default_model_config_id || '', @@ -279,24 +190,20 @@ const Agent = forwardRef((_props, ref) => { // 保存Agent配置 const handleSave = (flag = true) => { if (!isSave || !data) return Promise.resolve() - const { memoryEnabled, memory_content, ...rest } = values - const { knowledge_bases = [], ...knowledgeRest } = knowledgeConfig || {} - - + const { memory, knowledge_retrieval, tools, ...rest } = values + const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {} + const { memory_content } = memory || {} // 从原数据中获取memory的其他必要属性 const originalMemory = data.memory || ({} as MemoryConfig) const params: Config = { ...data, ...rest, - ...(formData || {}), memory: { ...originalMemory, - enabled: memoryEnabled, + ...memory, memory_content: memory_content ? String(memory_content) : '', - max_history: originalMemory.max_history || '', }, - variables: variableList || [], knowledge_retrieval: knowledge_bases.length > 0 ? { ...data.knowledge_retrieval, ...knowledgeRest, @@ -305,14 +212,12 @@ const Agent = forwardRef((_props, ref) => { ...(item.config || {}) })) } as KnowledgeConfig : null, - tools: toolList.map(vo => ({ + tools: tools.map(vo => ({ tool_id: vo.tool_id, operation: vo.operation, enabled: vo.enabled })) } - - console.log('params', rest, params) return new Promise((resolve, reject) => { saveAgentConfig(data.app_id, params) @@ -338,8 +243,8 @@ const Agent = forwardRef((_props, ref) => { modelConfigModalRef.current?.handleOpen('chat') } useEffect(() => { - if (formData?.default_model_config_id && modelList.length > 0) { - const filterValue = modelList.find(item => item.id === formData.default_model_config_id) + if (values?.default_model_config_id && modelList.length > 0) { + const filterValue = modelList.find(item => item.id === values.default_model_config_id) setDefaultModel(filterValue as Model | null) setChatList([{ label: filterValue?.name || '', @@ -348,7 +253,7 @@ const Agent = forwardRef((_props, ref) => { list: [] }]) } - }, [modelList, formData?.default_model_config_id]) + }, [modelList, values?.default_model_config_id]) useImperativeHandle(ref, () => ({ handleSave @@ -360,8 +265,31 @@ const Agent = forwardRef((_props, ref) => { } const updatePrompt = (value: string) => { form.setFieldValue('system_prompt', value) + const variables = value.match(/\{\{([^}]+)\}\}/g)?.map(match => match.slice(2, -2)) || [] + const uniqueVariables = [...new Set(variables)] + const newVariableList: Variable[] = uniqueVariables.map((name, index) => ({ + index, + type: 'text', + name, + display_name: name, + required: false + })) + updateVariableList(newVariableList) } + const updateVariableList = (list: Variable[]) => { + form.setFieldValue('variables', [...list]) + setChatVariables([...list]) + } + const chatVariableConfigModalRef = useRef(null) + const [chatVariables, setChatVariables] = useState([]) + const handleOpenVariableConfig = () => { + chatVariableConfigModalRef.current?.handleOpen(chatVariables) + } + const handleSaveChatVariable = (values: Variable[]) => { + setChatVariables(values) + } + console.log('values', values) return ( <> {loading && } @@ -379,8 +307,9 @@ const Agent = forwardRef((_props, ref) => {
+ + - {/* 提示词 */}
@@ -406,36 +335,31 @@ const Agent = forwardRef((_props, ref) => { - {/* 知识库 */} - + + + {/* 记忆配置 */} - + - {/* 变量配置 */} - + + + {/* 工具配置 */} - + + + @@ -444,6 +368,9 @@ const Agent = forwardRef((_props, ref) => { {t('application.debuggingAndPreview')} + @@ -463,7 +390,7 @@ const Agent = forwardRef((_props, ref) => { @@ -472,6 +399,10 @@ const Agent = forwardRef((_props, ref) => { defaultModel={defaultModel} refresh={updatePrompt} /> + ); }); diff --git a/web/src/views/ApplicationConfig/components/Card.tsx b/web/src/views/ApplicationConfig/components/Card.tsx index 7d9328ea..f414848f 100644 --- a/web/src/views/ApplicationConfig/components/Card.tsx +++ b/web/src/views/ApplicationConfig/components/Card.tsx @@ -3,18 +3,21 @@ import RbCard from '@/components/RbCard/Card' interface CardProps { title?: string | ReactNode; + subTitle?: string | ReactNode; children: ReactNode; extra?: ReactNode; } const Card: FC = ({ title, + subTitle, children, extra, }) => { return ( void; +} + +const ChatVariableConfigModal = forwardRef(({ + refresh, +}, ref) => { + const { t } = useTranslation(); + const [visible, setVisible] = useState(false); + const [form] = Form.useForm<{variables: Variable[]}>(); + const [loading, setLoading] = useState(false) + const [initialValues, setInitialValues] = useState([]) + + // 封装取消方法,添加关闭弹窗逻辑 + const handleClose = () => { + setVisible(false); + form.resetFields(); + setLoading(false) + }; + + const handleOpen = (values: Variable[]) => { + console.log('values', values) + setVisible(true); + form.setFieldsValue({variables: values}) + setInitialValues([...values]) + }; + // 封装保存方法,添加提交逻辑 + const handleSave = () => { + form.validateFields().then((values) => { + refresh([ + ...(values?.variables ?? []), + ]) + handleClose() + }) + } + + // 暴露给父组件的方法 + useImperativeHandle(ref, () => ({ + handleOpen, + handleClose + })); + + console.log(form.getFieldValue('variables')) + + return ( + +
+ + {(fields) => ( + <> + {fields.map(({ name }, index) => { + const field = initialValues[index] + return ( + + { + field.type === 'text' && + } + { + field.type === 'number' && form.setFieldValue(['variables', name, 'value'], value)} /> + } + { + field.type === 'paragraph' && + } + + ) + })} + + )} + +
+
+ ); +}); + +export default ChatVariableConfigModal; \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/Knowledge.tsx b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx similarity index 54% rename from web/src/views/ApplicationConfig/components/Knowledge.tsx rename to web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx index bc1207e4..1e59f26d 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx @@ -2,7 +2,6 @@ import { type FC, useRef, useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Space, Button, List } from 'antd' import knowledgeEmpty from '@/assets/images/application/knowledgeEmpty.svg' -import Card from './Card' import type { KnowledgeConfigForm, KnowledgeConfig, @@ -11,14 +10,16 @@ import type { KnowledgeModalRef, KnowledgeConfigModalRef, KnowledgeGlobalConfigModalRef, -} from '../types' +} from './types' import Empty from '@/components/Empty' import KnowledgeListModal from './KnowledgeListModal' import KnowledgeConfigModal from './KnowledgeConfigModal' import KnowledgeGlobalConfigModal from './KnowledgeGlobalConfigModal' import Tag from '@/components/Tag' +import { getKnowledgeBaseList } from '@/api/knowledgeBase' +import Card from '../Card' -const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) => void}> = ({data, onUpdate}) => { +const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfig) => void}> = ({value = {knowledge_bases: []}, onChange}) => { const { t } = useTranslation() const knowledgeModalRef = useRef(null) const knowledgeConfigModalRef = useRef(null) @@ -27,12 +28,31 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) const [editConfig, setEditConfig] = useState({} as KnowledgeConfig) useEffect(() => { - if (data) { - setEditConfig({ ...(data || {}) }) - const knowledge_bases = [...(data.knowledge_bases || [])] - setKnowledgeList(knowledge_bases) + if (value && JSON.stringify(value) !== JSON.stringify(editConfig)) { + setEditConfig({ ...(value || {}) }) + const knowledge_bases = [...(value.knowledge_bases || [])] + + // 检查是否有knowledge_bases缺少name字段 + const basesWithoutName = knowledge_bases.filter(base => !base.name) + if (basesWithoutName.length > 0) { + // 调用接口获取完整的知识库信息 + getKnowledgeBaseList().then(res => { + const fullBases = knowledge_bases.map(base => { + if (!base.name) { + const fullBase = res.items.find((item: any) => item.id === base.kb_id) + return fullBase ? { ...base, ...fullBase } : base + } + return base + }) + setKnowledgeList(fullBases) + }).catch(() => { + setKnowledgeList(knowledge_bases) + }) + } else { + setKnowledgeList(knowledge_bases) + } } - }, [data]) + }, [value]) const handleKnowledgeConfig = () => { knowledgeGlobalConfigModalRef.current?.handleOpen() @@ -43,7 +63,7 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) const handleDeleteKnowledge = (id: string) => { const list = knowledgeList.filter(item => item.id !== id) setKnowledgeList([...list]) - onUpdate({ + onChange && onChange({ ...editConfig, knowledge_bases: [...list], }) @@ -65,7 +85,7 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) list = [...values as KnowledgeBase[]] } setKnowledgeList([...list]) - onUpdate({ + onChange && onChange({ ...editConfig, knowledge_bases: [...list], }) @@ -77,14 +97,14 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) config: {...values as KnowledgeConfigForm} } setKnowledgeList([...list]) - onUpdate({ + onChange && onChange({ ...editConfig, knowledge_bases: [...list], }) } else if (type === 'rerankerConfig') { const rerankerValues = values as RerankerConfig setEditConfig(prev => ({ ...prev, ...rerankerValues })) - onUpdate({ + onChange && onChange({ ...editConfig, ...rerankerValues, reranker_id: rerankerValues.rerank_model ? rerankerValues.reranker_id : undefined, @@ -93,55 +113,54 @@ const Knowledge: FC<{data: KnowledgeConfig; onUpdate: (config: KnowledgeConfig) } } return ( - handleKnowledgeConfig()}>{t('application.globalConfig')} + + + + } > -
-
{t('application.associatedKnowledgeBase')}
- -
- {knowledgeList.length === 0 ? : ( - -
-
- {item.name} - - {item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')} - -
{t('application.contains', {include_count: item.doc_num})}
+ renderItem={(item) => { + if (!item.id) return null + return ( + +
+
+ {item.name} + + {item.status === 1 ? t('common.enable') : item.status === 0 ? t('common.disabled') : t('common.deleted')} + +
{t('application.contains', {include_count: item.doc_num})}
+
+ +
handleEditKnowledge(item)} + >
+
handleDeleteKnowledge(item.id)} + >
+
- -
handleEditKnowledge(item)} - >
-
handleDeleteKnowledge(item.id)} - >
-
-
- - )} + + ) + }} /> } - {/* 全局设置 */} - {/* 知识库列表 */} void; } -const retrieveTypes = ['participle', 'semantic', 'hybrid'] +const retrieveTypes: RetrieveType[] = ['participle', 'semantic', 'hybrid'] const KnowledgeConfigModal = forwardRef(({ refresh, @@ -33,8 +33,11 @@ const KnowledgeConfigModal = forwardRef { form.setFieldsValue({ - retrieve_type: retrieveTypes[0], + retrieve_type: data?.config?.retrieve_type || retrieveTypes[0], kb_id: data.id, + top_k: data?.config?.top_k || 5, + similarity_threshold: data?.config?.similarity_threshold || 0.5, + vector_similarity_weight: data?.config?.vector_similarity_weight || 0.5, ...(data || {}), ...(data?.config || {}), }) @@ -62,12 +65,10 @@ const KnowledgeConfigModal = forwardRef { if (values?.retrieve_type) { - const initialValues = Object.keys(values).map(key => { - return { - [key as keyof KnowledgeConfigForm]: (key === 'kb_id' || key === 'retrieve_type') ? values[key] : undefined - } - }) - form.resetFields(initialValues) + const fieldsToReset = Object.keys(values).filter(key => + key !== 'kb_id' && key !== 'retrieve_type' + ) as (keyof KnowledgeConfigForm)[]; + form.resetFields(fieldsToReset); } }, [values?.retrieve_type]) @@ -84,12 +85,12 @@ const KnowledgeConfigModal = forwardRef {data && ( -
-
+
+
{data.name} -
{t('application.contains', {include_count: data.doc_num})}
+
{t('application.contains', {include_count: data.doc_num})}
-
{formatDateTime(data.updated_at, 'YYYY-MM-DD HH:mm:ss')}
+
{formatDateTime(data.updated_at, 'YYYY-MM-DD HH:mm:ss')}
)} {/* 语义相似度阈值 similarity_threshold */} {values?.retrieve_type === 'semantic' && ( @@ -123,6 +130,7 @@ const KnowledgeConfigModal = forwardRef -
{t('application.globalConfigDesc')}
+
{t('application.globalConfigDesc')}
{/* 结果重排 */} -
-
+
+
{t('application.rerankModel')} -
{t('application.rerankModelDesc')}
+
{t('application.rerankModelDesc')}
@@ -110,7 +110,12 @@ const KnowledgeGlobalConfigModal = forwardRef - + form.setFieldValue('reranker_top_k', value)} + /> } diff --git a/web/src/views/ApplicationConfig/components/KnowledgeListModal.tsx b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeListModal.tsx similarity index 88% rename from web/src/views/ApplicationConfig/components/KnowledgeListModal.tsx rename to web/src/views/ApplicationConfig/components/Knowledge/KnowledgeListModal.tsx index 0c7b47b2..f1ebd516 100644 --- a/web/src/views/ApplicationConfig/components/KnowledgeListModal.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeListModal.tsx @@ -2,7 +2,7 @@ import { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; import { Space, List } from 'antd'; import { useTranslation } from 'react-i18next'; import clsx from 'clsx' -import type { KnowledgeModalRef, KnowledgeBase } from '../types' +import type { KnowledgeModalRef, KnowledgeBase } from './types' import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types' import RbModal from '@/components/RbModal' import { getKnowledgeBaseList } from '@/api/knowledgeBase' @@ -39,12 +39,13 @@ const KnowledgeListModal = forwardRef(({ setQuery({}) setSelectedIds([]) setSelectedRows([]) - getList() }; useEffect(() => { - getList() - }, [query.keywords]) + if (visible) { + getList() + } + }, [query.keywords, visible]) const getList = () => { getKnowledgeBaseList(undefined, { ...query, @@ -124,15 +125,15 @@ const KnowledgeListModal = forwardRef(({ dataSource={filterList} renderItem={(item: KnowledgeBase) => ( -
handleSelect(item)}> -
+
{item.name} -
{t('application.contains', {include_count: item.doc_num})}
+
{t('application.contains', {include_count: item.doc_num})}
-
{formatDateTime(item.created_at, 'YYYY-MM-DD HH:mm:ss')}
+
{formatDateTime(item.created_at, 'YYYY-MM-DD HH:mm:ss')}
)} diff --git a/web/src/views/ApplicationConfig/components/Knowledge/types.ts b/web/src/views/ApplicationConfig/components/Knowledge/types.ts new file mode 100644 index 00000000..f4f9ed17 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Knowledge/types.ts @@ -0,0 +1,30 @@ +import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types' +export interface RerankerConfig { + rerank_model?: boolean | undefined; + reranker_id?: string | undefined; + reranker_top_k?: number | undefined; +} +export type RetrieveType = 'participle' | 'semantic' | 'hybrid' +export interface KnowledgeConfigForm { + kb_id?: string; + similarity_threshold?: number; + vector_similarity_weight?: number; + top_k?: number; + retrieve_type?: RetrieveType; +} +export interface KnowledgeBase extends KnowledgeBaseListItem, KnowledgeConfigForm { + config?: KnowledgeConfigForm +} +export interface KnowledgeConfig extends RerankerConfig { + knowledge_bases: KnowledgeBase[]; +} + +export interface KnowledgeConfigModalRef { + handleOpen: (data: KnowledgeBase) => void; +} +export interface KnowledgeGlobalConfigModalRef { + handleOpen: () => void; +} +export interface KnowledgeModalRef { + handleOpen: (config?: KnowledgeConfig[]) => void; +} \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/ToolList.tsx b/web/src/views/ApplicationConfig/components/ToolList/ToolList.tsx similarity index 93% rename from web/src/views/ApplicationConfig/components/ToolList.tsx rename to web/src/views/ApplicationConfig/components/ToolList/ToolList.tsx index fde7286b..e914d879 100644 --- a/web/src/views/ApplicationConfig/components/ToolList.tsx +++ b/web/src/views/ApplicationConfig/components/ToolList/ToolList.tsx @@ -1,22 +1,22 @@ import { type FC, useRef, useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Space, Button, List, Switch } from 'antd' -import Card from './Card' +import Card from '../Card' import type { ToolModalRef, ToolOption -} from '../types' +} from './types' import Empty from '@/components/Empty' import ToolModal from './ToolModal' import { getToolMethods, getToolDetail } from '@/api/tools' -const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => void}> = ({data, onUpdate}) => { +const ToolList: FC<{ value?: ToolOption[]; onChange?: (config: ToolOption[]) => void}> = ({value, onChange}) => { const { t } = useTranslation() const toolModalRef = useRef(null) const [toolList, setToolList] = useState([]) useEffect(() => { - if (data) { - const processedData = data.map(async (item) => { + if (value) { + const processedData = value.map(async (item) => { if (!item.label && item.tool_id) { try { const [toolDetail, methods] = await Promise.all([ @@ -77,7 +77,7 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi Promise.all(processedData).then(setToolList) } - }, [data]) + }, [value]) const handleAddTool = () => { toolModalRef.current?.handleOpen() @@ -85,12 +85,12 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi const updateTools = (tool: ToolOption) => { const list = [...toolList, tool] setToolList(list) - onUpdate(list) + onChange && onChange(list) } const handleDeleteTool = (index: number) => { const list = toolList.filter((_item, idx) => idx !== index) setToolList([...list]) - onUpdate(list) + onChange && onChange(list) } const handleChangeEnabled = (index: number) => { const list = toolList.map((item, idx) => { @@ -103,7 +103,7 @@ const ToolList: FC<{ data: ToolOption[]; onUpdate: (config: ToolOption[]) => voi return item }) setToolList([...list]) - onUpdate(list) + onChange && onChange(list) } return ( voi } > - {toolList.length === 0 ? : diff --git a/web/src/views/ApplicationConfig/components/ToolModal.tsx b/web/src/views/ApplicationConfig/components/ToolList/ToolModal.tsx similarity index 100% rename from web/src/views/ApplicationConfig/components/ToolModal.tsx rename to web/src/views/ApplicationConfig/components/ToolList/ToolModal.tsx diff --git a/web/src/views/ApplicationConfig/components/ToolList/types.ts b/web/src/views/ApplicationConfig/components/ToolList/types.ts new file mode 100644 index 00000000..142ffe26 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/ToolList/types.ts @@ -0,0 +1,26 @@ +export interface ToolOption { + value?: string | number | null; + label?: React.ReactNode; + description?: string; + children?: ToolOption[]; + isLeaf?: boolean; + method_id?: string; + operation?: string; + parameters?: Parameter[]; + tool_id?: string; + enabled?: boolean; +} +export interface Parameter { + name: string; + type: string; + description: string; + required: boolean; + default: any; + enum: null | string[]; + minimum: number; + maximum: number; + pattern: null | string; +} +export interface ToolModalRef { + handleOpen: () => void; +} \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/VariableList.tsx b/web/src/views/ApplicationConfig/components/VariableList.tsx deleted file mode 100644 index fbadf2ea..00000000 --- a/web/src/views/ApplicationConfig/components/VariableList.tsx +++ /dev/null @@ -1,131 +0,0 @@ -import { type FC, useRef, useState, useEffect } from 'react' -import { useTranslation } from 'react-i18next' -import { Space, Button, Switch } from 'antd' -import variablesEmpty from '@/assets/images/application/variablesEmpty.svg' -import Card from './Card' -import Table from '@/components/Table'; -import type { Variable, VariableEditModalRef } from '../types' -import Empty from '@/components/Empty' -import VariableEditModal from './VariableEditModal' - -interface VariableListProps { - data?: Variable[]; - onUpdate: (data: Variable[]) => void; -} -const VariableList: FC = ({data = [], onUpdate}) => { - const { t } = useTranslation() - const variableEditModalRef = useRef(null) - const [variableList, setVariableList] = useState([]) - const [maxIndex, setMaxIndex] = useState(0) - - useEffect(() => { - if (!data || data.length === 0) return - const list = data.map((item, index) => ({ - ...item, - index - })) - setVariableList(list) - onUpdate(list) - setMaxIndex(list.length) - }, [data]) - - const handleAddVariable = () => { - variableEditModalRef.current?.handleOpen() - } - const handleSaveVariable = (value: Variable) => { - if (value.index !== undefined && value.index >= 0) { - const index = variableList.findIndex(item => item.index === value.index) - if (index !== -1) { - const newData = [...variableList] - newData[index] = value - setVariableList([...newData]) - onUpdate([...newData]) - } - } else { - const list = [...variableList, { - index: maxIndex + 1, - ...value - }] - setVariableList(list) - onUpdate([...list]) - setMaxIndex(maxIndex + 1) - } - } - const handleDeleteVariable = (index: number) => { - const list = variableList.filter((_, i) => i !== index) - setVariableList(list) - onUpdate([...list]) - } - return ( - -
-
- {t('application.VariableManagement')} - ({t('application.VariableManagementDesc')}) -
- -
- - {/* List */} - {variableList.length > 0 - ? ( -
- t(`application.${type}`) - }, - { - title: t('application.variableKey'), - dataIndex: 'name', - key: 'name', - }, - { - title: t('application.variableName'), - dataIndex: 'display_name', - key: 'display_name', - }, - { - title: t('application.optional'), - dataIndex: 'required', - key: 'required', - render: (required) => - }, - { - title: t('common.operation'), - key: 'action', - render: (_, record, index: number) => ( - - - - - ), - }, - ]} - initialData={variableList as unknown as Record[]} - emptySize={88} - /> - - ) - : - } - - - ) -} -export default VariableList \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/ApiExtensionModal.tsx b/web/src/views/ApplicationConfig/components/VariableList/ApiExtensionModal.tsx similarity index 99% rename from web/src/views/ApplicationConfig/components/ApiExtensionModal.tsx rename to web/src/views/ApplicationConfig/components/VariableList/ApiExtensionModal.tsx index b1c7450a..4f4f9047 100644 --- a/web/src/views/ApplicationConfig/components/ApiExtensionModal.tsx +++ b/web/src/views/ApplicationConfig/components/VariableList/ApiExtensionModal.tsx @@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Input } from 'antd'; import { useTranslation } from 'react-i18next'; -import type { ApiExtensionModalData, ApiExtensionModalRef } from '../types' +import type { ApiExtensionModalData, ApiExtensionModalRef } from './types' import RbModal from '@/components/RbModal' const FormItem = Form.Item; diff --git a/web/src/views/ApplicationConfig/components/VariableEditModal.tsx b/web/src/views/ApplicationConfig/components/VariableList/VariableEditModal.tsx similarity index 96% rename from web/src/views/ApplicationConfig/components/VariableEditModal.tsx rename to web/src/views/ApplicationConfig/components/VariableList/VariableEditModal.tsx index 3efd721c..69e213fb 100644 --- a/web/src/views/ApplicationConfig/components/VariableEditModal.tsx +++ b/web/src/views/ApplicationConfig/components/VariableList/VariableEditModal.tsx @@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState, useRef } from 'react'; import { Form, Input, Select, InputNumber, Checkbox, Tag, Divider, Button } from 'antd'; import { useTranslation } from 'react-i18next'; -import type { ApiExtensionModalRef, Variable, VariableEditModalRef } from '../types' +import type { ApiExtensionModalRef, Variable, VariableEditModalRef } from './types' import RbModal from '@/components/RbModal' import SortableList from '@/components/SortableList' import ApiExtensionModal from './ApiExtensionModal' @@ -137,7 +137,14 @@ const VariableEditModal = forwardRef - + { + if (!form.getFieldValue('display_name')) { + form.setFieldValue('display_name', e.target.value) + } + }} + /> {/* 显示名称 */} void; +} +const VariableList: FC = ({value = [], onChange}) => { + const { t } = useTranslation() + const variableEditModalRef = useRef(null) + + const handleAddVariable = () => { + variableEditModalRef.current?.handleOpen() + } + const handleSaveVariable = (variable: Variable) => { + const newList = [...(value || [])] + if (variable.index !== undefined && variable.index >= 0) { + const index = newList.findIndex(item => item.index === variable.index) + if (index !== -1) { + newList[index] = variable + } + } else { + newList.push({ ...variable, index: Date.now() }) + } + onChange?.(newList) + } + return ( + + {t('application.variableConfiguration')} + ({t('application.VariableManagementDesc')}) + } + extra={} + > + + {(fields, { remove }) => { + return ( + <> + {fields.length > 0 ? ( +
+
t(`application.${type}`) + }, + { + title: t('application.variableKey'), + dataIndex: 'name', + key: 'name', + }, + { + title: t('application.variableName'), + dataIndex: 'display_name', + key: 'display_name', + }, + { + title: t('application.optional'), + dataIndex: 'required', + key: 'required', + render: (required) => + }, + { + title: t('common.operation'), + key: 'action', + render: (_, record, index: number) => ( + + + + + ), + }, + ]} + initialData={value as unknown as Record[]} + emptySize={88} + /> + + ) : ( + + )} + + ) + }} + + + + ) +} +export default VariableList \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/VariableList/types.ts b/web/src/views/ApplicationConfig/components/VariableList/types.ts new file mode 100644 index 00000000..f262dda1 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/VariableList/types.ts @@ -0,0 +1,28 @@ +export interface Variable { + index?: number; + name: string; + display_name: string; + type: string; + required: boolean; + max_length?: number; + description?: string; + + key?: string; + default_value?: string; + options?: string[]; + api_extension?: string; + hidden?: boolean; + value?: any; +} +export interface VariableEditModalRef { + handleOpen: (values?: Variable) => void; +} + +export interface ApiExtensionModalData { + name: string; + apiEndpoint: string; + apiKey: string; +} +export interface ApiExtensionModalRef { + handleOpen: () => void; +} \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/types.ts b/web/src/views/ApplicationConfig/types.ts index 6eb97f22..6f641ebb 100644 --- a/web/src/views/ApplicationConfig/types.ts +++ b/web/src/views/ApplicationConfig/types.ts @@ -1,4 +1,6 @@ -import type { KnowledgeBaseListItem } from '@/views/KnowledgeBase/types' +import type { KnowledgeConfig } from './components/Knowledge/types' +import type { Variable } from './components/VariableList/types' +import type { ToolOption } from './components/ToolList/types' import type { ChatItem } from '@/components/Chat/types' import type { GraphRef } from '@/views/Workflow/types'; import type { ApiKey } from '@/views/ApiKeyManagement/types' @@ -14,55 +16,6 @@ export interface ModelConfig { n: number; stop?: string; } - -/*************** 知识库相关 ******************/ -export interface RerankerConfig { - rerank_model?: boolean | undefined; - reranker_id?: string | undefined; - reranker_top_k?: number | undefined; -} -export interface KnowledgeConfigForm { - kb_id?: string; - similarity_threshold?: number; - vector_similarity_weight?: number; - top_k?: number; - retrieve_type?: 'participle' | 'semantic' | 'hybrid'; -} -export interface KnowledgeBase extends KnowledgeBaseListItem, KnowledgeConfigForm { - config?: KnowledgeConfigForm -} -export interface KnowledgeConfig extends RerankerConfig { - knowledge_bases: KnowledgeBase[]; -} - -export interface KnowledgeConfigModalRef { - handleOpen: (data: KnowledgeBase) => void; -} -export interface KnowledgeGlobalConfigModalRef { - handleOpen: () => void; -} -/*********** end 知识库相关 ******************/ - -/*************** 变量相关 ******************/ -export interface Variable { - index?: number; - name: string; - display_name: string; - type: string; - required: boolean; - max_length?: number; - description?: string; - - key: string; - default_value?: string; - options?: string[]; - api_extension?: string; - hidden?: boolean; -} -export interface VariableEditModalRef { - handleOpen: (values?: Variable) => void; -} -/*********** end 变量相关 ******************/ export interface MemoryConfig { enabled: boolean; memory_content?: string; @@ -131,17 +84,6 @@ export interface ModelConfigModalData { export interface AiPromptModalRef { handleOpen: () => void; } -export interface KnowledgeModalRef { - handleOpen: (config?: KnowledgeConfig[]) => void; -} -export interface ApiExtensionModalData { - name: string; - apiEndpoint: string; - apiKey: string; -} -export interface ApiExtensionModalRef { - handleOpen: () => void; -} export interface ChatData { label?: string; model_config_id?: string; @@ -206,30 +148,6 @@ export interface AiPromptForm { message?: string; current_prompt?: string; } -export interface ToolModalRef { - handleOpen: () => void; -} - -export interface ToolOption { - value?: string | number | null; - label?: React.ReactNode; - description?: string; - children?: ToolOption[]; - isLeaf?: boolean; - method_id?: string; - operation?: string; - parameters?: Parameter[]; - tool_id?: string; - enabled?: boolean; -} -export interface Parameter { - name: string; - type: string; - description: string; - required: boolean; - default: any; - enum: null | string[]; - minimum: number; - maximum: number; - pattern: null | string; +export interface ChatVariableConfigModalRef { + handleOpen: (values: Variable[]) => void; } \ No newline at end of file From 2d90b0c7527f288f71492198d86bcf4642614b1a Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 19 Jan 2026 17:00:26 +0800 Subject: [PATCH 05/12] refactor: extract useVariableList; properties add output variable --- web/src/i18n/en.ts | 1 + web/src/i18n/zh.ts | 1 + .../Editor/plugin/InitialValuePlugin.tsx | 6 +- .../Properties/GroupVariableList/index.tsx | 16 +- .../components/Properties/VariableSelect.tsx | 4 +- .../Properties/hooks/useVariableList.ts | 209 +++ .../Workflow/components/Properties/index.tsx | 1155 +++++------------ .../views/Workflow/hooks/useWorkflowGraph.ts | 5 - 8 files changed, 562 insertions(+), 835 deletions(-) create mode 100644 web/src/views/Workflow/components/Properties/hooks/useVariableList.ts diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 1341bf55..bc757797 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1967,6 +1967,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re value: 'Value', addCase: 'Add Condition', addVariable: 'Add Variables', + output: 'Output Variable' }, clear: 'Clear', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index b6834fea..eeee6bc9 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -2061,6 +2061,7 @@ export const zh = { value: '值', addCase: '添加条件', addVariable: '添加变量', + output: '输出变量' }, clear: '清空', diff --git a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx index 5ad18dcd..22de9592 100644 --- a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx @@ -33,7 +33,8 @@ const InitialValuePlugin: React.FC = ({ value, options useEffect(() => { if (value !== prevValueRef.current && !isUserInputRef.current) { - editor.update(() => { + queueMicrotask(() => { + editor.update(() => { const root = $getRoot(); root.clear(); @@ -98,7 +99,8 @@ const InitialValuePlugin: React.FC = ({ value, options }); root.append(paragraph); } - }, { discrete: true }); + }, { discrete: true }); + }); } prevValueRef.current = value; diff --git a/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx b/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx index 06ea9e86..81eac38e 100644 --- a/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx +++ b/web/src/views/Workflow/components/Properties/GroupVariableList/index.tsx @@ -17,7 +17,7 @@ const GroupVariableList: FC = ({ name, options = [], isCanAdd = false, - size = "middle" + size = "small" }) => { const { t } = useTranslation(); const form = Form.useFormInstance(); @@ -37,16 +37,10 @@ const GroupVariableList: FC = ({ } return ( -
- -
- - {t('workflow.config.var-aggregator.variable')} - - - +
+
+ {t('workflow.config.var-aggregator.variable')} +
= ({ if (filterOption) { return ( diff --git a/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts b/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts new file mode 100644 index 00000000..ab37fec9 --- /dev/null +++ b/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts @@ -0,0 +1,209 @@ +import { useMemo, useEffect, useState } from 'react'; +import { Graph, Node } from '@antv/x6'; +import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'; +import type { ChatVariable } from '../../../types'; + +const NODE_VARIABLES = { + llm: [{ label: 'output', dataType: 'string', field: 'output' }], + 'jinja-render': [{ label: 'output', dataType: 'string', field: 'output' }], + tool: [{ label: 'data', dataType: 'string', field: 'data' }], + 'knowledge-retrieval': [{ label: 'output', dataType: 'array[object]', field: 'output' }], + 'parameter-extractor': [ + { label: '__is_success', dataType: 'number', field: '__is_success' }, + { label: '__reason', dataType: 'string', field: '__reason' } + ], + 'http-request': [ + { label: 'body', dataType: 'string', field: 'body' }, + { label: 'status_code', dataType: 'number', field: 'status_code' } + ], + 'question-classifier': [{ label: 'class_name', dataType: 'string', field: 'class_name' }], + 'memory-read': [ + { label: 'answer', dataType: 'string', field: 'answer' }, + { label: 'intermediate_outputs', dataType: 'array[object]', field: 'intermediate_outputs' } + ] +} as const; + +const addVariable = ( + list: Suggestion[], + keys: Set, + key: string, + label: string, + dataType: string, + value: string, + nodeData: any, + extra?: Partial +) => { + if (!keys.has(key)) { + keys.add(key); + list.push({ key, label, type: 'variable', dataType, value, nodeData, ...extra }); + } +}; + +const processNodeVariables = ( + nodeData: any, + dataNodeId: string, + variableList: Suggestion[], + addedKeys: Set +) => { + const { type, config } = nodeData; + + if (type in NODE_VARIABLES) { + NODE_VARIABLES[type as keyof typeof NODE_VARIABLES].forEach(({ label, dataType, field }) => { + addVariable(variableList, addedKeys, `${dataNodeId}_${label}`, label, dataType, `${dataNodeId}.${field}`, nodeData); + }); + } + + switch (type) { + case 'start': + [...(config?.variables?.defaultValue ?? []), ...(config?.variables?.value ?? [])].forEach((v: any) => { + if (v?.name) addVariable(variableList, addedKeys, `${dataNodeId}_${v.name}`, v.name, v.type, `${dataNodeId}.${v.name}`, nodeData); + }); + config?.variables?.sys?.forEach((v: any) => { + if (v?.name) addVariable(variableList, addedKeys, `${dataNodeId}_sys_${v.name}`, `sys.${v.name}`, v.type, `sys.${v.name}`, nodeData); + }); + break; + + case 'parameter-extractor': + (config?.params?.defaultValue || []).forEach((p: any) => { + if (p?.name) addVariable(variableList, addedKeys, `${dataNodeId}_${p.name}`, p.name, p.type || 'string', `${dataNodeId}.${p.name}`, nodeData); + }); + break; + + case 'var-aggregator': + if (config.group.defaultValue) { + (config.group_variables.defaultValue || []).forEach((gv: any) => { + if (gv?.key) { + let dt = 'string'; + if (gv.value?.[0]) { + const fv = variableList.find(v => `{{${v.value}}}` === gv.value[0]); + if (fv) dt = fv.dataType; + } + addVariable(variableList, addedKeys, `${dataNodeId}_${gv.key}`, gv.key, dt, `${dataNodeId}.${gv.key}`, nodeData); + } + }); + } else { + const fv = (config.group_variables.defaultValue || [])[0]; + let dt = 'any'; + if (fv) { + const found = variableList.find(v => `{{${v.value}}}` === fv); + if (found) dt = found.dataType; + } + addVariable(variableList, addedKeys, `${dataNodeId}_output`, 'output', dt, `${dataNodeId}.output`, nodeData); + } + break; + + case 'iteration': + let dt = 'string'; + if (nodeData.output) { + const sv = variableList.find(v => v.value === nodeData.output); + if (sv) dt = sv.dataType; + } + addVariable(variableList, addedKeys, `${dataNodeId}_output`, 'output', `array[${dt}]`, `${dataNodeId}.output`, nodeData); + break; + + case 'loop': + (config.cycle_vars.defaultValue || []).forEach((cv: any) => { + if (cv.name?.trim()) addVariable(variableList, addedKeys, `${dataNodeId}_cycle_${cv.name}`, cv.name, cv.type || 'string', `${dataNodeId}.${cv.name}`, nodeData); + }); + break; + } +}; + +const hasOutputNodeTypes = [ + 'llm', + 'knowledge-retrieval', + 'memory-read', + 'question-classifier', + 'var-aggregator', + 'http-request', + 'tool', + 'jinja-render' +] +export const getCurrentNodeVariables = (nodeData: any, values: any): Suggestion[] => { + if (!nodeData || !hasOutputNodeTypes.includes(nodeData.type)) return []; + const list: Suggestion[] = []; + const keys = new Set(); + const dataNodeId = nodeData.id; + + processNodeVariables({ + ...nodeData, + config: { + ...nodeData.config, + ...values + } + }, dataNodeId, list, keys); + return nodeData.type === 'var-aggregator' && !nodeData.config.group.defaultValue ? [] : list; +}; + +export const useVariableList = ( + selectedNode: Node | null | undefined, + graphRef: React.MutableRefObject, + chatVariables: ChatVariable[] +) => { + const [trigger, setTrigger] = useState(0); + + const variableList = useMemo(() => { + if (!selectedNode || !graphRef?.current) return []; + + const list: Suggestion[] = []; + const graph = graphRef.current; + const edges = graph.getEdges(); + const nodes = graph.getNodes(); + const keys = new Set(); + + const getPreviousNodes = (nodeId: string, visited = new Set()): string[] => { + if (visited.has(nodeId)) return []; + visited.add(nodeId); + const prev = edges.filter(e => e.getTargetCellId() === nodeId).map(e => e.getSourceCellId()); + return [...prev, ...prev.flatMap(id => getPreviousNodes(id, visited))]; + }; + + const getParentLoop = (nodeId: string): Node | null => { + const node = nodes.find(n => n.id === nodeId); + const cycle = node?.getData()?.cycle; + if (cycle) { + const parent = nodes.find(n => n.getData().id === cycle); + if (parent?.getData()?.type === 'loop' || parent?.getData()?.type === 'iteration') return parent; + } + return null; + }; + + const childIds = nodes.filter(n => n.getData()?.cycle === selectedNode.id).map(n => n.id); + const parentLoop = getParentLoop(selectedNode.id); + const relevantIds = [...getPreviousNodes(selectedNode.id), ...childIds, ...(parentLoop ? getPreviousNodes(parentLoop.id) : [])]; + + chatVariables?.forEach(v => addVariable(list, keys, `CONVERSATION_${v.name}`, v.name, v.type, `conv.${v.name}`, { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' }, { group: 'CONVERSATION' })); + + relevantIds.forEach(id => { + const node = nodes.find(n => n.id === id); + if (node) processNodeVariables(node.getData(), node.getData().id, list, keys); + }); + + if (parentLoop) { + const pd = parentLoop.getData(); + const pid = pd.id; + if (pd.type === 'loop') { + (pd.cycle_vars || []).forEach((cv: any) => addVariable(list, keys, `${pid}_cycle_${cv.name}`, cv.name, cv.type || 'String', `${pid}.${cv.name}`, pd)); + } else if (pd.type === 'iteration' && pd.config.input.defaultValue) { + let itemType = 'object'; + const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue); + if (iv?.dataType.startsWith('array[')) itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1'); + addVariable(list, keys, `${pid}_item`, 'item', itemType, `${pid}.item`, pd); + addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd); + } + } + + return list; + }, [selectedNode, graphRef, trigger, chatVariables]); + + useEffect(() => { + if (!graphRef?.current) return; + const graph = graphRef.current; + const handler = () => setTrigger(p => p + 1); + const events = ['edge:added', 'edge:removed', 'edge:changed', 'edge:connected', 'node:added', 'node:removed', 'node:change:data']; + events.forEach(e => graph.on(e, handler)); + return () => events.forEach(e => graph.off(e, handler)); + }, [graphRef]); + + return variableList; +}; diff --git a/web/src/views/Workflow/components/Properties/index.tsx b/web/src/views/Workflow/components/Properties/index.tsx index 0ea5e284..6d4571dc 100644 --- a/web/src/views/Workflow/components/Properties/index.tsx +++ b/web/src/views/Workflow/components/Properties/index.tsx @@ -2,7 +2,8 @@ import { type FC, useEffect, useState, useRef, useMemo } from "react"; import clsx from 'clsx' import { useTranslation } from 'react-i18next' import { Graph, Node } from '@antv/x6'; -import { Form, Input, Select, InputNumber, Switch } from 'antd' +import { Form, Input, Select, InputNumber, Switch, Divider, Space } from 'antd' +import { CaretDownOutlined, CaretRightOutlined } from '@ant-design/icons'; import type { NodeConfig, NodeProperties, ChatVariable } from '../../types' import Empty from '@/components/Empty'; @@ -24,7 +25,7 @@ import AssignmentList from './AssignmentList' import ToolConfig from './ToolConfig' import MemoryConfig from './MemoryConfig' import VariableList from './VariableList' -// import { calculateVariableList } from './utils/variableListCalculator' +import { useVariableList, getCurrentNodeVariables } from './hooks/useVariableList' import styles from './properties.module.css' import Editor from "../Editor"; import RbSlider from './RbSlider' @@ -49,12 +50,12 @@ const Properties: FC = ({ const [form] = Form.useForm(); const [configs, setConfigs] = useState>({} as Record) const values = Form.useWatch([], form); - const [graphUpdateTrigger, setGraphUpdateTrigger] = useState(0) const prevMappingNamesRef = useRef([]) const prevTemplateVarsRef = useRef([]) const syncTimeoutRef = useRef(null) const isSyncingRef = useRef(false) const lastSyncSourceRef = useRef<'mapping' | 'template' | null>(null) + const variableList = useVariableList(selectedNode, graphRef, chatVariables) useEffect(() => { if (selectedNode?.getData()?.id) { @@ -62,6 +63,7 @@ const Properties: FC = ({ prevMappingNamesRef.current = [] prevTemplateVarsRef.current = [] lastSyncSourceRef.current = null + setOutputCollapsed(true) } }, [selectedNode?.getData()?.id]) @@ -244,513 +246,7 @@ const Properties: FC = ({ } }, [values, selectedNode, form]) - const variableList = useMemo(() => { - if (!selectedNode || !graphRef?.current) return []; - - const variableList: Suggestion[] = []; - const graph = graphRef.current; - const edges = graph.getEdges(); - const nodes = graph.getNodes(); - const addedKeys = new Set(); - - // Find all connected previous nodes (recursive) - const getAllPreviousNodes = (nodeId: string, visited = new Set()): string[] => { - if (visited.has(nodeId)) return []; - visited.add(nodeId); - - const directPrevious = edges - .filter(edge => edge.getTargetCellId() === nodeId) - .map(edge => edge.getSourceCellId()); - - const allPrevious = [...directPrevious]; - directPrevious.forEach(prevNodeId => { - allPrevious.push(...getAllPreviousNodes(prevNodeId, visited)); - }); - - return allPrevious; - }; - - // Find child nodes (nodes whose cycle field equals current node's ID) - const getChildNodes = (nodeId: string): string[] => { - return nodes - .filter(node => node.getData()?.cycle === nodeId) - .map(node => node.id); - }; - - // Find parent loop/iteration node if current node is a child - const getParentLoopNode = (nodeId: string): Node | null => { - const node = nodes.find(n => n.id === nodeId); - if (!node) return null; - - const nodeData = node.getData(); - const cycle = nodeData?.cycle; - - if (cycle) { - const parentNode = nodes.find(n => n.getData().id === cycle); - if (parentNode) { - const parentData = parentNode.getData(); - if (parentData?.type === 'loop' || parentData?.type === 'iteration') { - return parentNode; - } - } - } - return null; - }; - - const allPreviousNodeIds = getAllPreviousNodes(selectedNode.id); - const childNodeIds = getChildNodes(selectedNode.id); - const parentLoopNode = getParentLoopNode(selectedNode.id); - - console.log('childNodeIds', selectedNode, childNodeIds) - let allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds]; - - // Add variables from nodes preceding the parent loop/iteration node if current node is a child - if (parentLoopNode) { - const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id); - allRelevantNodeIds.push(...parentPreviousNodeIds); - } - // Add conversation variables from global config - const conversationVariables = chatVariables || []; - - conversationVariables.forEach((variable: any) => { - const key = `CONVERSATION_${variable.name}`; - if (!addedKeys.has(key)) { - addedKeys.add(key); - variableList.push({ - key, - label: variable.name, - type: 'variable', - dataType: variable.type, - value: `conv.${variable.name}`, - nodeData: { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' }, - group: 'CONVERSATION' - }); - } - }); - - allRelevantNodeIds.forEach(nodeId => { - const node = nodes.find(n => n.id === nodeId); - if (!node) return; - - const nodeData = node.getData(); - const dataNodeId = nodeData.id; // Use the data.id instead of node.id for consistency - - switch(nodeData.type) { - case 'start': - const list = [ - ...(nodeData.config?.variables?.defaultValue ?? []), - ...(nodeData.config?.variables?.value ?? []) - ] - list.forEach((variable: any) => { - if (!variable || !variable?.name) return; - const key = `${dataNodeId}_${variable.name}`; - if (!addedKeys.has(key)) { - addedKeys.add(key); - variableList.push({ - key, - label: variable.name, - type: 'variable', - dataType: variable.type, - value: `${dataNodeId}.${variable.name}`, - nodeData: nodeData, - }); - } - }); - nodeData.config?.variables?.sys?.forEach((variable: any) => { - if (!variable || !variable?.name) return; - const key = `${dataNodeId}_sys_${variable.name}`; - if (!addedKeys.has(key)) { - addedKeys.add(key); - variableList.push({ - key, - label: `sys.${variable.name}`, - type: 'variable', - dataType: variable.type, - value: `sys.${variable.name}`, - nodeData: nodeData, - }); - } - }); - break - case 'llm': - const llmKey = `${dataNodeId}_output`; - if (!addedKeys.has(llmKey)) { - addedKeys.add(llmKey); - variableList.push({ - key: llmKey, - label: 'output', - type: 'variable', - dataType: 'string', - value: `${dataNodeId}.output`, - nodeData: nodeData, - }); - } - break - case 'knowledge-retrieval': - const knowledgeKey = `${dataNodeId}_output`; - if (!addedKeys.has(knowledgeKey)) { - addedKeys.add(knowledgeKey); - variableList.push({ - key: knowledgeKey, - label: 'output', - type: 'variable', - dataType: 'array[object]', - value: `${dataNodeId}.output`, - nodeData: nodeData, - }); - } - break - case 'parameter-extractor': - const successKey = `${dataNodeId}___is_success`; - const reasonKey = `${dataNodeId}___reason`; - if (!addedKeys.has(successKey)) { - addedKeys.add(successKey); - variableList.push({ - key: successKey, - label: '__is_success', - type: 'variable', - dataType: 'number', - value: `${dataNodeId}.__is_success`, - nodeData: nodeData, - }); - } - if (!addedKeys.has(reasonKey)) { - addedKeys.add(reasonKey); - variableList.push({ - key: reasonKey, - label: '__reason', - type: 'variable', - dataType: 'string', - value: `${dataNodeId}.__reason`, - nodeData: nodeData, - }); - } - // Add params variables - const paramsList = nodeData.config?.params?.defaultValue || []; - paramsList.forEach((param: any) => { - if (!param || !param?.name) return; - const paramKey = `${dataNodeId}_${param.name}`; - if (!addedKeys.has(paramKey)) { - addedKeys.add(paramKey); - variableList.push({ - key: paramKey, - label: param.name, - type: 'variable', - dataType: param.type || 'string', - value: `${dataNodeId}.${param.name}`, - nodeData: nodeData, - }); - } - }); - break - case 'var-aggregator': - if (nodeData.config.group.defaultValue) { - // If group=true, add variables from group_variables with key as variable name - const groupVariables = nodeData.config.group_variables.defaultValue || []; - groupVariables?.forEach((groupVar: any) => { - if (!groupVar || !groupVar.key) return; - - // Determine dataType from first variable in the group - let groupDataType = 'string'; - if (groupVar.value && Array.isArray(groupVar.value) && groupVar.value.length > 0) { - const firstVariableValue = groupVar.value[0]; - const firstVariable = variableList.find(v => `{{${v.value}}}` === firstVariableValue); - if (firstVariable) { - groupDataType = firstVariable.dataType; - } - } - - const groupVarKey = `${dataNodeId}_${groupVar.key}`; - if (!addedKeys.has(groupVarKey)) { - addedKeys.add(groupVarKey); - variableList.push({ - key: groupVarKey, - label: groupVar.key, - type: 'variable', - dataType: groupDataType, - value: `${dataNodeId}.${groupVar.key}`, - nodeData: nodeData, - }); - } - }); - } else { - // If group=false, add output variable with type from first group_variable - const groupVariables = nodeData.config.group_variables.defaultValue || []; - const firstVariable = groupVariables[0]; - let outputDataType: string = 'any'; - if (firstVariable) { - const filterVo = [...variableList].find(v => { - return `{{${v.value}}}` === firstVariable - }) - if (filterVo) { - outputDataType = filterVo?.dataType - } - } - - const varAggregatorKey = `${dataNodeId}_output`; - if (!addedKeys.has(varAggregatorKey)) { - addedKeys.add(varAggregatorKey); - variableList.push({ - key: varAggregatorKey, - label: 'output', - type: 'variable', - dataType: outputDataType, - value: `${dataNodeId}.output`, - nodeData: nodeData, - }); - } - } - break - case 'http-request': - const httpBodyKey = `${dataNodeId}_body`; - const httpStatusKey = `${dataNodeId}_status_code`; - if (!addedKeys.has(httpBodyKey)) { - addedKeys.add(httpBodyKey); - variableList.push({ - key: httpBodyKey, - label: 'body', - type: 'variable', - dataType: 'string', - value: `${dataNodeId}.body`, - nodeData: nodeData, - }); - } - if (!addedKeys.has(httpStatusKey)) { - addedKeys.add(httpStatusKey); - variableList.push({ - key: httpStatusKey, - label: 'status_code', - type: 'variable', - dataType: 'number', - value: `${dataNodeId}.status_code`, - nodeData: nodeData, - }); - } - break - case 'jinja-render': - const jinjaOutputKey = `${dataNodeId}_output`; - if (!addedKeys.has(jinjaOutputKey)) { - addedKeys.add(jinjaOutputKey); - variableList.push({ - key: jinjaOutputKey, - label: 'output', - type: 'variable', - dataType: 'string', - value: `${dataNodeId}.output`, - nodeData: nodeData, - }); - } - break - case 'question-classifier': - const classNameKey = `${dataNodeId}_class_name`; - // const outputKey = `${dataNodeId}_output`; - if (!addedKeys.has(classNameKey)) { - addedKeys.add(classNameKey); - variableList.push({ - key: classNameKey, - label: 'class_name', - type: 'variable', - dataType: 'string', - value: `${dataNodeId}.class_name`, - nodeData: nodeData, - }); - } - // if (!addedKeys.has(outputKey)) { - // addedKeys.add(outputKey); - // variableList.push({ - // key: outputKey, - // label: 'output', - // type: 'variable', - // dataType: 'string', - // value: `${dataNodeId}.output`, - // nodeData: nodeData, - // }); - // } - break - case 'iteration': - const iterationOutputKey = `${dataNodeId}_output`; - if (!addedKeys.has(iterationOutputKey)) { - addedKeys.add(iterationOutputKey); - // Get the data type from the output configuration, default to string - const outputConfig = nodeData.output; - let outputDataType = 'string'; - if (outputConfig) { - // Find the selected variable from variableList to get its type - const selectedVariable = variableList.find(v => v.value === outputConfig); - if (selectedVariable) { - outputDataType = selectedVariable.dataType; - } - } - variableList.push({ - key: iterationOutputKey, - label: 'output', - type: 'variable', - dataType: `array[${outputDataType}]`, - value: `${dataNodeId}.output`, - nodeData: nodeData, - }); - } - break - case 'loop': - const cycleVars = nodeData.config.cycle_vars.defaultValue || []; - console.log('cycleVars', cycleVars) - cycleVars.forEach((cycleVar: any) => { - const cycleVarKey = `${dataNodeId}_cycle_${cycleVar.name}`; - if (!addedKeys.has(cycleVarKey)) { - addedKeys.add(cycleVarKey); - if (cycleVar.name && cycleVar.name.trim() !== '') { - variableList.push({ - key: cycleVarKey, - label: cycleVar.name, - type: 'variable', - dataType: cycleVar.type || 'string', - value: `${dataNodeId}.${cycleVar.name}`, - nodeData: nodeData, - }); - } - } - }); - break - case 'tool': - const toolDataKey = `${dataNodeId}_data`; - if (!addedKeys.has(toolDataKey)) { - addedKeys.add(toolDataKey); - variableList.push({ - key: toolDataKey, - label: 'data', - type: 'variable', - dataType: 'string', - value: `${dataNodeId}.data`, - nodeData: nodeData, - }); - } - break - case 'memory-read': - const memoryReadAnswerKey = `${dataNodeId}_answer`; - const memoryReadIntermediateOutputs = `${dataNodeId}_intermediate_outputs`; - if (!addedKeys.has(memoryReadAnswerKey)) { - addedKeys.add(memoryReadAnswerKey); - variableList.push({ - key: memoryReadAnswerKey, - label: 'answer', - type: 'variable', - dataType: 'string', - value: `${dataNodeId}.answer`, - nodeData: nodeData, - }); - } - if (!addedKeys.has(memoryReadIntermediateOutputs)) { - addedKeys.add(memoryReadIntermediateOutputs); - variableList.push({ - key: memoryReadIntermediateOutputs, - label: 'intermediate_outputs', - type: 'variable', - dataType: 'array[object]', - value: `${dataNodeId}.intermediate_outputs`, - nodeData: nodeData, - }); - } - break - } - }); - - - // Add parent loop/iteration node variables if current node is a child - if (parentLoopNode) { - const parentData = parentLoopNode.getData(); - const parentNodeId = parentLoopNode.getData().id; - - if (parentData.type === 'loop') { - const cycleVars = parentData.cycle_vars || []; - cycleVars.forEach((cycleVar: any) => { - const key = `${parentNodeId}_cycle_${cycleVar.name}`; - if (!addedKeys.has(key)) { - addedKeys.add(key); - variableList.push({ - key, - label: cycleVar.name, - type: 'variable', - dataType: cycleVar.type || 'String', - value: `${parentNodeId}.${cycleVar.name}`, - nodeData: parentData, - }); - } - }); - } else if (parentData.type === 'iteration') { - // Add item and index variables for iteration parent only if input has value - if (parentData.config.input.defaultValue) { - const itemKey = `${parentNodeId}_item`; - const indexKey = `${parentNodeId}_index`; - - // Determine item dataType from input variable - let itemDataType = 'object'; - const inputVariable = variableList.find(v => `{{${v.value}}}` === parentData.config.input.defaultValue); - console.log('itemDataType defaultValue', parentData.config.input.defaultValue, variableList, inputVariable) - if (inputVariable && inputVariable.dataType.startsWith('array[')) { - itemDataType = inputVariable.dataType.replace(/^array\[(.+)\]$/, '$1'); - console.log('itemDataType', itemDataType) - } - - - if (!addedKeys.has(itemKey)) { - addedKeys.add(itemKey); - variableList.push({ - key: itemKey, - label: 'item', - type: 'variable', - dataType: itemDataType, - value: `${parentNodeId}.item`, - nodeData: parentData, - }); - } - - if (!addedKeys.has(indexKey)) { - addedKeys.add(indexKey); - variableList.push({ - key: indexKey, - label: 'index', - type: 'variable', - dataType: 'number', - value: `${parentNodeId}.index`, - nodeData: parentData, - }); - } - } - } - } - - return variableList; - }, [selectedNode, graphRef, graphUpdateTrigger, chatVariables]); - - // Trigger variableList update when graph edges or nodes change - useEffect(() => { - if (!graphRef?.current) return; - - const graph = graphRef.current; - const handleGraphChange = () => { - console.log('handleGraphChange') - // Force variableList recalculation by updating trigger - setGraphUpdateTrigger(prev => prev + 1); - }; - - // Listen to graph changes - graph.on('edge:added', handleGraphChange); - graph.on('edge:removed', handleGraphChange); - graph.on('edge:changed', handleGraphChange); - graph.on('node:added', handleGraphChange); - graph.on('node:removed', handleGraphChange); - graph.on('node:change:data', handleGraphChange); - - return () => { - graph.off('edge:added', handleGraphChange); - graph.off('edge:removed', handleGraphChange); - graph.off('edge:changed', handleGraphChange); - graph.off('node:added', handleGraphChange); - graph.off('node:removed', handleGraphChange); - graph.off('node:change:data', handleGraphChange); - }; - }, [graphRef]); // Filter out boolean type variables for loop and llm nodes const getFilteredVariableList = (nodeType?: string, key?: string) => { @@ -994,324 +490,353 @@ const Properties: FC = ({ // const defaultVariableList = calculateVariableList(selectedNode as Node, graphRef, workflowConfig ) console.log('values', values) - console.log('variableList', variableList) + + const currentNodeVariables = useMemo(() => { + if (!selectedNode) return [] + return getCurrentNodeVariables(selectedNode?.getData(), values) + }, [selectedNode?.getData(), values]) + + const [outputCollapsed, setOutputCollapsed] = useState(true) + const handleToggle = () => { + setOutputCollapsed((prev: boolean) => !prev) + } + console.log('variableList', variableList, currentNodeVariables) return (
{t('workflow.nodeProperties')}
{!selectedNode ? - :
- - { - updateNodeLabel(e.target.value); - }} + :
+ + + { + updateNodeLabel(e.target.value); + }} + /> + + + + + + {selectedNode?.data?.type === 'http-request' + ? - - - - - - {selectedNode?.data?.type === 'http-request' - ? - : selectedNode?.data?.type === 'tool' - ? - : configs && Object.keys(configs).length > 0 && Object.keys(configs).map((key) => { - const config = configs[key] || {} + : selectedNode?.data?.type === 'tool' + ? + : configs && Object.keys(configs).length > 0 && Object.keys(configs).map((key) => { + const config = configs[key] || {} - if (config.dependsOn && (values as any)?.[config.dependsOn as string] !== config.dependsOnValue) { - return null - } - - if (selectedNode?.data?.type === 'start' && key === 'variables' && config.type === 'define') { - return ( - - - - ) - } - - if (selectedNode?.data?.type === 'llm' && key === 'messages' && config.type === 'define') { - // 为llm节点且isArray=true时添加context变量支持 - let contextVariableList = [...getFilteredVariableList('llm')]; - const isArrayMode = config.isArray !== false; // 默认为true - - if (isArrayMode) { - const contextKey = `${selectedNode.id}_context`; - const hasContextVariable = contextVariableList.some(v => v.key === contextKey); - - if (!hasContextVariable) { - contextVariableList.unshift({ - key: contextKey, - label: 'context', - type: 'variable', - dataType: 'String', - value: `context`, - nodeData: selectedNode.getData(), - isContext: true, - }); - } - } - return ( - - variable.nodeData?.type !== 'knowledge-retrieval')} - parentName={key} - placeholder={t(config.placeholder || 'common.pleaseSelect')} - size="small" - /> - - ) - } - if (config.type === 'define') { - return null - } - - if (config.type === 'knowledge') { - return ( - - - - ) - } - - if (config.type === 'messageEditor') { - return ( - - - - ) - } - - if (config.type === 'paramList') { - return ( - - - - - ) - } - if (config.type === 'groupVariableList') { - return ( - - - - ) - } - if (config.type === 'caseList') { - return ( - - - - ) - } - - if (config.type === 'mappingList') { - return ( - - - - - ) - } - if (config.type === 'cycleVarsList') { - return ( - - - - ) - } - if (config.type === 'assignmentList') { - return ( - - { - if (config.filterLoopIterationVars) { - const loopIterationVars: Suggestion[] = []; - - return [...getFilteredVariableList(selectedNode?.data?.type, key), ...loopIterationVars]; - } - return getFilteredVariableList(selectedNode?.data?.type, key); - })() - } - /> - - ) - } - if (config.type === 'memoryConfig') { - return ( - - - - ) - } - if (config.type === 'conditionList') { - return ( - - { - const cycleVars = values?.cycle_vars || []; - const cycleVarSuggestions: Suggestion[] = cycleVars.filter(vo => vo.name && vo.name.trim() !== '').map((cycleVar: any) => ({ - key: `${selectedNode.id}_cycle_${cycleVar.name}`, - label: cycleVar.name, - type: 'variable', - dataType: cycleVar.type || 'String', - value: `${selectedNode.getData().id}.${cycleVar.name}`, - nodeData: selectedNode.getData(), - })); - - return [...getFilteredVariableList(selectedNode?.data?.type, key), ...cycleVarSuggestions]; - })()} - selectedNode={selectedNode} - graphRef={graphRef} - addBtnText={t('workflow.config.addCase')} - /> - - ) - } + if (config.dependsOn && (values as any)?.[config.dependsOn as string] !== config.dependsOnValue) { + return null + } + if (selectedNode?.data?.type === 'start' && key === 'variables' && config.type === 'define') { return ( - {t(`workflow.config.${selectedNode?.data?.type}.${key}`)} : t(`workflow.config.${selectedNode?.data?.type}.${key}`)} - layout={config.type === 'switch' ? 'horizontal' : 'vertical'} - className={key === 'parallel_count' ? 'rb:-mt-3! rb:leading-3.5!' : ''} - > - {config.type === 'input' - ? - : config.type === 'textarea' - ? - : config.type === 'select' - ? + : config.type === 'textarea' + ? + : config.type === 'select' + ? - {hasAll && ({allTitle || t('common.all')})} - {(format ? format(options) : options)?.map(option => ( + {hasAll && {allTitle || t('common.all')}} + {displayOptions.map((option) => ( {String(option[labelKey])} ))} ); -} +}; + export default CustomSelect; \ No newline at end of file From cd1a50a1d1382a209f8ad6c8d9bfac3f9e75cb1a Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 20 Jan 2026 10:21:00 +0800 Subject: [PATCH 09/12] fix(web): node cannot be connected to itself --- web/src/views/Workflow/hooks/useWorkflowGraph.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index a7ebb29a..615cd3e5 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -729,6 +729,9 @@ export const useWorkflowGraph = ({ validateConnection({ sourceCell, targetCell, targetMagnet }) { if (!targetMagnet) return false; + // 节点不能与自己连线 + if (sourceCell?.id === targetCell?.id) return false; + const sourceType = sourceCell?.getData()?.type; const targetType = targetCell?.getData()?.type; From 642587fc9714fc16a20494054a30bfc26986c444 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:36:30 +0800 Subject: [PATCH 10/12] Fix/memory mcp2 1 (#145) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * feat(celery): add comprehensive logging to worker and write task - Initialize logging system in Celery worker entry point with LoggingConfig - Add logger instance and startup message to celery_worker.py - Reorganize imports in tasks.py for better readability and consistency - Add detailed logging to write_message_task for debugging and monitoring - Log task start with group_id, config_id, and storage_type parameters - Log service execution and completion status with results - Add exception handling with error logging and stack trace capture - Log task completion time and Celery task ID for performance tracking - Improves observability and troubleshooting of async task execution * 去掉MCP框架,重构 * 去掉MCP框架,重构 --------- Co-authored-by: Ke Sun --- api/app/celery_worker.py | 6 + api/app/core/config.py | 1 + api/app/core/memory/agent/__init__.py | 0 .../memory/agent/langgraph_graph/__init__.py | 16 - .../agent/langgraph_graph/nodes/__init__.py | 8 +- .../agent/langgraph_graph/nodes/data_nodes.py | 16 + .../agent/langgraph_graph/nodes/input_node.py | 150 --- .../langgraph_graph/nodes/problem_nodes.py | 237 +++++ .../langgraph_graph/nodes/retrieve_nodes.py | 417 ++++++++ .../langgraph_graph/nodes/summary_nodes.py | 303 ++++++ .../agent/langgraph_graph/nodes/tool_node.py | 234 ----- .../nodes/verification_nodes.py | 85 ++ .../langgraph_graph/nodes/write_nodes.py | 50 + .../agent/langgraph_graph/read_graph.py | 612 ++++-------- .../agent/langgraph_graph/routing/__init__.py | 13 - .../agent/langgraph_graph/routing/routers.py | 149 +-- .../agent/langgraph_graph/state/__init__.py | 13 - .../agent/langgraph_graph/state/extractors.py | 179 ---- .../agent/langgraph_graph/tools/tool.py | 320 +++++++ .../agent/langgraph_graph/write_graph.py | 103 +- .../core/memory/agent/mcp_server/__init__.py | 28 - .../memory/agent/mcp_server/mcp_instance.py | 11 - .../core/memory/agent/mcp_server/server.py | 159 ---- .../memory/agent/mcp_server/tools/__init__.py | 27 - .../agent/mcp_server/tools/data_tools.py | 155 --- .../agent/mcp_server/tools/problem_tools.py | 304 ------ .../agent/mcp_server/tools/retrieval_tools.py | 294 ------ .../agent/mcp_server/tools/summary_tools.py | 640 ------------- .../mcp_server/tools/verification_tools.py | 174 ---- .../agent/{mcp_server => }/models/__init__.py | 0 .../{mcp_server => }/models/problem_models.py | 0 .../models/retrieval_models.py | 0 .../{mcp_server => }/models/summary_models.py | 0 .../models/verification_models.py | 0 .../memory/agent/multimodal/oss_picture.py | 114 --- .../memory/agent/multimodal/speech_model.py | 121 --- .../{mcp_server => }/services/__init__.py | 0 .../agent/services/optimized_llm_service.py | 277 ++++++ .../services/parameter_builder.py | 22 +- .../services/search_service.py | 66 +- .../services/session_service.py | 0 .../services/template_service.py | 16 +- api/app/core/memory/agent/utils/__init__.py | 7 - .../memory/agent/utils/llm_client_pool.py | 56 ++ api/app/core/memory/agent/utils/llm_tools.py | 211 ++-- api/app/core/memory/agent/utils/mcp_tools.py | 33 - .../core/memory/agent/utils/messages_tool.py | 260 ----- .../core/memory/agent/utils/messages_tools.py | 194 ++++ api/app/core/memory/agent/utils/model_tool.py | 38 - api/app/core/memory/agent/utils/multimodal.py | 131 --- .../memory/agent/utils/performance_monitor.py | 56 ++ ...Problem_Extension_prompt_simplified.jinja2 | 81 ++ .../prompt/Retrieve_Summary_prompt.jinja2 | 3 - .../utils/prompt/split_verify_prompt.jinja2 | 4 +- .../core/memory/agent/utils/session_tools.py | 169 ++++ .../core/memory/agent/utils/template_tools.py | 117 +++ .../memory/agent/utils/type_classifier.py | 3 +- .../memory/agent/utils/write_to_database.py | 49 - .../core/memory/agent/utils/write_tools.py | 12 +- api/app/services/memory_agent_service.py | 899 ++++++------------ api/app/tasks.py | 23 +- api/docker-compose.yml | 47 +- 62 files changed, 3128 insertions(+), 4585 deletions(-) delete mode 100644 api/app/core/memory/agent/__init__.py delete mode 100644 api/app/core/memory/agent/langgraph_graph/__init__.py create mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py delete mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/input_node.py create mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py create mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py create mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py delete mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py create mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py create mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py delete mode 100644 api/app/core/memory/agent/langgraph_graph/routing/__init__.py delete mode 100644 api/app/core/memory/agent/langgraph_graph/state/__init__.py delete mode 100644 api/app/core/memory/agent/langgraph_graph/state/extractors.py create mode 100644 api/app/core/memory/agent/langgraph_graph/tools/tool.py delete mode 100644 api/app/core/memory/agent/mcp_server/__init__.py delete mode 100644 api/app/core/memory/agent/mcp_server/mcp_instance.py delete mode 100644 api/app/core/memory/agent/mcp_server/server.py delete mode 100644 api/app/core/memory/agent/mcp_server/tools/__init__.py delete mode 100644 api/app/core/memory/agent/mcp_server/tools/data_tools.py delete mode 100644 api/app/core/memory/agent/mcp_server/tools/problem_tools.py delete mode 100644 api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py delete mode 100644 api/app/core/memory/agent/mcp_server/tools/summary_tools.py delete mode 100644 api/app/core/memory/agent/mcp_server/tools/verification_tools.py rename api/app/core/memory/agent/{mcp_server => }/models/__init__.py (100%) rename api/app/core/memory/agent/{mcp_server => }/models/problem_models.py (100%) rename api/app/core/memory/agent/{mcp_server => }/models/retrieval_models.py (100%) rename api/app/core/memory/agent/{mcp_server => }/models/summary_models.py (100%) rename api/app/core/memory/agent/{mcp_server => }/models/verification_models.py (100%) delete mode 100644 api/app/core/memory/agent/multimodal/oss_picture.py delete mode 100644 api/app/core/memory/agent/multimodal/speech_model.py rename api/app/core/memory/agent/{mcp_server => }/services/__init__.py (100%) create mode 100644 api/app/core/memory/agent/services/optimized_llm_service.py rename api/app/core/memory/agent/{mcp_server => }/services/parameter_builder.py (87%) rename api/app/core/memory/agent/{mcp_server => }/services/search_service.py (75%) rename api/app/core/memory/agent/{mcp_server => }/services/session_service.py (100%) rename api/app/core/memory/agent/{mcp_server => }/services/template_service.py (94%) delete mode 100644 api/app/core/memory/agent/utils/__init__.py create mode 100644 api/app/core/memory/agent/utils/llm_client_pool.py delete mode 100644 api/app/core/memory/agent/utils/mcp_tools.py delete mode 100644 api/app/core/memory/agent/utils/messages_tool.py create mode 100644 api/app/core/memory/agent/utils/messages_tools.py delete mode 100644 api/app/core/memory/agent/utils/model_tool.py delete mode 100644 api/app/core/memory/agent/utils/multimodal.py create mode 100644 api/app/core/memory/agent/utils/performance_monitor.py create mode 100644 api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 create mode 100644 api/app/core/memory/agent/utils/session_tools.py create mode 100644 api/app/core/memory/agent/utils/template_tools.py delete mode 100644 api/app/core/memory/agent/utils/write_to_database.py diff --git a/api/app/celery_worker.py b/api/app/celery_worker.py index baecdb3d..7d3ee686 100644 --- a/api/app/celery_worker.py +++ b/api/app/celery_worker.py @@ -3,6 +3,12 @@ Celery Worker 入口点 用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info """ from app.celery_app import celery_app +from app.core.logging_config import LoggingConfig, get_logger + +# Initialize logging system for Celery worker +LoggingConfig.setup_logging() +logger = get_logger(__name__) +logger.info("Celery worker logging initialized") # 导入任务模块以注册任务 import app.tasks diff --git a/api/app/core/config.py b/api/app/core/config.py index 01983457..9600b551 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -147,6 +147,7 @@ class Settings: # Celery configuration (internal) CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) + REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) diff --git a/api/app/core/memory/agent/__init__.py b/api/app/core/memory/agent/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/api/app/core/memory/agent/langgraph_graph/__init__.py b/api/app/core/memory/agent/langgraph_graph/__init__.py deleted file mode 100644 index a0596e38..00000000 --- a/api/app/core/memory/agent/langgraph_graph/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -LangGraph Graph package for memory agent. - -This package provides the LangGraph workflow orchestrator with modular -node implementations, routing logic, and state management. - -Package structure: -- read_graph: Main graph factory for read operations -- write_graph: Main graph factory for write operations -- nodes: LangGraph node implementations -- routing: State routing logic -- state: State management utilities -""" -from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph - -__all__ = ['make_read_graph'] \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py b/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py index 4e808919..231a167c 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/__init__.py @@ -4,7 +4,7 @@ LangGraph node implementations. This module contains custom node implementations for the LangGraph workflow. """ -from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode -from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message - -__all__ = ["ToolExecutionNode", "create_input_message"] +# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode +# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message +# +# __all__ = ["ToolExecutionNode", "create_input_message"] diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py new file mode 100644 index 00000000..6595a2ce --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py @@ -0,0 +1,16 @@ +from app.core.memory.agent.utils.llm_tools import ReadState, WriteState + + +def content_input_node(state: ReadState) -> ReadState: + """开始节点 - 提取内容并保持状态信息""" + + content = state['messages'][0].content if state.get('messages') else '' + # 返回内容并保持所有状态信息 + return {"data": content} + +def content_input_write(state: WriteState) -> WriteState: + """开始节点 - 提取内容并保持状态信息""" + + content = state['messages'][0].content if state.get('messages') else '' + # 返回内容并保持所有状态信息 + return {"data": content} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py deleted file mode 100644 index 3eed497f..00000000 --- a/api/app/core/memory/agent/langgraph_graph/nodes/input_node.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Input node for LangGraph workflow entry point. - -This module provides the create_input_message function which processes initial -user input with multimodal support and creates the first tool call message. -""" - -import logging -import re -import uuid -from datetime import datetime -from typing import Any, Dict - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage - -logger = logging.getLogger(__name__) - - -async def create_input_message( - state: Dict[str, Any], - tool_name: str, - session_id: str, - search_switch: str, - apply_id: str, - group_id: str, - multimodal_processor: MultimodalProcessor, - memory_config: MemoryConfig, -) -> Dict[str, Any]: - """ - Create initial tool call message from user input. - - This function: - 1. Extracts the last message content from state - 2. Processes multimodal inputs (images/audio) using the multimodal processor - 3. Generates a unique message ID - 4. Extracts namespace from session_id - 5. Handles verified_data extraction for backward compatibility - 6. Returns AIMessage with complete tool_calls structure - - Args: - state: LangGraph state dictionary containing messages - tool_name: Name of the tool to invoke (typically "Split_The_Problem") - session_id: Session identifier (format: "call_id_{namespace}") - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - multimodal_processor: Processor for handling image/audio inputs - memory_config: MemoryConfig object containing all configuration - - Returns: - State update with AIMessage containing tool_call - - Examples: - >>> state = {"messages": [HumanMessage(content="What is AI?")]} - >>> result = await create_input_message( - ... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config - ... ) - >>> result["messages"][0].tool_calls[0]["name"] - 'Split_The_Problem' - """ - messages = state.get("messages", []) - - # Extract last message content - if messages: - last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1]) - else: - logger.warning("[create_input_message] No messages in state, using empty string") - last_message = "" - - logger.debug(f"[create_input_message] Original input: {last_message[:100]}...") - - # Process multimodal input (images/audio) - try: - processed_content = await multimodal_processor.process_input(last_message) - if processed_content != last_message: - logger.info( - f"[create_input_message] Multimodal processing converted input " - f"from {len(last_message)} to {len(processed_content)} chars" - ) - last_message = processed_content - except Exception as e: - logger.error( - f"[create_input_message] Multimodal processing failed: {e}", - exc_info=True - ) - # Continue with original content - - # Generate unique message ID - uuid_str = uuid.uuid4() - time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # Extract namespace from session_id - # Expected format: "call_id_{namespace}" or similar - try: - namespace = str(session_id).split('_id_')[1] - except (IndexError, AttributeError): - logger.warning( - f"[create_input_message] Could not extract namespace from session_id: {session_id}" - ) - namespace = "unknown" - - # Handle verified_data extraction (backward compatibility) - # This regex-based extraction is kept for compatibility with existing data formats - if 'verified_data' in str(last_message): - try: - messages_last = str(last_message).replace('\\n', '').replace('\\', '') - query_match = re.findall(r'"query": "(.*?)",', messages_last) - if query_match: - last_message = query_match[0] - logger.debug( - f"[create_input_message] Extracted query from verified_data: {last_message}" - ) - except Exception as e: - logger.warning( - f"[create_input_message] Failed to extract query from verified_data: {e}" - ) - - # Construct tool call message - tool_call_id = f"{session_id}_{uuid_str}" - - logger.info( - f"[create_input_message] Creating tool call for '{tool_name}' " - f"with ID: {tool_call_id}" - ) - - # Build tool arguments - tool_args = { - "sentence": last_message, - "sessionid": session_id, - "messages_id": str(uuid_str), - "search_switch": search_switch, - "apply_id": apply_id, - "group_id": group_id, - "memory_config": memory_config, - } - - return { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": tool_name, - "args": tool_args, - "id": tool_call_id - }] - ) - ] - } diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py new file mode 100644 index 00000000..0c68a47e --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -0,0 +1,237 @@ +import json +import time +from app.core.logging_config import get_agent_logger +from app.db import get_db + +from app.core.memory.agent.models.problem_models import ProblemExtensionResponse +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +db_session = next(get_db()) +logger = get_agent_logger(__name__) + +class ProblemNodeService(LLMServiceMixin): + """问题处理节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +problem_service = ProblemNodeService() + +async def Split_The_Problem(state: ReadState) -> ReadState: + """问题分解节点""" + # 从状态中获取数据 + content = state.get('data', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + + history = await SessionService(store).get_history(group_id, group_id, group_id) + system_prompt = await problem_service.template_service.render_template( + template_name='problem_breakdown_prompt.jinja2', + operation_name='split_the_problem', + history=history, + sentence=content + ) + + try: + # 使用优化的LLM服务 + structured = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) + + # 添加更详细的日志记录 + logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") + + # 验证结构化响应 + if not structured or not hasattr(structured, 'root'): + logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") + split_result = json.dumps([], ensure_ascii=False) + elif not structured.root: + logger.warning("Split_The_Problem: 结构化响应的root为空") + split_result = json.dumps([], ensure_ascii=False) + else: + split_result = json.dumps( + [item.model_dump() for item in structured.root], + ensure_ascii=False + ) + + split_result_dict = [] + for index, item in enumerate(json.loads(split_result)): + split_data = { + "id": f"Q{index+1}", + "question": item['extended_question'], + "type": item['type'], + "reason": item['reason'] + } + split_result_dict.append(split_data) + + logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项") + + result = { + "context": split_result, + "original": content, + "_intermediate": { + "type": "problem_split", + "title": "问题拆分", + "data": split_result_dict, + "original_query": content + } + } + + except Exception as e: + logger.error( + f"Split_The_Problem failed: {e}", + exc_info=True + ) + + # 提供更详细的错误信息 + error_details = { + "error_type": type(e).__name__, + "error_message": str(e), + "content_length": len(content), + "llm_model_id": memory_config.llm_model_id if memory_config else None + } + + logger.error(f"Split_The_Problem error details: {error_details}") + + # 创建默认的空结果 + result = { + "context": json.dumps([], ensure_ascii=False), + "original": content, + "error": str(e), + "_intermediate": { + "type": "problem_split", + "title": "问题拆分", + "data": [], + "original_query": content, + "error": error_details + } + } + + # 返回更新后的状态,包含spit_context字段 + return {"spit_data": result} + +async def Problem_Extension(state: ReadState) -> ReadState: + """问题扩展节点""" + # 获取原始数据和分解结果 + start = time.time() + content = state.get('data', '') + data = state.get('spit_data', '')['context'] + group_id = state.get('group_id', '') + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + memory_config = state.get('memory_config', None) + + databasets = {} + try: + data = json.loads(data) + for i in data: + databasets[i['extended_question']] = i['type'] + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.error(f"Problem_Extension: 数据解析失败: {e}") + # 使用空字典作为fallback + databasets = {} + data = [] + + history = await SessionService(store).get_history(group_id, group_id, group_id) + system_prompt = await problem_service.template_service.render_template( + template_name='Problem_Extension_prompt.jinja2', + operation_name='problem_extension', + history=history, + questions=databasets + ) + + try: + # 使用优化的LLM服务 + response_content = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) + + logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") + + # 验证结构化响应 + if not response_content or not hasattr(response_content, 'root'): + logger.warning("Problem_Extension: 结构化响应为空或格式不正确") + aggregated_dict = {} + elif not response_content.root: + logger.warning("Problem_Extension: 结构化响应的root为空") + aggregated_dict = {} + else: + # Aggregate results by original question + aggregated_dict = {} + for item in response_content.root: + try: + key = getattr(item, "original_question", None) or ( + item.get("original_question") if isinstance(item, dict) else None + ) + value = getattr(item, "extended_question", None) or ( + item.get("extended_question") if isinstance(item, dict) else None + ) + if not key or not value: + logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}") + continue + aggregated_dict.setdefault(key, []).append(value) + except Exception as item_error: + logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}") + continue + + logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组") + + except Exception as e: + logger.error( + f"LLM call failed for Problem_Extension: {e}", + exc_info=True + ) + + # 提供更详细的错误信息 + error_details = { + "error_type": type(e).__name__, + "error_message": str(e), + "questions_count": len(databasets), + "llm_model_id": memory_config.llm_model_id if memory_config else None + } + + logger.error(f"Problem_Extension error details: {error_details}") + aggregated_dict = {} + + logger.info("Problem extension") + logger.info(f"Problem extension result: {aggregated_dict}") + + # Emit intermediate output for frontend + print(time.time() - start) + result = { + "context": aggregated_dict, + "original": data, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "problem_extension", + "title": "问题扩展", + "data": aggregated_dict, + "original_query": content, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + + return {"problem_extension": result} + + + diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py new file mode 100644 index 00000000..14f8fa8b --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -0,0 +1,417 @@ +# ===== 标准库 ===== +import asyncio +import json +import os + +# ===== 第三方库 ===== +from langchain.agents import create_agent +from langchain_openai import ChatOpenAI +from app.core.logging_config import get_agent_logger +from app.db import get_db, get_db_context + +from app.schemas import model_schema +from app.services.memory_config_service import MemoryConfigService +from app.services.model_service import ModelConfigService + +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + COUNTState, + ReadState, + deduplicate_entries, + merge_to_key_value_pairs, +) +from app.core.memory.agent.langgraph_graph.tools.tool import ( + create_hybrid_retrieval_tool_sync, + create_time_retrieval_tool, + extract_tool_message_content, +) + +from app.core.rag.nlp.search import knowledge_retrieval + +logger = get_agent_logger(__name__) +db = next(get_db()) + + + +async def rag_config(state): + user_rag_memory_id = state.get('user_rag_memory_id', '') + kb_config = { + "knowledge_bases": [ + { + "kb_id": user_rag_memory_id, + "similarity_threshold": 0.7, + "vector_similarity_weight": 0.5, + "top_k": 10, + "retrieve_type": "participle" + } + ], + "merge_strategy": "weight", + "reranker_id": os.getenv('reranker_id'), + "reranker_top_k": 10 + } + return kb_config +async def rag_knowledge(state,question): + kb_config = await rag_config(state) + group_id = state.get('group_id', '') + user_rag_memory_id=state.get("user_rag_memory_id",'') + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) + try: + retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] + clean_content = '\n\n'.join(retrieval_knowledge) + cleaned_query = question + raw_results = clean_content + logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") + except Exception : + retrieval_knowledge=[] + clean_content = '' + raw_results = '' + cleaned_query = question + logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") + return retrieval_knowledge,clean_content,cleaned_query,raw_results + + +async def llm_infomation(state: ReadState) -> ReadState: + memory_config = state.get('memory_config', None) + model_id = memory_config.llm_model_id + tenant_id = memory_config.tenant_id + + # 使用现有的 memory_config 而不是重新查询数据库 + # 或者使用线程安全的数据库访问 + with get_db_context() as db: + result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id) + result_pydantic = model_schema.ModelConfig.model_validate(result_orm) + return result_pydantic + + +async def clean_databases(data) -> str: + """ + 简化的数据库搜索结果清理函数 + + Args: + data: 搜索结果数据 + + Returns: + 清理后的内容字符串 + """ + try: + # 解析JSON字符串 + if isinstance(data, str): + try: + data = json.loads(data) + except json.JSONDecodeError: + return data + + if not isinstance(data, dict): + return str(data) + + # 获取结果数据 + # with open("搜索结果.json","w",encoding='utf-8') as f: + # f.write(json.dumps(data, indent=4, ensure_ascii=False)) + results = data.get('results', data) + if not isinstance(results, dict): + return str(results) + + # 收集所有内容 + content_list = [] + + # 处理重排序结果 + reranked = results.get('reranked_results', {}) + if reranked: + for category in ['summaries', 'statements', 'chunks', 'entities']: + items = reranked.get(category, []) + if isinstance(items, list): + content_list.extend(items) + # 处理时间搜索结果 + time_search = results.get('time_search', {}) + if time_search: + if isinstance(time_search, dict): + statements = time_search.get('statements', time_search.get('time_search', [])) + if isinstance(statements, list): + content_list.extend(statements) + elif isinstance(time_search, list): + content_list.extend(time_search) + + # 提取文本内容 + text_parts = [] + for item in content_list: + if isinstance(item, dict): + text = item.get('statement') or item.get('content', '') + if text: + text_parts.append(text) + elif isinstance(item, str): + text_parts.append(item) + + + return '\n'.join(text_parts).strip() + + except Exception as e: + logger.error(f"clean_databases failed: {e}", exc_info=True) + return str(data) + + +async def retrieve_nodes(state: ReadState) -> ReadState: + + ''' + + 模型信息 + ''' + + problem_extension=state.get('problem_extension', '')['context'] + storage_type=state.get('storage_type', '') + user_rag_memory_id=state.get('user_rag_memory_id', '') + group_id=state.get('group_id', '') + memory_config = state.get('memory_config', None) + original=state.get('data', '') + problem_list=[] + for key,values in problem_extension.items(): + for data in values: + problem_list.append(data) + logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + # 创建异步任务处理单个问题 + async def process_question_nodes(idx, question): + try: + # Prepare search parameters based on storage type + search_params = { + "group_id": group_id, + "question": question, + "return_raw_results": True + } + if storage_type == "rag" and user_rag_memory_id: + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + else: + clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search( + **search_params, memory_config=memory_config + ) + + return { + "Query_small": cleaned_query, + "Result_small": clean_content, + "_intermediate": { + "type": "search_result", + "query": cleaned_query, + "raw_results": raw_results, + "index": idx + 1, + "total": len(problem_list) + } + } + + except Exception as e: + logger.error( + f"Retrieve: hybrid_search failed for question '{question}': {e}", + exc_info=True + ) + # Return empty result for this question + return { + "Query_small": question, + "Result_small": "", + "_intermediate": { + "type": "search_result", + "query": question, + "raw_results": [], + "index": idx + 1, + "total": len(problem_list) + } + } + + # 并发处理所有问题 + tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)] + databases_anser = await asyncio.gather(*tasks) + databases_data = { + "Query": original, + "Expansion_issue": databases_anser + } + + # Collect intermediate outputs before deduplication + intermediate_outputs = [] + for item in databases_anser: + if '_intermediate' in item: + intermediate_outputs.append(item['_intermediate']) + + # Deduplicate and merge results + deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) + deduplicated_data_merged = merge_to_key_value_pairs( + deduplicated_data, + 'Query_small', + 'Result_small' + ) + + # Restructure for Verify/Retrieve_Summary compatibility + keys, val = [], [] + for item in deduplicated_data_merged: + for items_key, items_value in item.items(): + keys.append(items_key) + val.append(items_value) + + send_verify = [] + for i, j in zip(keys, val, strict=False): + if j!=['']: + send_verify.append({ + "Query_small": i, + "Answer_Small": j + }) + + dup_databases = { + "Query": original, + "Expansion_issue": send_verify, + "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs + } + + logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") + return {'retrieve':dup_databases} + + + + +async def retrieve(state: ReadState) -> ReadState: + # 从state中获取group_id + import time + start=time.time() + problem_extension = state.get('problem_extension', '')['context'] + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + original = state.get('data', '') + problem_list = [] + for key, values in problem_extension.items(): + for data in values: + problem_list.append(data) + logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + databases_anser = [] + + async def get_llm_info(): + with get_db_context() as db: # 使用同步数据库上下文管理器 + config_service = MemoryConfigService(db) + return await llm_infomation(state) + llm_config = await get_llm_info() + api_key_obj = llm_config.api_keys[0] + api_key = api_key_obj.api_key + api_base = api_key_obj.api_base + model_name = api_key_obj.model_name + llm = ChatOpenAI( + model=model_name, + api_key=api_key, + base_url=api_base, + temperature=0.2, + ) + + time_retrieval_tool = create_time_retrieval_tool(group_id) + search_params = { "group_id": group_id, "return_raw_results": True } + hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) + agent = create_agent( + llm, + tools=[time_retrieval_tool,hybrid_retrieval], + system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" + ) + + # 创建异步任务处理单个问题 + import asyncio + + # 在模块级别定义信号量,限制最大并发数 + SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作 + + async def process_question(idx, question): + async with SEMAPHORE: # 限制并发 + try: + if storage_type == "rag" and user_rag_memory_id: + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + else: + cleaned_query = question + # 使用 asyncio 在线程池中运行同步的 agent.invoke + import asyncio + response = await asyncio.get_event_loop().run_in_executor( + None, + lambda: agent.invoke({"messages": question}) + ) + tool_results = extract_tool_message_content(response) + if tool_results == None: + raw_results = [] + clean_content = '' + else: + raw_results = tool_results['content'] + clean_content = await clean_databases(raw_results) + + try: + raw_results = raw_results['results'] + except Exception: + raw_results = [] + + return { + "Query_small": cleaned_query, + "Result_small": clean_content, + "_intermediate": { + "type": "search_result", + "query": cleaned_query, + "raw_results": raw_results, + "index": idx + 1, + "total": len(problem_list) + } + } + + except Exception as e: + logger.error( + f"Retrieve: hybrid_search failed for question '{question}': {e}", + exc_info=True + ) + # Return empty result for this question + return { + "Query_small": question, + "Result_small": "", + "_intermediate": { + "type": "search_result", + "query": question, + "raw_results": [], + "index": idx + 1, + "total": len(problem_list) + } + } + + # 并发处理所有问题 + import asyncio + tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)] + databases_anser = await asyncio.gather(*tasks) + databases_data = { + "Query": original, + "Expansion_issue": databases_anser + } + + # Collect intermediate outputs before deduplication + intermediate_outputs = [] + for item in databases_anser: + if '_intermediate' in item: + intermediate_outputs.append(item['_intermediate']) + + # Deduplicate and merge results + deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) + deduplicated_data_merged = merge_to_key_value_pairs( + deduplicated_data, + 'Query_small', + 'Result_small' + ) + + # Restructure for Verify/Retrieve_Summary compatibility + keys, val = [], [] + for item in deduplicated_data_merged: + for items_key, items_value in item.items(): + keys.append(items_key) + val.append(items_value) + + send_verify = [] + for i, j in zip(keys, val, strict=False): + if j != ['']: + send_verify.append({ + "Query_small": i, + "Answer_Small": j + }) + + dup_databases = { + "Query": original, + "Expansion_issue": send_verify, + "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs + } + # with open('retrieve_text.json', 'w') as f: + # json.dump(dup_databases, f, indent=4) + logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") + return {'retrieve': dup_databases} + + diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py new file mode 100644 index 00000000..7b727da5 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -0,0 +1,303 @@ + + +import time + +from app.core.logging_config import get_agent_logger, log_time +from app.db import get_db + +from app.core.memory.agent.models.summary_models import ( + RetrieveSummaryResponse, + SummaryResponse, +) +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +logger = get_agent_logger(__name__) +db_session = next(get_db()) + +class SummaryNodeService(LLMServiceMixin): + """总结节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +summary_service = SummaryNodeService() + +async def summary_history(state: ReadState) -> ReadState: + group_id = state.get("group_id", '') + history = await SessionService(store).get_history(group_id, group_id, group_id) + return history + +async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: + """ + 增强的summary_llm函数,包含更好的错误处理和数据验证 + """ + data = state.get("data", '') + + # 构建系统提示词 + if str(search_mode) == "0": + system_prompt = await summary_service.template_service.render_template( + template_name=template_name, + operation_name=operation_name, + data=retrieve_info, + query=data + ) + else: + system_prompt = await summary_service.template_service.render_template( + template_name=template_name, + operation_name=operation_name, + query=data, + history=history, + retrieve_info=retrieve_info + ) + try: + # 使用优化的LLM服务进行结构化输出 + structured = await summary_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=response_model, + fallback_value=None + ) + # 验证结构化响应 + if structured is None: + logger.warning(f"LLM返回None,使用默认回答") + return "信息不足,无法回答" + + # 根据操作类型提取答案 + if operation_name == "summary": + aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" + else: + # 处理RetrieveSummaryResponse + if hasattr(structured, 'data') and structured.data: + aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" + else: + logger.warning(f"结构化响应缺少data字段") + aimessages = "信息不足,无法回答" + + # 验证答案不为空 + if not aimessages or aimessages.strip() == "": + aimessages = "信息不足,无法回答" + + return aimessages + + except Exception as e: + logger.error(f"结构化输出失败: {e}", exc_info=True) + + # 尝试非结构化输出作为fallback + try: + logger.info("尝试非结构化输出作为fallback") + response = await summary_service.call_llm_simple( + state=state, + db_session=db_session, + system_prompt=system_prompt, + fallback_message="信息不足,无法回答" + ) + + if response and response.strip(): + # 简单清理响应 + cleaned_response = response.strip() + # 移除可能的JSON标记 + if cleaned_response.startswith('```'): + lines = cleaned_response.split('\n') + cleaned_response = '\n'.join(lines[1:-1]) + + return cleaned_response + else: + return "信息不足,无法回答" + + except Exception as fallback_error: + logger.error(f"Fallback也失败: {fallback_error}") + return "信息不足,无法回答" + +async def summary_redis_save(state: ReadState,aimessages) -> ReadState: + data = state.get("data", '') + group_id = state.get("group_id", '') + await SessionService(store).save_session( + user_id=group_id, + query=data, + apply_id=group_id, + group_id=group_id, + ai_response=aimessages + ) + await SessionService(store).cleanup_duplicates() + logger.info(f"sessionid: {aimessages} 写入成功") +async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: + storage_type=state.get("storage_type",'') + user_rag_memory_id=state.get("user_rag_memory_id",'') + data=state.get("data", '') + input_summary = { + "status": "success", + "summary_result": aimessages, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "input_summary", + "title": "快速答案", + "summary": aimessages, + "query": data, + "raw_results": raw_results, + "search_mode": "quick_search", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + retrieve={ + "status": "success", + "summary_result": aimessages, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "retrieval_summary", + "title":"快速检索", + "summary": aimessages, + "query": data, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + + return input_summary,retrieve + +async def Input_Summary(state: ReadState) -> ReadState: + start=time.time() + storage_type=state.get("storage_type",'') + memory_config = state.get('memory_config', None) + user_rag_memory_id=state.get("user_rag_memory_id",'') + data=state.get("data", '') + group_id=state.get("group_id", '') + logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + history = await summary_history( state) + search_params = { + "group_id": group_id, + "question": data, + "return_raw_results": True + } + + try: + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + except Exception as e: + logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) + retrieve_info, question, raw_results = "", data, [] + + + try: + # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', + # 'input_summary',RetrieveSummaryResponse) + # logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}") + summary_result = await summary_prompt(state, retrieve_info, retrieve_info) + summary = summary_result[0] + except Exception as e: + logger.error( f"Input_Summary failed: {e}", exc_info=True ) + summary= { + "status": "fail", + "summary_result": "信息不足,无法回答", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "error": str(e) + } + end = time.time() + try: + duration = end - start + except Exception: + duration = 0.0 + log_time('检索', duration) + return {"summary":summary} + +async def Retrieve_Summary(state: ReadState)-> ReadState: + retrieve=state.get("retrieve", '') + history = await summary_history( state) + import json + with open("检索.json","w",encoding='utf-8') as f: + f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) + retrieve=retrieve.get("Expansion_issue", []) + start=time.time() + retrieve_info_str=[] + for data in retrieve: + if data=='': + retrieve_info_str='' + else: + for key, value in data.items(): + if key=='Answer_Small': + for i in value: + retrieve_info_str.append(i) + retrieve_info_str=list(set(retrieve_info_str)) + retrieve_info_str='\n'.join(retrieve_info_str) + + aimessages=await summary_llm(state,history,retrieve_info_str, + 'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") + if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": + await summary_redis_save(state, aimessages) + if aimessages == '': + aimessages = '信息不足,无法回答' + logger.info(f"Summary after retrieval: {aimessages}") + end = time.time() + try: + duration = end - start + except Exception: + duration = 0.0 + log_time('Retrieval summary', duration) + + # 修复协程调用 - 先await,然后访问返回值 + summary_result = await summary_prompt(state, aimessages, retrieve_info_str) + summary = summary_result[1] + return {"summary":summary} + + +async def Summary(state: ReadState)-> ReadState: + start=time.time() + query = state.get("data", '') + verify=state.get("verify", '') + verify_expansion_issue=verify.get("verified_data", '') + retrieve_info_str='' + for data in verify_expansion_issue: + for key, value in data.items(): + if key=='answer_small': + for i in value: + retrieve_info_str+=i+'\n' + history=await summary_history(state) + + data = { + "query": query, + "history": history, + "retrieve_info": retrieve_info_str + } + aimessages=await summary_llm(state,history,data, + 'summary_prompt.jinja2','summary',SummaryResponse,0) + + + if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": + await summary_redis_save(state, aimessages) + if aimessages == '': + aimessages = '信息不足,无法回答' + try: + duration = time.time() - start + except Exception: + duration = 0.0 + log_time('Retrieval summary', duration) + + # 修复协程调用 - 先await,然后访问返回值 + summary_result = await summary_prompt(state, aimessages, retrieve_info_str) + summary = summary_result[1] + return {"summary":summary} + +async def Summary_fails(state: ReadState)-> ReadState: + storage_type=state.get("storage_type", '') + user_rag_memory_id=state.get("user_rag_memory_id", '') + result= { + "status": "success", + "summary_result": "没有相关数据", + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + return {"summary":result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py deleted file mode 100644 index 4727fb9c..00000000 --- a/api/app/core/memory/agent/langgraph_graph/nodes/tool_node.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -Tool execution node for LangGraph workflow. - -This module provides the ToolExecutionNode class which wraps tool execution -with parameter transformation logic using the ParameterBuilder service. -""" - -import logging -import time -from typing import Any, Callable, Dict - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_content_payload, - extract_tool_call_id, -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage -from langgraph.prebuilt import ToolNode - -logger = logging.getLogger(__name__) - - -class ToolExecutionNode: - """ - Custom LangGraph node that wraps tool execution with parameter transformation. - - This node extracts content from previous tool results, transforms parameters - based on tool type using ParameterBuilder, and invokes the tool with the - correct argument structure. - - Attributes: - tool_node: LangGraph ToolNode wrapping the actual tool - id: Node identifier for message IDs - tool_name: Name of the tool being executed - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - memory_config: MemoryConfig object containing all configuration - """ - - def __init__( - self, - tool: Callable, - node_id: str, - namespace: str, - search_switch: str, - apply_id: str, - group_id: str, - parameter_builder: ParameterBuilder, - storage_type: str, - user_rag_memory_id: str, - memory_config: MemoryConfig, - ): - """ - Initialize the tool execution node. - - Args: - tool: The tool function to execute - node_id: Identifier for this node (used in message IDs) - namespace: Namespace for session management - search_switch: Search routing parameter - apply_id: Application identifier - group_id: Group identifier - parameter_builder: Service for building tool-specific arguments - storage_type: Storage type for the workspace - user_rag_memory_id: User RAG memory identifier - memory_config: MemoryConfig object containing all configuration - """ - self.tool_node = ToolNode([tool]) - self.id = node_id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.parameter_builder = parameter_builder - self.storage_type = storage_type - self.user_rag_memory_id = user_rag_memory_id - self.memory_config = memory_config - - logger.info( - f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'" - ) - - async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: - """ - Execute the tool with transformed parameters. - - This method: - 1. Extracts the last message from state - 2. Extracts tool call ID using state extractors - 3. Extracts content payload using state extractors - 4. Builds tool arguments using parameter builder - 5. Constructs AIMessage with tool_calls - 6. Invokes the tool and returns the result - - Args: - state: LangGraph state dictionary - - Returns: - Updated state with tool result in messages - """ - messages = state.get("messages", []) - logger.debug( self.tool_name) - - if not messages: - logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state") - return {"messages": [AIMessage(content="Error: No messages in state")]} - - last_message = messages[-1] - logger.debug( - f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}" - ) - - try: - # Extract tool call ID using state extractors - tool_call_id = extract_tool_call_id(last_message) - logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}") - - except ValueError as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}" - ) - return {"messages": [AIMessage(content=f"Error: {str(e)}")]} - - try: - # Extract content payload using state extractors - content = extract_content_payload(last_message) - logger.debug( - f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}" - ) - # Log raw message content for debugging - if hasattr(last_message, 'content'): - raw = last_message.content - logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}") - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}", - exc_info=True - ) - content = {} - - try: - # Build tool arguments using parameter builder - tool_args = self.parameter_builder.build_tool_args( - tool_name=self.tool_name, - content=content, - tool_call_id=tool_call_id, - search_switch=self.search_switch, - apply_id=self.apply_id, - group_id=self.group_id, - memory_config=self.memory_config, - storage_type=self.storage_type, - user_rag_memory_id=self.user_rag_memory_id, - ) - logger.debug( - f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}" - ) - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}", - exc_info=True - ) - return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]} - - # Construct tool input message - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": f"{self.id}_{tool_call_id}", - }] - ) - ] - } - - try: - # Invoke the tool - result = await self.tool_node.ainvoke(tool_input) - - logger.debug( - f"[ToolExecutionNode] {self.id} - Tool execution completed" - ) - - # Check for error in tool response - error_entry = None - if result and "messages" in result: - for msg in result["messages"]: - if hasattr(msg, 'content'): - try: - import json - content = msg.content - if isinstance(content, str): - parsed = json.loads(content) - if isinstance(parsed, dict) and "error" in parsed: - error_msg = parsed["error"] - logger.warning( - f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}" - ) - error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id} - except (json.JSONDecodeError, TypeError): - pass - - # Return result with error tracking if error was found - if error_entry: - result["errors"] = [error_entry] - - return result - - except Exception as e: - logger.error( - f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}", - exc_info=True - ) - # Track error in state and return error message - from langchain_core.messages import ToolMessage - error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id} - return { - "messages": [ - ToolMessage( - content=f"Error executing tool: {str(e)}", - tool_call_id=f"{self.id}_{tool_call_id}" - ) - ], - "errors": [error_entry] - } diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py new file mode 100644 index 00000000..f3a39afb --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -0,0 +1,85 @@ + +from app.core.logging_config import get_agent_logger +from app.db import get_db + +from app.core.memory.agent.models.verification_models import VerificationResult +from app.core.memory.agent.utils.llm_tools import ( + PROJECT_ROOT_, + ReadState, +) +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin + +template_root = PROJECT_ROOT_ + '/agent/utils/prompt' +db_session = next(get_db()) +logger = get_agent_logger(__name__) + +class VerificationNodeService(LLMServiceMixin): + """验证节点服务类""" + + def __init__(self): + super().__init__() + self.template_service = TemplateService(template_root) + +# 创建全局服务实例 +verification_service = VerificationNodeService() + +async def Verify_prompt(state: ReadState,messages_deal): + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + data = state.get('data', '') + Verify_result = { + "status": messages_deal.split_result, + "verified_data": messages_deal.expansion_issue, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "_intermediate": { + "type": "verification", + "title": "Data Verification", + "result": messages_deal.split_result, + "reason": messages_deal.reason, + "query": data, + "verified_count": len(messages_deal.expansion_issue), + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id + } + } + return Verify_result +async def Verify(state: ReadState): + content = state.get('data', '') + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', None) + + history = await SessionService(store).get_history(group_id, group_id, group_id) + + retrieve = state.get("retrieve", '') + retrieve = retrieve.get("Expansion_issue", []) + messages = { + "Query": content, + "Expansion_issue": retrieve + } + + system_prompt = await verification_service.template_service.render_template( + template_name='split_verify_prompt.jinja2', + operation_name='split_verify_prompt', + history=history, + sentence=messages + ) + + # 使用优化的LLM服务 + structured = await verification_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=VerificationResult, + fallback_value={ + "split_result": "fail", + "expansion_issue": [], + "reason": "验证失败" + } + ) + + result = await Verify_prompt(state, structured) + return {"verify": result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py new file mode 100644 index 00000000..8421d059 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -0,0 +1,50 @@ + +from app.core.memory.agent.utils.llm_tools import WriteState +from app.core.memory.agent.utils.write_tools import write +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) +async def write_node(state: WriteState) -> WriteState: + """ + Write data to the database/file system. + + Args: + ctx: FastMCP context for dependency injection + content: Data content to write + user_id: User identifier + apply_id: Application identifier + group_id: Group identifier + memory_config: MemoryConfig object containing all configuration + + Returns: + dict: Contains 'status', 'saved_to', and 'data' fields + """ + content=state.get('data','') + group_id=state.get('group_id','') + memory_config=state.get('memory_config', '') + try: + result=await write( + content=content, + user_id=group_id, + apply_id=group_id, + group_id=group_id, + memory_config=memory_config, + ) + logger.info(f"Write completed successfully! Config: {memory_config.config_name}") + + write_result= { + "status": "success", + "data": content, + "config_id": memory_config.config_id, + "config_name": memory_config.config_name, + } + return {"write_result":write_result} + + + except Exception as e: + logger.error(f"Data_write failed: {e}", exc_info=True) + write_result= { + "status": "error", + "message": str(e), + } + return {"write_result": write_result} diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index c29b5d86..19011a5f 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -1,469 +1,177 @@ -import json -import os -import re -import time -import warnings +#!/usr/bin/env python3 from contextlib import asynccontextmanager -from typing import Literal -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.nodes import ( - ToolExecutionNode, - create_input_message, -) -from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder -from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState -from app.core.memory.agent.utils.multimodal import MultimodalProcessor -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from langchain_core.messages import AIMessage -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.constants import END, START +from langchain_core.messages import HumanMessage +from langgraph.constants import START, END from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode - -logger = get_agent_logger(__name__) - -warnings.filterwarnings("ignore", category=RuntimeWarning) -load_dotenv() -redishost=os.getenv("REDISHOST") -redisport=os.getenv('REDISPORT') -redisdb=os.getenv('REDISDB') -redispassword=os.getenv('REDISPASSWORD') -counter = COUNTState(limit=3) - -# Update loop count in workflow -async def update_loop_count(state): - """Update loop counter""" - current_count = state.get("loop_count", 0) - return {"loop_count": current_count + 1} -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: - messages = state["messages"] +from app.db import get_db +from app.services.memory_config_service import MemoryConfigService - # Add boundary check - if not messages: - return END - counter.add(1) # Increment by 1 +from app.core.memory.agent.utils.llm_tools import ReadState +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node +from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( + Split_The_Problem, + Problem_Extension, +) +from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( + retrieve, +) +from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( + Input_Summary, + Retrieve_Summary, + Summary_fails, + Summary, +) +from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify +from app.core.memory.agent.langgraph_graph.routing.routers import ( + Split_continue, + Retrieve_continue, + Verify_continue, +) - loop_count = counter.get_total() - logger.debug(f"[should_continue] Current loop count: {loop_count}") - - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"Status tools: {status_tools}") - - if "success" in status_tools: - counter.reset() - return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # Maximum loop count is 3 - return "content_input" - else: - counter.reset() - return "Summary_fails" - else: - # Add default return value to avoid returning None - counter.reset() - return "Summary" # Default based on business requirements - - -def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # Add default return value to avoid returning None - return 'Retrieve_Summary' # Default based on business logic - - -def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing based on search_switch value. - - Args: - state: State dictionary containing search_switch - - Returns: - Next node to execute - """ - logger.debug(f"Split_continue state: {state}") - - # Direct dictionary access instead of regex parsing - search_switch = state.get("search_switch") - - # Handle case where search_switch might be in messages - if search_switch is None and "messages" in state: - messages = state.get("messages", []) - if messages: - last_message = messages[-1] - # Try to extract from tool_calls args - if hasattr(last_message, "tool_calls") and last_message.tool_calls: - for tool_call in last_message.tool_calls: - if isinstance(tool_call, dict) and "args" in tool_call: - search_switch = tool_call["args"].get("search_switch") - break - - # Convert to string for comparison if needed - if search_switch is not None: - search_switch = str(search_switch) - if search_switch == '2': - return 'Input_Summary' - return 'Split_The_Problem' # Default case - - -class ProblemExtensionNode: - def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""): - self.tool_node = ToolNode([tool]) - self.id = id - self.tool_name = tool.name if hasattr(tool, 'name') else str(tool) - self.namespace = namespace - self.search_switch = search_switch - self.apply_id = apply_id - self.group_id = group_id - self.storage_type = storage_type - self.user_rag_memory_id = user_rag_memory_id - - async def __call__(self, state): - messages = state["messages"] - last_message = messages[-1] if messages else "" - logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}") - if self.tool_name == 'Input_Summary': - tool_call = re.findall("'id': '(.*?)'", str(last_message))[0] - else: - tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1] - - # Try to extract actual content payload from previous tool result - raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message) - extracted_payload = None - # Capture ToolMessage content field (supports single/double quotes), avoid greedy matching - m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S) - if m: - extracted_payload = m.group(1) - else: - # Fallback: use raw string directly - extracted_payload = raw_msg - - # Try to parse content as JSON first - try: - content = json.loads(extracted_payload) - except Exception: - # Try to extract JSON fragment from text and parse - parsed = None - candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S) - for cand in candidates: - try: - parsed = json.loads(cand) - break - except Exception: - continue - # If still fails, use raw string as content - content = parsed if parsed is not None else extracted_payload - - # Build correct parameters based on tool name - tool_args = {} - - if self.tool_name == "Verify": - # Verify tool requires context and usermessages parameters - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Retrieve": - # Retrieve tool requires context and usermessages parameters - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary": - # Summary tool requires string type context parameter - if isinstance(content, dict): - # Convert dict to JSON string - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == "Summary_fails": - # Summary_fails tool requires string type context parameter - if isinstance(content, dict): - # Convert dict to JSON string - tool_args["context"] = json.dumps(content, ensure_ascii=False) - else: - tool_args["context"] = str(content) - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - elif self.tool_name == 'Input_Summary': - tool_args["context"] = str(last_message) - tool_args["usermessages"] = str(tool_call) - tool_args["search_switch"] = str(self.search_switch) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - tool_args["storage_type"] = getattr(self, 'storage_type', "") - tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "") - elif self.tool_name == 'Retrieve_Summary': - # Retrieve_Summary expects dict directly, not JSON string - # content might be a JSON string, try to parse it - if isinstance(content, str): - try: - parsed_content = json.loads(content) - # Check if it has a "context" key - if isinstance(parsed_content, dict) and "context" in parsed_content: - tool_args["context"] = parsed_content["context"] - else: - tool_args["context"] = parsed_content - except json.JSONDecodeError: - # If parsing fails, wrap the string - tool_args["context"] = {"content": content} - elif isinstance(content, dict): - # Check if content has a "context" key that needs unwrapping - if "context" in content: - tool_args["context"] = content["context"] - else: - tool_args["context"] = content - else: - tool_args["context"] = {"content": str(content)} - - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - else: - # Other tools use context parameter - if isinstance(content, dict): - tool_args["context"] = content - else: - tool_args["context"] = {"content": content} - tool_args["usermessages"] = str(tool_call) - tool_args["apply_id"] = str(self.apply_id) - tool_args["group_id"] = str(self.group_id) - - - tool_input = { - "messages": [ - AIMessage( - content="", - tool_calls=[{ - "name": self.tool_name, - "args": tool_args, - "id": self.id + f"{tool_call}", - }] - ) - ] - } - result = await self.tool_node.ainvoke(tool_input) - result_text = str(result) - - return {"messages": [AIMessage(content=result_text)]} @asynccontextmanager -async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None): - """ - Create a read graph workflow for memory operations. - - Args: - namespace: Namespace identifier - tools: MCP tools loaded from session - search_switch: Search mode switch ("0", "1", or "2") - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type (optional) - user_rag_memory_id: User RAG memory ID (optional) - """ - memory = InMemorySaver() - tool = [i.name for i in tools] - logger.info(f"Initializing read graph with tools: {tool}") - logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})") - - # Extract tool functions - Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None) - Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None) - Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None) - Verify_ = next((t for t in tools if t.name == "Verify"), None) - Summary_ = next((t for t in tools if t.name == "Summary"), None) - Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None) - Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None) - Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None) - - # Instantiate services - parameter_builder = ParameterBuilder() - multimodal_processor = MultimodalProcessor() - - # Create nodes using new modular components - Split_The_Problem_node = ToolNode([Split_The_Problem_]) - - Problem_Extension_node = ToolExecutionNode( - tool=Problem_Extension_, - node_id="Problem_Extension_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, +async def make_read_graph(): + """创建并返回 LangGraph 工作流""" + try: + # Build workflow graph + workflow = StateGraph(ReadState) + workflow.add_node("content_input", content_input_node) + workflow.add_node("Split_The_Problem", Split_The_Problem) + workflow.add_node("Problem_Extension", Problem_Extension) + workflow.add_node("Input_Summary", Input_Summary) + # workflow.add_node("Retrieve", retrieve_nodes) + workflow.add_node("Retrieve", retrieve) + workflow.add_node("Verify", Verify) + workflow.add_node("Retrieve_Summary", Retrieve_Summary) + workflow.add_node("Summary", Summary) + workflow.add_node("Summary_fails", Summary_fails) + + # 添加边 + workflow.add_edge(START, "content_input") + workflow.add_conditional_edges("content_input", Split_continue) + workflow.add_edge("Input_Summary", END) + workflow.add_edge("Split_The_Problem", "Problem_Extension") + workflow.add_edge("Problem_Extension", "Retrieve") + workflow.add_conditional_edges("Retrieve", Retrieve_continue) + workflow.add_edge("Retrieve_Summary", END) + workflow.add_conditional_edges("Verify", Verify_continue) + workflow.add_edge("Summary_fails", END) + workflow.add_edge("Summary", END) + + + '''-----''' + # workflow.add_edge("Retrieve", END) + + # 编译工作流 + graph = workflow.compile() + yield graph + + except Exception as e: + print(f"创建工作流失败: {e}") + raise + finally: + print("工作流创建完成") + +async def main(): + """主函数 - 运行工作流""" + message = "昨天有什么好看的电影" + group_id = '88a459f5_text09' # 组ID + storage_type = 'neo4j' # 存储类型 + search_switch = '1' # 搜索开关 + user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID + + # 获取数据库会话 + db_session = next(get_db()) + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=17, # 改为整数 + service_name="MemoryAgentService" ) + import time + start=time.time() + try: + async with make_read_graph() as graph: + config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id + ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} + # 获取节点更新信息 + _intermediate_outputs = [] + summary = '' + + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + print(f"处理节点: {node_name}") + + # 处理不同Summary节点的返回结构 + if 'Summary' in node_name: + if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: + summary = node_data['InputSummary']['summary_result'] + elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: + summary = node_data['RetrieveSummary']['summary_result'] + elif 'summary' in node_data and 'summary_result' in node_data['summary']: + summary = node_data['summary']['summary_result'] + elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: + summary = node_data['SummaryFails']['summary_result'] - Retrieve_node = ToolExecutionNode( - tool=Retrieve_, - node_id="Retrieve_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + spit_data = node_data.get('spit_data', {}).get('_intermediate', None) + if spit_data and spit_data != [] and spit_data != {}: + _intermediate_outputs.append(spit_data) + + # Problem_Extension 节点 + problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) + if problem_extension and problem_extension != [] and problem_extension != {}: + _intermediate_outputs.append(problem_extension) + + # Retrieve 节点 + retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) + if retrieve_node and retrieve_node != [] and retrieve_node != {}: + _intermediate_outputs.extend(retrieve_node) + + # Verify 节点 + verify_n = node_data.get('verify', {}).get('_intermediate', None) + if verify_n and verify_n != [] and verify_n != {}: + _intermediate_outputs.append(verify_n) - Verify_node = ToolExecutionNode( - tool=Verify_, - node_id="Verify_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) - - Summary_node = ToolExecutionNode( - tool=Summary_, - node_id="Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + + # Summary 节点 + summary_n = node_data.get('summary', {}).get('_intermediate', None) + if summary_n and summary_n != [] and summary_n != {}: + _intermediate_outputs.append(summary_n) - Summary_fails_node = ToolExecutionNode( - tool=Summary_fails_, - node_id="Summary_fails_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + # # 过滤掉空值 + # _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] + # + # # 优化搜索结果 + # print("=== 开始优化搜索结果 ===") + # optimized_outputs = merge_multiple_search_results(_intermediate_outputs) + # result=reorder_output_results(optimized_outputs) + # # 保存优化后的结果到文件 + # with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f: + # import json + # f.write(json.dumps(result, indent=4, ensure_ascii=False)) + # + print(f"=== 最终摘要 ===") + print(summary) + + except Exception as e: + import traceback + traceback.print_exc() - Retrieve_Summary_node = ToolExecutionNode( - tool=Retrieve_Summary_, - node_id="Retrieve_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) + end=time.time() + print(100*'y') + print(f"总耗时: {end-start}s") + print(100*'y') - Input_Summary_node = ToolExecutionNode( - tool=Input_Summary_, - node_id="Input_Summary_id", - namespace=namespace, - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - parameter_builder=parameter_builder, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_config=memory_config, - ) - async def content_input_node(state): - state_search_switch = state.get("search_switch", search_switch) - - tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem" - session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id" - - return await create_input_message( - state=state, - tool_name=tool_name, - session_id=f"{session_prefix}_{namespace}", - search_switch=search_switch, - apply_id=apply_id, - group_id=group_id, - multimodal_processor=multimodal_processor, - memory_config=memory_config, - ) - - - # Build workflow graph - workflow = StateGraph(ReadState) - workflow.add_node("content_input", content_input_node) - workflow.add_node("Split_The_Problem", Split_The_Problem_node) - workflow.add_node("Problem_Extension", Problem_Extension_node) - workflow.add_node("Retrieve", Retrieve_node) - workflow.add_node("Verify", Verify_node) - workflow.add_node("Summary", Summary_node) - workflow.add_node("Summary_fails", Summary_fails_node) - workflow.add_node("Retrieve_Summary", Retrieve_Summary_node) - workflow.add_node("Input_Summary", Input_Summary_node) - - # Add edges using imported routers - workflow.add_edge(START, "content_input") - workflow.add_conditional_edges("content_input", Split_continue) - workflow.add_edge("Input_Summary", END) - workflow.add_edge("Split_The_Problem", "Problem_Extension") - workflow.add_edge("Problem_Extension", "Retrieve") - workflow.add_conditional_edges("Retrieve", Retrieve_continue) - workflow.add_edge("Retrieve_Summary", END) - workflow.add_conditional_edges("Verify", Verify_continue) - workflow.add_edge("Summary_fails", END) - workflow.add_edge("Summary", END) - - graph = workflow.compile(checkpointer=memory) - yield graph +if __name__ == "__main__": + import asyncio + asyncio.run(main()) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/__init__.py b/api/app/core/memory/agent/langgraph_graph/routing/__init__.py deleted file mode 100644 index a9366bd0..00000000 --- a/api/app/core/memory/agent/langgraph_graph/routing/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph routing logic.""" - -from app.core.memory.agent.langgraph_graph.routing.routers import ( - Verify_continue, - Retrieve_continue, - Split_continue, -) - -__all__ = [ - "Verify_continue", - "Retrieve_continue", - "Split_continue", -] diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py index c8abd544..c0b01be1 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py @@ -1,123 +1,62 @@ -""" -Routing functions for LangGraph conditional edges. -This module provides routing functions that determine the next node to execute -based on state values. All functions return Literal types for type safety. -""" - -import logging -import re from typing import Literal -from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch +from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState -logger = logging.getLogger(__name__) -# Global counter for Verify routing +logger = get_agent_logger(__name__) counter = COUNTState(limit=3) - - -def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: +def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]: """ - Determine routing after Verify node based on verification result. - - This function checks the verification result in the last message and routes to: - - Summary: if verification succeeded - - content_input: if verification failed and retry limit not reached - - Summary_fails: if verification failed and retry limit reached - + Determine routing based on search_switch value. + Args: - state: LangGraph state containing messages - + state: State dictionary containing search_switch + Returns: - Next node name as Literal type + Next node to execute """ - messages = state.get("messages", []) - - # Boundary check - if not messages: - logger.warning("[Verify_continue] No messages in state, defaulting to Summary") - counter.reset() - return "Summary" - - # Increment counter - counter.add(1) + logger.debug(f"Split_continue state: {state}") + search_switch = state.get('search_switch', '') + if search_switch is not None: + search_switch = str(search_switch) + if search_switch == '2': + return 'Input_Summary' + return 'Split_The_Problem' # 默认情况 + +def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: + """ + Determine routing based on search_switch value. + + Args: + state: State dictionary containing search_switch + + Returns: + Next node to execute + """ + search_switch = state.get('search_switch', '') + if search_switch is not None: + search_switch = str(search_switch) + if search_switch == '0': + return 'Verify' + elif search_switch == '1': + return 'Retrieve_Summary' + return 'Retrieve_Summary' # Default based on business logic +def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: + status=state.get('verify', '')['status'] loop_count = counter.get_total() - logger.debug(f"[Verify_continue] Current loop count: {loop_count}") - - # Extract verification result from last message - last_message = messages[-1] - last_message_str = str(last_message).replace('\\', '') - status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str) - logger.debug(f"[Verify_continue] Status tools: {status_tools}") - - # Route based on verification result - if "success" in status_tools: + print(status) + if "success" in status: counter.reset() return "Summary" - elif "failed" in status_tools: - if loop_count < 2: # Max retry count is 2 + elif "failed" in status: + if loop_count < 2: # Maximum loop count is 3 return "content_input" else: counter.reset() return "Summary_fails" - else: - # Default to Summary if status is unclear - counter.reset() - return "Summary" - - -def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]: - """ - Determine routing after Retrieve node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '0': Route to Verify (verification needed) - - search_switch == '1': Route to Retrieve_Summary (direct summary) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - search_switch = extract_search_switch(state) - - logger.debug(f"[Retrieve_continue] search_switch: {search_switch}") - - if search_switch == '0': - return 'Verify' - elif search_switch == '1': - return 'Retrieve_Summary' - - # Default to Retrieve_Summary - logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary") - return 'Retrieve_Summary' - - -def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]: - """ - Determine routing after content_input node based on search_switch value. - - This function routes based on the search_switch parameter: - - search_switch == '2': Route to Input_Summary (direct input summary) - - Otherwise: Route to Split_The_Problem (problem decomposition) - - Args: - state: LangGraph state dictionary - - Returns: - Next node name as Literal type - """ - logger.debug(f"[Split_continue] state keys: {state.keys()}") - - search_switch = extract_search_switch(state) - - logger.debug(f"[Split_continue] search_switch: {search_switch}") - - if search_switch == '2': - return 'Input_Summary' - - # Default to Split_The_Problem - return 'Split_The_Problem' + # else: + # # Add default return value to avoid returning None + # counter.reset() + # return "Summary" # Default based on business requirements diff --git a/api/app/core/memory/agent/langgraph_graph/state/__init__.py b/api/app/core/memory/agent/langgraph_graph/state/__init__.py deleted file mode 100644 index 279c6463..00000000 --- a/api/app/core/memory/agent/langgraph_graph/state/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LangGraph state management utilities.""" - -from app.core.memory.agent.langgraph_graph.state.extractors import ( - extract_search_switch, - extract_tool_call_id, - extract_content_payload, -) - -__all__ = [ - "extract_search_switch", - "extract_tool_call_id", - "extract_content_payload", -] diff --git a/api/app/core/memory/agent/langgraph_graph/state/extractors.py b/api/app/core/memory/agent/langgraph_graph/state/extractors.py deleted file mode 100644 index f5a32f5d..00000000 --- a/api/app/core/memory/agent/langgraph_graph/state/extractors.py +++ /dev/null @@ -1,179 +0,0 @@ -""" -State extraction utilities for type-safe access to LangGraph state values. - -This module provides utility functions for extracting values from LangGraph state -dictionaries with proper error handling and sensible defaults. -""" - -import json -import logging -from typing import Any, Optional - -logger = logging.getLogger(__name__) - -def extract_search_switch(state: dict) -> Optional[str]: - """ - Extract search_switch from state or messages. - """ - - search_switch = state.get("search_switch") - - if search_switch is not None: - return str(search_switch) - - # Try to extract from messages - messages = state.get("messages", []) - if not messages: - return None - - # 从最新的消息开始查找 - for message in reversed(messages): - # 尝试从 tool_calls 中提取 - if hasattr(message, "tool_calls") and message.tool_calls: - for tool_call in message.tool_calls: - if isinstance(tool_call, dict): - # 从 tool_call 的 args 中提取 - if "args" in tool_call and isinstance(tool_call["args"], dict): - search_switch = tool_call["args"].get("search_switch") - if search_switch is not None: - return str(search_switch) - # 直接从 tool_call 中提取 - search_switch = tool_call.get("search_switch") - if search_switch is not None: - return str(search_switch) - - # 尝试从 content 中提取(如果是 JSON 格式) - if hasattr(message, "content"): - try: - import json - if isinstance(message.content, str): - content_data = json.loads(message.content) - if isinstance(content_data, dict): - search_switch = content_data.get("search_switch") - if search_switch is not None: - return str(search_switch) - except (json.JSONDecodeError, ValueError): - pass - - return None - - -def extract_tool_call_id(message: Any) -> str: - """ - Extract tool call ID from message using structured attributes. - - This function extracts the tool call ID from a message object, handling both - direct attribute access and tool_calls list structures. - - Args: - message: Message object (typically ToolMessage or AIMessage) - - Returns: - Tool call ID as string - - Raises: - ValueError: If tool call ID cannot be extracted - - Examples: - >>> message = ToolMessage(content="...", tool_call_id="call_123") - >>> extract_tool_call_id(message) - 'call_123' - """ - # Try direct attribute access for ToolMessage - if hasattr(message, "tool_call_id"): - tool_call_id = message.tool_call_id - if tool_call_id: - return str(tool_call_id) - - # Try extracting from tool_calls list for AIMessage - if hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "id" in tool_call: - return str(tool_call["id"]) - - # Try extracting from id attribute - if hasattr(message, "id"): - message_id = message.id - if message_id: - return str(message_id) - - # If all else fails, raise an error - raise ValueError(f"Could not extract tool call ID from message: {type(message)}") - - -def extract_content_payload(message: Any) -> Any: - """ - Extract content payload from ToolMessage, parsing JSON if needed. - - This function extracts the content from a message and attempts to parse it as JSON - if it appears to be a JSON string. It handles various message formats and provides - sensible fallbacks. - - Args: - message: Message object (typically ToolMessage) - - Returns: - Parsed content (dict, list, or str) - - Examples: - >>> message = ToolMessage(content='{"key": "value"}') - >>> extract_content_payload(message) - {'key': 'value'} - - >>> message = ToolMessage(content='plain text') - >>> extract_content_payload(message) - 'plain text' - """ - # Extract raw content - # For ToolMessages (responses from tools), extract from content - if hasattr(message, "content"): - raw_content = message.content - logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}") - - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(raw_content, list): - for block in raw_content: - if isinstance(block, dict) and block.get('type') == 'text': - raw_content = block.get('text', '') - logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}") - break - - # If content is empty and this is an AIMessage with tool_calls, - # extract from args (this handles the initial tool call from content_input) - if not raw_content and hasattr(message, "tool_calls") and message.tool_calls: - tool_call = message.tool_calls[0] - if isinstance(tool_call, dict) and "args" in tool_call: - return tool_call["args"] - else: - raw_content = str(message) - - # If content is already a dict or list, return it directly - if isinstance(raw_content, (dict, list)): - logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}") - return raw_content - - # Try to parse as JSON - if isinstance(raw_content, str): - # First, try direct JSON parsing - try: - parsed = json.loads(raw_content) - logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}") - return parsed - except (json.JSONDecodeError, ValueError): - pass - - # If that fails, try to extract JSON from the string - # This handles cases where the content is embedded in a larger string - import re - json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL) - for candidate in json_candidates: - try: - parsed = json.loads(candidate) - logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}") - return parsed - except (json.JSONDecodeError, ValueError): - continue - - # If all parsing attempts fail, return the raw content - logger.info(f"extract_content_payload: returning raw content (parsing failed)") - return raw_content diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py new file mode 100644 index 00000000..ce6d5dd4 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -0,0 +1,320 @@ +import asyncio +import json +from datetime import datetime, timedelta + + +from langchain.tools import tool +from pydantic import BaseModel, Field + + +from app.core.memory.src.search import ( + search_by_temporal, + search_by_keyword_temporal, +) + +def extract_tool_message_content(response): + """从agent响应中提取ToolMessage内容和工具名称""" + messages = response.get('messages', []) + + for message in messages: + if hasattr(message, 'tool_call_id') and hasattr(message, 'content'): + # 这是一个ToolMessage + tool_content = message.content + tool_name = None + + # 尝试获取工具名称 + if hasattr(message, 'name'): + tool_name = message.name + elif hasattr(message, 'tool_name'): + tool_name = message.tool_name + + try: + # 解析JSON内容 + parsed_content = json.loads(tool_content) + return { + 'tool_name': tool_name, + 'content': parsed_content + } + except json.JSONDecodeError: + # 如果不是JSON格式,直接返回内容 + return { + 'tool_name': tool_name, + 'content': tool_content + } + + return None + + +class TimeRetrievalInput(BaseModel): + """时间检索工具的输入模式""" + context: str = Field(description="用户输入的查询内容") + group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") + +def create_time_retrieval_tool(group_id: str): + """ + 创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) + """ + + def clean_temporal_result_fields(data): + """ + 清理时间搜索结果中不需要的字段,并修改结构 + + Args: + data: 要清理的数据 + + Returns: + 清理后的数据 + """ + # 需要过滤的字段列表 + fields_to_remove = { + 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', + 'valid_at', 'invalid_at', 'statement_ids' + } + + if isinstance(data, dict): + cleaned = {} + for key, value in data.items(): + if key == 'statements' and isinstance(value, dict) and 'statements' in value: + # 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]} + cleaned_value = clean_temporal_result_fields(value) + # 进一步将内部的 statements 改为 time_search + if 'statements' in cleaned_value: + cleaned['results'] = { + 'time_search': cleaned_value['statements'] + } + else: + cleaned['results'] = cleaned_value + elif key not in fields_to_remove: + cleaned[key] = clean_temporal_result_fields(value) + return cleaned + elif isinstance(data, list): + return [clean_temporal_result_fields(item) for item in data] + else: + return data + + @tool + def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: + """ + 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 + 显式接收参数: + - context: 查询上下文内容 + - start_date: 开始时间(可选,格式:YYYY-MM-DD) + - end_date: 结束时间(可选,格式:YYYY-MM-DD) + - group_id_param: 组ID(可选,用于覆盖默认组ID) + - clean_output: 是否清理输出中的元数据字段 + -end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + """ + async def _async_search(): + # 使用传入的参数或默认值 + actual_group_id = group_id_param or group_id + actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") + actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") + + # 基本时间搜索 + results = await search_by_temporal( + group_id=actual_group_id, + start_date=actual_start_date, + end_date=actual_end_date, + limit=10 + ) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_temporal_result_fields(results) + else: + cleaned_results = results + + return json.dumps(cleaned_results, ensure_ascii=False, indent=2) + + return asyncio.run(_async_search()) + + @tool + def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: + """ + 优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段 + 显式接收参数: + - context: 查询内容 + - days_back: 向前搜索的天数,默认7天 + - start_date: 开始时间(可选,格式:YYYY-MM-DD) + - end_date: 结束时间(可选,格式:YYYY-MM-DD) + - clean_output: 是否清理输出中的元数据字段 + - end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + """ + async def _async_search(): + actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") + actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") + + # 关键词时间搜索 + results = await search_by_keyword_temporal( + query_text=context, + group_id=group_id, + start_date=actual_start_date, + end_date=actual_end_date, + limit=15 + ) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_temporal_result_fields(results) + else: + cleaned_results = results + + return json.dumps(cleaned_results, ensure_ascii=False, indent=2) + + return asyncio.run(_async_search()) + + return TimeRetrievalWithGroupId + + +def create_hybrid_retrieval_tool_async(memory_config, **search_params): + """ + 创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段 + + Args: + memory_config: 内存配置对象 + **search_params: 搜索参数,包含group_id, limit, include等 + """ + + def clean_result_fields(data): + """ + 递归清理结果中不需要的字段 + + Args: + data: 要清理的数据(可能是字典、列表或其他类型) + + Returns: + 清理后的数据 + """ + # 需要过滤的字段列表 + fields_to_remove = { + 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', + 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', + 'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary" + } + + if isinstance(data, dict): + # 对字典进行清理 + cleaned = {} + for key, value in data.items(): + if key not in fields_to_remove: + cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据 + return cleaned + elif isinstance(data, list): + # 对列表中的每个元素进行清理 + return [clean_result_fields(item) for item in data] + else: + # 其他类型直接返回 + return data + + @tool + async def HybridSearch( + context: str, + search_type: str = "hybrid", + limit: int = 10, + group_id: str = None, + rerank_alpha: float = 0.6, + use_forgetting_rerank: bool = False, + use_llm_rerank: bool = False, + clean_output: bool = True # 新增:是否清理输出字段 + ) -> str: + """ + 优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段 + + Args: + context: 查询内容 + search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') + limit: 结果数量限制 + group_id: 组ID,用于过滤搜索结果 + rerank_alpha: 重排序权重参数 + use_forgetting_rerank: 是否使用遗忘重排序 + use_llm_rerank: 是否使用LLM重排序 + clean_output: 是否清理输出中的元数据字段 + """ + try: + # 导入run_hybrid_search函数 + from app.core.memory.src.search import run_hybrid_search + + # 合并参数,优先使用传入的参数 + final_params = { + "query_text": context, + "search_type": search_type, + "group_id": group_id or search_params.get("group_id"), + "limit": limit or search_params.get("limit", 10), + "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), + "output_path": None, # 不保存到文件 + "memory_config": memory_config, + "rerank_alpha": rerank_alpha, + "use_forgetting_rerank": use_forgetting_rerank, + "use_llm_rerank": use_llm_rerank + } + + # 执行混合检索 + raw_results = await run_hybrid_search(**final_params) + + # 清理结果中不需要的字段 + if clean_output: + cleaned_results = clean_result_fields(raw_results) + else: + cleaned_results = raw_results + + # 格式化返回结果 + formatted_results = { + "search_query": context, + "search_type": search_type, + "results": cleaned_results + } + + return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str) + + except Exception as e: + error_result = { + "error": f"混合检索失败: {str(e)}", + "search_query": context, + "search_type": search_type, + "timestamp": datetime.now().isoformat() + } + return json.dumps(error_result, ensure_ascii=False, indent=2) + + return HybridSearch + + +def create_hybrid_retrieval_tool_sync(memory_config, **search_params): + """ + 创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段 + + Args: + memory_config: 内存配置对象 + **search_params: 搜索参数 + """ + @tool + def HybridSearchSync( + context: str, + search_type: str = "hybrid", + limit: int = 10, + group_id: str = None, + clean_output: bool = True + ) -> str: + """ + 优化的混合检索工具(同步版本),自动过滤不需要的元数据字段 + + Args: + context: 查询内容 + search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') + limit: 结果数量限制 + group_id: 组ID,用于过滤搜索结果 + clean_output: 是否清理输出中的元数据字段 + """ + async def _async_search(): + # 创建异步工具并执行 + async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) + return await async_tool.ainvoke({ + "context": context, + "search_type": search_type, + "limit": limit, + "group_id": group_id, + "clean_output": clean_output + }) + + return asyncio.run(_async_search()) + + return HybridSearchSync \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index ae333e84..5a6f1e28 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,30 +1,32 @@ + import asyncio -import json import sys import warnings from contextlib import asynccontextmanager -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.llm_tools import WriteState -from app.schemas.memory_config_schema import MemoryConfig -from langchain_core.messages import AIMessage + +from langchain_core.messages import HumanMessage from langgraph.constants import END, START from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode + + +from app.db import get_db +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.llm_tools import WriteState +from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write +from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) - logger = get_agent_logger(__name__) if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - @asynccontextmanager -async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig): +async def make_write_graph(): """ Create a write graph workflow for memory operations. - + Args: user_id: User identifier tools: MCP tools loaded from session @@ -32,43 +34,8 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me group_id: Group identifier memory_config: MemoryConfig object containing all configuration """ - logger.info("Loading MCP tools: %s", [t.name for t in tools]) - logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})") - - data_write_tool = next((t for t in tools if t.name == "Data_write"), None) - - if not data_write_tool: - logger.error("Data_write tool not found", exc_info=True) - raise ValueError("Data_write tool not found") - - write_node = ToolNode([data_write_tool]) - - async def call_model(state): - messages = state["messages"] - last_message = messages[-1] - content = last_message[1] if isinstance(last_message, tuple) else last_message.content - - # Call Data_write directly with memory_config - write_params = { - "content": content, - "apply_id": apply_id, - "group_id": group_id, - "user_id": user_id, - "memory_config": memory_config, - } - logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}") - - write_result = await data_write_tool.ainvoke(write_params) - - if isinstance(write_result, dict): - result_content = write_result.get("data", str(write_result)) - else: - result_content = str(write_result) - logger.info("Write content: %s", result_content) - return {"messages": [AIMessage(content=result_content)]} - workflow = StateGraph(WriteState) - workflow.add_node("content_input", call_model) + workflow.add_node("content_input", content_input_write) workflow.add_node("save_neo4j", write_node) workflow.add_edge(START, "content_input") workflow.add_edge("content_input", "save_neo4j") @@ -76,5 +43,45 @@ async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: Me graph = workflow.compile() - yield graph + + +async def main(): + """主函数 - 运行工作流""" + message = "今天周一" + group_id = 'new_2025test1103' # 组ID + + + # 获取数据库会话 + db_session = next(get_db()) + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=17, # 改为整数 + service_name="MemoryAgentService" + ) + try: + async with make_write_graph() as graph: + config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} + + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + if 'save_neo4j'==node_name: + massages=node_data + massages=massages.get('write_result')['status'] + print(massages) # | 更新数据: {node_data} + + except Exception as e: + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/__init__.py b/api/app/core/memory/agent/mcp_server/__init__.py deleted file mode 100644 index efd03773..00000000 --- a/api/app/core/memory/agent/mcp_server/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -MCP Server package for memory agent. - -This package provides the FastMCP server implementation with context-based -dependency injection for tool functions. - -Package structure: -- server: FastMCP server initialization and context setup -- tools: MCP tool implementations -- models: Pydantic response models -- services: Business logic services -""" -# from app.core.memory.agent.mcp_server.server import ( -# mcp, -# initialize_context, -# main, -# get_context_resource -# ) - -# # Import tools to register them (but don't export them) -# from app.core.memory.agent.mcp_server import tools - -# __all__ = [ -# 'mcp', -# 'initialize_context', -# 'main', -# 'get_context_resource', -# ] \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/mcp_instance.py b/api/app/core/memory/agent/mcp_server/mcp_instance.py deleted file mode 100644 index 3a2eeb78..00000000 --- a/api/app/core/memory/agent/mcp_server/mcp_instance.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -MCP Server Instance - -This module contains the FastMCP server instance that is shared across all modules. -It's in a separate file to avoid circular import issues. -""" -from mcp.server.fastmcp import FastMCP - -# Initialize FastMCP server instance -# This instance is shared across all tool modules -mcp = FastMCP('data_flow') diff --git a/api/app/core/memory/agent/mcp_server/server.py b/api/app/core/memory/agent/mcp_server/server.py deleted file mode 100644 index 26f24824..00000000 --- a/api/app/core/memory/agent/mcp_server/server.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -MCP Server initialization with FastMCP context setup. - -This module initializes the FastMCP server and registers shared resources -in the context for dependency injection into tool functions. -""" -import os -import sys - -from app.core.config import settings -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.services.search_service import SearchService -from app.core.memory.agent.mcp_server.services.session_service import SessionService -from app.core.memory.agent.mcp_server.services.template_service import TemplateService -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.redis_tool import store - -logger = get_agent_logger(__name__) - - -def get_context_resource(ctx, resource_name: str): - """ - Helper function to retrieve a resource from the FastMCP context. - - Args: - ctx: FastMCP Context object (passed to tool functions) - resource_name: Name of the resource to retrieve - - Returns: - The requested resource - - Raises: - AttributeError: If the resource doesn't exist - - Example: - @mcp.tool() - async def my_tool(ctx: Context): - template_service = get_context_resource(ctx, 'template_service') - llm_client = get_context_resource(ctx, 'llm_client') - """ - if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None: - raise RuntimeError("Context does not have fastmcp attribute") - - if not hasattr(ctx.fastmcp, resource_name): - raise AttributeError( - f"Resource '{resource_name}' not found in context. " - f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}" - ) - - return getattr(ctx.fastmcp, resource_name) - - -def initialize_context(): - """ - Initialize and register shared resources in FastMCP context. - - This function sets up all shared resources that will be available - to tool functions via dependency injection through the context parameter. - - Resources are stored as attributes on the FastMCP instance and can be - accessed via ctx.fastmcp in tool functions. - - Resources registered: - - session_store: RedisSessionStore for session management - - llm_client: LLM client for structured API calls - - app_settings: Application settings (renamed to avoid conflict with FastMCP settings) - - template_service: Service for template rendering - - search_service: Service for hybrid search - - session_service: Service for session operations - """ - try: - # Register Redis session store - logger.info("Registering session_store in context") - mcp.session_store = store - - # Note: LLM client is NOT loaded at server startup - # It should be loaded dynamically when needed, with config_id passed explicitly - # to make_write_graph or make_read_graph functions - logger.info("LLM client will be loaded dynamically with config_id when needed") - mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id - - # Register application settings (renamed to avoid conflict with FastMCP's settings) - logger.info("Registering app_settings in context") - mcp.app_settings = settings - - # Register template service - template_root = PROJECT_ROOT_ + '/agent/utils/prompt' - # logger.info(f"Registering template_service in context with root: {template_root}") - template_service = TemplateService(template_root) - mcp.template_service = template_service - - # Register search service - # logger.info("Registering search_service in context") - search_service = SearchService() - mcp.search_service = search_service - - # Register session service - # logger.info("Registering session_service in context") - session_service = SessionService(store) - mcp.session_service = session_service - - # logger.info("All context resources registered successfully") - - except Exception as e: - logger.error(f"Failed to initialize context: {e}", exc_info=True) - raise - - -def main(): - """ - Main entry point for the MCP server. - - Initializes context and starts the server with SSE transport. - """ - try: - logger.info("Starting MCP server initialization") - # Initialize context resources - initialize_context() - - # Import and register tools (imports trigger tool registration) - from app.core.memory.agent.mcp_server.tools import ( # noqa: F401 - data_tools, - problem_tools, - retrieval_tools, - summary_tools, - verification_tools, - ) - - # Tools are registered via imports above - - # Get MCP port from environment (default: 8081) - mcp_port = int(os.getenv("MCP_PORT", "8081")) - logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") - - # Configure DNS rebinding protection for Docker container compatibility - from mcp.server.fastmcp.server import TransportSecuritySettings - - # Disable DNS rebinding protection to allow Docker container hostnames - # This allows containers to connect using service names like 'mcp-server' - mcp.settings.transport_security = TransportSecuritySettings( - enable_dns_rebinding_protection=False, - ) - logger.info("DNS rebinding protection: disabled for Docker container compatibility") - - # logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport") - - # Run the server with SSE transport for HTTP connections - import uvicorn - app = mcp.sse_app() - uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info") - - except Exception as e: - logger.error(f"Failed to start MCP server: {e}", exc_info=True) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/agent/mcp_server/tools/__init__.py b/api/app/core/memory/agent/mcp_server/tools/__init__.py deleted file mode 100644 index 5ce04ef3..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -MCP Tools module. - -This module contains all MCP tool implementations organized by functionality. - -Tools are organized into the following modules: -- problem_tools: Question segmentation and extension -- retrieval_tools: Database and context retrieval -- verification_tools: Data verification -- summary_tools: Summarization and summary retrieval -- data_tools: Data type differentiation and writing -""" - -# Import all tool modules to register them with the MCP server -from . import problem_tools -from . import retrieval_tools -from . import verification_tools -from . import summary_tools -from . import data_tools - -__all__ = [ - 'problem_tools', - 'retrieval_tools', - 'verification_tools', - 'summary_tools', - 'data_tools', -] diff --git a/api/app/core/memory/agent/mcp_server/tools/data_tools.py b/api/app/core/memory/agent/mcp_server/tools/data_tools.py deleted file mode 100644 index 631f7fd7..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/data_tools.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -Data Tools for data type differentiation and writing. - -This module contains MCP tools for distinguishing data types and writing data. -""" - -import os - -from app.core.logging_config import get_agent_logger -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.retrieval_models import ( - DistinguishTypeResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.write_tools import write -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Data_type_differentiation( - ctx: Context, - context: str, - memory_config: MemoryConfig, -) -> dict: - """ - Distinguish the type of data (read or write). - - Args: - ctx: FastMCP context for dependency injection - context: Text to analyze for type differentiation - memory_config: MemoryConfig object containing LLM configuration - - Returns: - dict: Contains 'context' with the original text and 'type' field - """ - try: - # Extract services from context - template_service = get_context_resource(ctx, 'template_service') - - # Get LLM client from memory_config using factory pattern - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='distinguish_types_prompt.jinja2', - operation_name='status_typle', - user_query=context - ) - except Exception as e: - logger.error( - f"Template rendering failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "type": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=DistinguishTypeResponse - ) - - result = structured.model_dump() - - # Add context to result - result["context"] = context - - return result - - except Exception as e: - logger.error( - f"LLM call failed for Data_type_differentiation: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": f"LLM call failed: {str(e)}" - } - - except Exception as e: - logger.error( - f"Data_type_differentiation failed: {e}", - exc_info=True - ) - return { - "context": context, - "type": "error", - "message": str(e) - } - - -@mcp.tool() -async def Data_write( - ctx: Context, - content: str, - user_id: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, -) -> dict: - """ - Write data to the database/file system. - - Args: - ctx: FastMCP context for dependency injection - content: Data content to write - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - - Returns: - dict: Contains 'status', 'saved_to', and 'data' fields - """ - try: - # Ensure output directory exists - os.makedirs("data_output", exist_ok=True) - file_path = os.path.join("data_output", "user_data.csv") - - # Write data - clients are constructed inside write() from memory_config - await write( - content=content, - user_id=user_id, - apply_id=apply_id, - group_id=group_id, - memory_config=memory_config, - ) - logger.info(f"Write completed successfully! Config: {memory_config.config_name}") - - return { - "status": "success", - "saved_to": file_path, - "data": content, - "config_id": memory_config.config_id, - "config_name": memory_config.config_name, - } - - except Exception as e: - logger.error(f"Data_write failed: {e}", exc_info=True) - return { - "status": "error", - "message": str(e), - } diff --git a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py b/api/app/core/memory/agent/mcp_server/tools/problem_tools.py deleted file mode 100644 index 49812e38..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/problem_tools.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -Problem Tools for question segmentation and extension. - -This module contains MCP tools for breaking down and extending user questions. -LLM clients are constructed from MemoryConfig when needed. -""" - -import json -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.problem_models import ( - ProblemBreakdownResponse, - ProblemExtensionResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Split_The_Problem( - ctx: Context, - sentence: str, - sessionid: str, - messages_id: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, -) -> dict: - """ - Segment the dialogue or sentence into sub-problems. - - Args: - ctx: FastMCP context for dependency injection - sentence: Original sentence to split - sessionid: Session identifier - messages_id: Message identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - - Returns: - dict: Contains 'context' (JSON string of split results) and 'original' sentence - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Extract user ID from session - user_id = session_service.resolve_user_id(sessionid) - - # Get conversation history - history = await session_service.get_history(user_id, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='problem_breakdown_prompt.jinja2', - operation_name='split_the_problem', - history=history, - sentence=sentence - ) - except Exception as e: - logger.error( - f"Template rendering failed for Split_The_Problem: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemBreakdownResponse - ) - - # Handle RootModel response with .root attribute access - if structured is None: - # LLM returned None, use empty list as fallback - split_result = json.dumps([], ensure_ascii=False) - elif hasattr(structured, 'root') and structured.root is not None: - split_result = json.dumps( - [item.model_dump() for item in structured.root], - ensure_ascii=False - ) - elif isinstance(structured, list): - # Fallback: treat structured itself as the list - split_result = json.dumps( - [item.model_dump() for item in structured], - ensure_ascii=False - ) - else: - # Last resort: use empty list - split_result = json.dumps([], ensure_ascii=False) - - except Exception as e: - logger.error( - f"LLM call failed for Split_The_Problem: {e}", - exc_info=True - ) - split_result = json.dumps([], ensure_ascii=False) - - logger.info("Problem splitting") - logger.info(f"Problem split result: {split_result}") - - # Emit intermediate output for frontend - result = { - "context": split_result, - "original": sentence, - "_intermediate": { - "type": "problem_split", - "data": json.loads(split_result) if split_result else [], - "original_query": sentence - } - } - - return result - - except Exception as e: - logger.error( - f"Split_The_Problem failed: {e}", - exc_info=True - ) - return { - "context": json.dumps([], ensure_ascii=False), - "original": sentence, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Problem splitting', duration) - - -@mcp.tool() -async def Problem_Extension( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Extend the problem with additional sub-questions. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing split problem results - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'context' (aggregated questions) and 'original' question - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID from usermessages - from app.core.memory.agent.utils.messages_tool import Resolve_username - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - history = [] - - # Process context to extract questions - extent_quest, original = await Problem_Extension_messages_deal(context) - - # Format questions for template rendering - questions_formatted = [] - for msg in extent_quest: - if msg.get("role") == "user": - questions_formatted.append(msg.get("content", "")) - - # Render template - try: - system_prompt = await template_service.render_template( - template_name='Problem_Extension_prompt.jinja2', - operation_name='problem_extension', - history=history, - questions=questions_formatted - ) - except Exception as e: - logger.error( - f"Template rendering failed for Problem_Extension: {e}", - exc_info=True - ) - return { - "context": {}, - "original": original, - "error": f"Prompt rendering failed: {str(e)}" - } - - # Call LLM with structured response - try: - response_content = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=ProblemExtensionResponse - ) - - # Aggregate results by original question - aggregated_dict = {} - for item in response_content.root: - key = getattr(item, "original_question", None) or ( - item.get("original_question") if isinstance(item, dict) else None - ) - value = getattr(item, "extended_question", None) or ( - item.get("extended_question") if isinstance(item, dict) else None - ) - if not key or not value: - continue - aggregated_dict.setdefault(key, []).append(value) - - except Exception as e: - logger.error( - f"LLM call failed for Problem_Extension: {e}", - exc_info=True - ) - aggregated_dict = {} - - logger.info("Problem extension") - logger.info(f"Problem extension result: {aggregated_dict}") - - # Emit intermediate output for frontend - result = { - "context": aggregated_dict, - "original": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "problem_extension", - "data": aggregated_dict, - "original_query": original, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - return result - - except Exception as e: - logger.error( - f"Problem_Extension failed: {e}", - exc_info=True - ) - return { - "context": {}, - "original": context.get("original", ""), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Problem extension', duration) diff --git a/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py b/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py deleted file mode 100644 index db18ba04..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/retrieval_tools.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -Retrieval Tools for database and context retrieval. - -This module contains MCP tools for retrieving data using hybrid search. -""" - -import os -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.llm_tools import ( - deduplicate_entries, - merge_to_key_value_pairs, -) -from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal -from app.core.rag.nlp.search import knowledge_retrieval -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from mcp.server.fastmcp import Context - -load_dotenv() -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Retrieve( - ctx: Context, - context, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Retrieve data from the database using hybrid search. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary or string containing query information - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'context' with Query and Expansion_issue results - """ - kb_config = { - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id": os.getenv('reranker_id'), - "reranker_top_k": 10 - } - start = time.time() - logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}") - - try: - # Extract services from context - search_service = get_context_resource(ctx, 'search_service') - - databases_anser = [] - - # Handle both dict and string context - if isinstance(context, dict): - # Process dict context with extended questions - all_items = [] - logger.info(f"Retrieve: context keys={list(context.keys())}") - content, original = await Retriev_messages_deal(context) - logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}") - logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'") - - if not original: - logger.warning(f"Retrieve: original query is empty! context={context}") - - # Extract all query items from content - # content is like {original_question: [extended_questions...], ...} - for key, values in content.items(): - if isinstance(values, list): - all_items.extend(values) - elif isinstance(values, str): - all_items.append(values) - elif values is not None: - # Fallback: convert non-empty non-list values to string - all_items.append(str(values)) - - # Execute search for each question - for idx, question in enumerate(all_items): - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": question, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query=question - raw_results=clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results='' - cleaned_query = question - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - - databases_anser.append({ - "Query_small": cleaned_query, - "Result_small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": idx + 1, - "total": len(all_items) - } - }) - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for question '{question}': {e}", - exc_info=True - ) - # Continue with empty result for this question - databases_anser.append({ - "Query_small": question, - "Result_small": "" - }) - - # Build initial database data structure - databases_data = { - "Query": original, - "Expansion_issue": databases_anser - } - - # Collect intermediate outputs before deduplication - intermediate_outputs = [] - for item in databases_anser: - if '_intermediate' in item: - intermediate_outputs.append(item['_intermediate']) - - # Deduplicate and merge results - deduplicated_data = deduplicate_entries(databases_data['Expansion_issue']) - deduplicated_data_merged = merge_to_key_value_pairs( - deduplicated_data, - 'Query_small', - 'Result_small' - ) - - # Restructure for Verify/Retrieve_Summary compatibility - keys, val = [], [] - for item in deduplicated_data_merged: - for items_key, items_value in item.items(): - keys.append(items_key) - val.append(items_value) - - send_verify = [] - for i, j in zip(keys, val, strict=False): - send_verify.append({ - "Query_small": i, - "Answer_Small": j - }) - - dup_databases = { - "Query": original, - "Expansion_issue": send_verify, - "_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs - } - - logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") - - else: - # Handle string context (simple query) - query = str(context).strip() - - try: - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - if storage_type == "rag" and user_rag_memory_id: - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - clean_content = '\n\n'.join(retrieval_knowledge) - cleaned_query = query - raw_results = clean_content - logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except: - clean_content = '' - raw_results = '' - cleaned_query = query - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - # Keep structure for Verify/Retrieve_Summary compatibility - dup_databases = { - "Query": cleaned_query, - "Expansion_issue": [{ - "Query_small": cleaned_query, - "Answer_Small": clean_content, - "_intermediate": { - "type": "search_result", - "query": cleaned_query, - "raw_results": raw_results, - "index": 1, - "total": 1 - } - }] - } - except Exception as e: - logger.error( - f"Retrieve: hybrid_search failed for query '{query}': {e}", - exc_info=True - ) - # Return empty results on failure - dup_databases = { - "Query": query, - "Expansion_issue": [] - } - - logger.info( - f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, " - f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}" - ) - - # Build result with intermediate outputs - result = { - "context": dup_databases, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - # Add intermediate outputs list if they exist - intermediate_outputs = dup_databases.get('_intermediate_outputs', []) - if intermediate_outputs: - result['_intermediates'] = intermediate_outputs - logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result") - else: - logger.warning("No intermediate outputs found in dup_databases") - - return result - - except Exception as e: - logger.error( - f"Retrieve failed: {e}", - exc_info=True - ) - return { - "context": { - "Query": "", - "Expansion_issue": [] - }, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval', duration) diff --git a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py b/api/app/core/memory/agent/mcp_server/tools/summary_tools.py deleted file mode 100644 index 0f306572..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/summary_tools.py +++ /dev/null @@ -1,640 +0,0 @@ -""" -Summary Tools for data summarization. - -This module contains MCP tools for summarizing retrieved data and generating responses. -LLM clients are constructed from MemoryConfig when needed. -""" - -import json -import os -import re -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.models.summary_models import ( - RetrieveSummaryResponse, - SummaryResponse, -) -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.messages_tool import ( - Resolve_username, - Summary_messages_deal, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.rag.nlp.search import knowledge_retrieval -from app.db import get_db_context -from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv -from mcp.server.fastmcp import Context - -load_dotenv() -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Summary( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Summarize the verified data. - - Args: - ctx: FastMCP context for dependency injection - context: JSON string containing verified data - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Process context to extract answer and query - answer_small, query = await Summary_messages_deal(context) - - - start_time= time.time() - history = await session_service.get_history(sessionid, apply_id, group_id) - end_time=time.time() - logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}") - data = { - "query": query, - "history": history, - "retrieve_info": answer_small - } - - except Exception as e: - logger.error( - f"Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='summary_prompt.jinja2', - operation_name='summary', - data=data, - query=query - ) - except Exception as e: - logger.error( - f"Template rendering failed for Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=SummaryResponse - ) - - aimessages = structured.query_answer or "" - - except Exception as e: - logger.error( - f"LLM call failed for Summary: {e}", - exc_info=True - ) - aimessages = "" - - try: - # Save session - if aimessages != "": - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}", - exc_info=True - ) - return { - "status": "error", - "message": str(e) - } - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"Summary after verification: {aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Summary', duration) - - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - -@mcp.tool() -async def Retrieve_Summary( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Summarize data directly from retrieval results. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing Query and Expansion_issue from Retrieve - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'summary_result' - """ - start = time.time() - - try: - # Extract services from context - template_service = get_context_resource(ctx, "template_service") - session_service = get_context_resource(ctx, "session_service") - - # Get LLM client from memory_config - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client_from_config(memory_config) - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - - - # Handle both 'content' and 'context' keys (LangGraph uses 'content') - logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}") - - if isinstance(context, dict): - if "content" in context: - inner = context["content"] - # If it's a JSON string, parse it - if isinstance(inner, str): - try: - parsed = json.loads(inner) - logger.info("Retrieve_Summary: successfully parsed JSON") - except json.JSONDecodeError: - # Try unescaping first - try: - unescaped = inner.encode('utf-8').decode('unicode_escape') - parsed = json.loads(unescaped) - logger.info("Retrieve_Summary: parsed after unescaping") - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error( - f"Retrieve_Summary: parsing failed even after unescape: {e}" - ) - context_dict = {"Query": "", "Expansion_issue": []} - parsed = None - - if parsed: - # Check if parsed has 'context' wrapper - if isinstance(parsed, dict) and "context" in parsed: - context_dict = parsed["context"] - else: - context_dict = parsed - elif isinstance(inner, dict): - context_dict = inner - else: - context_dict = {"Query": "", "Expansion_issue": []} - elif "context" in context: - context_dict = context["context"] if isinstance(context["context"], dict) else context - else: - context_dict = context - else: - context_dict = {"Query": "", "Expansion_issue": []} - - query = context_dict.get("Query", "") - expansion_issue = context_dict.get("Expansion_issue", []) - - logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}") - logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}") - - # Extract retrieve_info from expansion_issue - retrieve_info = [] - for item in expansion_issue: - # Check for both Answer_Small and Answer_Small (typo) for backward compatibility - answer = None - if isinstance(item, dict): - if "Answer_Small" in item: - answer = item["Answer_Small"] - - - if answer is not None: - # Handle both string and list formats - if isinstance(answer, list): - # Join list of characters/strings into a single string - retrieve_info.append(''.join(str(x) for x in answer)) - elif isinstance(answer, str): - retrieve_info.append(answer) - else: - retrieve_info.append(str(answer)) - - # Join all retrieve_info into a single string - retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else "" - - start_time=time.time() - history = await session_service.get_history(sessionid, apply_id, group_id) - # Override with empty list for now (as in original) - end_time=time.time() - logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}") - except Exception as e: - logger.error( - f"Retrieve_Summary: initialization failed: {e}", - exc_info=True - ) - return { - "status": "error", - "summary_result": "信息不足,无法回答" - } - - try: - # Render template - system_prompt = await template_service.render_template( - template_name='Retrieve_Summary_prompt.jinja2', - operation_name='retrieve_summary', - query=query, - history=history, - retrieve_info=retrieve_info_str - ) - except Exception as e: - logger.error( - f"Template rendering failed for Retrieve_Summary: {e}", - exc_info=True - ) - return { - "status": "error", - "message": f"Prompt rendering failed: {str(e)}" - } - - try: - # Call LLM with structured response - structured = await llm_client.response_structured( - messages=[{"role": "system", "content": system_prompt}], - response_model=RetrieveSummaryResponse - ) - - # Handle case where structured response might be None or incomplete - if structured and hasattr(structured, 'data') and structured.data: - aimessages = structured.data.query_answer or "" - else: - logger.warning("Structured response is None or incomplete, using default message") - aimessages = "信息不足,无法回答" - - - # Check for insufficient information response - if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="": - # Save session - await session_service.save_session( - user_id=sessionid, - query=query, - apply_id=apply_id, - group_id=group_id, - ai_response=aimessages - ) - logger.info(f"sessionid: {aimessages} 写入成功") - except Exception as e: - logger.error( - f"Retrieve_Summary: LLM call failed: {e}", - exc_info=True - ) - aimessages = "" - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - # Use fallback if empty - if aimessages == '': - aimessages = '信息不足,无法回答' - - logger.info(f"Summary after retrieval: {aimessages}") - - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval summary', duration) - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "retrieval_summary", - "summary": aimessages, - "query": query, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - -@mcp.tool() -async def Input_Summary( - ctx: Context, - context: str, - usermessages: str, - search_switch: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "", -) -> dict: - """ - Generate a quick summary for direct input without verification. - - Args: - ctx: FastMCP context for dependency injection - context: String containing the input sentence - usermessages: User messages identifier - search_switch: Search switch value for routing ('2' for summaries only) - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (e.g., 'rag', 'vector') - user_rag_memory_id: User RAG memory identifier - - Returns: - dict: Contains 'query_answer' with the summary result - """ - start = time.time() - logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - - try: - # Extract services from context - session_service = get_context_resource(ctx, "session_service") - search_service = get_context_resource(ctx, "search_service") - - # Resolve session ID - sessionid = Resolve_username(usermessages) or "" - sessionid = sessionid.replace('call_id_', '') - - start_time=time.time() - history = await session_service.get_history( - str(sessionid), - str(apply_id), - str(group_id) - ) - end_time=time.time() - logger.info(f"Input_Summary-REDIS搜索:{end_time - start_time}") - # Override with empty list for now (as in original) - - # Log the raw context for debugging - logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}") - - # Extract sentence from context - # Context can be a string or might contain the sentence in various formats - try: - # Try to parse as JSON first - if isinstance(context, str) and (context.startswith('{') or context.startswith('[')): - try: - import json - context_dict = json.loads(context) - if isinstance(context_dict, dict): - query = context_dict.get('sentence', context_dict.get('content', context)) - else: - query = context - except json.JSONDecodeError: - # Not valid JSON, try regex - match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context) - query = match.group(1) if match else context - else: - query = context - except Exception as e: - logger.warning(f"Failed to extract query from context: {e}") - query = context - - # Clean query - query = str(query).strip().strip("\"'") - - logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}") - - # Execute search based on search_switch and storage_type - try: - logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}") - - # Prepare search parameters based on storage type - search_params = { - "group_id": group_id, - "question": query, - "return_raw_results": True - } - - # Add storage-specific parameters - - # Retrieval - if search_switch == '2': - search_params["include"] = ["summaries"] - if storage_type == "rag" and user_rag_memory_id: - raw_results = [] - retrieve_info = "" - kb_config={ - "knowledge_bases": [ - { - "kb_id": user_rag_memory_id, - "similarity_threshold": 0.7, - "vector_similarity_weight": 0.5, - "top_k": 10, - "retrieve_type": "participle" - } - ], - "merge_strategy": "weight", - "reranker_id":os.getenv('reranker_id'), - "reranker_top_k": 10 - } - - retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)]) - try: - retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] - retrieve_info = '\n\n'.join(retrieval_knowledge) - raw_results=[retrieve_info] - logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}") - except: - retrieve_info='' - raw_results=[''] - logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - logger.info("Input_Summary: Using summary for retrieval") - else: - retrieve_info, question, raw_results = await search_service.execute_hybrid_search( - **search_params, memory_config=memory_config - ) - - except Exception as e: - logger.error( - f"Input_Summary: hybrid_search failed, using empty results: {e}", - exc_info=True - ) - retrieve_info, question, raw_results = "", query, [] - - # Return retrieved information directly without LLM processing - # Use the raw retrieved info as the answer - aimessages = retrieve_info if retrieve_info else "信息不足,无法回答" - - logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...") - - # Emit intermediate output for frontend - return { - "status": "success", - "summary_result": aimessages, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "input_summary", - "title": "快速答案", - "summary": aimessages, - "query": query, - "raw_results": raw_results, - "search_mode": "quick_search", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Input_Summary failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "信息不足,无法回答", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Retrieval', duration) - - -@mcp.tool() -async def Summary_fails( - ctx: Context, - context: str, - usermessages: str, - apply_id: str, - group_id: str, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Handle workflow failure when summary cannot be generated. - - Args: - ctx: FastMCP context for dependency injection - context: Failure context string - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'query_answer' with failure message - """ - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Parse session ID from usermessages - usermessages_parts = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages_parts[:-1]) - - # Cleanup duplicate sessions - await session_service.cleanup_duplicates() - - logger.info("没有相关数据") - logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}") - - return { - "status": "success", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - - except Exception as e: - logger.error( - f"Summary_fails failed: {e}", - exc_info=True - ) - return { - "status": "fail", - "summary_result": "没有相关数据", - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "error": str(e) - } diff --git a/api/app/core/memory/agent/mcp_server/tools/verification_tools.py b/api/app/core/memory/agent/mcp_server/tools/verification_tools.py deleted file mode 100644 index cb6af5bd..00000000 --- a/api/app/core/memory/agent/mcp_server/tools/verification_tools.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -Verification Tools for data verification. - -This module contains MCP tools for verifying retrieved data. -""" -import time - -from app.core.logging_config import get_agent_logger, log_time -from app.core.memory.agent.mcp_server.mcp_instance import mcp -from app.core.memory.agent.mcp_server.server import get_context_resource -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.messages_tool import ( - Resolve_username, - Retrieve_verify_tool_messages_deal, - Verify_messages_deal, -) -from app.core.memory.agent.utils.verify_tool import VerifyTool -from app.schemas.memory_config_schema import MemoryConfig -from jinja2 import Template -from mcp.server.fastmcp import Context - -logger = get_agent_logger(__name__) - - -@mcp.tool() -async def Verify( - ctx: Context, - context: dict, - usermessages: str, - apply_id: str, - group_id: str, - memory_config: MemoryConfig, - storage_type: str = "", - user_rag_memory_id: str = "" -) -> dict: - """ - Verify the retrieved data. - - Args: - ctx: FastMCP context for dependency injection - context: Dictionary containing query and expansion issues - usermessages: User messages identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration - storage_type: Storage type for the workspace (optional) - user_rag_memory_id: User RAG memory identifier (optional) - - Returns: - dict: Contains 'status' and 'verified_data' with verification results - """ - start = time.time() - - - try: - # Extract services from context - session_service = get_context_resource(ctx, 'session_service') - - # Load verification prompt template - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2' - - # Read template file directly (VerifyTool expects raw template content) - from app.core.memory.agent.utils.messages_tool import read_template_file - system_prompt = await read_template_file(file_path) - - - - # Resolve session ID - sessionid = Resolve_username(usermessages) - - # Get conversation history - history = await session_service.get_history(sessionid, apply_id, group_id) - - template = Template(system_prompt) - system_prompt = template.render(history=history, sentence=context) - - # Process context to extract query and results - Query_small, Result_small, query = await Verify_messages_deal(context) - - # Build query list for verification - query_list = [] - for query_small, anser in zip(Query_small, Result_small, strict=False): - query_list.append({ - 'Query_small': query_small, - 'Answer_Small': anser - }) - - messages = { - "Query": query, - "Expansion_issue": query_list - } - - - - # Call verification workflow with LLM model ID from memory_config - verify_tool = VerifyTool( - system_prompt=system_prompt, - verify_data=messages, - llm_model_id=str(memory_config.llm_model_id) - ) - verify_result = await verify_tool.verify() - - # Parse LLM verification result with error handling - try: - messages_deal = await Retrieve_verify_tool_messages_deal( - verify_result, - history, - query - ) - except Exception as e: - logger.error( - f"Retrieve_verify_tool_messages_deal parsing failed: {e}", - exc_info=True - ) - # Fallback to avoid 500 errors - messages_deal = { - "data": { - "query": query, - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": history, - } - - logger.info(f"Verification result: {messages_deal}") - - # Emit intermediate output for frontend - return { - "status": "success", - "verified_data": messages_deal, - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "_intermediate": { - "type": "verification", - "title": "Data Verification", - "result": messages_deal.get("split_result", "unknown"), - "reason": messages_deal.get("reason", ""), - "query": query, - "verified_count": len(query_list), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id - } - } - - except Exception as e: - logger.error( - f"Verify failed: {e}", - exc_info=True - ) - return { - "status": "error", - "message": str(e), - "storage_type": storage_type, - "user_rag_memory_id": user_rag_memory_id, - "verified_data": { - "data": { - "query": "", - "expansion_issue": [] - }, - "split_result": "failed", - "reason": str(e), - "history": [], - } - } - - finally: - # Log execution time - end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 - log_time('Verification', duration) diff --git a/api/app/core/memory/agent/mcp_server/models/__init__.py b/api/app/core/memory/agent/models/__init__.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/__init__.py rename to api/app/core/memory/agent/models/__init__.py diff --git a/api/app/core/memory/agent/mcp_server/models/problem_models.py b/api/app/core/memory/agent/models/problem_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/problem_models.py rename to api/app/core/memory/agent/models/problem_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/retrieval_models.py b/api/app/core/memory/agent/models/retrieval_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/retrieval_models.py rename to api/app/core/memory/agent/models/retrieval_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/summary_models.py b/api/app/core/memory/agent/models/summary_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/summary_models.py rename to api/app/core/memory/agent/models/summary_models.py diff --git a/api/app/core/memory/agent/mcp_server/models/verification_models.py b/api/app/core/memory/agent/models/verification_models.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/models/verification_models.py rename to api/app/core/memory/agent/models/verification_models.py diff --git a/api/app/core/memory/agent/multimodal/oss_picture.py b/api/app/core/memory/agent/multimodal/oss_picture.py deleted file mode 100644 index b5b4bd6b..00000000 --- a/api/app/core/memory/agent/multimodal/oss_picture.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -import sys -import traceback - -import requests - -# from qcloud_cos import CosConfig, CosS3Client -# from qcloud_cos.cos_exception import CosClientError, CosServiceError - -# from config.paths import BASE_DIR -BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0])) - -class OSSUploader: - """对象存储文件上传工具类""" - - def __init__(self, env): - api = { - "test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon", - "prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon" - } - self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon") - self.privacy = "false" - self.headers = { - "User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) ' - 'AppleWebKit/537.36 (KHTML, like Gecko)' - ' Chrome/133.0.6833.84 Safari/537.36' - } - - @staticmethod - def _generate_object_key(file_path, prefix='xhs_'): - """ - 生成对象存储的Key - - :param file_path: 本地文件路径 - :param prefix: 存储前缀,用于分类存储 - :return: 生成的对象Key - """ - # 文件md5值.后缀名 - filename = os.path.basename(file_path) - filename = f"{filename}" - - # 组合成完整的对象Key - return f"{prefix}{filename}" - - def upload_image(self, file_name, prefix='jd_'): - """ - 上传文件到COS并返回可访问的URL - - :param file_url: 文件路径 - :param file_name: 文件名称 - :param media_type: 文件类型 - :param prefix: 存储前缀,用于分类存储 - :return: 文件访问URL - """ - # 检查文件是否存在 - - - - file_path = os.path.join(BASE_DIR, file_name) - - # response = requests.get(url, headers=self.headers, stream=True) - - # if response.status_code == 200: - # with open(file_path, "wb") as f: - # for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大 - # f.write(chunk) - # else: - # raise Exception(f"文件下载失败,{file_name}") - - # 生成对象Key - object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1]) - - try: - upload_response = requests.post( - self.api, - data={ - "privacy": self.privacy, - "fileName": object_key, - } - ) - - if upload_response.status_code != 200: - raise Exception('上传接口请求失败') - resp = upload_response.json() - name = resp["data"]["name"] - file_url = resp["data"]["path"] - policy = resp["data"]["policy"] - with open(file_path, 'rb') as f: - oss_push_resp = requests.post( - policy["host"], - files={ - "key": policy["dir"], - "OSSAccessKeyId": policy["accessid"], - "name": name, - "policy": policy["policy"], - "success_action_status": 200, - "signature": policy["signature"], - "file": f, - } - ) - if oss_push_resp.status_code == 200: - return file_url - raise Exception("OSS上传失败") - except Exception: - raise Exception(f"上传失败: \n{traceback.format_exc()}") - finally: - print('success') - # os.remove(file_path) - - -if __name__ == '__main__': - cos_uploader = OSSUploader("prod") - url =cos_uploader.upload_image('./example01.jpg') - print(url) diff --git a/api/app/core/memory/agent/multimodal/speech_model.py b/api/app/core/memory/agent/multimodal/speech_model.py deleted file mode 100644 index 2df32dd0..00000000 --- a/api/app/core/memory/agent/multimodal/speech_model.py +++ /dev/null @@ -1,121 +0,0 @@ -import asyncio -import re - -from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_, picture_model_requests,Picture_recognize, Voice_recognize -from app.core.memory.agent.utils.messages_tool import read_template_file - -import requests -import json -import os -import time -# file_urls = [ -# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_female2.wav", -# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav", -# ] -class Vico_recognition: - def __init__(self,file_urls): - self.api_key='' - self.backend_model_name='' - self.api_base='' - self.file_urls=file_urls - - # 提交文件转写任务,包含待转写文件url列表 - async def submit_task(self) -> str: - self.api_key, self.backend_model_name, self.api_base =await Voice_recognize() - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "X-DashScope-Async": "enable", - } - data = { - "model": self.backend_model_name, - "input": {"file_urls": self.file_urls}, - "parameters": { - "channel_id": [0], - "vocabulary_id": "vocab-Xxxx", - }, - } - # 录音文件转写服务url - service_url = ( - "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription" - ) - response = requests.post( - service_url, headers=headers, data=json.dumps(data) - ) - - # 打印响应内容 - if response.status_code == 200: - return response.json()["output"]["task_id"] - else: - print("task failed!") - print(response.json()) - return None - - async def download_transcription_result(self, transcription_url): - """ - Args: - transcription_url (str): 转写结果文件URL - Returns: - dict: 转写结果内容 - """ - try: - response = requests.get(transcription_url) - response.raise_for_status() - return response.json() - except Exception as e: - print(f"下载转写结果失败: {e}") - return None - - # 循环查询任务状态直到成功 - async def wait_for_complete(self,task_id): - self.api_key, self.backend_model_name, self.api_base = await Voice_recognize() - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "X-DashScope-Async": "enable", - } - - pending = True - while pending: - # 查询任务状态服务url - service_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" - response = requests.post( - service_url, headers=headers - ) - if response.status_code == 200: - status = response.json()['output']['task_status'] - if status == 'SUCCEEDED': - print("task succeeded!") - pending = False - return response.json()['output']['results'] - elif status == 'RUNNING' or status == 'PENDING': - pass - else: - print("task failed!") - pending = False - else: - print("query failed!") - pending = False - time.sleep(0.1) - async def run(self): - self.api_key, self.backend_model_name, self.api_base = await Voice_recognize() - task_id=await self.submit_task() - result=await self.wait_for_complete(task_id) - result_context=[] - for i in result: - transcription_url=i['transcription_url'] - print(f"转写URL: {transcription_url}") - - # 下载并打印转写内容 - content = await self.download_transcription_result(transcription_url) - if content: - content=json.dumps(content, indent=2, ensure_ascii=False) - context=re.findall(r'"text": "(.*?)"', content) - result_context.append(context[0]) - result=''.join(result_context) - return (result) - - - - diff --git a/api/app/core/memory/agent/mcp_server/services/__init__.py b/api/app/core/memory/agent/services/__init__.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/services/__init__.py rename to api/app/core/memory/agent/services/__init__.py diff --git a/api/app/core/memory/agent/services/optimized_llm_service.py b/api/app/core/memory/agent/services/optimized_llm_service.py new file mode 100644 index 00000000..6942d421 --- /dev/null +++ b/api/app/core/memory/agent/services/optimized_llm_service.py @@ -0,0 +1,277 @@ +""" +优化的LLM服务类,用于压缩和统一LLM调用 +""" + +import asyncio +from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from app.core.logging_config import get_agent_logger +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.core.memory.llm_tools.openai_client import OpenAIClient + +T = TypeVar('T', bound=BaseModel) + +logger = get_agent_logger(__name__) + + +class OptimizedLLMService: + """ + 优化的LLM服务类,提供统一的LLM调用接口 + + 特性: + 1. 客户端复用 - 避免重复创建LLM客户端 + 2. 批量处理 - 支持并发处理多个请求 + 3. 错误处理 - 统一的错误处理和降级策略 + 4. 性能优化 - 缓存和连接池优化 + """ + + def __init__(self, db_session: Session): + self.db_session = db_session + self.client_factory = MemoryClientFactory(db_session) + self._client_cache: Dict[str, OpenAIClient] = {} + + def _get_cached_client(self, llm_model_id: str) -> OpenAIClient: + """获取缓存的LLM客户端,避免重复创建""" + if llm_model_id not in self._client_cache: + self._client_cache[llm_model_id] = self.client_factory.get_llm_client(llm_model_id) + return self._client_cache[llm_model_id] + + async def structured_response( + self, + llm_model_id: str, + system_prompt: str, + response_model: Type[T], + user_message: Optional[str] = None, + fallback_value: Optional[Any] = None + ) -> T: + """ + 统一的结构化响应接口 + + Args: + llm_model_id: LLM模型ID + system_prompt: 系统提示词 + response_model: 响应模型类 + user_message: 用户消息(可选) + fallback_value: 失败时的降级值 + + Returns: + 结构化响应对象 + """ + try: + llm_client = self._get_cached_client(llm_model_id) + + messages = [{"role": "system", "content": system_prompt}] + if user_message: + messages.append({"role": "user", "content": user_message}) + + logger.debug(f"LLM调用: model={llm_model_id}, prompt_length={len(system_prompt)}") + + structured = await llm_client.response_structured( + messages=messages, + response_model=response_model + ) + + if structured is None: + logger.warning(f"LLM返回None,使用降级值") + return self._create_fallback_response(response_model, fallback_value) + + return structured + + except Exception as e: + logger.error(f"结构化响应失败: {e}", exc_info=True) + return self._create_fallback_response(response_model, fallback_value) + + async def batch_structured_response( + self, + llm_model_id: str, + requests: List[Dict[str, Any]], + response_model: Type[T], + max_concurrent: int = 5 + ) -> List[T]: + """ + 批量处理结构化响应 + + Args: + llm_model_id: LLM模型ID + requests: 请求列表,每个请求包含system_prompt等参数 + response_model: 响应模型类 + max_concurrent: 最大并发数 + + Returns: + 结构化响应列表 + """ + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_single_request(request: Dict[str, Any]) -> T: + async with semaphore: + return await self.structured_response( + llm_model_id=llm_model_id, + system_prompt=request.get('system_prompt', ''), + response_model=response_model, + user_message=request.get('user_message'), + fallback_value=request.get('fallback_value') + ) + + tasks = [process_single_request(req) for req in requests] + return await asyncio.gather(*tasks) + + async def simple_response( + self, + llm_model_id: str, + system_prompt: str, + user_message: Optional[str] = None, + fallback_message: str = "信息不足,无法回答" + ) -> str: + """ + 简单的文本响应接口 + + Args: + llm_model_id: LLM模型ID + system_prompt: 系统提示词 + user_message: 用户消息(可选) + fallback_message: 失败时的降级消息 + + Returns: + 响应文本 + """ + try: + llm_client = self._get_cached_client(llm_model_id) + + messages = [{"role": "system", "content": system_prompt}] + if user_message: + messages.append({"role": "user", "content": user_message}) + + response = await llm_client.response(messages=messages) + + if not response or not response.strip(): + return fallback_message + + return response.strip() + + except Exception as e: + logger.error(f"简单响应失败: {e}", exc_info=True) + return fallback_message + + def _create_fallback_response(self, response_model: Type[T], fallback_value: Optional[Any]) -> T: + """创建降级响应""" + try: + if fallback_value is not None: + if isinstance(fallback_value, response_model): + return fallback_value + elif isinstance(fallback_value, dict): + return response_model(**fallback_value) + + # 尝试创建空的响应模型 + if hasattr(response_model, 'root'): + # RootModel类型 + return response_model([]) + else: + # 普通BaseModel类型 + return response_model() + + except Exception as e: + logger.error(f"创建降级响应失败: {e}") + # 最后的降级策略 + if hasattr(response_model, 'root'): + return response_model([]) + else: + return response_model() + + def clear_cache(self): + """清理客户端缓存""" + self._client_cache.clear() + + +class LLMServiceMixin: + """ + LLM服务混入类,为节点提供便捷的LLM调用方法 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._llm_service: Optional[OptimizedLLMService] = None + + def get_llm_service(self, db_session: Session) -> OptimizedLLMService: + """获取LLM服务实例""" + if self._llm_service is None: + self._llm_service = OptimizedLLMService(db_session) + return self._llm_service + + async def call_llm_structured( + self, + state: Dict[str, Any], + db_session: Session, + system_prompt: str, + response_model: Type[T], + user_message: Optional[str] = None, + fallback_value: Optional[Any] = None + ) -> T: + """ + 便捷的结构化LLM调用方法 + + Args: + state: 状态字典,包含memory_config + db_session: 数据库会话 + system_prompt: 系统提示词 + response_model: 响应模型类 + user_message: 用户消息(可选) + fallback_value: 失败时的降级值 + + Returns: + 结构化响应对象 + """ + memory_config = state.get('memory_config') + if not memory_config: + raise ValueError("State中缺少memory_config") + + llm_model_id = memory_config.llm_model_id + if not llm_model_id: + raise ValueError("Memory config中缺少llm_model_id") + + llm_service = self.get_llm_service(db_session) + return await llm_service.structured_response( + llm_model_id=llm_model_id, + system_prompt=system_prompt, + response_model=response_model, + user_message=user_message, + fallback_value=fallback_value + ) + + async def call_llm_simple( + self, + state: Dict[str, Any], + db_session: Session, + system_prompt: str, + user_message: Optional[str] = None, + fallback_message: str = "信息不足,无法回答" + ) -> str: + """ + 便捷的简单LLM调用方法 + + Args: + state: 状态字典,包含memory_config + db_session: 数据库会话 + system_prompt: 系统提示词 + user_message: 用户消息(可选) + fallback_message: 失败时的降级消息 + + Returns: + 响应文本 + """ + memory_config = state.get('memory_config') + if not memory_config: + raise ValueError("State中缺少memory_config") + + llm_model_id = memory_config.llm_model_id + if not llm_model_id: + raise ValueError("Memory config中缺少llm_model_id") + + llm_service = self.get_llm_service(db_session) + return await llm_service.simple_response( + llm_model_id=llm_model_id, + system_prompt=system_prompt, + user_message=user_message, + fallback_message=fallback_message + ) \ No newline at end of file diff --git a/api/app/core/memory/agent/mcp_server/services/parameter_builder.py b/api/app/core/memory/agent/services/parameter_builder.py similarity index 87% rename from api/app/core/memory/agent/mcp_server/services/parameter_builder.py rename to api/app/core/memory/agent/services/parameter_builder.py index d5305dc6..a58fcf1a 100644 --- a/api/app/core/memory/agent/mcp_server/services/parameter_builder.py +++ b/api/app/core/memory/agent/services/parameter_builder.py @@ -4,22 +4,19 @@ Parameter Builder for constructing tool call arguments. This service provides tool-specific parameter transformation logic to build correct arguments for each tool type. """ - from typing import Any, Dict, Optional - from app.core.logging_config import get_agent_logger -from app.schemas.memory_config_schema import MemoryConfig logger = get_agent_logger(__name__) class ParameterBuilder: """Service for building tool call arguments based on tool type.""" - + def __init__(self): """Initialize the parameter builder.""" logger.info("ParameterBuilder initialized") - + def build_tool_args( self, tool_name: str, @@ -28,9 +25,8 @@ class ParameterBuilder: search_switch: str, apply_id: str, group_id: str, - memory_config: MemoryConfig, storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, + user_rag_memory_id: Optional[str] = None ) -> Dict[str, Any]: """ Build tool arguments based on tool type. @@ -49,7 +45,6 @@ class ParameterBuilder: search_switch: Search routing parameter apply_id: Application identifier group_id: Group identifier - memory_config: MemoryConfig object containing all configuration storage_type: Storage type for the workspace (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) @@ -60,19 +55,18 @@ class ParameterBuilder: base_args = { "usermessages": tool_call_id, "apply_id": apply_id, - "group_id": group_id, - "memory_config": memory_config, + "group_id": group_id } - + # Always add storage_type and user_rag_memory_id (with defaults if None) base_args["storage_type"] = storage_type if storage_type is not None else "" base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else "" # Tool-specific argument construction - if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]: - # These tools expect dict context + if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']: + # Verify expects dict context return { - "context": content if isinstance(content, dict) else {"content": content}, + "context": content if isinstance(content, dict) else {}, **base_args } diff --git a/api/app/core/memory/agent/mcp_server/services/search_service.py b/api/app/core/memory/agent/services/search_service.py similarity index 75% rename from api/app/core/memory/agent/mcp_server/services/search_service.py rename to api/app/core/memory/agent/services/search_service.py index 47295f87..8a2e7cfe 100644 --- a/api/app/core/memory/agent/mcp_server/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -4,31 +4,21 @@ Search Service for executing hybrid search and processing results. This service provides clean search result processing with content extraction and deduplication. """ - -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import List, Tuple, Optional from app.core.logging_config import get_agent_logger from app.core.memory.src.search import run_hybrid_search from app.core.memory.utils.data.text_utils import escape_lucene_query -if TYPE_CHECKING: - from app.schemas.memory_config_schema import MemoryConfig logger = get_agent_logger(__name__) class SearchService: """Service for executing hybrid search and processing results.""" - - def __init__(self, memory_config: "MemoryConfig" = None): - """ - Initialize the search service. - - Args: - memory_config: Optional MemoryConfig for embedding model configuration. - If not provided, must be passed to execute_hybrid_search. - """ - self.memory_config = memory_config + + def __init__(self): + """Initialize the search service.""" logger.info("SearchService initialized") def extract_content_from_result(self, result: dict) -> str: @@ -103,49 +93,40 @@ class SearchService: self, group_id: str, question: str, - limit: int = 15, + limit: int = 5, search_type: str = "hybrid", include: Optional[List[str]] = None, - rerank_alpha: float = 0.6, - activation_boost_factor: float = 0.8, + rerank_alpha: float = 0.4, output_path: str = "search_results.json", return_raw_results: bool = False, - memory_config: "MemoryConfig" = None, + memory_config = None ) -> Tuple[str, str, Optional[dict]]: """ - Execute hybrid search with two-stage ranking. - - Stage 1: Filter by content relevance (BM25 + Embedding) - Stage 2: Rerank by activation values (ACTR) + Execute hybrid search and return clean content. Args: - group_id: Group identifier for filtering + group_id: Group identifier for filtering results question: Search query text - limit: Max results per category (default: 15) - search_type: "hybrid", "keyword", or "embedding" (default: "hybrid") - include: Result types (default: ["statements", "chunks", "entities", "summaries"]) - rerank_alpha: BM25 weight (default: 0.6) - activation_boost_factor: Activation impact on memory strength (default: 0.8) - output_path: JSON output path (default: "search_results.json") - return_raw_results: Return full metadata (default: False) - memory_config: MemoryConfig for embedding model + limit: Maximum number of results to return (default: 5) + search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") + include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"]) + rerank_alpha: Weight for BM25 scores in reranking (default: 0.4) + output_path: Path to save search results (default: "search_results.json") + return_raw_results: If True, also return the raw search results as third element (default: False) + memory_config: Memory configuration object (required) Returns: - Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results) + Tuple of (clean_content, cleaned_query, raw_results) + raw_results is None if return_raw_results=False """ if include is None: include = ["statements", "chunks", "entities", "summaries"] - - # Use provided memory_config or fall back to instance config - config = memory_config or self.memory_config - if not config: - raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search") - + # Clean query cleaned_query = self.clean_query(question) - + try: - # Execute search using memory_config + # Execute search answer = await run_hybrid_search( query_text=cleaned_query, search_type=search_type, @@ -153,9 +134,8 @@ class SearchService: limit=limit, include=include, output_path=output_path, - memory_config=config, - rerank_alpha=rerank_alpha, - activation_boost_factor=activation_boost_factor, + memory_config=memory_config, + rerank_alpha=rerank_alpha ) # Extract results based on search type and include parameter diff --git a/api/app/core/memory/agent/mcp_server/services/session_service.py b/api/app/core/memory/agent/services/session_service.py similarity index 100% rename from api/app/core/memory/agent/mcp_server/services/session_service.py rename to api/app/core/memory/agent/services/session_service.py diff --git a/api/app/core/memory/agent/mcp_server/services/template_service.py b/api/app/core/memory/agent/services/template_service.py similarity index 94% rename from api/app/core/memory/agent/mcp_server/services/template_service.py rename to api/app/core/memory/agent/services/template_service.py index 95223f0b..1bf86375 100644 --- a/api/app/core/memory/agent/mcp_server/services/template_service.py +++ b/api/app/core/memory/agent/services/template_service.py @@ -3,12 +3,22 @@ Template Service for loading and rendering Jinja2 templates. This service provides centralized template management with caching and error handling. """ + import os from functools import lru_cache -from typing import Optional -from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound -from app.core.logging_config import get_agent_logger, log_prompt_rendering +from jinja2 import ( + Environment, + FileSystemLoader, + Template, + TemplateNotFound, +) + +from app.core.logging_config import ( + get_agent_logger, + log_prompt_rendering, +) + logger = get_agent_logger(__name__) diff --git a/api/app/core/memory/agent/utils/__init__.py b/api/app/core/memory/agent/utils/__init__.py deleted file mode 100644 index 2b77e240..00000000 --- a/api/app/core/memory/agent/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Agent utilities.""" - -from app.core.memory.agent.utils.multimodal import MultimodalProcessor - -__all__ = [ - "MultimodalProcessor", -] diff --git a/api/app/core/memory/agent/utils/llm_client_pool.py b/api/app/core/memory/agent/utils/llm_client_pool.py new file mode 100644 index 00000000..fddd54f6 --- /dev/null +++ b/api/app/core/memory/agent/utils/llm_client_pool.py @@ -0,0 +1,56 @@ + +import asyncio +from typing import Dict, Optional +from app.core.memory.utils.llm.llm_utils import get_llm_client_fast +from app.db import get_db +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) + +class LLMClientPool: + """LLM客户端连接池""" + + def __init__(self, max_size: int = 5): + self.max_size = max_size + self.pools: Dict[str, asyncio.Queue] = {} + self.active_clients: Dict[str, int] = {} + + async def get_client(self, llm_model_id: str): + """获取LLM客户端""" + if llm_model_id not in self.pools: + self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size) + self.active_clients[llm_model_id] = 0 + + pool = self.pools[llm_model_id] + + try: + # 尝试从池中获取客户端 + client = pool.get_nowait() + logger.debug(f"从池中获取LLM客户端: {llm_model_id}") + return client + except asyncio.QueueEmpty: + # 池为空,创建新客户端 + if self.active_clients[llm_model_id] < self.max_size: + db_session = next(get_db()) + client = get_llm_client_fast(llm_model_id, db_session) + self.active_clients[llm_model_id] += 1 + logger.debug(f"创建新LLM客户端: {llm_model_id}") + return client + else: + # 等待可用客户端 + logger.debug(f"等待LLM客户端可用: {llm_model_id}") + return await pool.get() + + async def return_client(self, llm_model_id: str, client): + """归还LLM客户端到池中""" + if llm_model_id in self.pools: + try: + self.pools[llm_model_id].put_nowait(client) + logger.debug(f"归还LLM客户端到池: {llm_model_id}") + except asyncio.QueueFull: + # 池已满,丢弃客户端 + self.active_clients[llm_model_id] -= 1 + logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}") + +# 全局客户端池 +llm_client_pool = LLMClientPool() diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index ec22b628..8dd2f1d3 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -1,40 +1,12 @@ -import asyncio -import json -import logging import os from collections import defaultdict from typing import Annotated, TypedDict -from app.core.memory.agent.utils.messages_tool import read_template_file -from app.core.memory.utils.config.config_utils import ( - get_picture_config, - get_voice_config, -) - -# Removed global variable imports - use dependency injection instead -from dotenv import load_dotenv from langchain_core.messages import AnyMessage from langgraph.graph import add_messages -from openai import OpenAI PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -logger = logging.getLogger(__name__) -load_dotenv() - - -async def picture_model_requests(image_url): - ''' - - Args: - image_url: - Returns: - - ''' - file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 ' - system_prompt = await read_template_file(file_path) - result = await Picture_recognize(image_url,system_prompt) - return (result) class WriteState(TypedDict): ''' Langgrapg Writing TypedDict @@ -44,39 +16,69 @@ class WriteState(TypedDict): apply_id:str group_id:str errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] + memory_config: object + write_result: dict + data:str class ReadState(TypedDict): - ''' - Langgrapg READING TypedDict - name: - id:user id - loop_count:Traverse times - search_switch:type - config_id: configuration id for filtering results - errors: list of errors that occurred during workflow execution - ''' - messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息 - name: str - id: str - loop_count:int + """ + LangGraph 工作流状态定义 + + Attributes: + messages: 消息列表,支持自动追加 + loop_count: 遍历次数 + search_switch: 搜索类型开关 + group_id: 组标识 + config_id: 配置ID,用于过滤结果 + data: 从content_input_node传递的内容数据 + spit_data: 从Split_The_Problem传递的分解结果 + tool_calls: 工具调用请求列表 + tool_results: 工具执行结果列表 + memory_config: 内存配置对象 + """ + messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 + loop_count: int search_switch: str - user_id: str - apply_id: str group_id: str config_id: str - errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] - - + data: str # 新增字段用于传递内容 + spit_data: dict # 新增字段用于传递问题分解结果 + problem_extension:dict + storage_type: str + user_rag_memory_id: str + llm_id: str + embedding_id: str + memory_config: object # 新增字段用于传递内存配置对象 + retrieve:dict + RetrieveSummary: dict + InputSummary: dict + verify: dict + SummaryFails: dict + summary: dict class COUNTState: - ''' - The number of times the workflow dialogue retrieval content has no correct message recall traversal - ''' + """ + 工作流对话检索内容计数器 + + 用于记录工作流对话检索内容没有正确消息召回遍历的次数。 + """ + def __init__(self, limit: int = 5): + """ + 初始化计数器 + + Args: + limit: 最大计数限制,默认为5 + """ self.total: int = 0 # 当前累加值 self.limit: int = limit # 最大上限 - def add(self, value: int = 1): - """累加数字,如果达到上限就保持最大值""" + def add(self, value: int = 1) -> None: + """ + 累加数字,如果达到上限就保持最大值 + + Args: + value: 要累加的值,默认为1 + """ self.total += value print(f"[COUNTState] 当前值: {self.total}") if self.total >= self.limit: @@ -84,21 +86,19 @@ class COUNTState: self.total = self.limit # 达到上限不再增加 def get_total(self) -> int: - """获取当前累加值""" + """ + 获取当前累加值 + + Returns: + 当前累加值 + """ return self.total - def reset(self): + def reset(self) -> None: """手动重置累加值""" self.total = 0 print("[COUNTState] 已重置为 0") - -def merge_to_key_value_pairs(data, query_key, result_key): - grouped = defaultdict(list) - for item in data: - grouped[item[query_key]].append(item[result_key]) - return [{key: values} for key, values in grouped.items()] - def deduplicate_entries(entries): seen = set() deduped = [] @@ -109,70 +109,37 @@ def deduplicate_entries(entries): deduped.append(entry) return deduped +def merge_to_key_value_pairs(data, query_key, result_key): + grouped = defaultdict(list) + for item in data: + grouped[item[query_key]].append(item[result_key]) + return [{key: values} for key, values in grouped.items()] -async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str: +def convert_extended_question_to_question(data): """ - Updated to eliminate global variables in favor of explicit parameters. - + 递归地将数据中的 extended_question 字段转换为 question 字段 + Args: - image_path: Path to image file - PROMPT_TICKET_EXTRACTION: Extraction prompt - picture_model_name: Picture model name (required, no longer from global variables) + data: 要转换的数据(可能是字典、列表或其他类型) + + Returns: + 转换后的数据 """ - try: - model_config = get_picture_config(picture_model_name) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base=model_config['api_base'] - - logger.info(f"model_name: {backend_model_name}") - logger.info(f"api_key set: {'yes' if api_key else 'no'}") - logger.info(f"base_url: {model_config['api_base']}") - - client = OpenAI( - api_key=api_key, base_url=api_base, - ) - completion = client.chat.completions.create( - model=backend_model_name, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url":image_path, - }, - {"type": "text", - "text": PROMPT_TICKET_EXTRACTION} - ] - } - ]) - picture_text = completion.choices[0].message.content - picture_text = picture_text.replace('```json', '').replace('```', '') - picture_text = json.loads(picture_text) - return (picture_text['statement']) - -async def Voice_recognize(voice_model_name: str): - """ - Updated to eliminate global variables in favor of explicit parameters. - - Args: - voice_model_name: Voice model name (required, no longer from global variables) - """ - try: - model_config = get_voice_config(voice_model_name) - except Exception as e: - err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。" - logger.error(err) - return err - api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key - backend_model_name = model_config["llm_name"].split("/")[-1] - api_base = model_config['api_base'] - return api_key,backend_model_name,api_base - - + if isinstance(data, dict): + # 创建新字典来存储转换后的数据 + converted = {} + for key, value in data.items(): + if key == 'extended_question': + # 将 extended_question 转换为 question + converted['question'] = convert_extended_question_to_question(value) + else: + # 递归处理其他字段 + converted[key] = convert_extended_question_to_question(value) + return converted + elif isinstance(data, list): + # 递归处理列表中的每个元素 + return [convert_extended_question_to_question(item) for item in data] + else: + # 其他类型直接返回 + return data \ No newline at end of file diff --git a/api/app/core/memory/agent/utils/mcp_tools.py b/api/app/core/memory/agent/utils/mcp_tools.py deleted file mode 100644 index 7ede9843..00000000 --- a/api/app/core/memory/agent/utils/mcp_tools.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -from app.core.config import settings - -def get_mcp_server_config(): - """ - Get the MCP server configuration. - - Uses MCP_SERVER_URL environment variable if set (for Docker), - otherwise falls back to SERVER_IP and MCP_PORT (for local development). - """ - # Get MCP port from environment (default: 8081) - mcp_port = os.getenv("MCP_PORT", "8081") - - # In Docker: MCP_SERVER_URL=http://mcp-server:8081 - # In local dev: uses SERVER_IP (127.0.0.1 or localhost) - mcp_server_url = os.getenv("MCP_SERVER_URL") - - if mcp_server_url: - # Docker environment: use full URL from environment - base_url = mcp_server_url - else: - # Local development: build URL from SERVER_IP and MCP_PORT - base_url = f"http://{settings.SERVER_IP}:{mcp_port}" - - mcp_server_config = { - "data_flow": { - "url": f"{base_url}/sse", - "transport": "sse", - "timeout": 15000, - "sse_read_timeout": 15000, - } - } - return mcp_server_config diff --git a/api/app/core/memory/agent/utils/messages_tool.py b/api/app/core/memory/agent/utils/messages_tool.py deleted file mode 100644 index 769e795a..00000000 --- a/api/app/core/memory/agent/utils/messages_tool.py +++ /dev/null @@ -1,260 +0,0 @@ -import json -import logging -import re -from typing import Any, List - -from app.core.logging_config import get_agent_logger -from langchain_core.messages import AnyMessage - -logger = get_agent_logger(__name__) - - -def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]: - out = [] - for m in msgs: - if hasattr(m, "content"): - out.append({"role": "user", "content": getattr(m, "content", "")}) - elif isinstance(m, dict) and "role" in m and "content" in m: - out.append(m) - else: - out.append({"role": "user", "content": str(m)}) - return out - - -def _extract_content(resp: Any) -> str: - """Extract LLM content and sanitize to raw JSON/text. - - - Supports both object and dict response shapes. - - Removes leading role labels (e.g., "Assistant:"). - - Strips Markdown code fences like ```json ... ```. - - Attempts to isolate the first valid JSON array/object block when extra text is present. - """ - - def _to_text(r: Any) -> str: - try: - # 对象形式: resp.choices[0].message.content - if hasattr(r, "choices") and getattr(r, "choices", None): - msg = r.choices[0].message - if hasattr(msg, "content"): - return msg.content - if isinstance(msg, dict) and "content" in msg: - return msg["content"] - # 字典形式: resp["choices"][0]["message"]["content"] - if isinstance(r, dict): - return r.get("choices", [{}])[0].get("message", {}).get("content", "") - except Exception: - pass - return str(r) - - def _clean_text(text: str) -> str: - s = str(text).strip() - # 移除可能的角色前缀 - s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s) - # 提取 ```json ... ``` 代码块 - m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I) - if m: - s = m.group(1).strip() - # 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段 - if not (s.startswith("{") or s.startswith("[")): - left = s.find("[") - right = s.rfind("]") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - else: - left = s.find("{") - right = s.rfind("}") - if left != -1 and right != -1 and right > left: - s = s[left:right + 1].strip() - return s - - raw = _to_text(resp) - return _clean_text(raw) - -def Resolve_username(usermessages): - ''' - Extract username - Args: - usermessages: user name - - Returns: - - ''' - usermessages = usermessages.split('_')[1:] - sessionid = '_'.join(usermessages[:-1]) - return sessionid - - -# TODO: USE app.core.memory.src.utils.render_template instead -async def read_template_file(template_path: str) -> str: - """ - 读取模板文件 - - Args: - template_path: 模板文件路径 - - Returns: - 模板内容字符串 - - Note: - 建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能 - """ - try: - with open(template_path, "r", encoding="utf-8") as f: - return f.read() - except FileNotFoundError: - logger.error(f"模板文件未找到: {template_path}") - raise - except IOError as e: - logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True) - raise - - -async def Problem_Extension_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - extent_quest = [] - original = context.get('original', '') - messages = context.get('context', '') - - # Handle empty or non-string messages - if not messages: - return extent_quest, original - - if isinstance(messages, str): - try: - messages = json.loads(messages) - except json.JSONDecodeError: - # If JSON parsing fails, return empty list - return extent_quest, original - - if isinstance(messages, list): - for message in messages: - question = message.get('question', '') - type = message.get('type', '') - extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"}) - - return extent_quest, original - - -async def Retriev_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}") - - if isinstance(context, dict): - logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}") - if 'context' in context or 'original' in context: - content = context.get('context', {}) - original = context.get('original', '') - logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'") - return content, original - - # Return empty defaults if context is not a dict or doesn't have expected keys - logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults") - return {}, '' - -async def Verify_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - - query = context['context']['Query'] - Query_small_list = context['context']['Expansion_issue'] - Result_small = [] - Query_small = [] - for i in Query_small_list: - Result_small.append(i['Answer_Small'][0]) - Query_small.append(i['Query_small']) - return Query_small, Result_small, query - - -async def Summary_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - query = re.findall(r'"query": (.*?),', messages)[0] - query = query.replace('[', '').replace(']', '').strip() - matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages) - answer_small_texts = [] - for m in matches: - try: - parsed = json.loads(m) - for item in parsed: - answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', '')) - except Exception: - answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', '')) - - return answer_small_texts, query - - -async def VerifyTool_messages_deal(context): - ''' - Extract data - Args: - context: - Returns: - ''' - messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '') - content_messages = messages.split('"context":')[1].replace('""', '"') - messages = str(content_messages).split("name='Retrieve'")[0] - query = re.findall('"Query": "(.*?)"', messages)[0] - Query_small = re.findall('"Query_small": "(.*?)"', messages) - Result_small = re.findall('"Result_small": "(.*?)"', messages) - return Query_small, Result_small, query - - -async def Retrieve_Summary_messages_deal(context): - pass - - -async def Retrieve_verify_tool_messages_deal(context, history, query): - ''' - Extract data - Args: - context: - Returns: - ''' - results = [] - # 统一转为字符串,避免 None 或非字符串导致正则报错 - text = str(context) - blocks = re.findall(r'\{(.*?)\}', text, flags=re.S) - for block in blocks: - query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block) - answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block) - status = re.search(r'"status"\s*:\s*"([^"]*)"', block) - query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block) - - results.append({ - "query_small": query_small.group(1) if query_small else None, - "answer_small": answer_small.group(1) if answer_small else None, - # 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误 - "status": status.group(1) if status else "", - "query_answer": query_answer.group(1) if query_answer else None - }) - result = [] - for r in results: - # 统一按字符串判定状态,兼容大小写和缺失情况 - status_str = str(r.get('status', '')).strip().lower() - if status_str == 'false': - continue - else: - result.append(r) - split_result = 'failed' if not result else 'success' - result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "", - "history": history} - return result diff --git a/api/app/core/memory/agent/utils/messages_tools.py b/api/app/core/memory/agent/utils/messages_tools.py new file mode 100644 index 00000000..db95319f --- /dev/null +++ b/api/app/core/memory/agent/utils/messages_tools.py @@ -0,0 +1,194 @@ +from typing import List, Dict, Any +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) +async def read_template_file(template_path: str) -> str: + """ + 读取模板文件 + + Args: + template_path: 模板文件路径 + + Returns: + 模板内容字符串 + + Note: + 建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能 + """ + try: + with open(template_path, "r", encoding="utf-8") as f: + return f.read() + except FileNotFoundError: + logger.error(f"模板文件未找到: {template_path}") + raise + except IOError as e: + logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True) + raise + +def reorder_output_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 重新排序输出结果,将 retrieval_summary 类型的数据放到最后面 + + Args: + results: 原始输出结果列表 + + Returns: + 重新排序后的结果列表 + """ + retrieval_summaries = [] + other_results = [] + + # 分离 retrieval_summary 和其他类型的结果 + for result in results: + if 'summary' in result.get('type'): + retrieval_summaries.append(result) + else: + other_results.append(result) + + # 将 retrieval_summary 放到最后 + return other_results + retrieval_summaries + +def optimize_search_results(intermediate_outputs): + """ + 优化检索结果,合并多个搜索结果,过滤空结果,统一格式 + + Args: + intermediate_outputs: 原始的中间输出列表 + + Returns: + 优化后的检索结果列表 + """ + optimized_results = [] + + for item in intermediate_outputs: + if not item or item == [] or item == {}: + continue + + # 检查是否是搜索结果类型 + if isinstance(item, dict) and item.get('type') == 'search_result': + raw_results = item.get('raw_results', {}) + + # 如果 raw_results 为空,跳过 + if not raw_results or raw_results == [] or raw_results == {}: + continue + + # 创建优化后的结果结构 + optimized_item = { + "type": "search_result", + "title": f"检索结果 ({item.get('index', 1)}/{item.get('total', 1)})", + "query": item.get('query', ''), + "raw_results": {}, + "index": item.get('index', 1), + "total": item.get('total', 1) + } + + # 合并所有搜索结果类型到一个 raw_results 中 + merged_raw_results = {} + + # 处理 time_search + if 'time_search' in raw_results and raw_results['time_search']: + merged_raw_results['time_search'] = raw_results['time_search'] + + # 处理 keyword_search + if 'keyword_search' in raw_results and raw_results['keyword_search']: + merged_raw_results['keyword_search'] = raw_results['keyword_search'] + + # 处理 embedding_search + if 'embedding_search' in raw_results and raw_results['embedding_search']: + merged_raw_results['embedding_search'] = raw_results['embedding_search'] + + # 处理 combined_summary + if 'combined_summary' in raw_results and raw_results['combined_summary']: + merged_raw_results['combined_summary'] = raw_results['combined_summary'] + + # 处理 reranked_results + if 'reranked_results' in raw_results and raw_results['reranked_results']: + merged_raw_results['reranked_results'] = raw_results['reranked_results'] + + # 如果合并后的结果不为空,添加到优化结果中 + if merged_raw_results: + optimized_item['raw_results'] = merged_raw_results + optimized_results.append(optimized_item) + else: + # 非搜索结果类型,直接添加 + optimized_results.append(item) + + return optimized_results + + +def merge_multiple_search_results(intermediate_outputs): + """ + 将多个搜索结果合并为一个统一的搜索结果 + + Args: + intermediate_outputs: 原始的中间输出列表 + + Returns: + 合并后的结果列表 + """ + search_results = [] + other_results = [] + + # 分离搜索结果和其他结果 + for item in intermediate_outputs: + if isinstance(item, dict) and item.get('type') == 'search_result': + raw_results = item.get('raw_results', {}) + # 只保留有内容的搜索结果 + if raw_results and raw_results != [] and raw_results != {}: + search_results.append(item) + else: + other_results.append(item) + + # 如果没有搜索结果,返回原始结果 + if not search_results: + return intermediate_outputs + + # 如果只有一个搜索结果,优化格式后返回 + if len(search_results) == 1: + optimized = optimize_search_results(search_results) + return other_results + optimized + + # 合并多个搜索结果 + merged_raw_results = {} + all_queries = [] + + for result in search_results: + query = result.get('query', '') + if query: + all_queries.append(query) + + raw_results = result.get('raw_results', {}) + + # 合并各种搜索类型的结果 + for search_type in ['time_search', 'keyword_search', 'embedding_search', 'combined_summary', + 'reranked_results']: + if search_type in raw_results and raw_results[search_type]: + if search_type not in merged_raw_results: + merged_raw_results[search_type] = raw_results[search_type] + else: + # 如果是字典类型,需要合并 + if isinstance(raw_results[search_type], dict) and isinstance(merged_raw_results[search_type], dict): + for key, value in raw_results[search_type].items(): + if key not in merged_raw_results[search_type]: + merged_raw_results[search_type][key] = value + elif isinstance(value, list) and isinstance(merged_raw_results[search_type][key], list): + merged_raw_results[search_type][key].extend(value) + elif isinstance(raw_results[search_type], list): + if isinstance(merged_raw_results[search_type], list): + merged_raw_results[search_type].extend(raw_results[search_type]) + else: + merged_raw_results[search_type] = raw_results[search_type] + + # 创建合并后的结果 + if merged_raw_results: + merged_result = { + "type": "search_result", + "title": f"合并检索结果 (共{len(search_results)}个查询)", + "query": " | ".join(all_queries), + "raw_results": merged_raw_results, + "index": 1, + "total": 1 + } + return other_results + [merged_result] + + return other_results diff --git a/api/app/core/memory/agent/utils/model_tool.py b/api/app/core/memory/agent/utils/model_tool.py deleted file mode 100644 index 969a2a91..00000000 --- a/api/app/core/memory/agent/utils/model_tool.py +++ /dev/null @@ -1,38 +0,0 @@ - - -# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -# sys.path.insert(0, project_root) - -# load_dotenv() - -# async def llm_client_chat(messages: List[dict]) -> str: -# """使用 OpenAI 兼容接口进行对话,返回内容字符串。""" -# try: -# cfg = get_model_config(SELECTED_LLM_ID) -# rb_config = RedBearModelConfig( -# model_name=cfg["model_name"], -# provider=cfg["provider"], -# api_key=cfg["api_key"], -# base_url=cfg["base_url"], -# ) -# client = OpenAIClient(model_config=rb_config, type_="chat") - -# except Exception as e: -# logger.error(f"获取模型配置失败:{e}") -# err = f"获取模型配置失败:{str(e)}。请检查!!!" -# return err -# try: -# response = await client.chat(messages) -# print(f"model_tool's llm_client_chat response ======>:\n {response}") -# return _extract_content(response) -# # return _extract_content(result) -# except Exception as e: -# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。") -# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。" - -# async def main(image_url): -# await llm_client_chat(image_url) -# -# # 运行主函数 -# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav'])) -# diff --git a/api/app/core/memory/agent/utils/multimodal.py b/api/app/core/memory/agent/utils/multimodal.py deleted file mode 100644 index 0fc52634..00000000 --- a/api/app/core/memory/agent/utils/multimodal.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Multimodal input processor for handling image and audio content. - -This module provides utilities for detecting and processing multimodal inputs -(images and audio files) by converting them to text using appropriate models. -""" - -import logging -from typing import List - -from app.core.memory.agent.multimodal.speech_model import Vico_recognition -from app.core.memory.agent.utils.llm_tools import picture_model_requests - -logger = logging.getLogger(__name__) - - -class MultimodalProcessor: - """ - Processor for handling multimodal inputs (images and audio). - - This class detects image and audio file paths in input content and converts - them to text using appropriate recognition models. - """ - - # Supported file extensions - IMAGE_EXTENSIONS = ['.jpg', '.png'] - AUDIO_EXTENSIONS = [ - 'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov', - 'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv' - ] - - def __init__(self): - """Initialize the multimodal processor.""" - pass - - def is_image(self, content: str) -> bool: - """ - Check if content is an image file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported image extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_image("photo.jpg") - True - >>> processor.is_image("document.pdf") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS) - - def is_audio(self, content: str) -> bool: - """ - Check if content is an audio file path. - - Args: - content: Input string to check - - Returns: - True if content ends with a supported audio extension - - Examples: - >>> processor = MultimodalProcessor() - >>> processor.is_audio("recording.mp3") - True - >>> processor.is_audio("video.mp4") - True - >>> processor.is_audio("document.txt") - False - """ - if not isinstance(content, str): - return False - - content_lower = content.lower() - return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS) - - async def process_input(self, content: str) -> str: - """ - Process input content, converting images/audio to text if needed. - - This method detects if the input is an image or audio file and converts - it to text using the appropriate recognition model. If processing fails - or the content is not multimodal, it returns the original content. - - Args: - content: Input string (may be file path or regular text) - - Returns: - Text content (original or converted from image/audio) - - Examples: - >>> processor = MultimodalProcessor() - >>> await processor.process_input("photo.jpg") - "Recognized text from image..." - - >>> await processor.process_input("Hello world") - "Hello world" - """ - if not isinstance(content, str): - logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}") - return str(content) - - try: - # Check for image input - if self.is_image(content): - logger.info(f"[MultimodalProcessor] Detected image input: {content}") - result = await picture_model_requests(content) - logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...") - return result - - # Check for audio input - if self.is_audio(content): - logger.info(f"[MultimodalProcessor] Detected audio input: {content}") - result = await Vico_recognition([content]).run() - logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...") - return result - - except Exception as e: - logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True) - logger.info("[MultimodalProcessor] Falling back to original content") - return content - - # Return original content if not multimodal - return content diff --git a/api/app/core/memory/agent/utils/performance_monitor.py b/api/app/core/memory/agent/utils/performance_monitor.py new file mode 100644 index 00000000..d2d9fdfa --- /dev/null +++ b/api/app/core/memory/agent/utils/performance_monitor.py @@ -0,0 +1,56 @@ + +import time +import json +from collections import defaultdict +from typing import Dict, List +from app.core.logging_config import get_agent_logger + +logger = get_agent_logger(__name__) + +class ProblemExtensionMonitor: + """Problem_Extension性能监控器""" + + def __init__(self): + self.metrics = defaultdict(list) + self.slow_queries = [] + self.error_count = 0 + + def record_execution(self, duration: float, question_count: int, success: bool): + """记录执行指标""" + self.metrics['durations'].append(duration) + self.metrics['question_counts'].append(question_count) + + if not success: + self.error_count += 1 + + # 记录慢查询(超过10秒) + if duration > 10.0: + self.slow_queries.append({ + 'duration': duration, + 'question_count': question_count, + 'timestamp': time.time() + }) + + def get_stats(self) -> Dict: + """获取统计信息""" + durations = self.metrics['durations'] + if not durations: + return {"message": "暂无数据"} + + return { + "total_executions": len(durations), + "avg_duration": sum(durations) / len(durations), + "max_duration": max(durations), + "min_duration": min(durations), + "slow_queries_count": len(self.slow_queries), + "error_rate": self.error_count / len(durations) if durations else 0, + "recent_slow_queries": self.slow_queries[-5:] # 最近5个慢查询 + } + + def log_stats(self): + """记录统计信息到日志""" + stats = self.get_stats() + logger.info(f"Problem_Extension性能统计: {json.dumps(stats, indent=2)}") + +# 全局监控器实例 +performance_monitor = ProblemExtensionMonitor() diff --git a/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 new file mode 100644 index 00000000..a0e21fbd --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/Problem_Extension_prompt_simplified.jinja2 @@ -0,0 +1,81 @@ + +你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则: + +角色: +- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。 +- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。 +- 如果历史信息或上下文与当前问题无关,可忽略。 + +--- + +### 历史信息参考 +在生成扩展问题时,你可以参考以下历史数据(如果提供): +- 历史对话或任务的主题; +- 历史中出现的关键实体(时间、人物、地点、研究主题等); +- 历史中已解答的问题(避免重复); +- 历史推理链(保持逻辑一致性)。 + +> 如果没有提供历史信息,则仅根据当前输入问题进行分析。 +输入历史信息内容:{{history}} + +## User Input +{% if questions is string %} +{{ questions }} +{% else %} +{% for question in questions %} +- {{ question }} +{% endfor %} +{% endif %} + +需求: +- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。 +- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。 +- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。 +- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。 +- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。 +- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。 +- 子问题数量不超过4个。 +- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。 + 比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}] + 拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么? + + + +输出要求: +- 仅输出 JSON 数组,不要包含任何解释或代码块。 +- 每个元素包含: + - `original_question`: 原始问题 + - `extended_question`: 扩展后的问题 + - `type`: 类型(事实检索/澄清/定义/比较/行动建议) + - `reason`: 生成该扩展问题的简短理由 +- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。 + +示例: +输入: +[ + "问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳", +] + +输出: +[ + { + "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", + "extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?", + "type": "多跳", + "reason": "输出原问题的关键要素" + }, + { + "original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?", + "extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?", + "type": "多跳", + "reason": "输出原问题的关键要素" + } +] +**Output format** +**CRITICAL JSON FORMATTING REQUIREMENTS:** +1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes +2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\") +3. Ensure all JSON strings are properly closed and comma-separated +4. Do not include line breaks within JSON string values + +The output language should always be the same as the input language.{{ json_schema }} diff --git a/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 index 1fa71df3..5fbe8574 100644 --- a/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/Retrieve_Summary_prompt.jinja2 @@ -1,13 +1,10 @@ # 角色 你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。 - # 任务 根据提供的上下文信息回答用户的问题。 - # 输入信息 - 历史对话:{{history}} - 检索信息:{{retrieve_info}} - ## User Query {{query}} diff --git a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 index f4d4665c..d6ad8cab 100644 --- a/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 +++ b/api/app/core/memory/agent/utils/prompt/split_verify_prompt.jinja2 @@ -9,8 +9,8 @@ 3. 判断Answer_Small和Query_Small之间分析出来的关系状态 4. 如果是True保留,否则不要相对应的问题和回答 5. 输出,需要严格按照模版 -输入:{{history}} -历史消息:{"history":{{sentence}}} +输入:{{sentence}} +历史消息:{"history":{{history}}} ### 第一步 获取用户的输入 获取用户的输入提取对应的Query_Small和Answer_Small ### 第二步 分析验证 diff --git a/api/app/core/memory/agent/utils/session_tools.py b/api/app/core/memory/agent/utils/session_tools.py new file mode 100644 index 00000000..b2d4f0ff --- /dev/null +++ b/api/app/core/memory/agent/utils/session_tools.py @@ -0,0 +1,169 @@ +""" +Session Service for managing user sessions and conversation history. + +This service provides clean Redis interactions with error handling and +session management utilities. +""" +from typing import List, Optional + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.redis_tool import RedisSessionStore + + +logger = get_agent_logger(__name__) + + +class SessionService: + """Service for managing user sessions and conversation history.""" + + def __init__(self, store: RedisSessionStore): + """ + Initialize the session service. + + Args: + store: Redis session store instance + """ + self.store = store + logger.info("SessionService initialized") + + def resolve_user_id(self, session_string: str) -> str: + """ + Extract user ID from session string. + + Handles formats like: + - 'call_id_user123' -> 'user123' + - 'prefix_id_user456_suffix' -> 'user456_suffix' + + Args: + session_string: Session identifier string + + Returns: + Extracted user ID + """ + try: + # Split by '_id_' and take everything after it + parts = session_string.split('_id_') + if len(parts) > 1: + return parts[1] + + # Fallback: return original string + return session_string + + except Exception as e: + logger.warning( + f"Failed to parse user ID from session string '{session_string}': {e}" + ) + return session_string + + async def get_history( + self, + user_id: str, + apply_id: str, + group_id: str + ) -> List[dict]: + """ + Retrieve conversation history from Redis. + + Args: + user_id: User identifier + apply_id: Application identifier + group_id: Group identifier + + Returns: + List of conversation history items with Query and Answer keys + Returns empty list if no history found or on error + """ + try: + history = self.store.find_user_apply_group(user_id, apply_id, group_id) + + # Validate history structure + if not isinstance(history, list): + logger.warning( + f"Invalid history format for user {user_id}, " + f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" + ) + return [] + + return history + + except Exception as e: + logger.error( + f"Failed to retrieve history for user {user_id}, " + f"apply {apply_id}, group {group_id}: {e}", + exc_info=True + ) + # Return empty list on error to allow execution to continue + return [] + + async def save_session( + self, + user_id: str, + query: str, + apply_id: str, + group_id: str, + ai_response: str + ) -> Optional[str]: + """ + Save conversation turn to Redis. + + Args: + user_id: User identifier + query: User query/message + apply_id: Application identifier + group_id: Group identifier + ai_response: AI response/answer + + Returns: + Session ID if successful, None on error + """ + try: + # Validate required fields + if not user_id: + logger.warning("Cannot save session: user_id is empty") + return None + + if not query: + logger.warning("Cannot save session: query is empty") + return None + + # Save session + session_id = self.store.save_session( + userid=user_id, + messages=query, + apply_id=apply_id, + group_id=group_id, + aimessages=ai_response + ) + + logger.info(f"Session saved successfully: {session_id}") + return session_id + + except Exception as e: + logger.error( + f"Failed to save session for user {user_id}: {e}", + exc_info=True + ) + return None + + async def cleanup_duplicates(self) -> int: + """ + Remove duplicate session entries. + + Duplicates are identified by matching: + - sessionid + - user_id (id field) + - group_id + - messages + - aimessages + + Returns: + Number of duplicate sessions deleted + """ + try: + deleted_count = self.store.delete_duplicate_sessions() + logger.info(f"Cleaned up {deleted_count} duplicate sessions") + return deleted_count + + except Exception as e: + logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True) + return 0 diff --git a/api/app/core/memory/agent/utils/template_tools.py b/api/app/core/memory/agent/utils/template_tools.py new file mode 100644 index 00000000..854c5383 --- /dev/null +++ b/api/app/core/memory/agent/utils/template_tools.py @@ -0,0 +1,117 @@ +""" +Template Service for loading and rendering Jinja2 templates. + +This service provides centralized template management with caching and error handling. +""" +# 标准库 +import os +from functools import lru_cache + +from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound + +from app.core.logging_config import get_agent_logger, log_prompt_rendering + + +logger = get_agent_logger(__name__) + + +class TemplateRenderError(Exception): + """Exception raised when template rendering fails.""" + + def __init__(self, template_name: str, error: Exception, variables: dict): + self.template_name = template_name + self.error = error + self.variables = variables + super().__init__( + f"Failed to render template '{template_name}': {str(error)}" + ) + + +class TemplateService: + """Service for loading and rendering Jinja2 templates with caching.""" + + def __init__(self, template_root: str): + """ + Initialize the template service. + + Args: + template_root: Root directory containing template files + """ + self.template_root = template_root + self.env = Environment( + loader=FileSystemLoader(template_root), + autoescape=False # Disable autoescape for prompt templates + ) + logger.info(f"TemplateService initialized with root: {template_root}") + + @lru_cache(maxsize=128) + def _load_template(self, template_name: str) -> Template: + """ + Load a template from disk with caching. + + Args: + template_name: Relative path to template file + + Returns: + Loaded Jinja2 Template object + + Raises: + TemplateNotFound: If template file doesn't exist + """ + try: + return self.env.get_template(template_name) + except TemplateNotFound as e: + expected_path = os.path.join(self.template_root, template_name) + logger.error( + f"Template not found: {template_name}. " + f"Expected path: {expected_path}" + ) + raise + + async def render_template( + self, + template_name: str, + operation_name: str, + **variables + ) -> str: + """ + Load and render a Jinja2 template. + + Args: + template_name: Relative path to template file + operation_name: Name for logging (e.g., "split_the_problem") + **variables: Template variables to render + + Returns: + Rendered template string + + Raises: + TemplateRenderError: If template loading or rendering fails + """ + try: + # Load template (cached) + template = self._load_template(template_name) + + # Render template + rendered = template.render(**variables) + + # Log rendered prompt + log_prompt_rendering(operation_name, rendered) + + return rendered + + except TemplateNotFound as e: + logger.error( + f"Template rendering failed for {operation_name} " + f"({template_name}): Template not found", + exc_info=True + ) + raise TemplateRenderError(template_name, e, variables) + + except Exception as e: + logger.error( + f"Template rendering failed for {operation_name} " + f"({template_name}): {e}", + exc_info=True + ) + raise TemplateRenderError(template_name, e, variables) diff --git a/api/app/core/memory/agent/utils/type_classifier.py b/api/app/core/memory/agent/utils/type_classifier.py index 3e5358bd..f1df6f04 100644 --- a/api/app/core/memory/agent/utils/type_classifier.py +++ b/api/app/core/memory/agent/utils/type_classifier.py @@ -1,10 +1,9 @@ """ Type classification utility for distinguishing read/write operations. """ -from app.core.config import settings from app.core.logging_config import get_agent_logger, log_prompt_rendering from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.messages_tool import read_template_file +from app.core.memory.agent.utils.messages_tools import read_template_file from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from jinja2 import Template diff --git a/api/app/core/memory/agent/utils/write_to_database.py b/api/app/core/memory/agent/utils/write_to_database.py deleted file mode 100644 index bd78fe9d..00000000 --- a/api/app/core/memory/agent/utils/write_to_database.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import uuid -from datetime import datetime -from typing import Any -from sqlalchemy.orm import Session -import logging -import json - -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo - -logger = logging.getLogger(__name__) - -async def write_to_database(host_id: uuid.UUID, data: Any) -> str: - """ - 将数据写入数据库 - :param host_id: 宿主 ID - :param data: 要写入的数据 - :return: 写入数据库的结果 - """ - # 从数据库会话中获取会话 - db: Session = next(get_db()) - try: - if isinstance(data, (dict, list)): - serialized = json.dumps(data, ensure_ascii=False) - elif isinstance(data, str): - serialized = data - else: - serialized = str(data) - - new_retrieval_info = RetrievalInfo( - # host_id=host_id, - host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"), - retrieve_info=serialized, - created_at=datetime.now() - ) - db.add(new_retrieval_info) - db.commit() - logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}") - return "success to write data to database" - except Exception as e: - db.rollback() - logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}") - raise e - finally: - try: - db.close() - except Exception: - pass diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index f09b35e8..53c941ad 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -7,14 +7,12 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally. import time from datetime import datetime +from dotenv import load_dotenv + from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - ExtractionOrchestrator, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - memory_summary_generation, -) +from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context @@ -23,7 +21,7 @@ from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig -from dotenv import load_dotenv + load_dotenv() diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index f0756764..c9230a26 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -9,30 +9,27 @@ import os import re import time import uuid - from typing import Any, AsyncGenerator, Dict, List, Optional - import redis +from langchain_core.messages import HumanMessage + from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config +from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) -from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_mcp_adapters.tools import load_mcp_tools from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session @@ -50,21 +47,17 @@ _neo4j_connector = Neo4jConnector() class MemoryAgentService: """Service for memory agent operations""" - - - def writer_messages_deal(self,messages,start_time,group_id,config_id,message): - messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') - countext = re.findall(r'"status": "(.*?)",', messages)[0] + def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context): duration = time.time() - start_time - if countext == 'success': + if str(messages) == 'success': logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, duration=duration, details={"message_length": len(message)}) - return countext + return context else: logger.warning(f"Write operation failed for group {group_id}") @@ -80,9 +73,9 @@ class MemoryAgentService: ) raise ValueError(f"写入失败: {messages}") - - + + def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" last_message = event["messages"][-1] @@ -119,15 +112,15 @@ class MemoryAgentService: return True return False - + async def get_health_status(self) -> Dict: """ Get latest health status from Redis cache - + Returns health status information written by Celery periodic task """ logger.info("Checking health status") - + client = redis.Redis( host=settings.REDIS_HOST, port=settings.REDIS_PORT, @@ -135,34 +128,51 @@ class MemoryAgentService: password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None ) payload = client.hgetall("memsci:health:read_service") or {} - + if payload: # decode bytes to str decoded = {k.decode("utf-8"): v.decode("utf-8") for k, v in payload.items()} status = decoded.get("status", "unknown") else: status = "unknown" - + + # Add database connection pool status + try: + from app.db import get_pool_status + pool_status = get_pool_status() + logger.info(f"Database pool status: {pool_status}") + + # Check if pool usage is too high + if pool_status.get("usage_percent", 0) > 80: + logger.warning(f"High database pool usage: {pool_status['usage_percent']}%") + status = "warning" + + except Exception as e: + logger.error(f"Failed to get pool status: {e}") + pool_status = {"error": str(e)} + logger.info(f"Health status: {status}") - return {"status": status} + return { + "status": status, + "database_pool": pool_status + } def get_log_content(self) -> str: """ Read and return agent service log file content - - Returns cleaned log content using the same cleaning logic as transmission mode + + Returns cleaned log content using the same cleaning logic as transmission mode Returns cleaned log content using the same cleaning logic as transmission mode """ logger.info("Reading log file") - # Use project root directory for logs - # Get the project root (redbear-mem directory) + current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = os.path.dirname(app_dir) # redbear-mem directory log_path = os.path.join(project_root, "logs", "agent_service.log") - + summer = '' with open(log_path, "r", encoding="utf-8") as infile: @@ -176,83 +186,83 @@ class MemoryAgentService: logger.info(f"Log content retrieved, size: {len(summer)} bytes") return summer - + async def stream_log_content(self) -> AsyncGenerator[str, None]: """ Stream log content in real-time using Server-Sent Events (SSE) - + This method establishes a streaming connection and transmits log entries as they are written to the log file. It uses the LogStreamer to watch the file and yields SSE-formatted messages. - + Yields: SSE-formatted strings with the following event types: - log: Contains log content and timestamp - keepalive: Periodic keepalive messages to maintain connection - error: Error information if streaming fails - done: Indicates streaming has completed - + Raises: FileNotFoundError: If log file doesn't exist at stream start Exception: For other unexpected errors during streaming """ logger.info("Starting log content streaming") - + # Get log file path - use project root directory current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = os.path.dirname(app_dir) # redbear-mem directory log_path = os.path.join(project_root, "logs", "agent_service.log") - + # Check if file exists before starting stream if not os.path.exists(log_path): logger.error(f"Log file not found: {log_path}") # Send error event in SSE format yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n" return - + streamer = None try: # Initialize LogStreamer with keepalive interval from settings (default 300 seconds) keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300) streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval) - + logger.info(f"LogStreamer initialized for {log_path}") - + # Stream log content using read_existing_and_stream to get all existing content first async for message in streamer.read_existing_and_stream(): event_type = message.get("event") data = message.get("data") - + # Format as SSE message # SSE format: "event: \ndata: \n\n" sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - + logger.debug(f"Streaming event: {event_type}") yield sse_message - + # If error or done event, stop streaming if event_type in ["error", "done"]: logger.info(f"Stream ended with event: {event_type}") break - + except FileNotFoundError as e: logger.error(f"Log file not found during streaming: {e}") yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n" - + except Exception as e: logger.error(f"Unexpected error during log streaming: {e}", exc_info=True) yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n" - + finally: # Resource cleanup logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - + async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id - + Args: group_id: Group identifier (also used as end_user_id) message: Message to write @@ -260,10 +270,10 @@ class MemoryAgentService: db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID - + Returns: Write operation result status - + Raises: ValueError: If config loading fails or write operation fails """ @@ -279,7 +289,7 @@ class MemoryAgentService: raise # Re-raise our specific error logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") - + import time start_time = time.time() @@ -294,61 +304,49 @@ class MemoryAgentService: except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) - + # Log failed operation if audit_logger: duration = time.time() - start_time audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) - + raise ValueError(error_msg) - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) - return result - else: - async with client.session("data_flow") as session: - logger.debug("Connected to MCP Server: data_flow") - tools = await load_mcp_tools(session) - workflow_errors = [] # Track errors from workflow - - # Pass memory_config to the graph workflow - async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph: - logger.debug("Write graph created successfully") + try: + if storage_type == "rag": + result = await write_rag(group_id, message, user_rag_memory_id) + return result + else: + async with make_write_graph() as graph: config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, + "memory_config": memory_config} - async for event in graph.astream( - {"messages": message, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.error(f"Write workflow failed with errors: {error_details}") - - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=error_details - ) - - raise ValueError(f"Write workflow failed: {error_details}") - - return self.writer_messages_deal(messages, start_time, group_id, config_id, message) - + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Write operation failed: {str(e)}" + logger.error(error_msg) + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + raise ValueError(error_msg) + + + + async def read_memory( self, group_id: str, @@ -362,12 +360,12 @@ class MemoryAgentService: ) -> Dict: """ Process read operation with config_id - + search_switch values: - "0": Requires verification - "1": No verification, direct split - "2": Direct answer based on context - + Args: group_id: Group identifier (also used as end_user_id) message: User message @@ -377,18 +375,17 @@ class MemoryAgentService: db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID - + Returns: Dict with 'answer' and 'intermediate_outputs' keys - + Raises: ValueError: If config loading fails """ import time start_time = time.time() - ori_message=message - end_user_id=group_id + # Resolve config_id if None using end_user's connected config if config_id is None: try: @@ -410,6 +407,7 @@ class MemoryAgentService: except ImportError: audit_logger = None + try: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( @@ -440,326 +438,128 @@ class MemoryAgentService: logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - async with client.session('data_flow') as session: - session_start = time.time() - logger.debug("Connected to MCP Server: data_flow") - - tools_start = time.time() - tools = await load_mcp_tools(session) - tools_time = time.time() - tools_start - logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - - outputs = [] - intermediate_outputs = [] - seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates - - # Pass memory_config to the graph workflow - graph_start = time.time() - async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: - graph_init_time = time.time() - graph_start - logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - - start = time.time() + try: + async with make_read_graph() as graph: config = {"configurable": {"thread_id": group_id}} - workflow_errors = [] # Track errors from workflow - - event_count = 0 - async for event in graph.astream( - {"messages": history, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, + "group_id": group_id + , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} + # 获取节点更新信息 + _intermediate_outputs = [] + summary = '' + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - event_count += 1 - event_start = time.time() - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) + for node_name, node_data in update_event.items(): + # if 'save_neo4j' == node_name: + # massages = node_data + print(f"处理节点: {node_name}") - for msg in messages: - msg_content = msg.content - msg_role = msg.__class__.__name__.lower().replace("message", "") - outputs.append({ - "role": msg_role, - "content": msg_content - }) + # 处理不同Summary节点的返回结构 + if 'Summary' in node_name: + if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: + summary = node_data['InputSummary']['summary_result'] + elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: + summary = node_data['RetrieveSummary']['summary_result'] + elif 'summary' in node_data and 'summary_result' in node_data['summary']: + summary = node_data['summary']['summary_result'] + elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: + summary = node_data['SummaryFails']['summary_result'] - # Extract intermediate outputs - if hasattr(msg, 'content'): - try: - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - content_to_parse = msg_content - if isinstance(msg_content, list): - for block in msg_content: - if isinstance(block, dict) and block.get('type') == 'text': - content_to_parse = block.get('text', '') - break - else: - continue # No text block found + spit_data = node_data.get('spit_data', {}).get('_intermediate', None) + if spit_data and spit_data != [] and spit_data != {}: + _intermediate_outputs.append(spit_data) - # Try to parse content as JSON - if isinstance(content_to_parse, str): - try: - parsed = json.loads(content_to_parse) - if isinstance(parsed, dict): - # Check for single intermediate output - if '_intermediate' in parsed: - intermediate_data = parsed['_intermediate'] - output_key = self._create_intermediate_key(intermediate_data) + # Problem_Extension 节点 + problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) + if problem_extension and problem_extension != [] and problem_extension != {}: + _intermediate_outputs.append(problem_extension) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + # Retrieve 节点 + retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) + if retrieve_node and retrieve_node != [] and retrieve_node != {}: + _intermediate_outputs.extend(retrieve_node) - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in parsed: - for intermediate_data in parsed['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) + # Verify 节点 + verify_n = node_data.get('verify', {}).get('_intermediate', None) + if verify_n and verify_n != [] and verify_n != {}: + _intermediate_outputs.append(verify_n) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except (json.JSONDecodeError, ValueError): - pass - except Exception as e: - logger.debug(f"Failed to extract intermediate output: {e}") + # Summary 节点 + summary_n = node_data.get('summary', {}).get('_intermediate', None) + if summary_n and summary_n != [] and summary_n != {}: + _intermediate_outputs.append(summary_n) - event_time = time.time() - event_start - logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") + _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] - workflow_duration = time.time() - start - session_duration = time.time() - session_start - logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") - logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") - logger.info(f"[PERF] Total events processed: {event_count}") - # Extract final answer - final_answer = "" - for messages in outputs: - if messages['role'] == 'tool': - message = messages['content'] + optimized_outputs = merge_multiple_search_results(_intermediate_outputs) + result = reorder_output_results(optimized_outputs) - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(message, list): - # Extract text from MCP content blocks - for block in message: - if isinstance(block, dict) and block.get('type') == 'text': - message = block.get('text', '') - break - else: - continue # No text block found - - try: - parsed = json.loads(message) if isinstance(message, str) else message - if isinstance(parsed, dict): - if parsed.get('status') == 'success': - summary_result = parsed.get('summary_result') - if summary_result: - final_answer = summary_result - except (json.JSONDecodeError, ValueError): - pass - - # 记录成功的操作 - total_duration = time.time() - start_time - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.warning(f"Read workflow completed with errors: {error_details}") + # Log successful operation + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + group_id=group_id, + success=True, + duration=duration + ) + return { + "answer": summary, + "intermediate_outputs": result + } + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Read operation failed: {str(e)}" + logger.error(error_msg) if audit_logger: + duration = time.time() - start_time audit_logger.log_operation( operation="READ", config_id=config_id, group_id=group_id, success=False, - duration=total_duration, - error=error_details, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer), - "errors": workflow_errors - } + duration=duration, + error=error_msg ) - - # Raise error if no answer was produced - if not final_answer: - raise ValueError(f"Read workflow failed: {error_details}") - - if audit_logger and not workflow_errors: - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=True, - duration=total_duration, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer) - } - ) - retrieved_content=[] - repo = ShortTermMemoryRepository(db) - if str(search_switch)!="2": - for intermediate in intermediate_outputs: - print(intermediate) - intermediate_type=intermediate['type'] - if intermediate_type=="search_result": - query=intermediate['query'] - raw_results=intermediate['raw_results'] - reranked_results=raw_results.get('reranked_results',[]) - try: - statements=[statement['statement'] for statement in reranked_results.get('statements', [])] - except Exception: - statements=[] - statements=list(set(statements)) - retrieved_content.append({query:statements}) - if retrieved_content==[]: - retrieved_content='' - if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] - # 使用 upsert 方法 - repo.upsert( - end_user_id=end_user_id, # 确保这个变量在作用域内 - messages=ori_message, - aimessages=final_answer, - retrieved_content=retrieved_content, - search_switch=str(search_switch) - ) - print("写入成功") + raise ValueError(error_msg) - return { - "answer": final_answer, - "intermediate_outputs": intermediate_outputs - } - - def _create_intermediate_key(self, output: Dict) -> str: - """ - Create a unique key for an intermediate output to detect duplicates. - - Args: - output: Intermediate output dictionary - - Returns: - Unique string key for this output - """ - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - # Use type + original query as key - return f"split:{output.get('original_query', '')}" - elif output_type == 'problem_extension': - # Use type + original query as key - return f"extension:{output.get('original_query', '')}" - elif output_type == 'search_result': - # Use type + query + index as key - return f"search:{output.get('query', '')}:{output.get('index', 0)}" - elif output_type == 'retrieval_summary': - # Use type + query as key - return f"summary:{output.get('query', '')}" - elif output_type == 'verification': - # Use type + query as key - return f"verification:{output.get('query', '')}" - elif output_type == 'input_summary': - # Use type + query as key - return f"input_summary:{output.get('query', '')}" - else: - # Fallback: use JSON representation - import json - return json.dumps(output, sort_keys=True) - - def _format_intermediate_output(self, output: Dict) -> Dict: - """Format intermediate output for frontend display.""" - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - return { - 'type': 'problem_split', - 'title': '问题拆分', - 'data': output.get('data', []), - 'original_query': output.get('original_query', '') - } - elif output_type == 'problem_extension': - return { - 'type': 'problem_extension', - 'title': '问题扩展', - 'data': output.get('data', {}), - 'original_query': output.get('original_query', '') - } - elif output_type == 'search_result': - return { - 'type': 'search_result', - 'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})', - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results', ''), - 'index': output.get('index', 0), - 'total': output.get('total', 0) - } - elif output_type == 'retrieval_summary': - return { - 'type': 'retrieval_summary', - 'title': '检索总结', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - elif output_type == 'verification': - return { - 'type': 'verification', - 'title': '数据验证', - 'result': output.get('result', 'unknown'), - 'reason': output.get('reason', ''), - 'query': output.get('query', ''), - 'verified_count': output.get('verified_count', 0) - } - elif output_type == 'input_summary': - return { - 'type': 'input_summary', - 'title': '快速答案', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - else: - return output - async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: """ Determine the type of user message (read or write) Updated to eliminate global variables in favor of explicit parameters. - + Args: message: User message to classify config_id: Configuration ID to load LLM model from database db: Database session - + Returns: Type classification result """ logger.info("Classifying message type") - + # Load configuration to get LLM model ID config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( config_id=config_id, service_name="MemoryAgentService" ) - + status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") return status - + # ==================== 新增的三个接口方法 ==================== - + async def get_knowledge_type_stats( self, end_user_id: Optional[str] = None, @@ -772,13 +572,13 @@ class MemoryAgentService: 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤) 3. total: 所有类型的总和 - + 参数: - end_user_id: 用户组ID(可选,未提供时 memory 统计为 0) - only_active: 是否仅统计有效记录 - current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0) - db: 数据库会话 - + 返回格式: { "General": count, @@ -790,18 +590,18 @@ class MemoryAgentService: } """ result = {} - + # 1. 统计 PostgreSQL 中的知识库类型 try: if db is None: from app.db import get_db db_gen = get_db() db = next(db_gen) - + # 初始化所有标准类型为 0 for kb_type in KnowledgeType: result[kb_type.value] = 0 - + # 如果提供了 workspace_id,则按 workspace_id 过滤 if current_workspace_id: # 构建查询条件 @@ -809,47 +609,48 @@ class MemoryAgentService: Knowledge.type, func.count(Knowledge.id).label('count') ).filter(Knowledge.workspace_id == current_workspace_id) - + # 检查 Knowledge 模型是否有 status 字段 if only_active and hasattr(Knowledge, 'status'): query = query.filter(Knowledge.status == 1) - + # 按类型分组 type_counts = query.group_by(Knowledge.type).all() - + # 只填充标准类型的统计值,忽略其他类型 valid_types = {kb_type.value for kb_type in KnowledgeType} for type_name, count in type_counts: if type_name in valid_types: result[type_name] = count - + logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}") else: # 没有提供 workspace_id,所有知识库类型返回 0 logger.info("未提供 workspace_id,知识库类型统计全部为 0") - + except Exception as e: logger.error(f"知识库类型统计失败: {e}") raise Exception(f"知识库类型统计失败: {e}") - + # 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数) try: if current_workspace_id: # 获取当前空间下的所有宿主 from app.repositories import app_repository, end_user_repository from app.schemas.app_schema import App as AppSchema - + from app.schemas.end_user_schema import EndUser as EndUserSchema + # 查询应用并转换为 Pydantic 模型 apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) apps = [AppSchema.model_validate(h) for h in apps_orm] app_ids = [app.id for app in apps] - + # 获取所有宿主 end_users = [] for app_id in app_ids: end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id) end_users.extend(h for h in end_user_orm_list) - + # 统计所有宿主的 Chunk 总数 total_chunks = 0 for end_user in end_users: @@ -864,27 +665,27 @@ class MemoryAgentService: chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 total_chunks += chunk_count logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}") - + result["memory"] = total_chunks logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}") else: # 没有 workspace_id 时,返回 0 result["memory"] = 0 logger.info("未提供 workspace_id,memory 统计为 0") - + except Exception as e: logger.error(f"Neo4j memory统计失败: {e}", exc_info=True) # 如果 Neo4j 查询失败,memory 设为 0 result["memory"] = 0 - + # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( - result.get("General", 0) + - result.get("Web", 0) + - result.get("Third-party", 0) + + result.get("General", 0) + + result.get("Web", 0) + + result.get("Third-party", 0) + result.get("Folder", 0) ) - + return result @@ -895,11 +696,11 @@ class MemoryAgentService: ) -> List[Dict[str, Any]]: """ 获取指定用户的热门记忆标签 - + 参数: - end_user_id: 用户ID(可选),对应Neo4j中的group_id字段 - limit: 返回标签数量限制 - + 返回格式: [ {"name": "标签名", "frequency": 频次}, @@ -928,13 +729,13 @@ class MemoryAgentService: 1. 用户名字(直接使用 end_user_name) 2. 用户标签(从摘要中用LLM总结3个标签) 3. 热门记忆标签(从hot_memory_tags获取前4个) - + 参数: - end_user_id: 用户ID(可选) - current_user_id: 当前登录用户的ID(保留参数) - llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成) - db: 数据库会话(可选) - + 返回格式: { "name": "用户名", @@ -947,13 +748,13 @@ class MemoryAgentService: } """ result = {} - + # 1. 根据 end_user_id 获取 end_user_name try: if end_user_id and db: from app.repositories import end_user_repository from app.schemas.end_user_schema import EndUser as EndUserSchema - + end_user_orm = end_user_repository.get_end_user_by_id(db, end_user_id) if end_user_orm: end_user = EndUserSchema.model_validate(end_user_orm) @@ -965,14 +766,14 @@ class MemoryAgentService: except Exception as e: logger.error(f"Failed to get end_user_name: {e}") end_user_name = "默认用户" - + result["name"] = end_user_name logger.debug(f"The end_user is: {end_user_name}") - + # 2. 使用LLM从语句和实体中提取标签 try: connector = Neo4jConnector() - + # 查询该用户的语句 query = ( "MATCH (s:Statement) " @@ -982,7 +783,7 @@ class MemoryAgentService: ) rows = await connector.execute_query(query, group_id=end_user_id) statements = [r.get("statement", "") for r in rows if r.get("statement")] - + # 查询该用户的热门实体 entity_query = ( "MATCH (e:ExtractedEntity) " @@ -992,9 +793,9 @@ class MemoryAgentService: ) entity_rows = await connector.execute_query(entity_query, group_id=end_user_id) entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows] - + await connector.close() - + if not statements or not llm_id: result["tags"] = [] if not llm_id and statements: @@ -1003,16 +804,16 @@ class MemoryAgentService: # 构建摘要文本 summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}" logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities") - + # 使用LLM提取标签 with get_db_context() as db: factory = MemoryClientFactory(db) llm_client = factory.get_llm_client(llm_id) - + # 定义标签提取的结构 class UserTags(BaseModel): tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") - + messages = [ { "role": "system", @@ -1023,20 +824,20 @@ class MemoryAgentService: "content": f"请从以下用户信息中提取3个标签:\n\n{summary_text}" } ] - + user_tags = await llm_client.response_structured( messages=messages, response_model=UserTags ) - + result["tags"] = user_tags.tags logger.debug(f"Extracted tags: {user_tags.tags}") - + except Exception as e: # 如果提取失败,使用默认值 logger.error(f"Failed to extract user tags: {e}") result["tags"] = [] - + try: # 3. 获取热门记忆标签(前4个) connector = Neo4jConnector() @@ -1049,18 +850,18 @@ class MemoryAgentService: "ORDER BY frequency DESC LIMIT 4" ) hot_tag_rows = await connector.execute_query( - hot_tag_query, - group_id=end_user_id, + hot_tag_query, + group_id=end_user_id, names_to_exclude=names_to_exclude ) await connector.close() - + result["hot_tags"] = [{"name": r["name"], "frequency": r["frequency"]} for r in hot_tag_rows] logger.debug(f"Hot tags found: {len(result['hot_tags'])} tags") except Exception as e: logger.error(f"Failed to get hot tags: {e}") result["hot_tags"] = [] - + return result async def stream_log_content(self) -> AsyncGenerator[str, None]: @@ -1135,79 +936,40 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic -# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]: -# """ -# Parse and return API documentation - -# Args: -# file_path: Optional path to API docs file. If None, uses default path. - -# Returns: -# Dict containing parsed API documentation or error information -# """ -# try: -# target = file_path or get_default_docs_path() - -# if not os.path.isfile(target): -# return { -# "success": False, -# "msg": "API文档文件不存在", -# "error_code": "DOC_NOT_FOUND", -# "data": {"path": target} -# } - -# data = parse_api_docs(target) -# return { -# "success": True, -# "msg": "解析成功", -# "data": data -# } -# except Exception as e: -# logger.error(f"Failed to parse API docs: {e}") -# return { -# "success": False, -# "msg": "解析失败", -# "error_code": "DOC_PARSE_ERROR", -# "data": {"error": str(e)} -# } - - def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]: """ 获取终端用户关联的记忆配置 - + 通过以下流程获取配置: 1. 根据 end_user_id 获取用户的 app_id 2. 获取该应用的最新发布版本 3. 从发布版本的 config 字段中提取 memory_config_id - 4. 根据 memory_config_id 查询配置名称 - + Args: end_user_id: 终端用户ID db: 数据库会话 - + Returns: - 包含 memory_config_id、config_name 和相关信息的字典 - + 包含 memory_config_id 和相关信息的字典 + Raises: ValueError: 当终端用户不存在或应用未发布时 """ from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig from app.models.end_user_model import EndUser from sqlalchemy import select - + logger.info(f"Getting connected config for end_user: {end_user_id}") - + # 1. 获取 end_user 及其 app_id end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() if not end_user: logger.warning(f"End user not found: {end_user_id}") raise ValueError(f"终端用户不存在: {end_user_id}") - + app_id = end_user.app_id logger.debug(f"Found end_user app_id: {app_id}") - + # 2. 获取该应用的最新发布版本 stmt = ( select(AppRelease) @@ -1215,170 +977,135 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An .order_by(AppRelease.version.desc()) ) latest_release = db.scalars(stmt).first() - + if not latest_release: logger.warning(f"No active release found for app: {app_id}") raise ValueError(f"应用未发布: {app_id}") - + logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}") - + # 3. 从 config 中提取 memory_config_id config = latest_release.config or {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - - # 4. 根据 memory_config_id 查询配置名称 - config_name = None - if memory_config_id: - try: - # memory_config_id 可能是整数或字符串,需要转换 - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() - if data_config: - config_name = data_config.config_name - logger.debug(f"Found config_name: {config_name} for config_id: {config_id}") - else: - logger.warning(f"DataConfig not found for config_id: {config_id}") - except (ValueError, TypeError) as e: - logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}") - + result = { "end_user_id": str(end_user_id), "app_id": str(app_id), "release_id": str(latest_release.id), "release_version": latest_release.version, - "memory_config_id": memory_config_id, - "memory_config_name": config_name + "memory_config_id": memory_config_id } - - logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}") + + logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") return result def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]: """ - 批量获取多个终端用户关联的记忆配置 - - 通过优化的查询减少数据库往返次数: - 1. 一次性查询所有 end_user 及其 app_id - 2. 批量查询所有相关的 app_release - 3. 批量查询所有相关的 data_config - + 批量获取多个终端用户关联的记忆配置(优化版本,减少数据库查询次数) + + 通过以下流程获取配置: + 1. 批量查询所有 end_user_id 对应的 app_id + 2. 批量获取这些应用的最新发布版本 + 3. 从发布版本的 config 字段中提取 memory_config_id + Args: end_user_ids: 终端用户ID列表 db: 数据库会话 - + Returns: - 字典,key 为 end_user_id,value 为配置信息字典 - 对于查询失败的用户,value 包含 error 字段 + 字典,key 为 end_user_id,value 为包含 memory_config_id 和 memory_config_name 的字典 + 格式: { + "user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"}, + "user_id_2": {"memory_config_id": None, "memory_config_name": None}, + ... + } """ from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig from app.models.end_user_model import EndUser + from app.models.memory_config_model import MemoryConfig from sqlalchemy import select - - logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users") - + + logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") + result = {} - + + # 如果列表为空,直接返回空字典 + if not end_user_ids: + return result + # 1. 批量查询所有 end_user 及其 app_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - # 构建 end_user_id -> end_user 的映射 - end_user_map = {str(user.id): user for user in end_users} + # 创建 end_user_id -> app_id 的映射 + user_to_app = {str(eu.id): eu.app_id for eu in end_users} - # 记录不存在的用户 - for user_id in end_user_ids: - if user_id not in end_user_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"终端用户不存在: {user_id}" - } - - if not end_users: - logger.warning("No valid end users found") + # 记录未找到的用户 + found_user_ids = set(user_to_app.keys()) + missing_user_ids = set(end_user_ids) - found_user_ids + if missing_user_ids: + logger.warning(f"End users not found: {missing_user_ids}") + for user_id in missing_user_ids: + result[user_id] = {"memory_config_id": None, "memory_config_name": None} + + # 2. 批量获取所有相关应用的最新发布版本 + app_ids = list(user_to_app.values()) + if not app_ids: return result - - # 2. 批量查询所有相关应用的最新发布版本 - app_ids = [user.app_id for user in end_users] - - # 使用子查询找到每个 app 的最新版本 - from sqlalchemy import and_ - - # 查询所有相关的活跃发布版本 - releases = db.query(AppRelease).filter( - and_( - AppRelease.app_id.in_(app_ids), - AppRelease.is_active.is_(True) - ) - ).order_by(AppRelease.app_id, AppRelease.version.desc()).all() - - # 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本) - app_release_map = {} + + # 查询所有活跃的发布版本 + stmt = ( + select(AppRelease) + .where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True)) + .order_by(AppRelease.app_id, AppRelease.version.desc()) + ) + releases = db.scalars(stmt).all() + + # 创建 app_id -> latest_release 的映射(每个 app 只保留最新版本) + app_to_release = {} for release in releases: - app_id_str = str(release.app_id) - if app_id_str not in app_release_map: - app_release_map[app_id_str] = release - - # 3. 收集所有 memory_config_id + if release.app_id not in app_to_release: + app_to_release[release.app_id] = release + + # 3. 收集所有 memory_config_id 并批量查询配置名称 memory_config_ids = [] - for release in app_release_map.values(): - config = release.config or {} - memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - memory_config_ids.append(config_id) - except (ValueError, TypeError): - pass - - # 4. 批量查询所有 data_config - config_name_map = {} + for end_user_id, app_id in user_to_app.items(): + release = app_to_release.get(app_id) + if release: + config = release.config or {} + memory_obj = config.get('memory', {}) + memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + if memory_config_id: + memory_config_ids.append(memory_config_id) + + # 批量查询 memory_config_name + config_id_to_name = {} if memory_config_ids: - data_configs = db.query(DataConfig).filter( - DataConfig.config_id.in_(memory_config_ids) - ).all() - config_name_map = {config.config_id: config.config_name for config in data_configs} - - # 5. 组装结果 - for user in end_users: - user_id = str(user.id) - app_id = str(user.app_id) + memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() + config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs} + + # 4. 构建最终结果 + for end_user_id, app_id in user_to_app.items(): + release = app_to_release.get(app_id) - # 检查是否有发布版本 - if app_id not in app_release_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"应用未发布: {app_id}" - } + if not release: + logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})") + result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} continue - - release = app_release_map[app_id] - - # 提取 memory_config_id + + # 从 config 中提取 memory_config_id config = release.config or {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - # 获取 config_name - config_name = None - if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - config_name = config_name_map.get(config_id) - except (ValueError, TypeError): - pass - - result[user_id] = { - "end_user_id": user_id, + # 获取配置名称 + memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None + + result[end_user_id] = { "memory_config_id": memory_config_id, - "memory_config_name": config_name + "memory_config_name": memory_config_name } - - logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}") + + logger.info(f"Successfully retrieved {len(result)} connected configs") return result \ No newline at end of file diff --git a/api/app/tasks.py b/api/app/tasks.py index 28a882b7..fba9f290 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,27 +1,27 @@ import asyncio -import trio import json import os +import re import time import uuid from datetime import datetime, timezone from math import ceil from typing import Any, Dict, List, Optional -import re import redis import requests +import trio # Import a unified Celery instance from app.celery_app import celery_app from app.core.config import settings +from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache from app.core.rag.llm.chat_model import Base from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.rag.llm.sequence2txt_model import QWenSeq2txt from app.core.rag.models.chunk import DocumentChunk -from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.prompts.generator import question_proposal from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ElasticSearchVectorFactory, @@ -486,6 +486,10 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage Raises: Exception on failure """ + from app.core.logging_config import get_logger + logger = get_logger(__name__) + + logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}") start_time = time.time() # Resolve config_id if None @@ -506,8 +510,14 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage async def _run() -> str: db = next(get_db()) try: + logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory") service = MemoryAgentService() - return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) + result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) + logger.info(f"[CELERY WRITE] Write completed successfully: {result}") + return result + except Exception as e: + logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True) + raise finally: db.close() @@ -532,6 +542,8 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time + logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + return { "status": "SUCCESS", "result": result, @@ -548,6 +560,9 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage detailed_error = "; ".join(error_messages) else: detailed_error = str(e) + + logger.error(f"[CELERY WRITE] Task failed - elapsed_time={elapsed_time:.2f}s, error={detailed_error}", exc_info=True) + return { "status": "FAILURE", "error": detailed_error, diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 8470a5d1..8bc19f3a 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -1,32 +1,5 @@ -version: '3.9' - services: - # MCP Server - standalone service - mcp-server: - image: redbear-mem-open:latest - container_name: mcp-server - ports: - - "8081:8081" # MCP server port - env_file: - - .env - environment: - - SERVER_IP=0.0.0.0 # Bind to all interfaces - volumes: - - ./files:/files - - /etc/localtime:/etc/localtime:ro - command: python -m app.core.memory.agent.mcp_server.server - healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8081/sse')"] - interval: 10s - timeout: 5s - retries: 5 - start_period: 30s - restart: unless-stopped - networks: - - default - - celery - - # FastAPI application - connects to MCP server + # FastAPI application api: image: redbear-mem-open:latest container_name: api @@ -35,37 +8,31 @@ services: env_file: - .env environment: - - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name - - SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces + - SERVER_IP=0.0.0.0 + # 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位 + # - MCP_SERVER_URL= volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-level debug - depends_on: - mcp-server: - condition: service_healthy restart: unless-stopped networks: - default - celery - # Celery worker - connects to MCP server + # Celery worker worker: image: redbear-mem-open:latest container_name: worker env_file: - .env - environment: - - MCP_SERVER_URL=http://mcp-server:8081 # Back to using container name volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro command: celery -A app.celery_worker.celery_app worker --loglevel=info - depends_on: - mcp-server: - condition: service_healthy restart: unless-stopped networks: - celery + networks: - celery: \ No newline at end of file + celery: From e518b57deab18d3b9d635a67808311e5e1ef9692 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:39:12 +0800 Subject: [PATCH 11/12] Fix/memory bug fix (#150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) --- .../memory_reflection_controller.py | 95 +++++----------- .../reflection_engine/self_reflexion.py | 36 +++---- .../utils/prompt/prompts/evaluate.jinja2 | 3 +- .../utils/prompt/prompts/reflexion.jinja2 | 28 +++-- .../repositories/data_config_repository.py | 102 ++++++++++-------- api/app/repositories/neo4j/cypher_queries.py | 4 + api/app/repositories/neo4j/neo4j_update.py | 55 +++++++--- api/app/schemas/memory_storage_schema.py | 10 +- api/app/services/memory_reflection_service.py | 10 +- 9 files changed, 173 insertions(+), 170 deletions(-) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index b0287d80..24c143b9 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,10 +1,11 @@ import asyncio import time +import uuid from app.core.logging_config import get_api_logger from app.core.memory.storage_services.reflection_engine.self_reflexion import ( ReflectionConfig, - ReflectionEngine, + ReflectionEngine, ReflectionRange, ReflectionBaseline, ) from app.core.response_utils import success from app.db import get_db @@ -39,9 +40,6 @@ async def save_reflection_config( db: Session = Depends(get_db), ) -> dict: """Save reflection configuration to data_comfig table""" - - - try: config_id = request.config_id if not config_id: @@ -52,51 +50,30 @@ async def save_reflection_config( api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") - update_params = { - "enable_self_reflexion": request.reflection_enabled, - "iteration_period": request.reflection_period_in_hours, - "reflexion_range": request.reflexion_range, - "baseline": request.baseline, - "reflection_model_id": request.reflection_model_id, - "memory_verify": request.memory_verify, - "quality_assessment": request.quality_assessment, - } + data_config = DataConfigRepository.update_reflection_config( + db, + config_id=config_id, + enable_self_reflexion=request.reflection_enabled, + iteration_period=request.reflection_period_in_hours, + reflexion_range=request.reflexion_range, + baseline=request.baseline, + reflection_model_id=request.reflection_model_id, + memory_verify=request.memory_verify, + quality_assessment=request.quality_assessment + ) - - - query, params = DataConfigRepository.build_update_reflection(config_id, **update_params) - - result = db.execute(text(query), params) - if result.rowcount == 0: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"未找到config_id为 {config_id} 的配置" - ) - db.commit() - - # 查询更新后的配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - - if not result: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"更新后未找到config_id为 {config_id} 的配置" - ) - - api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}") + db.refresh(data_config) reflection_result={ - "config_id": result.config_id, - "enable_self_reflexion": result.enable_self_reflexion, - "iteration_period": result.iteration_period, - "reflexion_range": result.reflexion_range, - "baseline": result.baseline, - "reflection_model_id": result.reflection_model_id, - "memory_verify": result.memory_verify, - "quality_assessment": result.quality_assessment, - "user_id": result.user_id} + "config_id": data_config.config_id, + "enable_self_reflexion": data_config.enable_self_reflexion, + "iteration_period": data_config.iteration_period, + "reflexion_range": data_config.reflexion_range, + "baseline": data_config.baseline, + "reflection_model_id": data_config.reflection_model_id, + "memory_verify": data_config.memory_verify, + "quality_assessment": data_config.quality_assessment} return success(data=reflection_result, msg="反思配置成功") @@ -116,9 +93,8 @@ async def save_reflection_config( ) -@router.post("/reflection") +@router.get("/reflection") async def start_workspace_reflection( - config_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -178,17 +154,7 @@ async def start_reflection_configs( """通过config_id查询data_config表中的反思配置信息""" try: api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") - - # 使用DataConfigRepository查询反思配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - - if not result: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"未找到config_id为 {config_id} 的配置" - ) - + result = DataConfigRepository.query_reflection_config_by_id(db, config_id) # 构建返回数据 reflection_config = { "config_id": result.config_id, @@ -198,8 +164,7 @@ async def start_reflection_configs( "baseline": result.baseline, "reflection_model_id": result.reflection_model_id, "memory_verify": result.memory_verify, - "quality_assessment": result.quality_assessment, - "user_id": result.user_id + "quality_assessment": result.quality_assessment } api_logger.info(f"成功查询反思配置,config_id: {config_id}") return success(data=reflection_config, msg="反思配置查询成功") @@ -227,9 +192,7 @@ async def reflection_run( api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") # 使用DataConfigRepository查询反思配置 - select_query, select_params = DataConfigRepository.build_select_reflection(config_id) - result = db.execute(text(select_query), select_params).fetchone() - + result = DataConfigRepository.query_reflection_config_by_id(db, config_id) if not result: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -242,7 +205,7 @@ async def reflection_run( model_id = result.reflection_model_id if model_id: try: - ModelConfigService.get_model_by_id(db=db, model_id=model_id) + ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id)) api_logger.info(f"模型ID验证成功: {model_id}") except Exception as e: api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}") @@ -252,8 +215,8 @@ async def reflection_run( config = ReflectionConfig( enabled=result.enable_self_reflexion, iteration_period=result.iteration_period, - reflexion_range=result.reflexion_range, - baseline=result.baseline, + reflexion_range=ReflectionRange(result.reflexion_range), + baseline=ReflectionBaseline(result.baseline), output_example='', memory_verify=result.memory_verify, quality_assessment=result.quality_assessment, diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index e9fb8855..bd3a9190 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -24,15 +24,9 @@ from app.core.memory.utils.config.get_data import ( get_data, get_data_statement, ) -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.utils.prompt.template_render import ( - render_evaluate_prompt, - render_reflexion_prompt, -) + from app.core.models.base import RedBearModelConfig -from app.core.response_utils import success from app.repositories.neo4j.cypher_queries import ( - UPDATE_STATEMENT_INVALID_AT, neo4j_query_all, neo4j_query_part, neo4j_statement_all, @@ -160,12 +154,11 @@ class ReflectionEngine: self.neo4j_connector = Neo4jConnector() if self.llm_client is None: - from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context with get_db_context() as db: factory = MemoryClientFactory(db) - self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) + self.llm_client = factory.get_llm_client(self.config.model_id) elif isinstance(self.llm_client, str): # 如果 llm_client 是字符串(model_id),则用它初始化客户端 from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -263,25 +256,23 @@ class ReflectionEngine: # 2. 检测冲突(基于事实的反思) conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) - print(100 * '-') - print(conflict_data) - print(100 * '-') - # # 检查是否真的有冲突 - conflicts_found='' + conflict_list=[] + for i in conflict_data: + conflict_list.append(i['data']) - conflicts_found='' + + + conflicts_found=0 # 3. 解决冲突 - solved_data = await self._resolve_conflicts(conflict_data, statement_databasets) + solved_data = await self._resolve_conflicts(conflict_list, statement_databasets) + if not solved_data: return ReflectionResult( success=False, - message="反思失败,未解决冲突", + message=f"没有{self.config.baseline}相关的冲突数据", conflicts_found=conflicts_found, execution_time=asyncio.get_event_loop().time() - start_time ) - print(100 * '*') - print(solved_data) - print(100 * '*') conflicts_resolved = len(solved_data) logging.info(f"解决了 {conflicts_resolved} 个冲突") @@ -386,7 +377,7 @@ class ReflectionEngine: memory_verifies.append(item['memory_verify']) result_data['memory_verifies'] = memory_verifies result_data['quality_assessments'] = quality_assessments - conflicts_found='' + conflicts_found = 0 # 初始化为整数0而不是空字符串 REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"} # Clearn conflict_data,And memory_verify和quality_assessment cleaned_conflict_data = [] @@ -414,7 +405,7 @@ class ReflectionEngine: cleaned_conflict_data_.append(cleaned_item) print(cleaned_conflict_data_) # 3. 解决冲突 - solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data) + solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data) if not solved_data: return ReflectionResult( success=False, @@ -739,4 +730,3 @@ class ReflectionEngine: raise ValueError(f"未知的反思基线: {self.config.baseline}") - diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index e649897a..5da6d4b5 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -24,7 +24,8 @@ - **身份冲突**: 同一实体被赋予不同类型或角色 - **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 ### 混合冲突 -检测所有逻辑不一致或相互矛盾的记录。 +- 检测所有逻辑不一致或相互矛盾的记录。 +- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候 **检测原则**: - 重点检查相同实体的记录 - 分析description字段语义冲突 diff --git a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 index ed3aad32..99660aa4 100644 --- a/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/reflexion.jinja2 @@ -63,7 +63,7 @@ **脱敏字段**: name、entity1_name、entity2_name、description、relationship ## 4. 处理流程 - +###如果存在冲突数据执行以下步骤,不存在返回【】在data中 ### 步骤1: 类型匹配验证 **匹配规则**: - baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点) @@ -78,7 +78,7 @@ ### 步骤2: 冲突数据分组 **分组策略**: -- 时间冲突组: 涉及用户时间的记录 +- 时间冲突组: 涉及用户时间的记录比如(生日在2月17...) - 活动时间冲突组: 同一活动不同时间的记录 - 事实冲突组: 同一实体不同属性的记录 - 其他冲突组: 其他类型冲突记录 @@ -97,11 +97,12 @@ ### 处理规则 ** baseline是TIME - -保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的 -** baseline不是TIME + - 只处理时间相关的内容,比如时间表达式、日期、时间点 + -保留正确记录不变修改错误记录的expired_at为当前时间,比如(2025-12-16T12:00:00) +** baseline是FACT或者HYBRID + - 处理不是时间相关的内容 - 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段, 如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field - **核心原则**: - 只输出需要修改的记录 - 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录 @@ -110,22 +111,26 @@ - 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %} - 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空 - 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容, - ,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据 + ,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据,如果存在隐私审核,隐私审核是true,也需要放到reflexion的reason字段和solution **变更记录格式**: ```json "change": [ { "field": [ - {"id":修改字段对应的ID} - {"statement_id":需要修改的对象对应的statement_id} - {"字段名1": ["修改前的值1","修改后的值1"]}, - {"字段名2": ["修改前的值2","修改后的值2"]} + {"id": "修改字段对应的ID"}, + {"字段名1": ["修改前的值1", "修改后的值1"]}, + {"字段名2": ["修改前的值2", "修改后的值2"]} ] } ] ``` +**resolved_memory格式说明**: +- 对于TIME类型冲突: 只需expired_at字段即可 +- 对于FACT/HYBRID类型冲突: 需要包含完整的记录对象(包括name、entity1_name、entity2_name、description、relationship等所有相关字段) +- resolved_memory中只包含需要修改的记录,不需要修改的记录不要包含在内 + **类型不匹配处理**: - 冲突类型与baseline不匹配时,resolved设为null - reflexion.reason说明类型不匹配原因 @@ -157,7 +162,8 @@ "conflict": true }, "reflexion": { - "reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析, + "reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,如果 + 隐私审核打开的时候如果存在冲突,分析该冲突的原因 不可以随意分配冲突类型以及原因,不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种", "solution": "该冲突类型的解决方案(不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种)" }, diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index 135c0063..d26058b2 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -10,7 +10,7 @@ Classes: import uuid from typing import Dict, List, Optional, Tuple - +from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger from app.models.data_config_model import DataConfig from app.schemas.memory_storage_schema import ( @@ -20,7 +20,7 @@ from app.schemas.memory_storage_schema import ( ConfigUpdateExtracted, ConfigUpdateForget, ) -from sqlalchemy import desc +from sqlalchemy import desc, select from sqlalchemy.orm import Session # 获取数据库专用日志器 @@ -136,72 +136,88 @@ class DataConfigRepository: id: m.id } AS targetNode """ - - # ==================== SQLAlchemy ORM 数据库操作方法 ==================== @staticmethod - def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]: + def update_reflection_config( + db: Session, + config_id: int, + enable_self_reflexion: bool, + iteration_period: str, + reflexion_range: str, + baseline: str, + reflection_model_id: str, + memory_verify: bool, + quality_assessment: bool + ) -> DataConfig: """构建反思配置更新语句(SQLAlchemy text() 命名参数) Args: + quality_assessment: + memory_verify: + reflection_model_id: + baseline: + reflexion_range: + iteration_period: + enable_self_reflexion: + db: database object config_id: 配置ID - **kwargs: 反思配置参数 Returns: - Tuple[str, Dict]: (SQL查询字符串, 参数字典) + Data Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"构建反思配置更新语句: config_id={config_id}") + stmt = select(DataConfig).where(DataConfig.config_id == config_id) + data_config_obj = db.scalars(stmt).first() + if not data_config_obj: + raise BusinessException + data_config_obj.enable_self_reflexion = enable_self_reflexion + data_config_obj.iteration_period = iteration_period + data_config_obj.reflexion_range = reflexion_range + data_config_obj.baseline = baseline + data_config_obj.reflection_model_id = reflection_model_id + data_config_obj.memory_verify = memory_verify + data_config_obj.quality_assessment = quality_assessment - key_where = "config_id = :config_id" - set_fields: List[str] = [] - params: Dict = { - "config_id": config_id, - } - - # 反思配置字段映射 - mapping = { - "enable_self_reflexion": "enable_self_reflexion", - "iteration_period": "iteration_period", - "reflexion_range": "reflexion_range", - "baseline": "baseline", - "reflection_model_id": "reflection_model_id", - "memory_verify": "memory_verify", - "quality_assessment": "quality_assessment", - } - - for api_field, db_col in mapping.items(): - if api_field in kwargs and kwargs[api_field] is not None: - set_fields.append(f"{db_col} = :{api_field}") - params[api_field] = kwargs[api_field] - - if not set_fields: - raise ValueError("No fields to update") - - set_fields.append("updated_at = timezone('Asia/Shanghai', now())") - query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}" - return query, params + return data_config_obj @staticmethod - def build_select_reflection(config_id: int) -> Tuple[str, Dict]: + def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: + db: database object config_id: 配置ID Returns: Tuple[str, Dict]: (SQL查询字符串, 参数字典) """ db_logger.debug(f"构建反思配置查询语句: config_id={config_id}") + stmt = select(DataConfig).where(DataConfig.config_id == config_id) + data_config = db.scalars(stmt).first() + if not data_config: + raise RuntimeError("reflection config not found") + return data_config + @staticmethod + def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig: + """构建查询所有配置的语句(SQLAlchemy text() 命名参数) + + Args: + db: database object + workspace_id: 工作空间ID + + Returns: + Tuple[str, Dict]: (SQL查询字符串, 参数字典) + """ + db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") + + stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id) + data_config = db.scalars(stmt).first() + if not data_config: + raise RuntimeError("reflection config not found") + return data_config - query = ( - f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, " - f"reflection_model_id, memory_verify, quality_assessment, user_id " - f"FROM {TABLE_NAME} WHERE config_id = :config_id" - ) - params = {"config_id": config_id} - return query, params @staticmethod def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index c91c2e80..cd3cbed7 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -837,12 +837,14 @@ neo4j_query_part = """ WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN + elementId(m) as id, m.name as entity1_name, m.description as description, m.statement_id as statement_id, m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + elementId(rel) as rel_id, rel.predicate as predicate, rel.statement as relationship, rel.statement_id as relationship_statement_id, @@ -855,12 +857,14 @@ neo4j_query_all = """ WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN + elementId(m) as id, m.name as entity1_name, m.description as description, m.statement_id as statement_id, m.created_at as created_at, m.expired_at as expired_at, CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type, + elementId(rel) as rel_id, rel.predicate as predicate, rel.statement as relationship, rel.statement_id as relationship_statement_id, diff --git a/api/app/repositories/neo4j/neo4j_update.py b/api/app/repositories/neo4j/neo4j_update.py index 73b44396..753ae256 100644 --- a/api/app/repositories/neo4j/neo4j_update.py +++ b/api/app/repositories/neo4j/neo4j_update.py @@ -11,22 +11,28 @@ async def update_neo4j_data(neo4j_dict_data, update_databases): update_databases: update """ try: - # 构建WHERE条件 + # 构建WHERE条件 - 只使用elementId where_conditions = [] params = {} - for key, value in neo4j_dict_data.items(): - if value is not None: - param_name = f"param_{key}" - where_conditions.append(f"e.{key} = ${param_name}") - params[param_name] = value + # 优先使用id作为elementId进行查询 + if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None: + where_conditions.append(f"elementId(e) = $param_id") + params['param_id'] = neo4j_dict_data['id'] + else: + # 如果没有id,使用其他字段作为条件 + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"e.{key} = ${param_name}") + params[param_name] = value where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - # 构建SET条件 + # 构建SET条件 - 排除id字段 set_conditions = [] for key, value in update_databases.items(): - if value is not None: + if value is not None and key != 'id': # 不更新id字段 param_name = f"update_{key}" set_conditions.append(f"e.{key} = ${param_name}") params[param_name] = value @@ -76,22 +82,28 @@ async def update_neo4j_data_edge(neo4j_dict_data, update_databases): update_databases: update """ try: - # 构建WHERE条件 + # 构建WHERE条件 - 只使用elementId where_conditions = [] params = {} - for key, value in neo4j_dict_data.items(): - if value is not None: - param_name = f"param_{key}" - where_conditions.append(f"r.{key} = ${param_name}") - params[param_name] = value + # 优先使用id作为elementId进行查询 + if 'id' in neo4j_dict_data and neo4j_dict_data['id'] is not None: + where_conditions.append(f"elementId(r) = $param_id") + params['param_id'] = neo4j_dict_data['id'] + else: + # 如果没有id,使用其他字段作为条件 + for key, value in neo4j_dict_data.items(): + if value is not None: + param_name = f"param_{key}" + where_conditions.append(f"r.{key} = ${param_name}") + params[param_name] = value where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" - # 构建SET条件 + # 构建SET条件 - 排除id字段 set_conditions = [] for key, value in update_databases.items(): - if value is not None: + if value is not None and key != 'id': # 不更新id字段 param_name = f"update_{key}" set_conditions.append(f"r.{key} = ${param_name}") params[param_name] = value @@ -242,7 +254,16 @@ async def neo4j_data(solved_data): if key=='expired_at': updat_expired_at[key] = values[1] - elif key == 'statement_id': + elif key == 'id': + ori_edge[key] = values + updata_edge[key] = values + + ori_entity[key] = values + updata_entity[key] = values + + ori_expired_at[key] = values + elif key == 'rel_id': + key='id' ori_edge[key] = values updata_edge[key] = values diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index ecb1570f..d17a9f2c 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -35,10 +35,10 @@ class BaseDataSchema(BaseModel): expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.") description: Optional[str] = Field(None, description="The description of the data entry.") - # 新增字段以匹配实际输入数据 - entity1_name: str = Field(..., description="The first entity name.") + # 新增字段以匹配实际输入数据 - 改为可选以支持resolved_memory场景 + entity1_name: Optional[str] = Field(None, description="The first entity name.") entity2_name: Optional[str] = Field(None, description="The second entity name.") - statement_id: str = Field(..., description="The statement identifier.") + statement_id: Optional[str] = Field(None, description="The statement identifier.") # 新增字段 - 设为可选以保持向后兼容性 predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.") relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.") @@ -108,13 +108,13 @@ class ChangeRecordSchema(BaseModel): """Schema for individual change records 字段值格式说明: - - id 和 statement_id: 字符串或 None + - id: 字符串,表示修改字段对应的记录ID - 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构 - entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式 """ field: List[Dict[str, Any]] = Field( ..., - description="List of field changes. First item: {id: value or None}, second: {statement_id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" + description="List of field changes. First item: {id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}" ) class ResolvedSchema(BaseModel): diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index 0f8fb569..015cc08a 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -120,10 +120,12 @@ class WorkspaceAppService: def _get_data_config(self, memory_content: str) -> Dict[str, Any]: """Retrieve data_comfig information based on memory_comtent""" try: - data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) - data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() - if data_config_result is None: - return None + data_config_result = DataConfigRepository.query_reflection_config_by_id(self.db, int(memory_content)) + + # data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) + # data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() + # if data_config_result is None: + # return None if data_config_result: return { From 804d87bca2f9a7b2fbd341efb851d7da23778949 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 20 Jan 2026 10:42:13 +0800 Subject: [PATCH 12/12] refactor: extract jinja render's form --- .../Workflow/components/Editor/index.tsx | 36 +-- .../components/Editor/nodes/VariableNode.tsx | 2 +- .../Editor/plugin/AutocompletePlugin.tsx | 8 +- .../components/Editor/plugin/BlurPlugin.tsx | 33 +++ .../Properties/JinjaRender/index.tsx | 206 ++++++++++++++++++ .../Workflow/components/Properties/index.tsx | 145 +----------- 6 files changed, 279 insertions(+), 151 deletions(-) create mode 100644 web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx create mode 100644 web/src/views/Workflow/components/Properties/JinjaRender/index.tsx diff --git a/web/src/views/Workflow/components/Editor/index.tsx b/web/src/views/Workflow/components/Editor/index.tsx index ba2e3a41..fd3e937b 100644 --- a/web/src/views/Workflow/components/Editor/index.tsx +++ b/web/src/views/Workflow/components/Editor/index.tsx @@ -16,6 +16,7 @@ import InitialValuePlugin from './plugin/InitialValuePlugin'; import CommandPlugin from './plugin/CommandPlugin'; import Jinja2HighlightPlugin from './plugin/Jinja2HighlightPlugin'; import LineNumberPlugin from './plugin/LineNumberPlugin'; +import BlurPlugin from './plugin/BlurPlugin'; import { VariableNode } from './nodes/VariableNode' interface LexicalEditorProps { @@ -113,8 +114,10 @@ const Editor: FC =({ display: flex; align-items: flex-start; } - .editor-content-with-numbers { + .editor-content-wrapper { flex: 1; + } + .editor-content-with-numbers { white-space: pre-wrap; } .editor-content-with-numbers p { @@ -174,18 +177,20 @@ const Editor: FC =({
1
- +
+ +
) : ( =({ style={{ minHeight: placeHolderMinheight, position: 'absolute', - top: variant === 'borderless' ? '0' : '6px', - left: enableJinja2 ? '59px' : (variant === 'borderless' ? '0' : '11px'), + top: enableJinja2 ? '4px' : variant === 'borderless' ? '0' : '6px', + left: enableJinja2 ? '16px' : (variant === 'borderless' ? '0' : '11px'), color: '#A8A9AA', fontSize: fontSize, lineHeight: placeHolderMinheight, @@ -227,6 +232,7 @@ const Editor: FC =({ { setCount(count) }} onChange={onChange} /> + {enableJinja2 && }
); diff --git a/web/src/views/Workflow/components/Editor/nodes/VariableNode.tsx b/web/src/views/Workflow/components/Editor/nodes/VariableNode.tsx index 13d12ee1..d29fba4c 100644 --- a/web/src/views/Workflow/components/Editor/nodes/VariableNode.tsx +++ b/web/src/views/Workflow/components/Editor/nodes/VariableNode.tsx @@ -36,7 +36,7 @@ const VariableComponent: React.FC<{ nodeKey: NodeKey; data: Suggestion }> = ({ return ( const textAfter = nodeText.substring(anchorOffset); const newText = textBefore + `{{${suggestion.value}}}` + textAfter; - anchorNode.setTextContent(newText); + if ($isTextNode(anchorNode)) { + anchorNode.setTextContent(newText); + } // 设置光标位置到插入文本之后 const newOffset = textBefore.length + `{{${suggestion.value}}}`.length; @@ -129,6 +131,8 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> } return (
e.preventDefault()} style={{ position: 'fixed', top: popupPosition.top, diff --git a/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx new file mode 100644 index 00000000..b636605b --- /dev/null +++ b/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx @@ -0,0 +1,33 @@ +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; +import { useEffect } from 'react'; +import { $setSelection } from 'lexical'; + +export default function BlurPlugin() { + const [editor] = useLexicalComposerContext(); + + useEffect(() => { + return editor.registerRootListener((rootElement) => { + if (rootElement) { + const handleBlur = (e: FocusEvent) => { + // 检查是否点击了自动完成弹窗 + const target = e.target as HTMLElement; + console.log('target', target) + if (target?.closest('[data-autocomplete-popup="true"]')) { + return; + } + + editor.update(() => { + $setSelection(null); + }); + }; + + rootElement.addEventListener('blur', handleBlur); + return () => { + rootElement.removeEventListener('blur', handleBlur); + }; + } + }); + }, [editor]); + + return null; +} diff --git a/web/src/views/Workflow/components/Properties/JinjaRender/index.tsx b/web/src/views/Workflow/components/Properties/JinjaRender/index.tsx new file mode 100644 index 00000000..a2c9da37 --- /dev/null +++ b/web/src/views/Workflow/components/Properties/JinjaRender/index.tsx @@ -0,0 +1,206 @@ +import { type FC, useEffect, useRef } from 'react' +import { useTranslation } from 'react-i18next' +import { Form } from 'antd' +import { Node } from '@antv/x6' +import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' +import MappingList from '../MappingList' +import MessageEditor from '../MessageEditor' + +interface MappingItem { + name?: string + value?: string +} + +interface JinjaRenderProps { + options: Suggestion[] + templateOptions: Suggestion[] + selectedNode: Node +} + +const extractTemplateVars = (template: string): string[] => { + return (template.match(/{{\s*([\w.]+)\s*}}/g) || []) + .map(m => m.replace(/{{\s*|\s*}}/g, '')) +} + +const getMappingNames = (mapping: MappingItem[]): string[] => { + return mapping.filter(item => item?.name).map(item => item.name!) +} + +const JinjaRender: FC = ({ selectedNode, options, templateOptions }) => { + const { t } = useTranslation() + const form = Form.useFormInstance() + const values = Form.useWatch([], form) || {} + + console.log('JinjaRender values', values) + + const prevMappingNamesRef = useRef([]) + const prevTemplateVarsRef = useRef([]) + const syncTimeoutRef = useRef(null) + const isSyncingRef = useRef(false) + const lastSyncSourceRef = useRef<'mapping' | 'template' | null>(null) + + // Reset refs when node changes + useEffect(() => { + if (selectedNode?.getData()?.id) { + prevMappingNamesRef.current = [] + prevTemplateVarsRef.current = [] + lastSyncSourceRef.current = null + } + }, [selectedNode?.getData()?.id]) + + // Sync template when mapping names change + useEffect(() => { + if ( + isSyncingRef.current || + lastSyncSourceRef.current === 'mapping' || + selectedNode?.data?.type !== 'jinja-render' || + !values?.mapping || + !values?.template + ) return + + const currentMappingNames = Array.isArray(values.mapping) ? getMappingNames(values.mapping) : [] + const prevNames = prevMappingNamesRef.current + + if (prevNames.length === 0) { + prevMappingNamesRef.current = currentMappingNames + return + } + + if (JSON.stringify(prevNames) === JSON.stringify(currentMappingNames)) return + + if (syncTimeoutRef.current) clearTimeout(syncTimeoutRef.current) + const activeElement = document.activeElement as HTMLElement + + syncTimeoutRef.current = setTimeout(() => { + let updatedTemplate = String(form.getFieldValue('template') || '') + + prevNames.forEach((oldName, index) => { + const newName = currentMappingNames[index] + if (newName && oldName !== newName) { + updatedTemplate = updatedTemplate.replace( + new RegExp(`{{\\s*${oldName}\\s*}}`, 'g'), + `{{${newName}}}` + ) + } + }) + + if (updatedTemplate !== form.getFieldValue('template')) { + isSyncingRef.current = true + lastSyncSourceRef.current = 'mapping' + + prevTemplateVarsRef.current = extractTemplateVars(updatedTemplate) + prevMappingNamesRef.current = currentMappingNames + form.setFieldValue('template', updatedTemplate) + + requestAnimationFrame(() => { + activeElement?.focus?.() + setTimeout(() => { + isSyncingRef.current = false + lastSyncSourceRef.current = null + }, 50) + }) + } else { + prevMappingNamesRef.current = currentMappingNames + } + }, 0) + }, [values?.mapping, selectedNode?.data?.type, form]) + + // Sync mapping when template variables change + useEffect(() => { + console.log('values?.template', values?.template) + if ( + isSyncingRef.current || + lastSyncSourceRef.current === 'template' || + selectedNode?.data?.type !== 'jinja-render' || + !values?.template || + !values?.mapping + ) return + + const templateVars = extractTemplateVars(String(values.template)) + if (JSON.stringify(prevTemplateVarsRef.current) === JSON.stringify(templateVars)) return + + const isTemplateEditor = document.activeElement?.closest('[data-editor-type="template"]') + if (!isTemplateEditor) { + prevTemplateVarsRef.current = templateVars + return + } + + const updatedMapping: MappingItem[] = Array.isArray(values.mapping) + ? [...values.mapping.filter((item: MappingItem) => item)] + : [] + const existingNames = getMappingNames(updatedMapping) + let updatedTemplate = String(values.template) + + // Update existing mapping names based on position + if (prevTemplateVarsRef.current.length > 0) { + prevTemplateVarsRef.current.forEach((oldVar, index) => { + const newVar = templateVars[index] + if (newVar && oldVar !== newVar && updatedMapping[index]) { + updatedMapping[index] = { ...updatedMapping[index], name: newVar } + } + }) + } + + // Add new mappings and normalize template + templateVars.forEach(varName => { + const existingMapping = updatedMapping.find(item => item.value === `{{${varName}}}`) + const regex = new RegExp(`{{\\s*${varName.replace(/\./g, '\\.')}\\s*}}`, 'g') + + if (existingMapping) { + updatedTemplate = updatedTemplate.replace(regex, `{{${existingMapping.name}}}`) + } else if (!existingNames.includes(varName)) { + const mappingName = varName.includes('.') ? varName.split('.').pop() || varName : varName + updatedMapping.push({ name: mappingName, value: `{{${varName}}}` }) + updatedTemplate = updatedTemplate.replace(regex, `{{${mappingName}}}`) + } + }) + + // Remove unused mappings and duplicates + const seenNames = new Set() + const finalMapping = updatedMapping.filter(item => { + const isUsed = templateVars.some(v => item.name === v || item.value === `{{${v}}}`) + if (!isUsed || !item.name || seenNames.has(item.name)) return false + seenNames.add(item.name) + return true + }) + + isSyncingRef.current = true + lastSyncSourceRef.current = 'template' + prevMappingNamesRef.current = getMappingNames(finalMapping) + prevTemplateVarsRef.current = templateVars + + if (JSON.stringify(finalMapping) !== JSON.stringify(values.mapping)) { + form.setFieldValue('mapping', finalMapping) + } + if (updatedTemplate !== String(values.template)) { + form.setFieldValue('template', updatedTemplate) + } + + setTimeout(() => { + isSyncingRef.current = false + lastSyncSourceRef.current = null + }, 50) + }, [values?.template, selectedNode?.data?.type, form]) + + return ( + <> + + + + + + + + + ) +} + +export default JinjaRender diff --git a/web/src/views/Workflow/components/Properties/index.tsx b/web/src/views/Workflow/components/Properties/index.tsx index 6d4571dc..d55e1d9e 100644 --- a/web/src/views/Workflow/components/Properties/index.tsx +++ b/web/src/views/Workflow/components/Properties/index.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState, useRef, useMemo } from "react"; +import { type FC, useEffect, useState, useMemo } from "react"; import clsx from 'clsx' import { useTranslation } from 'react-i18next' import { Graph, Node } from '@antv/x6'; @@ -17,7 +17,6 @@ import ParamsList from './ParamsList'; import GroupVariableList from './GroupVariableList' import CaseList from './CaseList' import HttpRequest from './HttpRequest'; -import MappingList from './MappingList' import CategoryList from './CategoryList' import ConditionList from './ConditionList' import CycleVarsList from './CycleVarsList' @@ -29,6 +28,7 @@ import { useVariableList, getCurrentNodeVariables } from './hooks/useVariableLis import styles from './properties.module.css' import Editor from "../Editor"; import RbSlider from './RbSlider' +import JinjaRender from './JinjaRender' interface PropertiesProps { selectedNode?: Node | null; @@ -50,136 +50,16 @@ const Properties: FC = ({ const [form] = Form.useForm(); const [configs, setConfigs] = useState>({} as Record) const values = Form.useWatch([], form); - const prevMappingNamesRef = useRef([]) - const prevTemplateVarsRef = useRef([]) - const syncTimeoutRef = useRef(null) - const isSyncingRef = useRef(false) - const lastSyncSourceRef = useRef<'mapping' | 'template' | null>(null) const variableList = useVariableList(selectedNode, graphRef, chatVariables) useEffect(() => { if (selectedNode?.getData()?.id) { - form.resetFields() - prevMappingNamesRef.current = [] - prevTemplateVarsRef.current = [] - lastSyncSourceRef.current = null setOutputCollapsed(true) + } else { + form.resetFields() } }, [selectedNode?.getData()?.id]) - // Sync template when mapping names change - useEffect(() => { - if (isSyncingRef.current || lastSyncSourceRef.current === 'mapping' || selectedNode?.data?.type !== 'jinja-render' || !values?.mapping || !values?.template) return - - const currentMappingNames = Array.isArray(values.mapping) ? values.mapping.filter(item => item && item.name).map((item: any) => item.name) : [] - const prevNames = prevMappingNamesRef.current - - if (prevNames.length === 0) { - prevMappingNamesRef.current = currentMappingNames - return - } - - if (JSON.stringify(prevNames) === JSON.stringify(currentMappingNames)) return - - if (syncTimeoutRef.current) clearTimeout(syncTimeoutRef.current) - const activeElement = document.activeElement as HTMLElement - - syncTimeoutRef.current = setTimeout(() => { - let updatedTemplate = String(form.getFieldValue('template') || '') - - prevNames.forEach((oldName, index) => { - const newName = currentMappingNames[index] - if (newName && oldName !== newName) { - updatedTemplate = updatedTemplate.replace(new RegExp(`{{\\s*${oldName}\\s*}}`, 'g'), `{{${newName}}}`) - } - }) - - if (updatedTemplate !== form.getFieldValue('template')) { - isSyncingRef.current = true - lastSyncSourceRef.current = 'mapping' - const newTemplateVars = (updatedTemplate.match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, '')) - prevTemplateVarsRef.current = newTemplateVars - prevMappingNamesRef.current = currentMappingNames - form.setFieldValue('template', updatedTemplate) - - requestAnimationFrame(() => { - activeElement?.focus?.() - setTimeout(() => { - isSyncingRef.current = false - lastSyncSourceRef.current = null - }, 50) - }) - } else { - prevMappingNamesRef.current = currentMappingNames - } - }, 0) - }, [values?.mapping, selectedNode?.data?.type, form]) - - // Sync mapping when template variables change - useEffect(() => { - if (isSyncingRef.current || lastSyncSourceRef.current === 'template' || selectedNode?.data?.type !== 'jinja-render' || !values?.template || !values?.mapping) return - - const templateVars = (String(values.template).match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, '')) - if (JSON.stringify(prevTemplateVarsRef.current) === JSON.stringify(templateVars)) return - - const isTemplateEditor = document.activeElement?.closest('[data-editor-type="template"]') - if (!isTemplateEditor) { - prevTemplateVarsRef.current = templateVars - return - } - - const updatedMapping = Array.isArray(values.mapping) ? [...values.mapping.filter(item => item)] : [] - const existingNames = updatedMapping.filter(item => item && item.name).map(item => item.name) - let updatedTemplate = String(values.template) - - if (prevTemplateVarsRef.current.length > 0) { - prevTemplateVarsRef.current.forEach((oldVar, index) => { - const newVar = templateVars[index] - if (newVar && oldVar !== newVar && updatedMapping[index]) { - updatedMapping[index] = { ...updatedMapping[index], name: newVar } - } - }) - } - - templateVars.forEach(varName => { - const existingMapping = updatedMapping.find(item => item.value === `{{${varName}}}`) - const regex = new RegExp(`{{\\s*${varName.replace(/\./g, '\\.')}\\s*}}`, 'g') - - if (existingMapping) { - updatedTemplate = updatedTemplate.replace(regex, `{{${existingMapping.name}}}`) - } else if (!existingNames.includes(varName)) { - const mappingName = varName.includes('.') ? varName.split('.').pop() || varName : varName - updatedMapping.push({ name: mappingName, value: `{{${varName}}}` }) - updatedTemplate = updatedTemplate.replace(regex, `{{${mappingName}}}`) - } - }) - - const seenNames = new Set() - const finalMapping = updatedMapping.filter(item => { - const isUsed = templateVars.some(v => item.name === v || item.value === `{{${v}}}`) - if (!isUsed || seenNames.has(item.name)) return false - seenNames.add(item.name) - return true - }) - - isSyncingRef.current = true - lastSyncSourceRef.current = 'template' - prevMappingNamesRef.current = finalMapping.filter(item => item && item.name).map((item: any) => item.name) - prevTemplateVarsRef.current = templateVars - - if (JSON.stringify(finalMapping) !== JSON.stringify(values.mapping)) { - form.setFieldValue('mapping', finalMapping) - } - if (updatedTemplate !== String(values.template)) { - form.setFieldValue('template', updatedTemplate) - } - - setTimeout(() => { - isSyncingRef.current = false - lastSyncSourceRef.current = null - }, 50) - }, [values?.template, selectedNode?.data?.type, form]) - useEffect(() => { if (selectedNode && form) { const { type = 'default', name = '', config } = selectedNode.getData() || {} @@ -197,6 +77,8 @@ const Properties: FC = ({ ...initialValue, }) setConfigs(config || {}) + } else { + form.resetFields() } }, [selectedNode, form]) @@ -529,6 +411,12 @@ const Properties: FC = ({ /> : selectedNode?.data?.type === 'tool' ? + : selectedNode?.data.type === 'jinja-render' + ? : configs && Object.keys(configs).length > 0 && Object.keys(configs).map((key) => { const config = configs[key] || {} @@ -646,15 +534,6 @@ const Properties: FC = ({ ) } - - if (config.type === 'mappingList') { - return ( - - - - - ) - } if (config.type === 'cycleVarsList') { return (