diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 3b4e5a25..d57ee69d 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -872,3 +872,44 @@ async def update_workflow_config( workspace_id = current_user.current_workspace_id cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) return success(data=WorkflowConfigSchema.model_validate(cfg)) + + +@router.get("/{app_id}/statistics", summary="应用统计数据") +@cur_workspace_access_guard() +def get_app_statistics( + app_id: uuid.UUID, + start_date: int, + end_date: int, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """获取应用统计数据 + + Args: + app_id: 应用ID + start_date: 开始时间戳(毫秒) + end_date: 结束时间戳(毫秒) + + Returns: + - daily_conversations: 每日会话数统计 + - total_conversations: 总会话数 + - daily_new_users: 每日新增用户数 + - total_new_users: 总新增用户数 + - daily_api_calls: 每日API调用次数 + - total_api_calls: 总API调用次数 + - daily_tokens: 每日token消耗 + - total_tokens: 总token消耗 + """ + workspace_id = current_user.current_workspace_id + + from app.services.app_statistics_service import AppStatisticsService + stats_service = AppStatisticsService(db) + + result = stats_service.get_app_statistics( + app_id=app_id, + workspace_id=workspace_id, + start_date=start_date, + end_date=end_date + ) + + return success(data=result) diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 42d59664..509f7cad 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -3,15 +3,17 @@ from sqlalchemy.orm import Session from typing import Optional import uuid - +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.db import get_db from app.dependencies import get_current_user from app.models.models_model import ModelProvider, ModelType from app.models.user_model import User +from app.repositories.model_repository import ModelConfigRepository from app.schemas import model_schema from app.core.response_utils import success from app.schemas.response_schema import ApiResponse, PageData -from app.services.model_service import ModelConfigService, ModelApiKeyService +from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService from app.core.logging_config import get_api_logger # 获取API专用日志器 @@ -24,7 +26,6 @@ router = APIRouter( @router.get("/type", response_model=ApiResponse) def get_model_types(): - return success(msg="获取模型类型成功", data=list(ModelType)) @@ -35,13 +36,68 @@ def get_model_providers(): @router.get("", response_model=ApiResponse) def get_model_list( - type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), - provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), + type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), + is_active: Optional[bool] = Query(None, description="激活状态筛选"), + is_public: Optional[bool] = Query(None, description="公开状态筛选"), + search: Optional[str] = Query(None, description="搜索关键词"), + page: int = Query(1, ge=1, description="页码"), + pagesize: int = Query(10, ge=1, le=100, description="每页数量"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 获取模型配置列表 + + 支持多个 type 参数: + - 单个:?type=LLM + - 多个(逗号分隔):?type=LLM,EMBEDDING + - 多个(重复参数):?type=LLM&type=EMBEDDING + """ + api_logger.info( + f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}") + + try: + # 解析 type 参数(支持逗号分隔) + type_list = [] + if type is not None: + flat_type = [] + for item in type: + split_items = [t.strip() for t in item.split(',') if t.strip()] + flat_type.extend(split_items) + + unique_flat_type = list(dict.fromkeys(flat_type)) + type_list = [ModelType(t.lower()) for t in unique_flat_type] + + api_logger.error(f"获取模型type_list: {type_list}") + query = model_schema.ModelConfigQuery( + type=type_list, + provider=provider, + is_active=is_active, + is_public=is_public, + search=search, + page=page, + pagesize=pagesize + ) + + api_logger.debug(f"开始获取模型配置列表: {query.dict()}") + result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id) + result = PageData.model_validate(result_orm) + api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}") + return success(data=result, msg="模型配置列表获取成功") + except Exception as e: + api_logger.error(f"获取模型配置列表失败: {str(e)}") + raise + + +@router.get("/new", response_model=ApiResponse) +def get_model_list( + type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"), is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"), search: Optional[str] = Query(None, description="搜索关键词"), - page: int = Query(1, ge=1, description="页码"), - pagesize: int = Query(10, ge=1, le=100, description="每页数量"), + is_composite: Optional[bool] = Query(None, description="组合模型筛选"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -53,36 +109,123 @@ def get_model_list( - 多个(逗号分隔):?type=LLM,EMBEDDING - 多个(重复参数):?type=LLM&type=EMBEDDING """ - api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}") + api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}") try: # 解析 type 参数(支持逗号分隔) - type_list = None - if type: - type_values = [t.strip() for t in type.split(',')] - type_list = [model_schema.ModelType(t.lower()) for t in type_values if t] + type_list = [] + if type is not None: + flat_type = [] + for item in type: + split_items = [t.strip() for t in item.split(',') if t.strip()] + flat_type.extend(split_items) + + unique_flat_type = list(dict.fromkeys(flat_type)) + type_list = [ModelType(t.lower()) for t in unique_flat_type] - api_logger.error(f"获取模型type_list: {type_list}") - query = model_schema.ModelConfigQuery( + api_logger.info(f"获取模型type_list: {type_list}") + query = model_schema.ModelConfigQueryNew( type=type_list, provider=provider, is_active=is_active, is_public=is_public, - search=search, - page=page, - pagesize=pagesize + is_composite=is_composite, + search=search ) - api_logger.debug(f"开始获取模型配置列表: {query.dict()}") - result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id) - result = PageData.model_validate(result_orm) - api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}") + api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}") + result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id) + api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}") return success(data=result, msg="模型配置列表获取成功") except Exception as e: api_logger.error(f"获取模型配置列表失败: {str(e)}") raise +@router.get("/model_plaza", response_model=ApiResponse) +def get_model_plaza_list( + type: Optional[ModelType] = Query(None, description="模型类型"), + provider: Optional[ModelProvider] = Query(None, description="供应商"), + is_official: Optional[bool] = Query(None, description="是否官方模型"), + is_deprecated: Optional[bool] = Query(False, description="是否弃用"), + search: Optional[str] = Query(None, description="搜索关键词"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """模型广场查询接口(按供应商分组)""" + + query = model_schema.ModelBaseQuery( + type=type, + provider=provider, + is_official=is_official, + is_deprecated=is_deprecated, + search=search + ) + result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id) + return success(data=result, msg="模型广场列表获取成功") + + +@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse) +def get_model_base_by_id( + model_base_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取基础模型详情""" + + result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id) + return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功") + + +@router.post("/model_plaza", response_model=ApiResponse) +def create_model_base( + data: model_schema.ModelBaseCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建基础模型""" + + result = ModelBaseService.create_model_base(db=db, data=data) + return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功") + + +@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse) +def update_model_base( + model_base_id: uuid.UUID, + data: model_schema.ModelBaseUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新基础模型""" + + result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data) + return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功") + + +@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse) +def delete_model_base( + model_base_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除基础模型""" + + ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id) + return success(msg="基础模型删除成功") + + +@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse) +def add_model_from_plaza( + model_base_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """从模型广场添加模型到模型列表""" + + result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id) + return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功") + + @router.get("/{model_id}", response_model=ApiResponse) def get_model_by_id( model_id: uuid.UUID, @@ -138,6 +281,71 @@ async def create_model( raise +@router.post("/composite", response_model=ApiResponse) +async def create_composite_model( + model_data: model_schema.CompositeModelCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 创建组合模型 + + - 绑定一个或多个现有的 API Key + - 所有 API Key 必须来自非组合模型 + - 所有 API Key 关联的模型类型必须与组合模型类型一致 + """ + api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}") + + try: + result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id) + api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})") + + result = model_schema.ModelConfig.model_validate(result_orm) + return success(data=result, msg="组合模型创建成功") + except Exception as e: + api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}") + raise + + +@router.put("/composite/{model_id}", response_model=ApiResponse) +async def update_composite_model( + model_id: uuid.UUID, + model_data: model_schema.CompositeModelCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新组合模型""" + api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}") + + try: + result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id) + api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})") + + result = model_schema.ModelConfig.model_validate(result_orm) + return success(data=result, msg="组合模型更新成功") + except Exception as e: + api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}") + raise + + +@router.delete("/composite/{model_id}", response_model=ApiResponse) +def delete_composite_model( + model_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除组合模型""" + api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}") + + try: + ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id) + api_logger.info(f"组合模型删除成功: model_id={model_id}") + return success(msg="组合模型删除成功") + except Exception as e: + api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}") + raise + + @router.put("/{model_id}", response_model=ApiResponse) def update_model( model_id: uuid.UUID, @@ -214,6 +422,51 @@ def get_model_api_keys( raise +@router.post("/provider/apikeys", response_model=ApiResponse) +async def create_model_api_key_by_provider( + api_key_data: model_schema.ModelApiKeyCreateByProvider, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + 根据供应商为所有匹配的模型创建API Key + """ + api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}") + + try: + # 根据tenant_id和provider筛选model_config_id列表 + model_config_ids = api_key_data.model_config_ids + if not model_config_ids: + model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider( + db=db, + tenant_id=current_user.tenant_id, + provider=api_key_data.provider + ) + + if not model_config_ids: + raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND) + + # 构造schema并调用service + create_data = model_schema.ModelApiKeyCreateByProvider( + provider=api_key_data.provider, + api_key=api_key_data.api_key, + api_base=api_key_data.api_base, + description=api_key_data.description, + config=api_key_data.config, + is_active=api_key_data.is_active, + priority=api_key_data.priority, + model_config_ids=model_config_ids + ) + created_keys = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data) + + api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型") + result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys] + return success(data=result_list, msg=f"成功为 {len(created_keys)} 个模型创建API Key") + except Exception as e: + api_logger.error(f"创建API Key失败: {str(e)}") + raise + + @router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED) async def create_model_api_key( model_id: uuid.UUID, @@ -228,11 +481,12 @@ async def create_model_api_key( try: # 设置模型配置ID - api_key_data.model_config_id = model_id + api_key_data.model_config_ids = [model_id] api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}") - result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data) - api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})") + result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data) + api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})") + result = model_schema.ModelApiKey.model_validate(result_orm) return success(data=result, msg="模型API Key创建成功") except Exception as e: api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}") @@ -334,5 +588,3 @@ async def validate_model_config( return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成") - - diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index f0411ae3..b7abf659 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -16,7 +16,6 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_config import VariableType from app.core.workflow.nodes.enums import NodeType -from app.core.workflow.template_renderer import render_template logger = logging.getLogger(__name__) @@ -157,12 +156,137 @@ class WorkflowExecutor: "error": result.get("error"), } - def _update_end_activate(self, node_id): + def _update_scope_activate(self, scope, status=None): + """ + Update the activation state of all End nodes based on a completed scope (node or variable). + + Iterates over all End nodes in `self.end_outputs` and calls + `update_activate` on each, which may: + - Activate variable segments that depend on the completed node/scope. + - Activate the entire End node output if all control conditions are met. + + If any End node becomes active and `self.activate_end` is not yet set, + this node will be marked as the currently active End node. + + Args: + scope (str): The node ID or scope that has completed execution. + status (str | None): Optional status of the node (used for branch/control nodes). + """ for node in self.end_outputs.keys(): - self.end_outputs[node].update_activate(node_id) + self.end_outputs[node].update_activate(scope, status) if self.end_outputs[node].activate and self.activate_end is None: self.activate_end = node + def _update_stream_output_status(self, activate, data): + """ + Update the stream output state of End nodes based on workflow state updates. + + This method checks which nodes/scopes are activated and propagates + activation to End nodes accordingly. + + Args: + activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated. + data (dict): Mapping of node_id -> node runtime data, including outputs. + + Behavior: + For each node in `data`: + 1. If the node is activated (`activate[node_id]` is True), + retrieve its output status from `runtime_vars`. + 2. Call `_update_scope_activate` to propagate the activation + to all relevant End nodes and update `self.activate_end`. + """ + for node_id in data.keys(): + if activate.get(node_id): + node_output_status = ( + data[node_id] + .get('runtime_vars', {}) + .get(node_id) + .get("output") + ) + self._update_scope_activate(node_id, status=node_output_status) + + async def _emit_active_chunks( + self, + node_outputs: dict, + variables: dict, + force=False + ): + """ + Process and yield all currently active output segments for the currently active End node. + + This method handles stream-mode output for an End node by iterating through its output segments + (`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless + `force=True`, which allows all segments to be processed regardless of their activation state. + + Behavior: + 1. Iterates from the current `cursor` position to the end of the outputs list. + 2. For each segment: + - If the segment is literal text (`is_variable=False`), append it directly. + - If the segment is a variable (`is_variable=True`), evaluate it using + `evaluate_expression` with the given `node_outputs` and `variables`, + then transform the result with `_trans_output_string`. + 3. Yield a stream event of type "message" containing the processed chunk. + 4. Move the `cursor` forward after processing each segment. + 5. When all segments have been processed, remove this End node from `end_outputs` + and reset `activate_end` to None. + + Args: + node_outputs (dict): Current runtime node outputs, used for variable evaluation. + variables (dict): Current runtime variables, used for variable evaluation. + force (bool, default=False): If True, process segments even if `activate=False`. + + Yields: + dict: A stream event of type "message" containing the processed chunk. + + Notes: + - Segments that fail evaluation (ValueError) are skipped with a warning logged. + - This method only processes the currently active End node (`self.activate_end`). + - Use `force=True` for final emission regardless of activation state. + """ + + end_info = self.end_outputs[self.activate_end] + + while end_info.cursor < len(end_info.outputs): + final_chunk = '' + current_segment = end_info.outputs[end_info.cursor] + + if not current_segment.activate and not force: + # Stop processing until this segment becomes active + break + + # Literal segment + if not current_segment.is_variable: + final_chunk += current_segment.literal + else: + # Variable segment: evaluate and transform + try: + chunk = evaluate_expression( + current_segment.literal, + variables=variables, + node_outputs=node_outputs + ) + chunk = self._trans_output_string(chunk) + final_chunk += chunk + except ValueError: + # Log failed evaluation but continue streaming + logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}") + + if final_chunk: + yield { + "event": "message", + "data": { + "chunk": final_chunk + } + } + + # Advance cursor after processing + end_info.cursor += 1 + + # Remove End node from active tracking if all segments have been processed + if end_info.cursor >= len(end_info.outputs): + self.end_outputs.pop(self.activate_end) + self.activate_end = None + @staticmethod def _trans_output_string(content): if isinstance(content, str): @@ -218,14 +342,8 @@ class WorkflowExecutor: result = await graph.ainvoke(initial_state, config=self.checkpoint_config) full_content = '' - for end_info in self.end_outputs.values(): - output_template = "".join([output.literal for output in end_info.outputs]) - full_content += render_template( - output_template, - result.get("variables", {}), - result.get("runtime_vars", {}), - strict=False - ) + for end_id in self.end_outputs.keys(): + full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '') result["messages"].extend( [ { @@ -306,7 +424,7 @@ class WorkflowExecutor: try: chunk_count = 0 full_content = '' - + self._update_scope_activate("sys") async for event in graph.astream( initial_state, stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode @@ -333,9 +451,12 @@ class WorkflowExecutor: if not end_info or end_info.cursor >= len(end_info.outputs): continue current_output = end_info.outputs[end_info.cursor] - if current_output.is_variable and current_output.depends_on_node(node_id): + if current_output.is_variable and current_output.depends_on_scope(node_id): if data.get("done"): end_info.cursor += 1 + if end_info.cursor >= len(end_info.outputs): + self.end_outputs.pop(self.activate_end) + self.activate_end = None else: full_content += data.get("chunk") yield { @@ -415,91 +536,53 @@ class WorkflowExecutor: elif mode == "updates": # Handle state updates - store final state - for node_id in data.keys(): - self._update_end_activate(node_id) - wait = False - state = graph.get_state(config=self.checkpoint_config) - node_outputs = state.values.get("runtime_vars", {}) - for _ in data.keys(): - node_outputs = node_outputs | data.get(_).get("runtime_vars", {}) + state = graph.get_state(config=self.checkpoint_config).values + node_outputs = state.get("runtime_vars", {}) + variables = state.get("variables", {}) + activate = state.get("activate", {}) + for _, node_data in data.items(): + node_outputs |= node_data.get("runtime_vars", {}) + variables |= node_data.get("variables", {}) + self._update_stream_output_status(activate, data) + wait = False while self.activate_end and not wait: - message = '' - logger.info(self.activate_end) - end_info = self.end_outputs[self.activate_end] - content = end_info.outputs[end_info.cursor] - while content.activate: - if not content.is_variable: - full_content += content.literal - message += content.literal - else: - try: - chunk = evaluate_expression( - content.literal, - variables={}, - node_outputs=node_outputs - ) - chunk = self._trans_output_string(chunk) - message += chunk - full_content += chunk - except ValueError: - pass - end_info.cursor += 1 - if end_info.cursor == len(end_info.outputs): - break - content = end_info.outputs[end_info.cursor] - if end_info.cursor != len(end_info.outputs): + async for msg_event in self._emit_active_chunks( + node_outputs=node_outputs, + variables=variables + ): + full_content += msg_event["data"]['chunk'] + yield msg_event + + if self.activate_end: wait = True else: - self.end_outputs.pop(self.activate_end) - self.activate_end = None - for node_id in data.keys(): - self._update_end_activate(node_id) - if message: - yield { - "event": "message", - "data": { - "chunk": message - } - } + self._update_stream_output_status(activate, data) logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} " f"- execution_id: {self.execution_id}") result = graph.get_state(self.checkpoint_config).values - while self.activate_end: - message = '' - end_info = self.end_outputs[self.activate_end] - content = end_info.outputs[end_info.cursor] - if not content.is_variable: - message += content.literal - else: - node_outputs = result.get("runtime_vars", {}) - variables = result.get("variables", {}) - try: - chunk = evaluate_expression( - content.literal, + node_outputs = result.get("runtime_vars", {}) + variables = result.get("variables", {}) + self.end_outputs = { + node_id: node_info + for node_id, node_info in self.end_outputs.items() + if node_info.activate + } + + if self.end_outputs or self.activate_end: + while self.activate_end: + async for msg_event in self._emit_active_chunks( + node_outputs=node_outputs, variables=variables, - node_outputs=node_outputs - ) - chunk = self._trans_output_string(chunk) - message += chunk - full_content += chunk - except ValueError: - pass - end_info.cursor += 1 - if end_info.cursor == len(end_info.outputs): - self.end_outputs.pop(self.activate_end) - self.activate_end = None - if self.end_outputs: + force=True + ): + full_content += msg_event["data"]['chunk'] + yield msg_event + + if not self.activate_end and self.end_outputs: self.activate_end = list(self.end_outputs.keys())[0] - if message: - yield { - "event": "message", - "data": { - "chunk": message - } - } # 计算耗时 end_time = datetime.datetime.now() diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index 9fa89fd2..b1d43e08 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -53,114 +53,110 @@ class OutputContent(BaseModel): ) ) - def depends_on_node(self, node_id: str) -> bool: + def depends_on_scope(self, scope: str) -> bool: """ - Check if this output segment depends on a specific node's variable. - - This method examines the `literal` of the output segment to see if it - contains a variable placeholder referencing the given node in the form: - - {{ node_id.field_name }} - - It uses a regular expression to match the exact node ID, avoiding - false positives from substring matches (e.g., 'node1' should not match 'node10'). + Check if this segment depends on a given scope. Args: - node_id (str): The ID of the node to check for in this segment's variable placeholders. + scope (str): Node ID or special variable prefix (e.g., "sys"). Returns: - bool: - - True if the segment contains a variable referencing the given node. - - False otherwise. - - Example: - literal = "{{node1.name}}" - - depends_on_node("node1") -> True - depends_on_node("node2") -> False - - Usage: - This method is primarily used in stream mode to determine whether - a particular variable output segment should be activated when a - specific upstream node completes execution. + bool: True if this segment references the given scope. """ - variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}" - pattern = re.compile(variable_pattern) - match = pattern.search(self.literal) - if match: - return True - return False + pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}" + return bool(re.search(pattern, self.literal)) class StreamOutputConfig(BaseModel): """ Streaming output configuration for an End node. - This structure controls: - - whether the End node output is globally active - - which upstream branch nodes are responsible for activation - - how each output segment behaves in streaming mode + This configuration describes how the End node output behaves in streaming mode, + including: + - whether output emission is globally activated + - which upstream branch/control nodes gate the activation + - how each parsed output segment is streamed and activated """ activate: bool = Field( ..., description=( - "Global activation state of the End node output.\n" - "If False, no output should be emitted until all control nodes are resolved." + "Global activation flag for the End node output.\n" + "When False, output segments should not be emitted even if available.\n" + "This flag typically becomes True once required control branch conditions " + "are satisfied." ) ) - control_nodes: list[str] = Field( + control_nodes: dict[str, str] = Field( ..., description=( - "List of upstream branch node IDs that control this End node.\n" - "Each node must signal completion before output becomes active." + "Control branch conditions for this End node output.\n" + "Mapping of `branch_node_id -> expected_branch_label`.\n" + "The End node output becomes globally active when a controlling branch node " + "reports a matching completion status." ) ) outputs: list[OutputContent] = Field( ..., - description="Ordered list of output segments parsed from the output template." + description=( + "Ordered list of output segments parsed from the output template.\n" + "Each segment represents either a literal text block or a variable placeholder " + "that may be activated independently." + ) ) cursor: int = Field( ..., description=( "Streaming cursor index.\n" - "Indicates how many output segments have already been emitted." + "Indicates the next output segment index to be emitted.\n" + "Segments with index < cursor are considered already streamed." ) ) - def update_activate(self, node_id): + def update_activate(self, scope: str, status=None): """ - Update activation state based on an upstream node completion. + Update streaming activation state based on an upstream node or special variable. - This method is typically called when a branch/control node finishes execution. + Args: + scope (str): + Identifier of the completed upstream entity. + - If a control branch node, it should match a key in `control_nodes`. + - If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments. + status (optional): + Completion status of the control branch node. + Required when `scope` refers to a control node. Behavior: - 1. If the node is a control node: - - Remove it from `control_nodes` - - If all control nodes are resolved, activate the entire output + 1. Control branch nodes: + - If `scope` matches a key in `control_nodes` and `status` matches the expected + branch label, the End node output becomes globally active (`activate = True`). - 2. Activate variable output segments that depend on this node: - - If an output segment is a variable - - And its literal references the completed node_id - - Mark that segment as active + 2. Variable output segments: + - For each segment that is a variable (`is_variable=True`): + - If the segment literal references `scope`, mark the segment as active. + - This applies both to regular node variables (e.g., "node_id.field") + and special system variables (e.g., "sys.xxx"). + + Notes: + - This method does not emit output or advance the streaming cursor. + - It only updates activation flags based on upstream events or special variables. """ # Case 1: resolve control branch dependency - if node_id in self.control_nodes: - self.control_nodes.remove(node_id) - - # All branch constraints resolved → enable output - if not self.control_nodes: + if scope in self.control_nodes.keys(): + if status is None: + raise RuntimeError("[Stream Output] Control node activation status not provided") + if status == self.control_nodes[scope]: self.activate = True # Case 2: activate variable segments related to this node for i in range(len(self.outputs)): if ( self.outputs[i].is_variable - and self.outputs[i].depends_on_node(node_id) + and self.outputs[i].depends_on_scope(scope) ): self.outputs[i].activate = True @@ -184,11 +180,11 @@ class GraphBuilder: self._find_upstream_branch_node = lru_cache( maxsize=len(self.nodes) * 2 )(self._find_upstream_branch_node) - self._analyze_end_node_output() self.graph = StateGraph(WorkflowState) self.add_nodes() self.add_edges() + self._analyze_end_node_output() # EDGES MUST BE ADDED AFTER NODES ARE ADDED. @property @@ -216,30 +212,53 @@ class GraphBuilder: except KeyError: raise RuntimeError(f"Node not found: Id={node_id}") - def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]: - """Find upstream branch nodes for a given target node in the workflow graph. + def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]: + """ + Recursively find all upstream branch (control) nodes that influence the execution + of the given target node. - This method identifies all upstream control (branch) nodes that can affect - the execution of `target_node`. If `target_node` is reachable from a start - node (i.e., a node with no upstream nodes), the method returns an empty tuple. + This method walks upstream along the workflow graph starting from `target_node`. + It distinguishes between: + - branch nodes (node types listed in `BRANCH_NODES`) + - non-branch nodes (ordinary processing nodes) - The function distinguishes between branch nodes (defined in `BRANCH_NODES`) - and non-branch nodes, recursively traversing upstream through non-branch - nodes. If any non-branch upstream path does not lead to a branch node, - the result will indicate that no valid upstream branch node exists. + Traversal rules: + 1. For each immediate upstream node: + - If it is a branch node, it is recorded as an affecting control node. + - If it is a non-branch node, the traversal continues recursively upstream. + 2. If ANY upstream path reaches a START / CYCLE_START node without encountering + a branch node, the traversal is considered invalid: + - `has_branch` will be False + - no branch nodes are returned. + 3. Only when ALL upstream non-branch paths eventually lead to at least one + branch node will `has_branch` be True. + + Special case: + - If `target_node` has no upstream nodes AND its type is START or CYCLE_START, + it is considered directly reachable from the workflow entry, and therefore + has no controlling branch nodes. Args: - target_node (str): The identifier of the target node. + target_node (str): + The identifier of the node whose upstream control branches + are to be resolved. Returns: - tuple[bool, tuple[str]]: - - has_branch (bool): True if all upstream non-branch paths lead to at least - one branch node; False if any path reaches a start node without a branch. - - branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs - affecting `target_node`. Returns an empty tuple if `has_branch` is False. + tuple[bool, tuple[tuple[str, str]]]: + - has_branch (bool): + True if every upstream path from `target_node` encounters + at least one branch node. + False if any path reaches a start node without a branch. + - branch_nodes (tuple[tuple[str, str]]): + A deduplicated tuple of `(branch_node_id, branch_label)` pairs + representing all branch nodes that can influence `target_node`. + Returns an empty tuple if `has_branch` is False. """ source_nodes = [ - edge.get("source") + { + "id": edge.get("source"), + "branch": edge.get("label") + } for edge in self.edges if edge.get("target") == target_node ] @@ -249,11 +268,13 @@ class GraphBuilder: branch_nodes = [] non_branch_nodes = [] - for node_id in source_nodes: - if self.get_node_type(node_id) in BRANCH_NODES: - branch_nodes.append(node_id) + for node_info in source_nodes: + if self.get_node_type(node_info["id"]) in BRANCH_NODES: + branch_nodes.append( + (node_info["id"], node_info["branch"]) + ) else: - non_branch_nodes.append(node_id) + non_branch_nodes.append(node_info["id"]) has_branch = True for node_id in non_branch_nodes: @@ -334,7 +355,7 @@ class GraphBuilder: activate=not has_branch, # Branch nodes that control activation of this End node - control_nodes=list(control_nodes), + control_nodes=dict(control_nodes), # Convert output segments into OutputContent objects outputs=list( @@ -362,7 +383,7 @@ class GraphBuilder: else: self.end_node_map[end_node_id] = StreamOutputConfig( activate=True, - control_nodes=[], + control_nodes={}, outputs=list( [ OutputContent( diff --git a/api/app/core/workflow/nodes/memory/config.py b/api/app/core/workflow/nodes/memory/config.py index 57ee6dc2..31881e24 100644 --- a/api/app/core/workflow/nodes/memory/config.py +++ b/api/app/core/workflow/nodes/memory/config.py @@ -25,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig): ... ) - config_id: UUID = Field( + config_id: UUID | int = Field( ... ) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index f71c70ee..13860bec 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -36,9 +36,10 @@ class MemoryReadNode(BaseNode): class MemoryWriteNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = MemoryWriteNodeConfig(**self.config) + self.typed_config: MemoryWriteNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: + self.typed_config = MemoryWriteNodeConfig(**self.config) end_user_id = self.get_variable("sys.user_id", state) if not end_user_id: diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index e069b40d..a429dd8e 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -6,7 +6,7 @@ from .document_model import Document from .file_model import File from .file_metadata_model import FileMetadata from .generic_file_model import GenericFile -from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey +from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey, ModelBase, LoadBalanceStrategy from .memory_short_model import ShortTermMemory, LongTermMemory from .knowledgeshare_model import KnowledgeShare from .app_model import App @@ -79,4 +79,6 @@ __all__ = [ "AuthType", "ExecutionStatus", "MemoryPerceptualModel", + "ModelBase", + "LoadBalanceStrategy" ] diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 2e60ef1c..a8918c7c 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -1,19 +1,31 @@ import datetime import uuid from enum import StrEnum -from typing import Optional, List -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum + +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.orm import relationship from app.db import Base +class BaseModel(Base): + """基础模型(抽象类,提取公共字段)""" + __abstract__ = True # 标记为抽象类,不生成表 + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") + + class ModelType(StrEnum): """模型类型枚举""" LLM = "llm" CHAT = "chat" EMBEDDING = "embedding" RERANK = "rerank" + # IMAGE = "image" + # AUDIO = "audio" + # VISION = "vision" class ModelProvider(StrEnum): @@ -30,16 +42,37 @@ class ModelProvider(StrEnum): XINFERENCE = "xinference" GPUSTACK = "gpustack" BEDROCK = "bedrock" + COMPOSITE = "composite" -class ModelConfig(Base): +class LoadBalanceStrategy(StrEnum): + """API Key负载均衡策略枚举""" + ROUND_ROBIN = "round_robin" # 轮询 + WEIGHTED_ROUND_ROBIN = "weighted_round_robin" # 加权轮询 + RANDOM = "random" # 随机 + + +# 多对多关联表 +model_config_api_key_association = Table( + 'model_config_api_key_association', + Base.metadata, + Column('model_config_id', UUID(as_uuid=True), ForeignKey('model_configs.id'), primary_key=True), + Column('api_key_id', UUID(as_uuid=True), ForeignKey('model_api_keys.id'), primary_key=True), + Column('created_at', DateTime, default=datetime.datetime.now) +) + + +class ModelConfig(BaseModel): """模型配置表""" __tablename__ = "model_configs" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + model_id = Column(UUID(as_uuid=True), ForeignKey("model_bases.id"), nullable=True, index=True, comment="基础模型ID") tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID") + logo = Column(String(255), nullable=True, comment="模型logo图片URL") name = Column(String, nullable=False, comment="模型显示名称") + provider = Column(String, nullable=False, comment="供应商", server_default=ModelProvider.COMPOSITE) type = Column(String, nullable=False, index=True, comment="模型类型") + is_composite = Column(Boolean, default=False, server_default="true", nullable=False, comment="是否为组合模型") description = Column(String, comment="模型描述") # 模型配置参数 @@ -56,29 +89,28 @@ class ModelConfig(Base): # context_length = Column(String, comment="上下文长度") # 状态管理 - is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") is_public = Column(Boolean, default=False, nullable=False, comment="是否公开") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + load_balance_strategy = Column(String, nullable=True, comment="负载均衡策略") # 关联关系 - api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan") + model_base = relationship("ModelBase", back_populates="configs") + api_keys = relationship( + "ModelApiKey", + secondary=model_config_api_key_association, + back_populates="model_configs" + ) def __repr__(self): return f"" -class ModelApiKey(Base): +class ModelApiKey(BaseModel): """模型API密钥表""" __tablename__ = "model_api_keys" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) - model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=False, comment="模型配置ID") # API Key 信息 model_name = Column(String, nullable=False, comment="模型实际名称") + description = Column(String, comment="备注") provider = Column(String, nullable=False, comment="API Key提供商") api_key = Column(String, nullable=False, comment="API密钥") api_base = Column(String, comment="API基础URL") @@ -91,15 +123,41 @@ class ModelApiKey(Base): last_used_at = Column(DateTime, comment="最后使用时间") # 状态管理 - is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") priority = Column(String, default="1", comment="优先级") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - + # 关联关系 - model_config = relationship("ModelConfig", back_populates="api_keys") + model_configs = relationship( + "ModelConfig", + secondary=model_config_api_key_association, + back_populates="api_keys" + ) + def __repr__(self): - return f"" + return f"" + + +class ModelBase(Base): + """基础模型信息表(模型广场)""" + __tablename__ = "model_bases" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + logo = Column(String(255), nullable=True, comment="模型logo图片URL") + name = Column(String, nullable=False, comment="模型唯一标识(如gpt-3.5-turbo)") + type = Column(String, nullable=False, index=True, comment="模型类型") + provider = Column(String, nullable=False, index=True) + description = Column(Text, comment="模型描述") + is_deprecated = Column(Boolean, default=False, nullable=False, comment="是否弃用") + is_official = Column(Boolean, default=True, comment="是否供应商官方模型(区分自定义)") + tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])") + add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数") + + # 关联关系 + configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan") + + __table_args__ = ( + UniqueConstraint("name", "provider", name="uk_model_name_provider"), + ) + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index 1fe29d66..8e4632cc 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -1,12 +1,12 @@ -from sqlalchemy.orm import Session, joinedload -from sqlalchemy import and_, or_, func, desc +from sqlalchemy.orm import Session, joinedload, selectinload +from sqlalchemy import and_, or_, func, desc, select from typing import List, Optional, Dict, Any, Tuple import uuid -from app.models.models_model import ModelConfig, ModelApiKey, ModelType +from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association from app.schemas.model_schema import ( ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, - ModelConfigQuery + ModelConfigQuery, ModelConfigQueryNew ) from app.core.logging_config import get_db_logger @@ -107,6 +107,80 @@ class ModelConfigRepository: def get_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> Tuple[List[ModelConfig], int]: """获取模型配置列表""" db_logger.debug(f"查询模型配置列表: {query.dict()}, tenant_id={tenant_id}") + + try: + # 构建查询条件 + filters = [] + + # 添加租户过滤(查询本租户的模型或公开模型) + if tenant_id: + filters.append( + or_( + ModelConfig.tenant_id == tenant_id, + ModelConfig.is_public + ) + ) + + # 支持多个 type 值(使用 IN 查询) + # 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者 + if query.type: + type_values = list(query.type) + # 如果包含 chat 或 llm,则同时包含两者 + if ModelType.CHAT in type_values or ModelType.LLM in type_values: + if ModelType.CHAT not in type_values: + type_values.append(ModelType.CHAT) + if ModelType.LLM not in type_values: + type_values.append(ModelType.LLM) + filters.append(ModelConfig.type.in_(type_values)) + + if query.is_active is not None: + filters.append(ModelConfig.is_active == query.is_active) + + if query.is_public is not None: + filters.append(ModelConfig.is_public == query.is_public) + + if query.search: + # 搜索逻辑需要join ModelApiKey表来搜索model_name + search_filter = or_( + ModelConfig.name.ilike(f"%{query.search}%"), + # ModelConfig.description.ilike(f"%{query.search}%") + ) + filters.append(search_filter) + + # 构建基础查询 + base_query = db.query(ModelConfig).options( + joinedload(ModelConfig.api_keys) + ) + + # 如果需要按provider筛选,需要join ModelApiKey表 + if query.provider: + base_query = base_query.join(ModelApiKey).filter( + ModelApiKey.provider == query.provider + ).distinct() + + if filters: + base_query = base_query.filter(and_(*filters)) + + # 获取总数 + total = base_query.count() + + # 分页查询 + models = base_query.order_by(desc(ModelConfig.updated_at)).offset( + (query.page - 1) * query.pagesize + ).limit(query.pagesize).all() + + db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}") + return models, total + + except Exception as e: + db_logger.error(f"查询模型配置列表失败: {str(e)}") + raise + + @staticmethod + def get_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> tuple[ + dict[str, list[ModelConfig]], Any]: + """获取模型配置列表""" + db_logger.debug(f"查询模型配置列表: {query.model_dump()}, tenant_id={tenant_id}") try: # 构建查询条件 @@ -138,13 +212,15 @@ class ModelConfigRepository: if query.is_public is not None: filters.append(ModelConfig.is_public == query.is_public) + + if query.is_composite is not None: + filters.append(ModelConfig.is_composite == query.is_composite) + + if query.provider: + filters.append(ModelConfig.provider == query.provider) if query.search: - # 搜索逻辑需要join ModelApiKey表来搜索model_name - search_filter = or_( - ModelConfig.name.ilike(f"%{query.search}%"), - # ModelConfig.description.ilike(f"%{query.search}%") - ) + search_filter = ModelConfig.name.ilike(f"%{query.search}%") filters.append(search_filter) # 构建基础查询 @@ -152,28 +228,30 @@ class ModelConfigRepository: joinedload(ModelConfig.api_keys) ) - # 如果需要按provider筛选,需要join ModelApiKey表 - if query.provider: - base_query = base_query.join(ModelApiKey).filter( - ModelApiKey.provider == query.provider - ).distinct() - if filters: base_query = base_query.filter(and_(*filters)) # 获取总数 total = base_query.count() + + query_results = base_query.order_by(desc(ModelConfig.updated_at)).all() + + provider_groups: Dict[str, List[ModelConfig]] = {} + for model_config in query_results: + provider = model_config.provider + if provider not in provider_groups: + provider_groups[provider] = [] + provider_groups[provider].append(model_config) - # 分页查询 - models = base_query.order_by(desc(ModelConfig.updated_at)).offset( - (query.page - 1) * query.pagesize - ).limit(query.pagesize).all() - - db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}") - return models, total + db_logger.debug( + f"模型配置列表查询成功: 总数={total}, " + f"分组数={len(provider_groups)}, " + f"各分组模型数={[len(v) for v in provider_groups.values()]}, " + f"type筛选={query.type}") + return provider_groups, total except Exception as e: - db_logger.error(f"查询模型配置列表失败: {str(e)}") + db_logger.error(f"查询模型配置列表失败(按provider分组/无分页): {str(e)}") raise @staticmethod @@ -241,7 +319,7 @@ class ModelConfigRepository: return None # 更新字段 - update_data = model_data.dict(exclude_unset=True) + update_data = model_data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(db_model, field, value) @@ -303,8 +381,18 @@ class ModelConfigRepository: # 按提供商统计 - 现在从ModelApiKey表获取 provider_stats = {} provider_results = db.query( - ModelApiKey.provider, func.count(func.distinct(ModelApiKey.model_config_id)) - ).group_by(ModelApiKey.provider).all() + # 保留 provider 字段 + ModelApiKey.provider, + # 统计中间表中 唯一的 model_config_id 数量(替换原 ModelApiKey.model_config_id) + func.count(func.distinct(model_config_api_key_association.c.model_config_id)) + ).join( + # 联表:ModelApiKey <-> 中间表(多对多关联) + model_config_api_key_association, + ModelApiKey.id == model_config_api_key_association.c.api_key_id + ).group_by( + # 按 provider 分组(保留原有逻辑) + ModelApiKey.provider + ).all() for provider, count in provider_results: provider_stats[provider.value] = count @@ -325,6 +413,37 @@ class ModelConfigRepository: db_logger.error(f"获取模型统计信息失败: {str(e)}") raise + @staticmethod + def get_model_config_ids_by_provider( + db: Session, + tenant_id: uuid.UUID, + provider: Any + ) -> List[uuid.UUID]: + """根据tenant_id和provider获取model_config_id列表""" + db_logger.debug(f"查询model_config_id列表: tenant_id={tenant_id}, provider={provider}") + + try: + # 查询ModelConfig关联的ModelApiKey,筛选出匹配的model_config_id + model_config_ids = db.query(ModelConfig.id).join( + ModelBase, ModelConfig.model_id == ModelBase.id + ).filter( + and_( + or_( + ModelConfig.tenant_id == tenant_id, + ModelConfig.is_public + ), + ModelBase.provider == provider, + ~ModelConfig.is_composite + ) + ).distinct().all() + + db_logger.debug(f"查询成功: 数量={len(model_config_ids)}") + return [row[0] for row in model_config_ids] + + except Exception as e: + db_logger.error(f"查询model_config_id列表失败: {str(e)}") + raise + class ModelApiKeyRepository: """模型API Key Repository""" @@ -349,7 +468,14 @@ class ModelApiKeyRepository: db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}") try: - query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id) + from app.models.models_model import ModelConfig, model_config_api_key_association + + query = db.query(ModelApiKey).join( + model_config_api_key_association, + ModelApiKey.id == model_config_api_key_association.c.api_key_id + ).filter( + model_config_api_key_association.c.model_config_id == model_config_id + ) if is_active: query = query.filter(ModelApiKey.is_active) @@ -368,8 +494,20 @@ class ModelApiKeyRepository: db_logger.debug(f"创建API Key: {api_key_data.provider}") try: - db_api_key = ModelApiKey(**api_key_data.dict()) + from app.models.models_model import ModelConfig + + # 创建API Key,不包含model_config_ids + api_key_dict = api_key_data.model_dump(exclude={"model_config_ids"}) + db_api_key = ModelApiKey(**api_key_dict) db.add(db_api_key) + db.flush() # 获取生成的ID + + # 关联ModelConfig + if api_key_data.model_config_ids: + for model_config_id in api_key_data.model_config_ids: + model_config = db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first() + if model_config: + db_api_key.model_configs.append(model_config) db_logger.info(f"API Key已添加到会话: {db_api_key.provider}") return db_api_key @@ -391,7 +529,7 @@ class ModelApiKeyRepository: return None # 更新字段 - update_data = api_key_data.dict(exclude_unset=True) + update_data = api_key_data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(db_api_key, field, value) @@ -451,4 +589,74 @@ class ModelApiKeyRepository: except Exception as e: db.rollback() db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}") - raise \ No newline at end of file + raise + + +class ModelBaseRepository: + """基础模型Repository""" + + @staticmethod + def get_by_id(db: Session, model_base_id: uuid.UUID) -> Optional['ModelBase']: + return db.query(ModelBase).filter(ModelBase.id == model_base_id).first() + + @staticmethod + def get_list(db: Session, query: 'ModelBaseQuery') -> List['ModelBase']: + + filters = [] + if query.type: + filters.append(ModelBase.type == query.type) + if query.provider: + filters.append(ModelBase.provider == query.provider) + if query.is_official is not None: + filters.append(ModelBase.is_official == query.is_official) + if query.is_deprecated is not None: + filters.append(ModelBase.is_deprecated == query.is_deprecated) + if query.search: + filters.append(or_( + ModelBase.name.ilike(f"%{query.search}%"), + # ModelBase.description.ilike(f"%{query.search}%") + )) + + q = db.query(ModelBase) + if filters: + q = q.filter(and_(*filters)) + + return q.order_by(ModelBase.add_count.desc()).all() + + @staticmethod + def create(db: Session, data: dict) -> 'ModelBase': + model_base = ModelBase(**data) + db.add(model_base) + return model_base + + @staticmethod + def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']: + model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first() + if not model_base: + return None + for key, value in data.items(): + setattr(model_base, key, value) + return model_base + + @staticmethod + def delete(db: Session, model_base_id: uuid.UUID) -> bool: + model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first() + if not model_base: + return False + db.delete(model_base) + return True + + @staticmethod + def increment_add_count(db: Session, model_base_id: uuid.UUID) -> bool: + model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first() + if not model_base: + return False + model_base.add_count += 1 + return True + + @staticmethod + def check_added_by_tenant(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: + return db.query(ModelConfig).filter( + ModelConfig.model_id == model_base_id, + ModelConfig.tenant_id == tenant_id + ).first() is not None diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 68f15115..ce1b36bb 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -4,6 +4,10 @@ import datetime import uuid from app.models.models_model import ModelProvider, ModelType +from app.core.logging_config import get_business_logger + +schema_logger = get_business_logger() + @@ -12,7 +16,9 @@ class ModelConfigBase(BaseModel): """模型配置基础Schema""" name: str = Field(..., description="模型显示名称", max_length=255) type: ModelType = Field(..., description="模型类型") + logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255) description: Optional[str] = Field(None, description="模型描述") + provider: str = Field(..., description="供应商") config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数") is_active: bool = Field(True, description="是否激活") is_public: bool = Field(False, description="是否公开") @@ -21,6 +27,7 @@ class ModelConfigBase(BaseModel): class ApiKeyCreateNested(BaseModel): """用于在创建模型时内嵌创建API Key的Schema""" model_name: str = Field(..., description="模型实际名称", max_length=255) + description: Optional[str] = Field(None, description="备注") provider: ModelProvider = Field(..., description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) @@ -30,10 +37,22 @@ class ApiKeyCreateNested(BaseModel): class ModelConfigCreate(ModelConfigBase): """创建模型配置Schema""" - api_keys: Optional[ApiKeyCreateNested] = Field(None, description="同时创建的API Key配置") + api_keys: Optional[List[ApiKeyCreateNested]] = Field(None, description="同时创建的API Key配置") skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证") +class CompositeModelCreate(BaseModel): + """创建组合模型Schema""" + name: str = Field(..., description="组合模型名称", max_length=255) + type: ModelType = Field(..., description="模型类型") + logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255) + description: Optional[str] = Field(None, description="模型描述") + config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数") + is_active: bool = Field(True, description="是否激活") + is_public: bool = Field(False, description="是否公开") + api_key_ids: List[uuid.UUID] = Field(..., description="绑定的API Key ID列表") + + class ModelConfigUpdate(BaseModel): """更新模型配置Schema""" name: Optional[str] = Field(None, description="模型显示名称", max_length=255) @@ -53,22 +72,48 @@ class ModelConfig(ModelConfigBase): updated_at: datetime.datetime api_keys: List["ModelApiKey"] = [] + @field_validator("api_keys", mode="after") + @classmethod + def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]: + return [key for key in api_keys if key.is_active] + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime | None): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + # ModelApiKey Schemas -class ModelApiKeyBase(BaseModel): - """API Key基础Schema""" - model_name: str = Field(..., description="模型实际名称", max_length=255) +class ModelApiKeyCreateByProvider(BaseModel): + """基于供应商创建API Key Schema""" provider: ModelProvider = Field(..., description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) - config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置") + description: Optional[str] = Field(None, description="备注") + config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") + is_active: bool = Field(True, description="是否激活") + priority: str = Field("1", description="优先级", max_length=10) + model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表") + + +class ModelApiKeyBase(BaseModel): + """API Key基础Schema""" + model_name: str = Field(..., description="模型实际名称", max_length=255) + description: Optional[str] = Field(None, description="备注") + provider: ModelProvider = Field(..., description="API Key提供商") + api_key: str = Field(..., description="API密钥", max_length=500) + api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) class ModelApiKeyCreate(ModelApiKeyBase): """创建API Key Schema""" - model_config_id: uuid.UUID = Field(..., description="模型配置ID") + model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表") class ModelApiKeyUpdate(BaseModel): @@ -85,23 +130,54 @@ class ModelApiKeyUpdate(BaseModel): class ModelApiKey(ModelApiKeyBase): """API Key Schema""" id: uuid.UUID - model_config_id: uuid.UUID usage_count: str last_used_at: Optional[datetime.datetime] created_at: datetime.datetime updated_at: datetime.datetime + model_configs: Any = Field(default=None, exclude=True) + model_config_ids: List[uuid.UUID] = Field(default_factory=list, description="关联的模型配置ID列表") - @field_validator("config", mode="before") - @classmethod - def parse_config(cls, v): - """处理 config 字段,如果是字符串则解析为字典""" - if isinstance(v, str): - import json + def model_post_init(self, __context: Any) -> None: + """实例化后强制提取 model_configs 的ID到 model_config_ids""" + # 如果手动传入了 model_config_ids,不覆盖 + if self.model_config_ids and len(self.model_config_ids) > 0: + return + + # 从 model_configs 提取ID(只提取与 model_name 相同的非组合模型) + if self.model_configs is not None: try: - return json.loads(v) - except json.JSONDecodeError: - return {} - return v + # 情况1:ORM 对象列表(SQLAlchemy 关联) + if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict): + self.model_config_ids = [ + mc.id for mc in self.model_configs + if hasattr(mc, 'id') + and not getattr(mc, 'is_composite', False) + and getattr(mc, 'name', None) == self.model_name + ] + # 情况2:字典列表 + elif isinstance(self.model_configs, list): + self.model_config_ids = [ + mc['id'] if isinstance(mc, dict) else mc.id + for mc in self.model_configs + if ((isinstance(mc, dict) + and 'id' in mc + and not mc.get('is_composite', False) + and mc.get('name') == self.model_name) or + (hasattr(mc, 'id') + and not getattr(mc, 'is_composite', False) + and getattr(mc, 'name', None) == self.model_name)) + ] + except Exception as e: + schema_logger.warning(f"提取 model_config_ids 失败:{e}") + self.model_config_ids = [] + + model_config = ConfigDict( + from_attributes=True, # 支持从 ORM 解析 + arbitrary_types_allowed=True, # 允许任意类型(ORM 对象) + populate_by_name=True, # 按属性名匹配字段 + validate_assignment=True # 确保赋值触发校验 + ) + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): @@ -110,15 +186,12 @@ class ModelApiKey(ModelApiKeyBase): @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - - model_config = ConfigDict(from_attributes=True) @field_serializer("last_used_at", when_used="json") def _serialize_last_used_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None -# 查询和响应Schemas class ModelConfigQuery(BaseModel): """模型配置查询Schema""" type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") @@ -129,6 +202,17 @@ class ModelConfigQuery(BaseModel): page: int = Field(1, description="页码", ge=1) pagesize: int = Field(10, description="每页数量", ge=1, le=100) + +# 查询和响应Schemas +class ModelConfigQueryNew(BaseModel): + """模型配置查询Schema""" + type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)") + provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)") + is_active: Optional[bool] = Field(None, description="激活状态筛选") + is_public: Optional[bool] = Field(None, description="公开状态筛选") + is_composite: Optional[bool] = Field(None, description="组合模型筛选") + search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + class ModelMarketplace(BaseModel): """模型广场响应Schema""" llm_models: List[ModelConfig] = [] @@ -171,4 +255,53 @@ class ModelValidateResponse(BaseModel): # 更新前向引用 -ModelConfig.model_rebuild() \ No newline at end of file +ModelConfig.model_rebuild() + + +# ModelBase Schemas +class ModelBaseCreate(BaseModel): + """创建基础模型Schema""" + name: str = Field(..., description="模型唯一标识", max_length=255) + type: ModelType = Field(..., description="模型类型") + provider: ModelProvider = Field(..., description="提供商") + logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255) + description: Optional[str] = Field(None, description="模型描述") + is_official: bool = Field(True, description="是否供应商官方模型") + tags: List[str] = Field(default_factory=list, description="模型标签") + + +class ModelBaseUpdate(BaseModel): + """更新基础模型Schema""" + name: Optional[str] = Field(None, description="模型唯一标识", max_length=255) + type: Optional[ModelType] = Field(None, description="模型类型") + provider: Optional[ModelProvider] = Field(None, description="提供商") + logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255) + description: Optional[str] = Field(None, description="模型描述") + is_deprecated: Optional[bool] = Field(None, description="是否弃用") + is_official: Optional[bool] = Field(None, description="是否供应商官方模型") + tags: Optional[List[str]] = Field(None, description="模型标签") + + +class ModelBase(BaseModel): + """基础模型Schema""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + name: str + type: str + provider: str + logo: Optional[str] + description: Optional[str] + is_deprecated: bool + is_official: bool + tags: List[str] + add_count: int + + +class ModelBaseQuery(BaseModel): + """基础模型查询Schema""" + type: Optional[ModelType] = Field(None, description="模型类型") + provider: Optional[ModelProvider] = Field(None, description="提供商") + is_official: Optional[bool] = Field(None, description="是否官方模型") + is_deprecated: Optional[bool] = Field(None, description="是否弃用") + search: Optional[str] = Field(None, description="搜索关键词", max_length=255) diff --git a/api/app/services/app_statistics_service.py b/api/app/services/app_statistics_service.py new file mode 100644 index 00000000..c164924a --- /dev/null +++ b/api/app/services/app_statistics_service.py @@ -0,0 +1,193 @@ +"""应用统计服务""" +from datetime import datetime, timedelta +from typing import Dict, Any, List +import uuid +from sqlalchemy import func, and_, cast, Date +from sqlalchemy.orm import Session + +from app.models.conversation_model import Conversation, Message +from app.models.end_user_model import EndUser +from app.models.api_key_model import ApiKey, ApiKeyLog +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode + + +class AppStatisticsService: + """应用统计服务""" + + def __init__(self, db: Session): + self.db = db + + def get_app_statistics( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + start_date: int, + end_date: int + ) -> Dict[str, Any]: + """获取应用统计数据 + + Args: + app_id: 应用ID + workspace_id: 工作空间ID + start_date: 开始时间戳(毫秒) + end_date: 结束时间戳(毫秒) + + Returns: + 统计数据字典 + """ + # 将毫秒时间戳转换为 datetime + start_dt = datetime.fromtimestamp(start_date / 1000) + end_dt = datetime.fromtimestamp(end_date / 1000) + timedelta(days=1) + + # 1. 会话统计 + conversations_stats = self._get_conversations_statistics(app_id, workspace_id, start_dt, end_dt) + + # 2. 新增用户统计 + users_stats = self._get_new_users_statistics(app_id, start_dt, end_dt) + + # 3. API调用统计 + api_stats = self._get_api_calls_statistics(app_id, start_dt, end_dt) + + # 4. Token消耗统计 + token_stats = self._get_token_statistics(app_id, start_dt, end_dt) + + return { + "daily_conversations": conversations_stats["daily"], + "total_conversations": conversations_stats["total"], + "daily_new_users": users_stats["daily"], + "total_new_users": users_stats["total"], + "daily_api_calls": api_stats["daily"], + "total_api_calls": api_stats["total"], + "daily_tokens": token_stats["daily"], + "total_tokens": token_stats["total"] + } + + def _get_conversations_statistics( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + start_dt: datetime, + end_dt: datetime + ) -> Dict[str, Any]: + """获取会话统计""" + # 每日会话数 + daily_query = self.db.query( + cast(Conversation.created_at, Date).label('date'), + func.count(Conversation.id).label('count') + ).filter( + and_( + Conversation.app_id == app_id, + Conversation.workspace_id == workspace_id, + Conversation.created_at >= start_dt, + Conversation.created_at < end_dt + ) + ).group_by(cast(Conversation.created_at, Date)).all() + + daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query] + total = sum(row["count"] for row in daily_data) + + return {"daily": daily_data, "total": total} + + def _get_new_users_statistics( + self, + app_id: uuid.UUID, + start_dt: datetime, + end_dt: datetime + ) -> Dict[str, Any]: + """获取新增用户统计""" + # 每日新增用户数 + daily_query = self.db.query( + cast(EndUser.created_at, Date).label('date'), + func.count(EndUser.id).label('count') + ).filter( + and_( + EndUser.app_id == app_id, + EndUser.created_at >= start_dt, + EndUser.created_at < end_dt + ) + ).group_by(cast(EndUser.created_at, Date)).all() + + daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query] + total = sum(row["count"] for row in daily_data) + + return {"daily": daily_data, "total": total} + + def _get_api_calls_statistics( + self, + app_id: uuid.UUID, + start_dt: datetime, + end_dt: datetime + ) -> Dict[str, Any]: + """获取API调用统计""" + # 每日API调用次数 + daily_query = self.db.query( + cast(ApiKeyLog.created_at, Date).label('date'), + func.count(ApiKeyLog.id).label('count') + ).join( + ApiKey, ApiKeyLog.api_key_id == ApiKey.id + ).filter( + and_( + ApiKey.resource_id == app_id, + ApiKeyLog.created_at >= start_dt, + ApiKeyLog.created_at < end_dt + ) + ).group_by(cast(ApiKeyLog.created_at, Date)).all() + + daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query] + total = sum(row["count"] for row in daily_data) + + return {"daily": daily_data, "total": total} + + def _get_token_statistics( + self, + app_id: uuid.UUID, + start_dt: datetime, + end_dt: datetime + ) -> Dict[str, Any]: + """获取Token消耗统计(从Message的meta_data中提取)""" + from sqlalchemy import text + + # 查询所有相关消息的token使用情况 + # meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100} + daily_query = self.db.query( + cast(Message.created_at, Date).label('date'), + Message.meta_data + ).join( + Conversation, Message.conversation_id == Conversation.id + ).filter( + and_( + Conversation.app_id == app_id, + Message.created_at >= start_dt, + Message.created_at < end_dt, + Message.meta_data.isnot(None) + ) + ).all() + + # 按日期聚合token + daily_tokens = {} + for row in daily_query: + date_str = str(row.date) + meta = row.meta_data or {} + + # 提取token数量(支持多种格式) + tokens = 0 + if isinstance(meta, dict): + # 格式1: {"usage": {"total_tokens": 100}} + if "usage" in meta and isinstance(meta["usage"], dict): + tokens = meta["usage"].get("total_tokens", 0) + # 格式2: {"tokens": 100} + elif "tokens" in meta: + tokens = meta.get("tokens", 0) + # 格式3: {"total_tokens": 100} + elif "total_tokens" in meta: + tokens = meta.get("total_tokens", 0) + + if date_str not in daily_tokens: + daily_tokens[date_str] = 0 + daily_tokens[date_str] += int(tokens) + + daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] + total = sum(row["tokens"] for row in daily_data) + + return {"daily": daily_data, "total": total} diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 0d1f51a4..524c9ff6 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -16,6 +16,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.models import AgentConfig, ModelApiKey, ModelConfig +from app.repositories.model_repository import ModelApiKeyRepository from app.repositories.tool_repository import ToolRepository from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service @@ -724,17 +725,21 @@ class DraftRunService: Raises: BusinessException: 当没有可用的 API Key 时 """ - stmt = ( - select(ModelApiKey) - .where( - ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active.is_(True) - ) - .order_by(ModelApiKey.priority.desc()) - .limit(1) - ) - - api_key = self.db.scalars(stmt).first() + api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) + # stmt = ( + # select(ModelApiKey).join( + # ModelConfig, ModelApiKey.model_configs + # ) + # .where( + # ModelConfig.id == model_config_id, + # ModelApiKey.is_active.is_(True) + # ) + # .order_by(ModelApiKey.priority.desc()) + # .limit(1) + # ) + # + # api_key = self.db.scalars(stmt).first() + api_key = api_keys[0] if api_keys else None if not api_key: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index 9ef9dbb1..9e102ac3 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -5,6 +5,7 @@ import uuid from typing import Dict, Any, List, Optional, Tuple from sqlalchemy.orm import Session +from app.repositories.model_repository import ModelApiKeyRepository from app.services.conversation_state_manager import ConversationStateManager from app.models import ModelConfig, AgentConfig from app.core.logging_config import get_business_logger @@ -382,11 +383,14 @@ class LLMRouter: from app.core.models.base import RedBearModelConfig from app.models import ModelApiKey, ModelType - # 获取 API Key 配置 - api_key_config = self.db.query(ModelApiKey).filter( - ModelApiKey.model_config_id == self.routing_model_config.id, - ModelApiKey.is_active - ).first() + # 获取 API Key 配置(通过关联关系) + # api_key_config = self.db.query(ModelApiKey).join( + # ModelConfig, ModelApiKey.model_configs + # ).filter(ModelConfig.id == self.routing_model_config.id, + # ModelApiKey.is_active == True + # ).first() + api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id) + api_key_config = api_keys[0] if api_keys else None if not api_key_config: raise Exception("路由模型没有可用的 API Key") @@ -419,6 +423,9 @@ class LLMRouter: # 调用模型 response = await llm.ainvoke(prompt) + + from app.services.model_service import ModelApiKeyService + ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) # 提取响应内容 if hasattr(response, 'content'): diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index 26b86b71..e09cf67f 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -338,7 +338,7 @@ class MemoryConfigService: "provider": api_config.provider, "api_key": api_config.api_key, "base_url": api_config.api_base, - "model_config_id": api_config.model_config_id, + "model_config_id": str(config.id), "type": config.type, "timeout": settings.LLM_TIMEOUT, "max_retries": settings.LLM_MAX_RETRIES, @@ -370,7 +370,7 @@ class MemoryConfigService: "provider": api_config.provider, "api_key": api_config.api_key, "base_url": api_config.api_base, - "model_config_id": api_config.model_config_id, + "model_config_id": str(config.id), "type": config.type, "timeout": 120.0, "max_retries": 5, diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index e94a889b..5b2ab7e6 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -1,3 +1,4 @@ +from datetime import datetime from sqlalchemy.orm import Session from typing import List, Optional, Dict, Any import uuid @@ -6,11 +7,11 @@ import time import asyncio from app.models.models_model import ModelConfig, ModelApiKey, ModelType -from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository +from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository from app.schemas import model_schema from app.schemas.model_schema import ( ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, - ModelConfigQuery, ModelStats + ModelConfigQuery, ModelStats, ModelConfigQueryNew ) from app.core.logging_config import get_business_logger from app.schemas.response_schema import PageData, PageMeta @@ -47,6 +48,26 @@ class ModelConfigService: items=[model_schema.ModelConfig.model_validate(model) for model in models] ) + @staticmethod + def get_model_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> List[dict]: + """获取模型配置列表""" + provider_groups, total = ModelConfigRepository.get_list_new(db, query, tenant_id=tenant_id) + + items = [] + for provider, models in provider_groups.items(): + # 验证每个模型并封装分组信息 + validated_models = [model_schema.ModelConfig.model_validate(model) for model in models] + tags = list({model.type for model in validated_models}) + group_item = { + "provider": provider, # 服务商名称 + "logo": validated_models[0].logo, + "tags": tags, + "models": validated_models # 该服务商下的所有模型 + } + items.append(group_item) + + return items + @staticmethod def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" @@ -228,37 +249,39 @@ class ModelConfigService: # 验证配置 if not model_data.skip_validation and model_data.api_keys: - api_key_data = model_data.api_keys - validation_result = await ModelConfigService.validate_model_config( - db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, - api_key=api_key_data.api_key, - api_base=api_key_data.api_base, - model_type=model_data.type, # 传递模型类型 - test_message="Hello" - ) - if not validation_result["valid"]: - raise BusinessException( - f"模型配置验证失败: {validation_result['error']}", - BizCode.INVALID_PARAMETER + api_key_data_list = model_data.api_keys + for api_key_data in api_key_data_list: + validation_result = await ModelConfigService.validate_model_config( + db=db, + model_name=api_key_data.model_name, + provider=api_key_data.provider, + api_key=api_key_data.api_key, + api_base=api_key_data.api_base, + model_type=model_data.type, # 传递模型类型 + test_message="Hello" ) + if not validation_result["valid"]: + raise BusinessException( + f"模型配置验证失败: {validation_result['error']}", + BizCode.INVALID_PARAMETER + ) # 事务处理 - api_key_data = model_data.api_keys - model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"}) + api_key_datas = model_data.api_keys + model_config_data = model_data.model_dump(exclude={"api_keys", "skip_validation"}) # 添加租户ID model_config_data["tenant_id"] = tenant_id model = ModelConfigRepository.create(db, model_config_data) db.flush() # 获取生成的 ID - if api_key_data: - api_key_create_schema = ModelApiKeyCreate( - model_config_id=model.id, - **api_key_data.dict() - ) - ModelApiKeyRepository.create(db, api_key_create_schema) + if api_key_datas: + for api_key_data in api_key_datas: + api_key_create_schema = ModelApiKeyCreate( + model_config_ids=[model.id], + **api_key_data.model_dump() + ) + ModelApiKeyRepository.create(db, api_key_create_schema) db.commit() db.refresh(model) @@ -280,6 +303,112 @@ class ModelConfigService: db.refresh(model) return model + @staticmethod + async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + """创建组合模型""" + if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): + raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) + + # 验证所有 API Key 存在且类型匹配 + for api_key_id in model_data.api_key_ids: + api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) + if not api_key: + raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) + + # 检查 API Key 关联的模型配置类型 + for model_config in api_key.model_configs: + # chat 和 llm 类型可以兼容 + compatible_types = {ModelType.LLM, ModelType.CHAT} + config_type = model_config.type + request_type = model_data.type + + if not (config_type == request_type or + (config_type in compatible_types and request_type in compatible_types)): + raise BusinessException( + f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", + BizCode.INVALID_PARAMETER + ) + # if model_config.is_composite: + # raise BusinessException( + # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", + # BizCode.INVALID_PARAMETER + # ) + + # 创建组合模型 + model_config_data = { + "tenant_id": tenant_id, + "name": model_data.name, + "type": model_data.type, + "logo": model_data.logo, + "description": model_data.description, + "provider": "composite", + "config": model_data.config, + "is_active": model_data.is_active, + "is_public": model_data.is_public, + "is_composite": True + } + + model = ModelConfigRepository.create(db, model_config_data) + db.flush() + + # 关联 API Keys + for api_key_id in model_data.api_key_ids: + api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) + if api_key: + model.api_keys.append(api_key) + + db.commit() + db.refresh(model) + return model + + @staticmethod + async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + """更新组合模型""" + existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) + if not existing_model: + raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) + + if not existing_model.is_composite: + raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) + + # 验证所有 API Key 存在且类型匹配 + for api_key_id in model_data.api_key_ids: + api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) + if not api_key: + raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) + + for model_config in api_key.model_configs: + compatible_types = {ModelType.LLM, ModelType.CHAT} + config_type = model_config.type + request_type = model_data.type + + if not (config_type == request_type or + (config_type in compatible_types and request_type in compatible_types)): + raise BusinessException( + f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", + BizCode.INVALID_PARAMETER + ) + + # 更新基本信息 + existing_model.name = model_data.name + existing_model.type = model_data.type + existing_model.logo = model_data.logo + existing_model.description = model_data.description + existing_model.config = model_data.config + existing_model.is_active = model_data.is_active + existing_model.is_public = model_data.is_public + + # 更新 API Keys 关联 + existing_model.api_keys.clear() + for api_key_id in model_data.api_key_ids: + api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) + if api_key: + existing_model.api_keys.append(api_key) + + db.commit() + db.refresh(existing_model) + return existing_model + @staticmethod def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool: """删除模型配置""" @@ -324,27 +453,132 @@ class ModelApiKeyService: return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active) @staticmethod - async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey: - """创建API Key""" - model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id) - if not model_config: - raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - - validation_result = await ModelConfigService.validate_model_config( + async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> List[ModelApiKey]: + """根据provider为多个ModelConfig创建API Key""" + created_keys = [] + + for model_config_id in data.model_config_ids: + model_config = ModelConfigRepository.get_by_id(db, model_config_id) + if not model_config: + continue + + # 从ModelBase获取model_name + model_name = model_config.model_base.name if model_config.model_base else model_config.name + + # 检查是否存在API Key(包括软删除) + existing_key = db.query(ModelApiKey).filter( + ModelApiKey.api_key == data.api_key, + ModelApiKey.provider == data.provider, + ModelApiKey.model_name == model_name + ).first() + + if existing_key: + # 如果已存在,重新激活并更新 + if existing_key.is_active: + continue + existing_key.is_active = True + existing_key.api_base = data.api_base + existing_key.description = data.description + existing_key.config = data.config + existing_key.priority = data.priority + existing_key.model_name = model_name + + # 检查是否已关联该模型配置 + if model_config not in existing_key.model_configs: + existing_key.model_configs.append(model_config) + + created_keys.append(existing_key) + continue + + # 验证配置 + validation_result = await ModelConfigService.validate_model_config( db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, - api_key=api_key_data.api_key, - api_base=api_key_data.api_base, - model_type=model_config.type, # 传递模型类型 + model_name=model_name, + provider=data.provider, + api_key=data.api_key, + api_base=data.api_base, + model_type=model_config.type, test_message="Hello" ) - print(validation_result) - if not validation_result["valid"]: + if not validation_result["valid"]: raise BusinessException( f"模型配置验证失败: {validation_result['error']}", BizCode.INVALID_PARAMETER ) + + # 创建API Key + api_key_data = ModelApiKeyCreate( + model_config_ids=[model_config_id], + model_name=model_name, + description=data.description, + provider=data.provider, + api_key=data.api_key, + api_base=data.api_base, + config=data.config, + is_active=data.is_active, + priority=data.priority + ) + api_key_obj = ModelApiKeyRepository.create(db, api_key_data) + created_keys.append(api_key_obj) + + if created_keys: + db.commit() + for key in created_keys: + db.refresh(key) + + return created_keys + + @staticmethod + async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey: + # 验证所有关联的模型配置是否存在 + if api_key_data.model_config_ids: + for model_config_id in api_key_data.model_config_ids: + model_config = ModelConfigRepository.get_by_id(db, model_config_id) + if not model_config: + raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) + + # 检查API Key是否已存在(包括软删除) + existing_key = db.query(ModelApiKey).filter( + ModelApiKey.api_key == api_key_data.api_key, + ModelApiKey.provider == api_key_data.provider, + ModelApiKey.model_name == api_key_data.model_name + ).first() + + if existing_key: + if existing_key.is_active: + # 如果已激活,跳过 + raise BusinessException("该API Key已存在", BizCode.DUPLICATE_NAME) + # 如果已存在,重新激活并更新 + existing_key.is_active = True + existing_key.api_base = api_key_data.api_base + existing_key.description = api_key_data.description + existing_key.config = api_key_data.config + existing_key.priority = api_key_data.priority + existing_key.model_name = api_key_data.model_name + + # 检查是否已关联该模型配置 + if model_config not in existing_key.model_configs: + existing_key.model_configs.append(model_config) + + db.commit() + db.refresh(existing_key) + return existing_key + + # 验证配置 + validation_result = await ModelConfigService.validate_model_config( + db=db, + model_name=api_key_data.model_name, + provider=api_key_data.provider, + api_key=api_key_data.api_key, + api_base=api_key_data.api_base, + model_type=model_config.type, + test_message="Hello" + ) + if not validation_result["valid"]: + raise BusinessException( + f"模型配置验证失败: {validation_result['error']}", + BizCode.INVALID_PARAMETER + ) api_key = ModelApiKeyRepository.create(db, api_key_data) db.commit() @@ -359,21 +593,19 @@ class ModelApiKeyService: raise BusinessException("API Key不存在", BizCode.NOT_FOUND) # 获取关联的模型配置以获取模型类型 - model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id) - if not model_config: - raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND) - - validation_result = await ModelConfigService.validate_model_config( + if existing_api_key.model_configs: + model_config = existing_api_key.model_configs[0] + + validation_result = await ModelConfigService.validate_model_config( db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, - api_key=api_key_data.api_key, - api_base=api_key_data.api_base, - model_type=model_config.type, # 传递模型类型 + model_name=api_key_data.model_name or existing_api_key.model_name, + provider=api_key_data.provider or existing_api_key.provider, + api_key=api_key_data.api_key or existing_api_key.api_key, + api_base=api_key_data.api_base or existing_api_key.api_base, + model_type=model_config.type, test_message="Hello" ) - print(validation_result) - if not validation_result["valid"]: + if not validation_result["valid"]: raise BusinessException( f"模型配置验证失败: {validation_result['error']}", BizCode.INVALID_PARAMETER @@ -417,3 +649,84 @@ class ModelApiKeyService: if api_kes and len(api_kes) > 0: return api_kes[0] raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) + + + +class ModelBaseService: + """基础模型服务""" + + @staticmethod + def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: + models = ModelBaseRepository.get_list(db, query) + + provider_groups = {} + for m in models: + model_dict = model_schema.ModelBase.model_validate(m).model_dump() + if tenant_id: + model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) + + provider = m.provider + if provider not in provider_groups: + provider_groups[provider] = { + "provider": provider, + "models": [] + } + provider_groups[provider]["models"].append(model_dict) + + return list(provider_groups.values()) + + @staticmethod + def get_model_base_by_id(db: Session, model_base_id: uuid.UUID): + model = ModelBaseRepository.get_by_id(db, model_base_id) + if not model: + raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) + return model + + @staticmethod + def create_model_base(db: Session, data: model_schema.ModelBaseCreate): + model_base = ModelBaseRepository.create(db, data.model_dump()) + db.commit() + db.refresh(model_base) + return model_base + + @staticmethod + def update_model_base(db: Session, model_base_id: uuid.UUID, data: model_schema.ModelBaseUpdate): + model_base = ModelBaseRepository.update(db, model_base_id, data.model_dump(exclude_unset=True)) + if not model_base: + raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) + db.commit() + db.refresh(model_base) + return model_base + + @staticmethod + def delete_model_base(db: Session, model_base_id: uuid.UUID) -> bool: + success = ModelBaseRepository.delete(db, model_base_id) + if not success: + raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) + db.commit() + return success + + @staticmethod + def add_model_from_plaza(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> ModelConfig: + model_base = ModelBaseRepository.get_by_id(db, model_base_id) + if not model_base: + raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) + + if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): + raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) + + model_config_data = { + "model_id": model_base_id, + "tenant_id": tenant_id, + "name": model_base.name, + "provider": model_base.provider, + "type": model_base.type, + "logo": model_base.logo, + "description": model_base.description, + "is_composite": False + } + model_config = ModelConfigRepository.create(db, model_config_data) + ModelBaseRepository.increment_add_count(db, model_base_id) + db.commit() + db.refresh(model_config) + return model_config diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 4bcd28cd..d9062eaf 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from app.models import MultiAgentConfig, AgentConfig, ModelConfig from app.models.multi_agent_model import AggregationStrategy, OrchestrationMode +from app.repositories.model_repository import ModelApiKeyRepository from app.services.agent_registry import AgentRegistry from app.services.master_agent_router import MasterAgentRouter from app.services.conversation_state_manager import ConversationStateManager @@ -2546,10 +2547,14 @@ class MultiAgentOrchestrator: return self._smart_merge_results(results, strategy) # 获取 API Key 配置 - api_key_config = self.db.query(ModelApiKey).filter( - ModelApiKey.model_config_id == default_model_config_id, - ModelApiKey.is_active.is_(True) - ).first() + # api_key_config = self.db.query(ModelApiKey).join( + # ModelConfig, ModelApiKey.model_configs + # ).filter( + # ModelConfig.id == default_model_config_id, + # ModelApiKey.is_active.is_(True) + # ).first() + api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id) + api_key_config = api_keys[0] if api_keys else None if not api_key_config: logger.warning("Master Agent 没有可用的 API Key,使用简单整合") @@ -2703,10 +2708,14 @@ class MultiAgentOrchestrator: return # 获取 API Key 配置 - api_key_config = self.db.query(ModelApiKey).filter( - ModelApiKey.model_config_id == default_model_config_id, - ModelApiKey.is_active.is_(True) - ).first() + # api_key_config = self.db.query(ModelApiKey).join( + # ModelConfig, ModelApiKey.model_configs + # ).filter( + # ModelConfig.id == default_model_config_id, + # ModelApiKey.is_active.is_(True) + # ).first() + api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id) + api_key_config = api_keys[0] if api_keys else None if not api_key_config: logger.warning("Master Agent 没有可用的 API Key,使用简单整合") diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 5eee5edc..1d012088 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -4,6 +4,8 @@ import time import asyncio from typing import Optional, Dict, Any, AsyncGenerator from sqlalchemy.orm import Session + +from app.repositories.model_repository import ModelApiKeyRepository from app.services.memory_konwledges_server import write_rag from app.models import ReleaseShare, AppRelease, Conversation from app.services.conversation_service import ConversationService @@ -164,16 +166,20 @@ class SharedChatService: raise ResourceNotFoundException("模型配置", str(model_config_id)) # 获取 API Key - stmt = ( - select(ModelApiKey) - .where( - ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active.is_(True) - ) - .order_by(ModelApiKey.priority.desc()) - .limit(1) - ) - api_key_obj = self.db.scalars(stmt).first() + # stmt = ( + # select(ModelApiKey).join( + # ModelConfig, ModelApiKey.model_configs + # ) + # .where( + # ModelConfig.id == model_config_id, + # ModelApiKey.is_active.is_(True) + # ) + # .order_by(ModelApiKey.priority.desc()) + # .limit(1) + # ) + # api_key_obj = self.db.scalars(stmt).first() + api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) + api_key_obj = api_keys[0] if api_keys else None if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) @@ -358,16 +364,20 @@ class SharedChatService: raise ResourceNotFoundException("模型配置", str(model_config_id)) # 获取 API Key - stmt = ( - select(ModelApiKey) - .where( - ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active.is_(True) - ) - .order_by(ModelApiKey.priority.desc()) - .limit(1) - ) - api_key_obj = self.db.scalars(stmt).first() + # stmt = ( + # select(ModelApiKey).join( + # ModelConfig, ModelApiKey.model_configs + # ) + # .where( + # ModelConfig.id == model_config_id, + # ModelApiKey.is_active.is_(True) + # ) + # .order_by(ModelApiKey.priority.desc()) + # .limit(1) + # ) + # api_key_obj = self.db.scalars(stmt).first() + api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) + api_key_obj = api_keys[0] if api_keys else None if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) diff --git a/api/app/version_info.json b/api/app/version_info.json index bee52989..86a5e33e 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -5,11 +5,12 @@ "releaseDate": "2026-1-23", "upgradePosition": "\uD83D\uDC3B 本次更新主要优化使用体验和修复已知问题,让系统更稳定、更好用。", "coreUpgrades": [ - "1. 工作流更好用了\n* 界面更清晰,一眼看懂怎么配置\n* 新增节点输出变量展示,方便其他节点引用\n* 修复了几个影响体验的bug", - "2. 智能体配置更简单\n* 提示词和变量联动更顺畅\n* 配置界面重新整理,找功能更方便", - "3. 记忆系统更稳定\n* 优化了情绪记忆和隐性记忆的缓存更新\n* 修复了记忆配置页面的报错问题\n* 现在能自动识别用户和AI的身份了", - "4. 知识库体验提升\n* 修复了文档解析异常的问题\n* 上传文档时能看到处理进度了\n* 取消了操作也不会报错了", - "5. 系统整体更可靠\n* 修复了新用户访问跳转问题\n* 流式接口更稳定,长对话不断线\n* 调整了菜单顺序,操作更顺手\n", + "1. 工作流更好用了
* 界面更清晰,一眼看懂怎么配置
* 新增节点输出变量展示,方便其他节点引用
* 修复了几个影响体验的bug", + "2. 智能体配置更简单
* 提示词和变量联动更顺畅
* 配置界面重新整理,找功能更方便", + "3. 记忆系统更稳定
* 优化了情绪记忆和隐性记忆的缓存更新
* 修复了记忆配置页面的报错问题
* 现在能自动识别用户和AI的身份了", + "4. 知识库体验提升
* 修复了文档解析异常的问题
* 上传文档时能看到处理进度了
* 取消了操作也不会报错了", + "5. 系统整体更可靠
* 修复了新用户访问跳转问题
* 流式接口更稳定,长对话不断线
* 调整了菜单顺序,操作更顺手", + "
", "这次更新虽然不大,但让记忆熊的基础更扎实、体验更流畅。我们继续努力,让AI记忆更好用!", "记忆熊,记得更牢,用得更好。\uD83D\uDC3B✨" ] @@ -19,12 +20,13 @@ "releaseDate": "2026-1-23", "upgradePosition": "\uD83D\uDC3B This update focuses on improving usability and fixing known issues, making the system more stable and easier to use overall.", "coreUpgrades": [ - "1. Improved Workflow Experience\nCleaner, more intuitive UI for easier configuration at a glance\nAdded visibility of node output variables, making them easier to reference in downstream nodes\nFixed several usability-related bugs that affected the workflow experience", - "2. Simpler Agent Configuration\nSmoother linkage between prompts and variables\nReorganized configuration layout for easier navigation and better clarity", - "3. More Stable Memory System\nOptimized cache refresh for emotional memory and implicit memory\nFixed error issues on the memory configuration page\nThe system can now automatically distinguish between user and AI roles", - "4. Enhanced Knowledge Base Experience\nFixed issues with document parsing failures\nUpload progress is now displayed during document processing\nCanceling an upload no longer triggers errors", - "5. Overall System Reliability Improvements\nFixed redirect issues affecting new users\nImproved stability of streaming APIs to prevent interruptions during long conversations\nAdjusted menu ordering for a smoother and more intuitive workflow\n", - "Although this is a relatively small update, it strengthens MemoryBear’s foundation and delivers a noticeably smoother experience.\nWe’ll keep refining the system to make AI memory more powerful and easier to use.", + "1. Improved Workflow Experience
* Cleaner, more intuitive UI for easier configuration at a glance
* Added visibility of node output variables, making them easier to reference in downstream nodes
* Fixed several usability-related bugs that affected the workflow experience", + "2. Simpler Agent Configuration
* Smoother linkage between prompts and variables
* Reorganized configuration layout for easier navigation and better clarity", + "3. More Stable Memory System
* Optimized cache refresh for emotional memory and implicit memory
* Fixed error issues on the memory configuration page
* The system can now automatically distinguish between user and AI roles", + "4. Enhanced Knowledge Base Experience
* Fixed issues with document parsing failures
* Upload progress is now displayed during document processing
* Canceling an upload no longer triggers errors", + "5. Overall System Reliability Improvements
* Fixed redirect issues affecting new users
* Improved stability of streaming APIs to prevent interruptions during long conversations
* Adjusted menu ordering for a smoother and more intuitive workflow", + "
", + "Although this is a relatively small update, it strengthens MemoryBear’s foundation and delivers a noticeably smoother experience. We’ll keep refining the system to make AI memory more powerful and easier to use.", "MemoryBear — remember better, work smarter. \uD83D\uDC3B✨" ] } @@ -35,10 +37,10 @@ "releaseDate": "2026-1-16", "upgradePosition": "本次为架构升级,核心目标是把\"被动存储\"升级为\"主动认知\",让系统具备情绪感知、情景理解与类人记忆机制,为后续多智能体协作与专业场景落地奠定底座。", "coreUpgrades": [ - "记忆详情:拟人记忆——情绪引擎、情景记忆、短期记忆、工作记忆、感知记忆、显性记忆、隐性记忆,并配套类脑遗忘机制,实现从感知→情绪→情景→长期沉淀的完整人类记忆闭环", - "可视化工作流:拖拽式节点编排(LLM、知识库、逻辑、工具),业务落地周期由天缩至小时。", - "多模态知识处理:PDF、PPT、MP3、MP4 一键解析,时间感知检索准确率 94.3%,问答对数据即插即用。", - "Agent集群内置\"记忆-知识-工具-审核\"四类角色模板,用户一键生成;主控Agent把复杂任务拆为子任务并行分发,再靠情景记忆统一消解冲突、校验一致性,输出完整报告。" + "1. 记忆详情:拟人记忆——情绪引擎、情景记忆、短期记忆、工作记忆、感知记忆、显性记忆、隐性记忆,并配套类脑遗忘机制,实现从感知→情绪→情景→长期沉淀的完整人类记忆闭环", + "2. 可视化工作流:拖拽式节点编排(LLM、知识库、逻辑、工具),业务落地周期由天缩至小时。", + "3. 多模态知识处理:PDF、PPT、MP3、MP4 一键解析,时间感知检索准确率 94.3%,问答对数据即插即用。", + "4. Agent集群内置\"记忆-知识-工具-审核\"四类角色模板,用户一键生成;主控Agent把复杂任务拆为子任务并行分发,再靠情景记忆统一消解冲突、校验一致性,输出完整报告。" ] }, "introduction_en": { @@ -46,10 +48,10 @@ "releaseDate": "2026-1-16", "upgradePosition": "This release marks a foundational upgrade to the system’s cognitive architecture. The core objective is to evolve the platform from passive information storage into active cognitive intelligence—enabling emotional awareness, situational understanding, and human-like memory mechanisms. This upgrade lays the groundwork for future multi-agent collaboration and domain-specific, production-grade AI applications.", "coreUpgrades": [ - "Human-Like Memory Architecture: A comprehensive, human-inspired memory system is introduced, encompassing emotional processing, situational memory, short-term and working memory, perceptual memory, as well as explicit and implicit memory. Combined with brain-inspired forgetting mechanisms, the system now supports a complete cognitive loop—from perception → emotion → context → long-term consolidation, closely mirroring human memory formation.", - "Visual Workflow Orchestration: A fully visual, drag-and-drop workflow enables modular composition of LLMs, knowledge bases, logic, and tools. This dramatically reduces the time required to move from experimentation to production—from days to hours.", - "Multimodal Knowledge Processing: The system now supports one-click parsing and ingestion of PDF, PPT, MP3, and MP4 content. With time-aware retrieval accuracy reaching 94.3%, structured Q&A data becomes instantly usable for downstream reasoning and generation.", - "Built-in Agent Clusters: Predefined role templates across four categories—Memory, Knowledge, Tools, and Review—can be generated with a single click. A Coordinator Agent decomposes complex tasks into parallel subtasks, while situational memory is used to resolve conflicts, validate consistency, and synthesize outputs into a coherent, end-to-end report." + "1. Human-Like Memory Architecture: A comprehensive, human-inspired memory system is introduced, encompassing emotional processing, situational memory, short-term and working memory, perceptual memory, as well as explicit and implicit memory. Combined with brain-inspired forgetting mechanisms, the system now supports a complete cognitive loop—from perception → emotion → context → long-term consolidation, closely mirroring human memory formation.", + "2. Visual Workflow Orchestration: A fully visual, drag-and-drop workflow enables modular composition of LLMs, knowledge bases, logic, and tools. This dramatically reduces the time required to move from experimentation to production—from days to hours.", + "3. Multimodal Knowledge Processing: The system now supports one-click parsing and ingestion of PDF, PPT, MP3, and MP4 content. With time-aware retrieval accuracy reaching 94.3%, structured Q&A data becomes instantly usable for downstream reasoning and generation.", + "4. Built-in Agent Clusters: Predefined role templates across four categories—Memory, Knowledge, Tools, and Review—can be generated with a single click. A Coordinator Agent decomposes complex tasks into parallel subtasks, while situational memory is used to resolve conflicts, validate consistency, and synthesize outputs into a coherent, end-to-end report." ] } }, @@ -59,16 +61,17 @@ "releaseDate": "2025-12-01", "upgradePosition": "这是一款专注于管理和利用AI记忆的工具,支持RAG和知识图谱两种主流存储方式,旨在为AI应用提供持久化、结构化的\"记忆\"能力。", "coreUpgrades": [ - "记忆空间:用户可以创建独立的空间来隔离不同记忆,并灵活选择存储方式。", - "记忆配置:简化了配置流程,内置自动提取关键信息的\"记忆萃取\"和管理生命周期的\"遗忘\"引擎。", - "知识检索:提供语义、分词和混合三种检索模式,并支持多种参数微调和结果重排序,以提升召回效果。", - "全局管理:支持统一设置默认检索参数,并可一键应用到所有知识库。", - "测试与调试:内置\"召回测试\"功能,方便用户实时验证检索效果并调整参数,支持通过分享码与他人协作。", - "记忆洞察:可查看详细的对话记录、用户画像和分析报告,帮助理解AI的\"记忆\"内容。", - "集成与管理:提供API Key用于系统集成,并包含基本的用户管理功能。", - "界面与体验:采用现代化的卡片式布局和渐变色设计,注重交互的流畅性和视觉美感。", - "起步与使用:文档中提供了清晰的基础使用流程,引导用户从创建空间、配置记忆到测试检索快速上手。", - "版本说明与限制: 记忆熊 v0.1.0 版本\"初心\"囊括智能记忆管理的核心思路和基础能力,为后续开发奠定了基础。", + "1. 记忆空间:用户可以创建独立的空间来隔离不同记忆,并灵活选择存储方式。", + "2. 记忆配置:简化了配置流程,内置自动提取关键信息的\"记忆萃取\"和管理生命周期的\"遗忘\"引擎。", + "3. 知识检索:提供语义、分词和混合三种检索模式,并支持多种参数微调和结果重排序,以提升召回效果。", + "4. 全局管理:支持统一设置默认检索参数,并可一键应用到所有知识库。", + "5. 测试与调试:内置\"召回测试\"功能,方便用户实时验证检索效果并调整参数,支持通过分享码与他人协作。", + "6. 记忆洞察:可查看详细的对话记录、用户画像和分析报告,帮助理解AI的\"记忆\"内容。", + "7. 集成与管理:提供API Key用于系统集成,并包含基本的用户管理功能。", + "8. 界面与体验:采用现代化的卡片式布局和渐变色设计,注重交互的流畅性和视觉美感。", + "9. 起步与使用:文档中提供了清晰的基础使用流程,引导用户从创建空间、配置记忆到测试检索快速上手。", + "10. 版本说明与限制: 记忆熊 v0.1.0 版本\"初心\"囊括智能记忆管理的核心思路和基础能力,为后续开发奠定了基础。", + "
", "文档资源:用户手册、API文档、FAQ", "问题反馈:GitHub Issues、邮件支持", "致谢:感谢所有参与测试和提供反馈的用户!" @@ -79,16 +82,17 @@ "releaseDate": "2025-12-01", "upgradePosition": "A tool focused on managing and utilizing AI memory, supporting both RAG and knowledge graph storage methods, aiming to provide persistent and structured 'memory' capabilities for AI applications.", "coreUpgrades": [ - "Memory Space: Users can create independent spaces to isolate different memories and flexibly choose storage methods.", - "Memory Configuration: Simplified configuration process with built-in 'memory extraction' for automatic key information extraction and 'forgetting' engine for lifecycle management.", - "Knowledge Retrieval: Provides semantic, tokenization, and hybrid retrieval modes with various parameter tuning and result reranking to improve recall.", - "Global Management: Supports unified default retrieval parameter settings with one-click application to all knowledge bases.", - "Testing & Debugging: Built-in 'recall testing' for real-time verification of retrieval effects and parameter adjustment, with sharing code support for collaboration.", - "Memory Insights: View detailed conversation records, user profiles, and analysis reports to understand AI 'memory' content.", - "Integration & Management: Provides API Key for system integration with basic user management features.", - "Interface & Experience: Modern card-based layout with gradient design, focusing on interaction fluidity and visual aesthetics.", - "Getting Started: Documentation provides clear basic usage flow, guiding users from creating spaces, configuring memory to testing retrieval.", - "Version Notes: MemoryBear v0.1.0 'Original Intent' encompasses core concepts and basic capabilities of intelligent memory management, laying foundation for future development.", + "1. Memory Space: Users can create independent spaces to isolate different memories and flexibly choose storage methods.", + "2. Memory Configuration: Simplified configuration process with built-in 'memory extraction' for automatic key information extraction and 'forgetting' engine for lifecycle management.", + "3. Knowledge Retrieval: Provides semantic, tokenization, and hybrid retrieval modes with various parameter tuning and result reranking to improve recall.", + "4. Global Management: Supports unified default retrieval parameter settings with one-click application to all knowledge bases.", + "5. Testing & Debugging: Built-in 'recall testing' for real-time verification of retrieval effects and parameter adjustment, with sharing code support for collaboration.", + "6. Memory Insights: View detailed conversation records, user profiles, and analysis reports to understand AI 'memory' content.", + "7. Integration & Management: Provides API Key for system integration with basic user management features.", + "8. Interface & Experience: Modern card-based layout with gradient design, focusing on interaction fluidity and visual aesthetics.", + "9. Getting Started: Documentation provides clear basic usage flow, guiding users from creating spaces, configuring memory to testing retrieval.", + "10. Version Notes: MemoryBear v0.1.0 'Original Intent' encompasses core concepts and basic capabilities of intelligent memory management, laying foundation for future development.", + "
", "Documentation: User Manual, API Documentation, FAQ", "Feedback: GitHub Issues, Email Support", "Acknowledgments: Thanks to all users who participated in testing and provided feedback!" diff --git a/api/migrations/versions/915bed077f8d_202601281340.py b/api/migrations/versions/915bed077f8d_202601281340.py new file mode 100644 index 00000000..022f0d25 --- /dev/null +++ b/api/migrations/versions/915bed077f8d_202601281340.py @@ -0,0 +1,224 @@ +"""202601281340 + +Revision ID: 915bed077f8d +Revises: 75f0ec80e50b +Create Date: 2026-01-28 13:38:49.471560 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '915bed077f8d' +down_revision: Union[str, None] = '75f0ec80e50b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +BACKUP_TABLE_NAME = 'model_api_keys_backup_20260123' + +def get_temp_models(): + """创建临时模型,用于迁移过程中查询数据""" + metadata = sa.MetaData() + + # 临时ModelApiKey表(仅包含需要的字段) + ModelApiKey = sa.Table( + 'model_api_keys', metadata, + sa.Column('id', sa.UUID(), primary_key=True), + sa.Column('model_config_id', sa.UUID(), nullable=True), + ) + + # 临时关联表(和升级脚本创建的表结构一致) + ModelConfigApiKeyAssociation = sa.Table( + 'model_config_api_key_association', metadata, + sa.Column('model_config_id', sa.UUID(), nullable=False), + sa.Column('api_key_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + ) + + ModelApiKeyBackup = sa.Table( + BACKUP_TABLE_NAME, metadata, + sa.Column('id', sa.UUID(), primary_key=True), + sa.Column('model_name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=True), + sa.Column('provider', sa.String(), nullable=False), + sa.Column('api_key', sa.String(), nullable=False), + sa.Column('api_base', sa.String(), nullable=True), + sa.Column('config', sa.JSON(), nullable=True), + sa.Column('usage_count', sa.String(), default="0"), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('priority', sa.String(), default="1"), + sa.Column('model_config_id', sa.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('is_active', sa.Boolean(), default=True), + ) + + return ModelApiKey, ModelConfigApiKeyAssociation, ModelApiKeyBackup + + +def backup_model_api_keys(): + """备份model_api_keys表的结构和数据""" + connection = op.get_bind() + + # 检查备份表是否已存在 + result = connection.execute(sa.text(f""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = '{BACKUP_TABLE_NAME}' + ); + """)).scalar() + + if result: + # 备份表已存在,先删除再重建(确保结构一致) + op.execute(f"DROP TABLE IF EXISTS {BACKUP_TABLE_NAME};") + + # 直接复制表结构和数据(PostgreSQL专用,一步完成) + op.execute(f""" + CREATE TABLE {BACKUP_TABLE_NAME} AS + SELECT * FROM model_api_keys; + """) + + # 统计行数 + backup_count = connection.execute(sa.text(f"SELECT COUNT(*) FROM {BACKUP_TABLE_NAME}")).scalar() + original_count = connection.execute(sa.text("SELECT COUNT(*) FROM model_api_keys")).scalar() + + print( + f"已备份model_api_keys表到 {BACKUP_TABLE_NAME} \n" + f" 原表数据行数:{original_count} | 备份表数据行数:{backup_count}" + ) + +# def restore_model_api_keys_from_backup(): +# """从备份表恢复model_api_keys数据(可选,用于回滚失败时手动恢复)""" +# # 1. 清空原表(谨慎使用!) +# # op.execute("TRUNCATE TABLE model_api_keys;") +# +# # 2. 从备份表恢复数据 +# op.execute(f""" +# INSERT INTO model_api_keys +# SELECT * FROM {BACKUP_TABLE_NAME} +# ON CONFLICT (id) DO UPDATE SET +# model_name = EXCLUDED.model_name, +# description = EXCLUDED.description, +# provider = EXCLUDED.provider, +# api_key = EXCLUDED.api_key, +# api_base = EXCLUDED.api_base, +# config = EXCLUDED.config, +# usage_count = EXCLUDED.usage_count, +# last_used_at = EXCLUDED.last_used_at, +# priority = EXCLUDED.priority, +# model_config_id = EXCLUDED.model_config_id, +# created_at = EXCLUDED.created_at, +# updated_at = EXCLUDED.updated_at, +# is_active = EXCLUDED.is_active; +# """) +# print(f"✅ 已从 {BACKUP_TABLE_NAME} 恢复model_api_keys表数据") + +def upgrade() -> None: + backup_model_api_keys() + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('model_bases', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL'), + sa.Column('name', sa.String(), nullable=False, comment='模型唯一标识(如gpt-3.5-turbo)'), + sa.Column('type', sa.String(), nullable=False, comment='模型类型'), + sa.Column('provider', sa.String(), nullable=False), + sa.Column('description', sa.Text(), nullable=True, comment='模型描述'), + sa.Column('is_deprecated', sa.Boolean(), nullable=False, comment='是否弃用'), + sa.Column('is_official', sa.Boolean(), nullable=True, comment='是否供应商官方模型(区分自定义)'), + sa.Column('tags', sa.ARRAY(sa.String()), nullable=False, comment="模型标签(如['聊天', '创作'])"), + sa.Column('add_count', sa.Integer(), nullable=False, comment='模型被用户添加的次数'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name', 'provider', name='uk_model_name_provider') + ) + op.create_index(op.f('ix_model_bases_id'), 'model_bases', ['id'], unique=False) + op.create_index(op.f('ix_model_bases_provider'), 'model_bases', ['provider'], unique=False) + op.create_index(op.f('ix_model_bases_type'), 'model_bases', ['type'], unique=False) + op.create_table('model_config_api_key_association', + sa.Column('model_config_id', sa.UUID(), nullable=False), + sa.Column('api_key_id', sa.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['api_key_id'], ['model_api_keys.id'], ), + sa.ForeignKeyConstraint(['model_config_id'], ['model_configs.id'], ), + sa.PrimaryKeyConstraint('model_config_id', 'api_key_id') + ) + op.add_column('model_api_keys', sa.Column('description', sa.String(), nullable=True, comment='备注')) + op.add_column('model_configs', sa.Column('model_id', sa.UUID(), nullable=True, comment='基础模型ID')) + op.add_column('model_configs', sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL')) + op.add_column('model_configs', sa.Column('provider', sa.String(), server_default='composite', nullable=False, comment='供应商')) + op.add_column('model_configs', sa.Column('is_composite', sa.Boolean(), server_default='true', nullable=False, comment='是否为组合模型')) + op.add_column('model_configs', sa.Column('load_balance_strategy', sa.String(), nullable=True, comment='负载均衡策略')) + op.create_index(op.f('ix_model_configs_model_id'), 'model_configs', ['model_id'], unique=False) + op.create_foreign_key("model_configs_model_id_fkey", 'model_configs', 'model_bases', ['model_id'], ['id']) + connection = op.get_bind() + ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models() + + # 查询所有有model_config_id的API Key + api_keys = connection.execute( + sa.select(ModelApiKey.c.id, ModelApiKey.c.model_config_id) + .where(ModelApiKey.c.model_config_id.isnot(None)) + ).fetchall() + + # 批量插入到多对多表 + if api_keys: + association_data = [ + { + 'model_config_id': row.model_config_id, + 'api_key_id': row.id + } + for row in api_keys + ] + connection.execute(ModelConfigApiKeyAssociation.insert(), association_data) + op.drop_constraint(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', type_='foreignkey') + op.drop_column('model_api_keys', 'model_config_id') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("model_configs_model_id_fkey", 'model_configs', type_='foreignkey') + op.drop_index(op.f('ix_model_configs_model_id'), table_name='model_configs') + op.drop_column('model_configs', 'load_balance_strategy') + op.drop_column('model_configs', 'is_composite') + op.drop_column('model_configs', 'provider') + op.drop_column('model_configs', 'logo') + op.drop_column('model_configs', 'model_id') + op.add_column('model_api_keys', sa.Column('model_config_id', sa.UUID(), autoincrement=False, nullable=True, comment='模型配置ID')) + connection = op.get_bind() + ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models() + + # 查询多对多表中的关联数据(取每个API Key的第一个关联的model_config_id) + association_data = connection.execute( + sa.select( + ModelConfigApiKeyAssociation.c.api_key_id, + ModelConfigApiKeyAssociation.c.model_config_id + ).distinct(ModelConfigApiKeyAssociation.c.api_key_id) + ).fetchall() + + # 批量更新model_api_keys表 + if association_data: + for api_key_id, model_config_id in association_data: + connection.execute( + sa.update(ModelApiKey) + .where(ModelApiKey.c.id == api_key_id) + .values(model_config_id=model_config_id) + ) + + op.execute( + "UPDATE model_api_keys SET model_config_id = '00000000-0000-0000-0000-000000000000' WHERE model_config_id IS NULL") + op.alter_column('model_api_keys', 'model_config_id', nullable=False) + op.create_foreign_key(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', 'model_configs', ['model_config_id'], ['id']) + op.drop_column('model_api_keys', 'description') + op.drop_table('model_config_api_key_association') + # ### 可选:回滚时恢复备份(如需)### + # restore_model_api_keys_from_backup() + + print( + f"回滚完成!备份表 {BACKUP_TABLE_NAME} 仍保留,如需手动恢复可执行 restore_model_api_keys_from_backup() 函数") + op.drop_index(op.f('ix_model_bases_type'), table_name='model_bases') + op.drop_index(op.f('ix_model_bases_provider'), table_name='model_bases') + op.drop_index(op.f('ix_model_bases_id'), table_name='model_bases') + op.drop_table('model_bases') + # ### end Alembic commands ### diff --git a/web/src/api/application.ts b/web/src/api/application.ts index 69d27d44..1f20282e 100644 --- a/web/src/api/application.ts +++ b/web/src/api/application.ts @@ -108,4 +108,8 @@ export const getShareToken = (share_token: string, user_id: string) => { // 复制应用 export const copyApplication = (app_id: string, new_name: string) => { return request.post(`/apps/${app_id}/copy?new_name=${new_name}`) -} \ No newline at end of file +} +// 数据统计 +export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => { + return request.get(`/apps/${app_id}/statistics`, data) +} diff --git a/web/src/api/fileStorage.ts b/web/src/api/fileStorage.ts new file mode 100644 index 00000000..e7b476a3 --- /dev/null +++ b/web/src/api/fileStorage.ts @@ -0,0 +1,25 @@ +import { request, API_PREFIX } from '@/utils/request' + +// Upload file,file storage has expiration period +export const fileUploadUrl = `${API_PREFIX}/storage/files` +export const fileUpload = (formData?: unknown) => { + return request.uploadFile('/storage/files', formData) +} + +// Get file access URL (no token required) +export const getFileUrl = (file_id: string) => `/storage/files/${file_id}/url` +export const getFileLink = (fileId: string, data: { permanent?: boolean } = { permanent: true }) => { + return request.get(getFileUrl(fileId), data) +} + +// Get file internally +export const getInternalFileUrl = (file_id: string) => `/storage/files/${file_id}` +export const getInternalFile = (fileId: string) => { + return request.get(getInternalFileUrl(fileId)) +} + +// Delete file +export const deleteFileUrl = (file_id: string) => `/storage/files/${file_id}` +export const deleteFile = (fileId: string) => { + return request.delete(deleteFileUrl(fileId)) +} diff --git a/web/src/api/knowledgeBase.ts b/web/src/api/knowledgeBase.ts index 5f171a72..38a0d40d 100644 --- a/web/src/api/knowledgeBase.ts +++ b/web/src/api/knowledgeBase.ts @@ -65,7 +65,7 @@ export const getModelTypeList = async () => { }; // 获取模型列表 export const getModelList = async (pageInfo: PageRequest) => { - const response = await request.get(`${apiPrefix}/models`, pageInfo); + const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, is_active: true }); return response as any; }; //获取模型提供者 diff --git a/web/src/api/models.ts b/web/src/api/models.ts index 20fdf91a..e5d0f339 100644 --- a/web/src/api/models.ts +++ b/web/src/api/models.ts @@ -1,23 +1,68 @@ import { request } from '@/utils/request' -import type { ModelFormData } from '@/views/ModelManagement/types' +import type { MultiKeyForm, Query, KeyConfigModalForm, CompositeModelForm, CustomModelForm } from '@/views/ModelManagement/types' -// 模型列表 +// Model list export const getModelListUrl = '/models' -export const getModelList = (data: { type: string; pagesize: number; page: number; }) => { +export const getModelList = (data: Query) => { return request.get(getModelListUrl, data) } -// 创建模型 -export const addModel = (data: ModelFormData) => { - return request.post('/models', data) -} -// 更新模型 -export const updateModel = (apiKeyId: string, data: ModelFormData) => { - return request.put(`/models/apikeys/${apiKeyId}`, data) -} -// 模型类型列表 +// Model type list export const modelTypeUrl = '/models/type' -// 模型供应商列表 +// Model provider list export const modelProviderUrl = '/models/provider' export const getModelProviderList = () => { return request.get(modelProviderUrl) +} +// New model list +export const getModelNewListUrl = '/models/new' +export const getModelNewList = (data: Query) => { + return request.get(getModelNewListUrl, data) +} +// Get model information +export const getModelInfo = (model_id: string) => { + return request.get(`/models/${model_id}`) +} +// Create composite model +export const addCompositeModel = (data: CompositeModelForm) => { + return request.post('/models/composite', data) +} +// Update composite model +export const updateCompositeModel = (model_id: string, data: CompositeModelForm) => { + return request.put(`/models/composite/${model_id}`, data) +} +// Delete composite model +export const deleteCompositeModel = (model_id: string) => { + return request.delete(`/models/composite/${model_id}`) +} +// Create API keys for all matching models by provider +export const updateProviderApiKeys = (data: KeyConfigModalForm) => { + return request.post('/models/provider/apikeys', data) +} +// Create model API key +export const addModelApiKey = (model_id: string, data: MultiKeyForm) => { + return request.post(`/models/${model_id}/apikeys`, data) +} +// Delete model API key +export const deleteModelApiKey = (api_key_id: string) => { + return request.delete(`/models/apikeys/${api_key_id}`) +} +// Update model status +export const updateModelStatus = (model_id: string, data: { is_active: boolean; }) => { + return request.put(`/models/${model_id}`, data) +} +// Model plaza list +export const getModelPlaza = (data: { search?: string; provider?: string; }) => { + return request.get('/models/model_plaza', data) +} +// Add model to plaza +export const addModelPlaza = (model_base_id: string) => { + return request.post(`/models/model_plaza/${model_base_id}/add`) +} +// Create custom model +export const addCustomModel = (data: CustomModelForm) => { + return request.post('/models/model_plaza', data) +} +// Update custom model +export const updateCustomModel = (model_base_id: string, data: CustomModelForm) => { + return request.put(`/models/model_plaza/${model_base_id}`, data) } \ No newline at end of file diff --git a/web/src/assets/images/empty/pageEmpty.png b/web/src/assets/images/empty/pageEmpty.png new file mode 100644 index 00000000..f78cc42d Binary files /dev/null and b/web/src/assets/images/empty/pageEmpty.png differ diff --git a/web/src/components/CustomSelect/index.tsx b/web/src/components/CustomSelect/index.tsx index 6153a76d..1887d635 100644 --- a/web/src/components/CustomSelect/index.tsx +++ b/web/src/components/CustomSelect/index.tsx @@ -15,7 +15,7 @@ interface ApiResponse { interface CustomSelectProps extends Omit { url: string; params?: Record; - valueKey?: string | string[]; + valueKey?: string; labelKey?: string; placeholder?: string; hasAll?: boolean; @@ -66,18 +66,11 @@ const CustomSelect: FC = ({ {...props} > {hasAll && {allTitle || t('common.all')}} - {displayOptions.map((option) => { - const getValue = () => { - if (typeof valueKey === 'string') return option[valueKey]; - return valueKey.find(key => option[key] != null) ? option[valueKey.find(key => option[key] != null)!] : undefined; - }; - const value = getValue(); - return ( - - {String(option[labelKey])} - - ); - })} + {displayOptions.map((option) => ( + + {String(option[labelKey])} + + ))} ); }; diff --git a/web/src/components/Empty/PageEmpty.tsx b/web/src/components/Empty/PageEmpty.tsx new file mode 100644 index 00000000..17926fde --- /dev/null +++ b/web/src/components/Empty/PageEmpty.tsx @@ -0,0 +1,16 @@ +import { useTranslation } from 'react-i18next' +import pageEmptyIcon from '@/assets/images/empty/pageEmpty.png' +import Empty from './index' +const PageEmpty = ({ size = [240, 210] }: { size?: number | number[] }) => { + const { t } = useTranslation() + return ( + + ) +} +export default PageEmpty; \ No newline at end of file diff --git a/web/src/components/PageTabs/index.module.css b/web/src/components/PageTabs/index.module.css new file mode 100644 index 00000000..6eab8a48 --- /dev/null +++ b/web/src/components/PageTabs/index.module.css @@ -0,0 +1,13 @@ +.page-tabs:global(.ant-segmented) { + background-color: rgba(91, 97, 103, 0.08); + padding: 4px; +} +.page-tabs:global(.ant-segmented .ant-segmented-item-label) { + line-height: 24px; + min-height: 24px; + padding: 0 12px; +} + +.page-tabs:global(.ant-segmented .ant-segmented-item-selected) { + box-shadow: 0px 2px 4px 0px rgba(33, 35, 50, 0.16); +} \ No newline at end of file diff --git a/web/src/components/PageTabs/index.tsx b/web/src/components/PageTabs/index.tsx new file mode 100644 index 00000000..33f02097 --- /dev/null +++ b/web/src/components/PageTabs/index.tsx @@ -0,0 +1,18 @@ +import { type FC } from 'react'; +import { Segmented, type SegmentedProps } from 'antd'; +import styles from './index.module.css'; + +const PageTabs: FC = ({ + value, + options, + onChange +}) => { + return ; +}; + +export default PageTabs; diff --git a/web/src/components/RbCard/Card.tsx b/web/src/components/RbCard/Card.tsx index f86b1c60..eadd2916 100644 --- a/web/src/components/RbCard/Card.tsx +++ b/web/src/components/RbCard/Card.tsx @@ -1,5 +1,5 @@ import { type FC, type ReactNode } from 'react' -import { Card } from 'antd'; +import { Card, Tooltip } from 'antd'; import clsx from 'clsx'; interface RbCardProps { @@ -9,7 +9,7 @@ interface RbCardProps { extra?: ReactNode; children?: ReactNode; avatar?: ReactNode; - avatarUrl?: string; + avatarUrl?: string | null; bodyPadding?: string; bodyClassName?: string; headerType?: 'border' | 'borderless' | 'borderBL' | 'borderL'; @@ -63,7 +63,7 @@ const RbCard: FC = ({ } ) }> -
{title}
+
{title}
{subTitle &&
{subTitle}
} : null diff --git a/web/src/components/Upload/UploadImages.tsx b/web/src/components/Upload/UploadImages.tsx index 2006ea09..0875707a 100644 --- a/web/src/components/Upload/UploadImages.tsx +++ b/web/src/components/Upload/UploadImages.tsx @@ -1,23 +1,23 @@ import { useState, useEffect, forwardRef, useImperativeHandle } from 'react'; -import { Upload, Modal, Image, App } from 'antd'; +import { Upload, Image, App } from 'antd'; import type { GetProp, UploadFile, UploadProps } from 'antd'; // import { UploadOutlined, } from '@ant-design/icons'; import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface'; import { useTranslation } from 'react-i18next'; import PlusIcon from '@/assets/images/plus.svg' import { cookieUtils } from '@/utils/request' +import { fileUploadUrl } from '@/api/fileStorage' +import styles from './index.module.less' -const { confirm } = Modal; - -interface UploadImagesProps extends Omit { +interface UploadImagesProps extends Omit { /** 上传接口地址 */ action?: string; /** 是否支持多选 */ multiple?: boolean; /** 已上传的文件列表 */ - fileList?: UploadFile[]; + fileList?: UploadFile[] | UploadFile; /** 文件列表变化回调 */ - onChange?: (fileList: UploadFile[]) => void; + onChange?: (fileList?: UploadFile[] | UploadFile) => void; /** 禁用上传 */ disabled?: boolean; /** 文件大小限制(MB) */ @@ -28,6 +28,7 @@ interface UploadImagesProps extends Omit { isAutoUpload?: boolean; /** 最大上传文件数 */ maxCount?: number; + className?: string; } const ALL_FILE_TYPE: { [key: string]: string; @@ -59,7 +60,7 @@ const getBase64 = (file: FileType): Promise => { * 支持单文件/多文件上传、拖拽上传、文件验证、预览等功能 */ const UploadImages = forwardRef(({ - action = '/api/upload', + action = fileUploadUrl, multiple = false, fileList: propFileList = [], onChange, @@ -68,27 +69,42 @@ const UploadImages = forwardRef(({ fileType = ['png', 'jpg', 'gif'], isAutoUpload = true, maxCount = 1, + className = 'rb:size-24! rb:leading-1!', ...props }, ref) => { const { t } = useTranslation(); - const { message } = App.useApp() - const [fileList, setFileList] = useState(propFileList); + const { message, modal } = App.useApp() + const [fileList, setFileList] = useState([]); const [accept, setAccept] = useState(); // const [loading, setLoading] = useState(false); const [previewOpen, setPreviewOpen] = useState(false); const [previewImage, setPreviewImage] = useState(''); + useEffect(() => { + if (!Array.isArray(propFileList) && typeof propFileList === 'object') { + setFileList([propFileList]); + } + }, [propFileList]) + + const updateValue = (list: UploadFile[]) => { + if (maxCount === 1) { + onChange?.(list[0]) + } else { + onChange?.(list) + } + } + // 处理文件移除 const handleRemove = (file: UploadFile) => { - confirm({ - title: '确定要删除此文件吗?', - okText: '确定', + modal.confirm({ + title: t('common.confirmRemoveFile'), + okText: `${t('common.confirm')}`, okType: 'danger', - cancelText: '取消', + cancelText: `${t('common.cancel')}`, onOk: () => { const newFileList = fileList.filter((item) => item.uid !== file.uid); setFileList(newFileList); - onChange?.(newFileList); + updateValue(newFileList) }, }); return false; // 阻止默认删除行为,由confirm控制 @@ -100,7 +116,7 @@ const UploadImages = forwardRef(({ if (fileSize && file.size) { const isLtMaxSize = (file.size / 1024 / 1024) < fileSize; if (!isLtMaxSize) { - message.error(`文件大小不能超过 ${fileSize}MB`); + message.error(t('common.fileSizeTip', { size: fileSize })); return Upload.LIST_IGNORE; } } @@ -108,7 +124,7 @@ const UploadImages = forwardRef(({ if (accept && accept.length > 0 && file.type) { const isAccept = accept.includes(file.type); if (!isAccept) { - message.error(`不支持的文件类型: ${file.type}`); + message.error(`${t('common.fileAcceptTip')}${file.type}`); return Upload.LIST_IGNORE; } } @@ -119,7 +135,7 @@ const UploadImages = forwardRef(({ } const newFileList = [...fileList, file]; setFileList(newFileList); - onChange?.(newFileList); + updateValue(newFileList); return Upload.LIST_IGNORE; // 阻止自动上传 } @@ -129,17 +145,13 @@ const UploadImages = forwardRef(({ // 处理上传状态变化 const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => { setFileList(newFileList); - if (onChange) { - onChange(newFileList); - } + updateValue(newFileList); }; // 清空已上传文件 const clearFiles = () => { setFileList([]); - if (onChange) { - onChange([]); - } + updateValue([]); } const handlePreview = async (file: UploadFile) => { @@ -167,7 +179,7 @@ const UploadImages = forwardRef(({ fileList, beforeUpload, headers: { - authorization: cookieUtils.get('authToken') || '', + authorization: `Bearer ${cookieUtils.get('authToken') }`, }, onPreview: handlePreview, onRemove: handleRemove, @@ -180,6 +192,7 @@ const UploadImages = forwardRef(({ showRemoveIcon: true, showDownloadIcon: false, }, + className: `${styles.imageUpload} ${className}`, ...props, }; @@ -193,16 +206,9 @@ const UploadImages = forwardRef(({ <> {fileList.length < maxCount && ( -
- -
{t('common.clickUploadIcon')}
-
+ )}
{previewImage && ( diff --git a/web/src/components/Upload/index.module.less b/web/src/components/Upload/index.module.less new file mode 100644 index 00000000..a263d743 --- /dev/null +++ b/web/src/components/Upload/index.module.less @@ -0,0 +1,7 @@ +.image-upload:global(.ant-upload-wrapper.ant-upload-picture-card-wrapper .ant-upload-list.ant-upload-list-picture-card .ant-upload-list-item-container), +.image-upload:global(.ant-upload-wrapper.ant-upload-picture-circle-wrapper .ant-upload-list.ant-upload-list-picture-card .ant-upload-list-item-container), +.image-upload:global(.ant-upload-wrapper.ant-upload-picture-card-wrapper .ant-upload-list.ant-upload-list-picture-circle .ant-upload-list-item-container), +.image-upload:global(.ant-upload-wrapper.ant-upload-picture-circle-wrapper .ant-upload-list.ant-upload-list-picture-circle .ant-upload-list-item-container) { + width: 96px; + height: 96px; +} \ No newline at end of file diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 60c06acf..ea45ea6d 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -419,6 +419,9 @@ export const en = { statusEnabled: 'Available', statusDisabled: 'Unavailable', remove: 'Remove', + + fileSizeTip: 'File size cannot exceed {{size}}MB', + fileAcceptTip: 'Unsupported file type:' }, model: { searchPlaceholder: 'search model…', @@ -510,6 +513,59 @@ export const en = { gpustack: "Gpustack", bedrock: "Bedrock" }, + modelNew: { + group: 'Model Group', + list: 'Model List', + square: 'Model Plaza', + createGroupModel: 'Create Model Group', + groupSearchPlaceholder: 'Search model groups', + listSearchPlaceholder: 'Search available models', + squareSearchPlaceholder: 'Search platform models', + status: 'Model Status', + created_at: 'Created At', + configureBtn: 'Click to Configure', + showModel: 'Show Model', + keyConfig: 'Configure KEY', + + modelConfiguration: 'Model Configuration', + logo: 'Model LOGO', + name: 'Model Name', + type: 'Model Type', + modelImplement: 'Model Implementation', + addImplement: 'Add Implementation', + noAuth: 'Unauthorized (Limited to 1 implementation)', + implementConfig: 'Configure Model Implementation', + provider: 'Model Provider', + api_key_ids: 'Select Model', + viewAll: 'More', + modelCount: 'Total {{count}} models', + modelList: 'Model List', + added: ' Added', + addSuccess: 'Added successfully', + model_name: 'Model Name', + tags: 'Tags', + createCustomModel: 'Add Custom Model', + edit: 'Edit', + selectOneTip: 'Model API KEY not configured, please configure in Model Plaza first', + + api_key: 'API KEY', + api_base: 'API Base URL', + description: 'Description', + add: 'Add', + item: 'item', + apiKeyNum: ' API Keys', + + llm: 'LLM', + chat: 'Chat', + embedding: 'Embedding', + rerank: 'Rerank', + openai: "Openai", + dashscope: "Dashscope", + ollama: "Ollama", + xinference: "Xinference", + gpustack: "Gpustack", + bedrock: "Bedrock" + }, knowledgeBase: { pleaseUploadFileFirst: 'Please upload file first', shareSuccess: 'Share successfully', @@ -1175,6 +1231,12 @@ export const en = { priority: 'Structured Integration', addTool: 'Add Tool', tool: 'Tool', + + statistics: 'Data Statistics', + daily_conversations: 'Daily Conversations', + daily_new_users: 'Daily New Users', + daily_api_calls: 'Daily API Calls', + daily_tokens: 'Token Consumption', }, userMemory: { userMemory: 'User Memory', @@ -1534,7 +1596,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re noPermissionDesc: ' Please contact the administrator to grant permission', tableEmpty: 'No data available.', loadingEmpty: 'The content is loading…', - loadingEmptyDesc: 'Your content is on its way by rocket! It will soon land on your screen' + loadingEmptyDesc: 'Your content is on its way by rocket! It will soon land on your screen', + pageEmpty: 'Oops! No search results available at the moment', + pageEmptyDesc: "Red Bear tilts its head and waits for you to change a new keyword, let's explore together.", }, apiKey: { name: 'Project Name', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 76a95da4..0e5c9288 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -658,7 +658,13 @@ export const zh = { priority: '结构化整合', addTool: '添加工具', tool: '工具', - variableConfig: '配置变量' + variableConfig: '配置变量', + + statistics: '数据统计', + daily_conversations: '消息会话数', + daily_new_users: '新增用户数', + daily_api_calls: '调用次数', + daily_tokens: 'Token消耗', }, role: { roleManagement: '角色管理', @@ -967,6 +973,9 @@ export const zh = { statusEnabled: '可用', statusDisabled: '不可用', remove: '删除', + + fileSizeTip: '文件大小不能超过 {{size}}MB', + fileAcceptTip: '不支持的文件类型:' }, product: { applicationManagement: '应用管理', @@ -1076,6 +1085,59 @@ export const zh = { gpustack: "Gpustack", bedrock: "Bedrock" }, + modelNew: { + group: '模型组合', + list: '模型列表', + square: '模型广场', + createGroupModel: '创建模型组合', + groupSearchPlaceholder: '搜索模型组合', + listSearchPlaceholder: '搜索可用模型', + squareSearchPlaceholder: '搜索平台模型', + status: '模型状态', + created_at: '创建时间', + configureBtn: '点击配置', + showModel: '显示模型', + keyConfig: '配置 KEY', + + modelConfiguration: '模型配置', + logo: '模型LOGO', + name: '模型名称', + type: '模型类型', + modelImplement: '模型实现', + addImplement: '添加实现', + noAuth: '未授权(限1个实现)', + implementConfig: '配置模型实现', + provider: '模型供应商', + api_key_ids: '选择模型', + viewAll: '更多', + modelCount: '共 {{count}} 个模型', + modelList: '模型列表', + added: ' 已添加', + addSuccess: '添加成功', + model_name: '模型名称', + tags: '标签', + createCustomModel: '添加自定义模型', + edit: '编辑', + selectOneTip: '模型未配置API KEY,请先在模型广场配置', + + api_key: 'API KEY', + api_base: 'API Base URL', + description: '描述', + add: '添加', + item: '个', + apiKeyNum: '个 API Key', + + llm: 'LLM', + chat: 'Chat', + embedding: 'Embedding', + rerank: 'Rerank', + openai: "Openai", + dashscope: "Dashscope", + ollama: "Ollama", + xinference: "Xinference", + gpustack: "Gpustack", + bedrock: "Bedrock" + }, timezones: { 'Asia/Shanghai': '中国标准时间 (UTC+8)', 'Asia/Kolkata': '印度标准时间 (UTC+5:30)', @@ -1607,7 +1669,9 @@ export const zh = { noPermissionDesc: '请联系管理员授予权限', tableEmpty: '目前没有数据', loadingEmpty: '内容正在加载中…', - loadingEmptyDesc: '您的内容正在火箭运输中!很快就会降落在您的屏幕上' + loadingEmptyDesc: '您的内容正在火箭运输中!很快就会降落在您的屏幕上', + pageEmpty: '哎呀!暂无搜索结果', + pageEmptyDesc: '红熊歪着头等待您更换新的关键词,让我们一起探索吧。', }, home: { diff --git a/web/src/styles/antdThemeConfig.ts b/web/src/styles/antdThemeConfig.ts index db1166fb..1d281730 100644 --- a/web/src/styles/antdThemeConfig.ts +++ b/web/src/styles/antdThemeConfig.ts @@ -22,7 +22,7 @@ export const lightTheme: ThemeConfig = { // colorBgContainer: '#FBFDFF', colorError: '#FF5D34', sizeSM: 12, - fontSizeSM: 12, + fontSizeSM: 12, }, components: { Layout: { @@ -105,6 +105,9 @@ export const lightTheme: ThemeConfig = { }, Select: { lineHeightSM: 26 + }, + Upload: { + pictureCardSize: 96, } } }; \ No newline at end of file diff --git a/web/src/utils/request.ts b/web/src/utils/request.ts index 479fc1f3..e7112ded 100644 --- a/web/src/utils/request.ts +++ b/web/src/utils/request.ts @@ -23,9 +23,10 @@ interface data { } +export const API_PREFIX = '/api' // 创建axios实例 const service = axios.create({ - baseURL: '/api', // 与vite.config.ts中的代理配置对应 + baseURL: API_PREFIX, // 与vite.config.ts中的代理配置对应 // timeout: 10000, // 请求超时时间 withCredentials: false, headers: { @@ -126,7 +127,7 @@ service.interceptors.response.use( if (axios.isCancel(error) || error.name === 'AbortError' || error.code === 'ERR_CANCELED') { return Promise.reject(error); } - + // 处理网络错误、超时等 let msg = error.response?.data?.error || error.response?.error; const status = error?.response ? error.response.status : error; diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 97a622d1..8898897a 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -20,7 +20,7 @@ import type { } from './types' import type { Variable } from './components/VariableList/types' import type { KnowledgeConfig } from './components/Knowledge/types' -import type { Model } from '@/views/ModelManagement/types' +import type { ModelListItem } from '@/views/ModelManagement/types' import { getModelList } from '@/api/models'; import { saveAgentConfig } from '@/api/application' import Knowledge from './components/Knowledge/Knowledge' @@ -79,7 +79,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], placeholder={t('common.pleaseSelect')} url={url} hasAll={false} - valueKey={['config_id_old', 'config_id']} + valueKey='config_id' labelKey="config_name" /> @@ -96,8 +96,8 @@ const Agent = forwardRef((_props, ref) => { const [loading, setLoading] = useState(false) const [data, setData] = useState(null); const modelConfigModalRef = useRef(null) - const [modelList, setModelList] = useState([]) - const [defaultModel, setDefaultModel] = useState(null) + const [modelList, setModelList] = useState([]) + const [defaultModel, setDefaultModel] = useState(null) const [chatList, setChatList] = useState([]) const values = Form.useWatch([], form) const [isSave, setIsSave] = useState(false) @@ -126,14 +126,12 @@ const Agent = forwardRef((_props, ref) => { getApplicationConfig(id as string).then(res => { const response = res as Config let allTools = Array.isArray(response.tools) ? response.tools : [] - const memoryContent = response.memory?.memory_content - const convertedMemoryContent = memoryContent && !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent form.setFieldsValue({ ...response, tools: allTools, memory: { ...response.memory, - memory_content: convertedMemoryContent + memory_content: response.memory?.memory_content ? Number(response.memory?.memory_content) : undefined } }) setData({ @@ -214,7 +212,7 @@ const Agent = forwardRef((_props, ref) => { ...data.knowledge_retrieval, ...knowledgeRest, knowledge_bases: knowledge_bases.map(item => ({ - kb_id: item.id, + kb_id: item.kb_id || item.id, ...(item.config || {}) })) } as KnowledgeConfig : null, @@ -239,9 +237,9 @@ const Agent = forwardRef((_props, ref) => { }) } const getModels = () => { - getModelList({ type: 'llm,chat', pagesize: 100, page: 1 }) + getModelList({ type: 'llm,chat', pagesize: 100, page: 1, is_active: true }) .then(res => { - const response = res as { items: Model[] } + const response = res as { items: ModelListItem[] } setModelList(response.items) }) } @@ -251,7 +249,7 @@ const Agent = forwardRef((_props, ref) => { useEffect(() => { if (values?.default_model_config_id && modelList.length > 0) { const filterValue = modelList.find(item => item.id === values.default_model_config_id) - setDefaultModel(filterValue as Model | null) + setDefaultModel(filterValue as ModelListItem | null) setChatList([{ label: filterValue?.name || '', model_config_id: filterValue?.id || '', diff --git a/web/src/views/ApplicationConfig/Cluster.tsx b/web/src/views/ApplicationConfig/Cluster.tsx index 3081aa04..aa4a5d98 100644 --- a/web/src/views/ApplicationConfig/Cluster.tsx +++ b/web/src/views/ApplicationConfig/Cluster.tsx @@ -225,7 +225,7 @@ const Cluster = forwardRef((_props, ref) => { = { + daily_conversations: 'total_conversations', + daily_new_users: 'total_new_users', + daily_api_calls: 'total_api_calls', + daily_tokens: 'total_tokens', +} +const Statistics: FC<{ application: Application | null }> = ({ application }) => { + const [data, setData] = useState({ + daily_conversations: [], + total_conversations: 0, + daily_new_users: [], + total_new_users: 0, + daily_api_calls: [], + total_api_calls: 0, + daily_tokens: [], + total_tokens: 0 + }) + const [query, setQuery] = useState({ + start_date: dayjs().subtract(6, 'd'), + end_date: dayjs().subtract(0, 'd'), + }) + + useEffect(() => { + getData() + }, [application, query]) + const getData = () => { + if (!application?.id) { + return + } + const params = { + start_date: query.start_date.startOf('d').valueOf(), + end_date: query.end_date.endOf('d').valueOf(), + } + + getAppStatistics(application.id, params) + .then(res => { + setData(res as StatisticsData) + }) + } + const handleChange = (date: [Dayjs | null, Dayjs | null] | null) => { + if (!date || !date[0] || !date[1]) return + setQuery({ + start_date: date[0], + end_date: date[1], + }) + } + return ( +
+ + + + + + + {Object.entries(data).map(([key, value]) => { + if (key.includes('total')) { + return null + } + const totalKey = TotalObj[key]; + return ( + + + + ) + })} + +
+ ); +} +export default Statistics; \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/components/AiPromptModal.tsx b/web/src/views/ApplicationConfig/components/AiPromptModal.tsx index b910e1b0..0c7bf480 100644 --- a/web/src/views/ApplicationConfig/components/AiPromptModal.tsx +++ b/web/src/views/ApplicationConfig/components/AiPromptModal.tsx @@ -181,7 +181,7 @@ const AiPromptModal = forwardRef(({ > = { edit: editIcon, copy: copyIcon, diff --git a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeGlobalConfigModal.tsx b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeGlobalConfigModal.tsx index 2f349487..e4204836 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeGlobalConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeGlobalConfigModal.tsx @@ -97,7 +97,7 @@ const KnowledgeGlobalConfigModal = forwardRef = { + daily_conversations: '#FFB048', + daily_new_users: '#4DA8FF', + daily_api_calls: '#155EEF', + daily_tokens: '#AD88FF' +} + +const LineCard: FC = ({ chartData, type, total }) => { + const { t } = useTranslation() + const chartRef = useRef(null); + + useEffect(() => { + + }, [chartData]) + + const getSeries = () => { + return [{ + ...SeriesConfig, + name: t(`application.${type}`), + data: chartData.map(vo => vo.count), + areaStyle: { + opacity: 0.8, + color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [ + { offset: 0, color: ColorObj[type] }, + { offset: 1, color: '#FFFFFF' } + ]) + }, + }] + } + + return ( + {t(`application.${type}`)} {total}} + > + {chartData && chartData.length > 0 ? ( + item.date), + boundaryGap: false, + }, + yAxis: { + type: 'value', + axisLabel: { + color: '#A8A9AA', + fontFamily: 'PingFangSC, PingFang SC', + align: 'right', + lineHeight: 17, + }, + axisLine: { + lineStyle: { + color: '#EBEBEB', + } + }, + }, + series: getSeries() + }} + style={{ height: '265px', width: '100%', minWidth: '100%', boxSizing: 'border-box' }} + opts={{ renderer: 'canvas' }} + notMerge={true} + lazyUpdate={true} + /> + ) : } + + ) +} + +export default LineCard diff --git a/web/src/views/ApplicationConfig/index.tsx b/web/src/views/ApplicationConfig/index.tsx index 7d5d5950..4dd9231a 100644 --- a/web/src/views/ApplicationConfig/index.tsx +++ b/web/src/views/ApplicationConfig/index.tsx @@ -9,6 +9,7 @@ import ReleasePage from './ReleasePage' import Cluster from './Cluster' import { getApplication } from '@/api/application' import Workflow from '@/views/Workflow'; +import Statistics from './Statistics' const ApplicationConfig: React.FC = () => { const { id } = useParams(); @@ -68,6 +69,7 @@ const ApplicationConfig: React.FC = () => { {activeTab === 'arrangement' && application?.type === 'workflow' && } {activeTab === 'api' && } {activeTab === 'release' && } + {activeTab === 'statistics' && } ); }; diff --git a/web/src/views/ApplicationConfig/types.ts b/web/src/views/ApplicationConfig/types.ts index 6f641ebb..9df6e04a 100644 --- a/web/src/views/ApplicationConfig/types.ts +++ b/web/src/views/ApplicationConfig/types.ts @@ -150,4 +150,19 @@ export interface AiPromptForm { } export interface ChatVariableConfigModalRef { handleOpen: (values: Variable[]) => void; +} + +export interface StatisticsItem { + count: number; + date: string; +} +export interface StatisticsData { + daily_conversations: StatisticsItem[]; + daily_new_users: StatisticsItem[]; + daily_api_calls: StatisticsItem[]; + daily_tokens: StatisticsItem[]; + total_conversations: number; + total_new_users: number; + total_api_calls: number; + total_tokens: number; } \ No newline at end of file diff --git a/web/src/views/EmotionEngine/index.tsx b/web/src/views/EmotionEngine/index.tsx index 73bfd376..6528bbbe 100644 --- a/web/src/views/EmotionEngine/index.tsx +++ b/web/src/views/EmotionEngine/index.tsx @@ -20,7 +20,7 @@ const configList = [ key: 'emotion_model_id', type: 'customSelect', url: getModelListUrl, - params: { type: 'chat,llm', page: 1, pagesize: 100 }, // chat,llm + params: { type: 'chat,llm', page: 1, pagesize: 100, is_active: true }, // chat,llm }, { key: 'emotion_min_intensity', diff --git a/web/src/views/MemberManagement/index.tsx b/web/src/views/MemberManagement/index.tsx index 8ce2fc62..68c90410 100644 --- a/web/src/views/MemberManagement/index.tsx +++ b/web/src/views/MemberManagement/index.tsx @@ -39,7 +39,7 @@ const MemberManagement: React.FC = () => { onOk: () => { deleteMember(member.id) .then(() => { - message.success(t('member.deleteSuccess')); + message.success(t('common.deleteSuccess')); refreshTable(); }) } @@ -93,7 +93,7 @@ const MemberManagement: React.FC = () => { return ( <> -
+
diff --git a/web/src/views/MemoryExtractionEngine/index.tsx b/web/src/views/MemoryExtractionEngine/index.tsx index 3d67270c..96138a55 100644 --- a/web/src/views/MemoryExtractionEngine/index.tsx +++ b/web/src/views/MemoryExtractionEngine/index.tsx @@ -1,14 +1,14 @@ import { type FC, useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' -import { Row, Col, Space, Switch, Select, InputNumber, Slider, App, Form } from 'antd' +import { Row, Col, Space, Select, InputNumber, Slider, App, Form } from 'antd' import clsx from 'clsx' import Card from './components/Card' import type { ConfigForm, Variable } from './types' import { getMemoryExtractionConfig, updateMemoryExtractionConfig } from '@/api/memory' import Markdown from '@/components/Markdown' import { getModelList } from '@/api/models'; -import type { Model } from '@/views/ModelManagement/types' +import type { ModelListItem } from '@/views/ModelManagement/types' import { configList } from './constant' import Result from './components/Result' import SwitchFormItem from '@/components/FormItem/SwitchFormItem' @@ -43,7 +43,7 @@ const MemoryExtractionEngine: FC = () => { const values = Form.useWatch([], form) const [loading, setLoading] = useState(false) const [iterationPeriodDisabled, setIterationPeriodDisabled] = useState(false) - const [modelList, setModelList] = useState([]) + const [modelList, setModelList] = useState([]) useEffect(() => { if (values?.reflexion_range === 'database') { @@ -55,9 +55,9 @@ const MemoryExtractionEngine: FC = () => { }, [values]) const getModels = () => { - getModelList({ type: 'llm,chat', pagesize: 100, page: 1 }) + getModelList({ type: 'llm,chat', pagesize: 100, page: 1, is_active: true }) .then(res => { - const response = res as { items: Model[] } + const response = res as { items: ModelListItem[] } setModelList(response.items) }) } diff --git a/web/src/views/ModelManagement/Group.tsx b/web/src/views/ModelManagement/Group.tsx new file mode 100644 index 00000000..311455b4 --- /dev/null +++ b/web/src/views/ModelManagement/Group.tsx @@ -0,0 +1,97 @@ +import { useState, useEffect, forwardRef, useImperativeHandle } from 'react'; +import clsx from 'clsx' +import { Button } from 'antd' +import { useTranslation } from 'react-i18next'; + +import type { ProviderModelItem, ModelListItem, DescriptionItem, BaseRef } from './types' +import RbCard from '@/components/RbCard/Card' +import { getModelNewList } from '@/api/models' +import PageEmpty from '@/components/Empty/PageEmpty'; +import { formatDateTime } from '@/utils/format'; + +const Group = forwardRef void; }>(({ query, handleEdit }, ref) => { + const { t } = useTranslation(); + const [list, setList] = useState([]) + useEffect(() => { + getList() + }, [query]) + const getList = () => { + getModelNewList({ + ...query, + is_composite: true, + is_active: true, + }) + .then(res => { + const response = res as ProviderModelItem[] + setList(response[0]?.models || []) + }) + } + const formatData = (data: ModelListItem) => { + return [ + { + key: 'type', + label: t(`modelNew.type`), + children: data.type || '-', + }, + { + key: 'provider', + label: t(`modelNew.provider`), + children: data.provider || '-', + }, + { + key: 'is_active', + label: t(`modelNew.status`), + children: data.is_active ? t(`common.statusEnabled`) : t(`common.statusDisabled`), + }, + { + key: 'created_at', + label: t(`modelNew.created_at`), + children: data.created_at ? formatDateTime(data.created_at, 'YYYY-MM-DD HH:mm:ss') : '-', + }, + ] + } + + useImperativeHandle(ref, () => ({ + getList, + })); + + return ( + <> + {list.length === 0 + ? + :( +
+ {list.map(item => ( + + {item.name[0]} +
+ } + > + {formatData(item)?.map((description: DescriptionItem) => ( +
+ {(description.label as string)} + {(description.children as string)} +
+ ))} + + + ))} +
+ ) + } + + ) +}) + +export default Group \ No newline at end of file diff --git a/web/src/views/ModelManagement/List.tsx b/web/src/views/ModelManagement/List.tsx new file mode 100644 index 00000000..f1127623 --- /dev/null +++ b/web/src/views/ModelManagement/List.tsx @@ -0,0 +1,83 @@ +import { useRef, useState, useEffect, type FC } from 'react'; +import { Button, Space, Row, Col } from 'antd' +import { useTranslation } from 'react-i18next'; + +import type { ProviderModelItem, KeyConfigModalRef, ModelListDetailRef } from './types' +import RbCard from '@/components/RbCard/Card' +import { getModelNewList } from '@/api/models' +import PageEmpty from '@/components/Empty/PageEmpty'; +import Tag from '@/components/Tag'; +import KeyConfigModal from './components/KeyConfigModal' +import ModelListDetail from './components/ModelListDetail' + +const ModelList: FC<{ query: any }> = ({ query }) => { + const { t } = useTranslation(); + const keyConfigModalRef = useRef(null) + const modelListDetailRef = useRef(null) + const [list, setList] = useState([]) + useEffect(() => { + getList() + }, [query]) + const getList = () => { + getModelNewList({ + ...query, + is_composite: false, + is_active: true, + }) + .then(res => { + setList((res || []) as ProviderModelItem[]) + }) + } + + const handleShowModel = (vo: ProviderModelItem) => { + modelListDetailRef.current?.handleOpen(vo) + } + const handleKeyConfig = (vo: ProviderModelItem) => { + keyConfigModalRef.current?.handleOpen(vo) + } + + return ( + <> + {list.length === 0 + ? + :( +
+ {list.map(item => ( + + {item.provider[0]} +
+ } + > + {item.tags.map(tag => {t(`modelNew.${tag}`)})} + + + + + + + + + + ))} +
+ ) + } + + + + + ) +} + +export default ModelList \ No newline at end of file diff --git a/web/src/views/ModelManagement/Square.tsx b/web/src/views/ModelManagement/Square.tsx new file mode 100644 index 00000000..7ecd838c --- /dev/null +++ b/web/src/views/ModelManagement/Square.tsx @@ -0,0 +1,95 @@ +import { useRef, useState, useEffect, forwardRef, useImperativeHandle } from 'react'; +import { Button, Space, App, Divider, Flex } from 'antd' +import { UsergroupAddOutlined } from '@ant-design/icons'; +import { useTranslation } from 'react-i18next'; + +import type { ModelPlaza, ModelPlazaItem, ModelSquareDetailRef, BaseRef } from './types' +import RbCard from '@/components/RbCard/Card' +import { getModelPlaza, addModelPlaza } from '@/api/models' +import PageEmpty from '@/components/Empty/PageEmpty'; +import Tag from '@/components/Tag'; +import ModelSquareDetail from './components/ModelSquareDetail' + +const ModelSquare = forwardRef void; }>(({ query, handleEdit }, ref) => { + const { t } = useTranslation(); + const { message } = App.useApp() + const modelSquareDetailRef = useRef(null) + const [list, setList] = useState([]) + useEffect(() => { + getList() + }, [query]) + const getList = () => { + getModelPlaza(query) + .then(res => { + setList((res as ModelPlaza[]) || []) + }) + } + + const handleMore = (vo: ModelPlaza) => { + modelSquareDetailRef.current?.handleOpen(vo) + } + const handleAdd = (item: ModelPlazaItem) => { + addModelPlaza(item.id) + .then(() => { + message.success(`${item.name}${t('modelNew.addSuccess')}`) + getList() + }) + } + + useImperativeHandle(ref, () => ({ + getList, + })); + return ( + <> + {list.length === 0 + ? + : list.map(vo => ( +
+
+
{vo.provider}
+ +
+ +
+ {vo.models.slice(0, 6).map(item => ( + + {item.name[0]} +
+ } + > + {t(`modelNew.${item.type}`)} +
{item.description}
+ {item.tags.map((tag, tagIndex) => {tag})} + + + {item.add_count} + + {!item.is_official && } + {item.is_added + ? + : + } + + + + ))} +
+ + )) + } + + + + ) +}) + +export default ModelSquare \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/ConfigModal.tsx b/web/src/views/ModelManagement/components/ConfigModal.tsx deleted file mode 100644 index e4bdf84c..00000000 --- a/web/src/views/ModelManagement/components/ConfigModal.tsx +++ /dev/null @@ -1,171 +0,0 @@ -import { forwardRef, useImperativeHandle, useState } from 'react'; -import { Form, Input, App } from 'antd'; -import { useTranslation } from 'react-i18next'; -import type { ModelFormData, Model, ConfigModalRef, ConfigModalProps } from '../types'; -import RbModal from '@/components/RbModal' -import CustomSelect from '@/components/CustomSelect' -import { updateModel, addModel, modelTypeUrl, modelProviderUrl } from '@/api/models' - -const ConfigModal = forwardRef(({ - refresh -}, ref) => { - const { t } = useTranslation(); - const { message } = App.useApp(); - const [visible, setVisible] = useState(false); - const [model, setModel] = useState({} as Model); - const [isEdit, setIsEdit] = useState(false); - const [form] = Form.useForm(); - const [loading, setLoading] = useState(false) - - const values = Form.useWatch([], form); - - // 封装取消方法,添加关闭弹窗逻辑 - const handleClose = () => { - setModel({} as Model); - form.resetFields(); - setLoading(false) - setVisible(false); - }; - - const handleOpen = (model?: Model) => { - if (model) { - setIsEdit(true); - setModel(model); - // 设置表单值 - const apiKeyInfo = model.api_keys[0] - form.setFieldsValue({ - provider: apiKeyInfo.provider, - model_name: apiKeyInfo.model_name, - api_key: apiKeyInfo.api_key, - api_base: apiKeyInfo.api_base - }); - } else { - setIsEdit(false); - form.resetFields(); - } - setVisible(true); - }; - // 封装保存方法,添加提交逻辑 - const handleSave = () => { - form - .validateFields() - .then(() => { - const data = { - name: values.name, - type: values.type, - api_keys: { - provider: values.provider, - model_name: values.model_name, - api_key: values.api_key, - api_base: values.api_base - }, - } - setLoading(true) - const res = isEdit - ? updateModel(model.api_keys[0].id, { - provider: values.provider, - model_name: values.model_name, - api_key: values.api_key, - api_base: values.api_base - } as ModelFormData) - : addModel(data as ModelFormData) - - res.then(() => { - if (refresh) { - refresh(); - } - handleClose() - message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess')) - }) - .catch(() => { - setLoading(false) - }); - }) - .catch((err) => { - console.log('err', err) - }); - } - - // 暴露给父组件的方法 - useImperativeHandle(ref, () => ({ - handleOpen, - handleClose - })); - - return ( - -
- {!isEdit && ( - <> - - - - - items.map((item) => ({ label: t(`model.${item}`), value: item }))} - /> - - - )} - - - - items.map((item) => ({ label: t(`model.${item}`), value: item }))} - /> - - - - - - - - - - - - -
-
- ); -}); - -export default ConfigModal; \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/CustomModelModal.tsx b/web/src/views/ModelManagement/components/CustomModelModal.tsx new file mode 100644 index 00000000..d22fbcdd --- /dev/null +++ b/web/src/views/ModelManagement/components/CustomModelModal.tsx @@ -0,0 +1,165 @@ +import { forwardRef, useImperativeHandle, useState } from 'react'; +import { Form, Input, App, Select } from 'antd'; +import { useTranslation } from 'react-i18next'; + +import type { CustomModelForm, ModelPlazaItem, CustomModelModalRef, CustomModelModalProps } from '../types'; +import RbModal from '@/components/RbModal' +import CustomSelect from '@/components/CustomSelect' +import UploadImages from '@/components/Upload/UploadImages' +import { updateCustomModel, addCustomModel, modelTypeUrl, modelProviderUrl } from '@/api/models' +import { getFileLink } from '@/api/fileStorage' + +const CustomModelModal = forwardRef(({ + refresh +}, ref) => { + const { t } = useTranslation(); + const { message } = App.useApp(); + const [visible, setVisible] = useState(false); + const [model, setModel] = useState({} as ModelPlazaItem); + const [isEdit, setIsEdit] = useState(false); + const [form] = Form.useForm(); + const [loading, setLoading] = useState(false) + const formValues = Form.useWatch([], form) + + const handleClose = () => { + setModel({} as ModelPlazaItem); + form.resetFields(); + setLoading(false) + setVisible(false); + }; + + const handleOpen = (model?: ModelPlazaItem) => { + if (model) { + setIsEdit(true); + setModel(model); + form.setFieldsValue({ + ...model, + logo: model.logo ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined + }); + } else { + setIsEdit(false); + form.resetFields(); + } + setVisible(true); + }; + const handleUpdate = (data: CustomModelForm) => { + setLoading(true) + const res = isEdit ? updateCustomModel(model.id, data) : addCustomModel(data) + + res.then(() => { + refresh && refresh() + handleClose() + message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess')) + }) + .catch(() => { + setLoading(false) + }); + } + const handleSave = () => { + form + .validateFields() + .then((values) => { + setLoading(true) + const { logo, ...rest } = values; + let formData: CustomModelForm = { + ...rest + } + formData.is_official = false; + + if (typeof logo === 'object' && logo?.response?.data.file_id) { + getFileLink(logo?.response?.data.file_id) + .then(res => { + const logoRes = res as { url: string } + formData.logo = logoRes.url + handleUpdate(formData) + }) + .catch(() => { + handleUpdate(formData) + }) + } else { + formData.logo = typeof logo === 'string' ? logo : logo.url + handleUpdate(formData) + } + }) + .catch((err) => { + console.log('err', err) + }); + } + + useImperativeHandle(ref, () => ({ + handleOpen, + })); + + console.log('formValues', formValues) + + return ( + +
+ + + + + + + + + items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))} + /> + + + + items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))} + /> + + + + + + + + + + + items.map((item) => ({ + label: t(`modelNew.${typeof item === 'object' ? item.value : item}`), + value: typeof item === 'object' ? item.value : item + }))} + disabled={isEdit} + /> + + + + + + + + + +
+
+ ); +}); + +export default GroupModelModal; \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/KeyConfigModal.tsx b/web/src/views/ModelManagement/components/KeyConfigModal.tsx new file mode 100644 index 00000000..d157dde7 --- /dev/null +++ b/web/src/views/ModelManagement/components/KeyConfigModal.tsx @@ -0,0 +1,92 @@ +import { forwardRef, useImperativeHandle, useState } from 'react'; +import { Form, Input, App } from 'antd'; +import { useTranslation } from 'react-i18next'; +import type { KeyConfigModalForm, ProviderModelItem, KeyConfigModalRef, KeyConfigModalProps } from '../types'; +import RbModal from '@/components/RbModal' +import { updateProviderApiKeys } from '@/api/models' + +const KeyConfigModal = forwardRef(({ + refresh +}, ref) => { + const { t } = useTranslation(); + const { message } = App.useApp(); + const [visible, setVisible] = useState(false); + const [model, setModel] = useState({} as ProviderModelItem); + const [form] = Form.useForm(); + const [loading, setLoading] = useState(false) + + const handleClose = () => { + setModel({} as ProviderModelItem); + form.resetFields(); + setLoading(false) + setVisible(false); + }; + + const handleOpen = (vo: ProviderModelItem) => { + setVisible(true); + setModel(vo); + }; + const handleSave = () => { + form + .validateFields() + .then((values) => { + setLoading(true) + + updateProviderApiKeys({ + ...values, + provider: model.provider + }).then(() => { + if (refresh) { + refresh(); + } + handleClose() + message.success(t('common.updateSuccess')) + }) + .catch(() => { + setLoading(false) + }); + }) + .catch((err) => { + console.log('err', err) + }); + } + + useImperativeHandle(ref, () => ({ + handleOpen, + handleClose + })); + + return ( + +
+ + + + + + + +
+
+ ); +}); + +export default KeyConfigModal; \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/ModelImplement/SubModelModal.tsx b/web/src/views/ModelManagement/components/ModelImplement/SubModelModal.tsx new file mode 100644 index 00000000..069f785d --- /dev/null +++ b/web/src/views/ModelManagement/components/ModelImplement/SubModelModal.tsx @@ -0,0 +1,164 @@ +import { forwardRef, useImperativeHandle, useState } from 'react'; +import { Form, Cascader, App } from 'antd'; +import { useTranslation } from 'react-i18next'; + +import type { SubModelModalForm, SubModelModalRef, SubModelModalProps, ModelList } from './types'; +import RbModal from '@/components/RbModal' +import CustomSelect from '@/components/CustomSelect' +import { modelProviderUrl, getModelNewList } from '@/api/models' +import type { ProviderModelItem } from '../../types' + +const { SHOW_CHILD } = Cascader; + +interface Option { + value: string | number; + label: string; + children?: Option[]; + [key: string]: any; +} +const SubModelModal = forwardRef(({ + refresh, + type +}, ref) => { + const { t } = useTranslation(); + const { message } = App.useApp() + const [visible, setVisible] = useState(false); + const [form] = Form.useForm(); + const [selecteds, setSelecteds] = useState([]) + const [modelList, setModelList] = useState([]) + + // 封装取消方法,添加关闭弹窗逻辑 + const handleClose = () => { + form.resetFields(); + setVisible(false); + setSelecteds([]) + }; + + const handleOpen = (list?: ModelList[], provider?: string) => { + if (list?.length && provider) { + const initialValue: SubModelModalForm = { + provider, + api_key_ids: list.map(vo => { + return [vo.model_config_ids[0], vo.id] + }) + } + + form.setFieldsValue(initialValue); + handleChangeProvider(provider, initialValue.api_key_ids) + } else { + form.resetFields() + } + setVisible(true); + }; + // 封装保存方法,添加提交逻辑 + const handleSave = () => { + form + .validateFields() + .then(() => { + refresh?.(selecteds.map(vo => ({ + ...vo[0], + model_name: vo[0].name, + model_config_ids: [vo[0].id], + id: vo[1].value + }))) + handleClose() + }) + } + const handleChange = (value: (string | number)[][], selectedOptions: Option[][]) => { + const filterList = selectedOptions.filter(vo => vo.length === 1).map(item => item[0]) + const lastFilterLit = value.filter(vo => vo.length !== 1) + console.log('onchange', value, lastFilterLit, selectedOptions, filterList) + if (filterList.length) { + message.warning(`【${filterList.map(vo => vo.label)}】${t('modelNew.selectOneTip')}`) + form.setFieldValue('api_key_ids', lastFilterLit) + } + setSelecteds(selectedOptions) + } + + const handleChangeProvider = (provider: string, api_key_ids?: any[]) => { + form.setFieldValue('api_key_ids', undefined) + getModelNewList({ + provider: provider, + is_composite: false, + is_active: true, + type + }) + .then(res => { + const response = res as ProviderModelItem[] + const list = response[0]?.models || [] + setModelList(list.map(vo => { + const children = vo.api_keys.map(item => ({ + label: item.api_key, + value: item.id, + })) + return { + ...vo, + label: vo.name, + value: vo.id, + children: children + } + })) + + if (api_key_ids?.length) { + form.setFieldsValue({ + api_key_ids: api_key_ids + }) + } + }) + } + + // 暴露给父组件的方法 + useImperativeHandle(ref, () => ({ + handleOpen, + })); + + return ( + +
+ + items.map((item) => ({ + label: t(`modelNew.${typeof item === 'object' ? item.value : item}`), + value: typeof item === 'object' ? item.value : item + }))} + onChange={(value) => handleChangeProvider(value)} + /> + + + + +
+
+ ); +}); + +export default SubModelModal; \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/ModelImplement/index.tsx b/web/src/views/ModelManagement/components/ModelImplement/index.tsx new file mode 100644 index 00000000..a876587d --- /dev/null +++ b/web/src/views/ModelManagement/components/ModelImplement/index.tsx @@ -0,0 +1,106 @@ +import { type FC, useRef } from "react"; +import { useTranslation } from 'react-i18next'; +import { Flex, Button, Space, App } from 'antd' + +import type { SubModelModalRef, ModelList } from './types' +import SubModelModal from './SubModelModal' +import Empty from '@/components/Empty' +import Tag from '@/components/Tag' + +interface ModelImplementProps { + type?: string; + value?: any; + onChange?: (value: any) => void; +} +const ModelImplement: FC = ({ type, value, onChange }) => { + const { t } = useTranslation(); + const { modal, message } = App.useApp(); + const subModelModalRef = useRef(null) + + const handleAdd = () => { + if (!type || type.trim() === '') { + message.warning(t('common.selectPlaceholder', { title: t('modelNew.type') })) + return + } + subModelModalRef.current?.handleOpen() + } + const handleEdit = (list: ModelList[], provider: string ) => { + subModelModalRef.current?.handleOpen(list, provider) + } + const handleDelete = (provider: string) => { + modal.confirm({ + title: t('common.confirmDeleteDesc', { name: provider }), + content: t('application.apiKeyDeleteContent'), + okText: t('common.delete'), + cancelText: t('common.cancel'), + okType: 'danger', + onOk: () => { + onChange?.(value?.filter((item: any) => item.provider !== provider)) + } + }) + } + const handleRefresh = (list: ModelList[]) => { + const existingModels = value || []; + let updatedModels = [...existingModels]; + + const provider = list[0].provider + + updatedModels = updatedModels.filter(item => item.provider !== provider) + updatedModels = [...updatedModels, ...list] + + onChange?.([...updatedModels]); + } + + const groupedByProvider: Record = (value || []).reduce((acc: Record, item: ModelList) => { + const provider = item.provider || 'unknown'; + if (!acc[provider]) acc[provider] = []; + acc[provider].push(item); + return acc; + }, {} as Record); + + return ( +
+ + {t('modelNew.modelImplement')} + + + + + + + + +
+ {!value || value.length === 0 + ? + : Object.entries(groupedByProvider).map(([provider, items]: [string, ModelList[]]) => { + return ( +
+ +
{[...new Set(items?.map((vo) => vo.model_name))].join(', ')}
+ +
handleEdit(items, provider)} + >
+
handleDelete(provider)} + >
+
+
+ {provider} +
+ ) + })} +
+ +
+ ) +} + +export default ModelImplement \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/ModelImplement/types.ts b/web/src/views/ModelManagement/components/ModelImplement/types.ts new file mode 100644 index 00000000..c6d2f6d6 --- /dev/null +++ b/web/src/views/ModelManagement/components/ModelImplement/types.ts @@ -0,0 +1,16 @@ +import type { ModelListItem } from '../../types' + +export interface ModelList extends ModelListItem { + api_key_id: string; +} +export interface SubModelModalForm { + provider: string; + api_key_ids: string[][]; +} +export interface SubModelModalRef { + handleOpen: (list?: ModelList[], provider?: string) => void; +} +export interface SubModelModalProps { + type?: string; + refresh?: (vo: ModelList[]) => void; +} \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/ModelListDetail.tsx b/web/src/views/ModelManagement/components/ModelListDetail.tsx new file mode 100644 index 00000000..48abd953 --- /dev/null +++ b/web/src/views/ModelManagement/components/ModelListDetail.tsx @@ -0,0 +1,111 @@ +import { useState, useImperativeHandle, forwardRef, useRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Button, Switch, Row, Col, Space } from 'antd' + +import type { ProviderModelItem, ModelListItem, ModelListDetailRef, MultiKeyConfigModalRef } from '../types'; +import RbDrawer from '@/components/RbDrawer'; +import RbCard from '@/components/RbCard/Card' +import Tag from '@/components/Tag'; +import PageEmpty from '@/components/Empty/PageEmpty'; +import MultiKeyConfigModal from './MultiKeyConfigModal' +import { getModelNewList, updateModelStatus } from '@/api/models' + +interface ModelListDetailProps { + refresh?: () => void; +} + +const ModelListDetail = forwardRef(({ refresh }, ref) => { + const { t } = useTranslation(); + const [open, setOpen] = useState(false); + const [data, setData] = useState({} as ProviderModelItem) + const [list, setList] = useState([]) + const multiKeyConfigModalRef = useRef(null) + const [loading, setLoading] = useState(false) + + const handleOpen = (vo: ProviderModelItem) => { + setOpen(true) + getData(vo) + } + + const getData = (vo: ProviderModelItem) => { + if (!vo.provider) return + + getModelNewList({ + provider: vo.provider + }) + .then(res => { + const response = res as ProviderModelItem[] + setData(response[0]) + setList(response[0].models) + }) + } + const handleKeyConfig = (vo: ModelListItem) => { + multiKeyConfigModalRef.current?.handleOpen(vo, data.provider) + } + const handleChange = (vo: ModelListItem) => { + setLoading(true) + updateModelStatus(vo.id, { is_active: !vo.is_active }) + .finally(() => { + getData(data) + setLoading(false) + }) + } + + const handleClose = () => { + setOpen(false) + refresh?.() + } + const handleRefresh = () => { + getData(data) + } + + useImperativeHandle(ref, () => ({ + handleOpen, + })); + + return ( + {data.provider} {t('modelNew.modelList')} ({list.length}{t('modelNew.item')})} + open={open} + onClose={handleClose} + > + {list.length === 0 + ? + :
+ {list.map(item => ( + + {t(`modelNew.${item.type}`)} + {item.api_keys.length}{t('modelNew.apiKeyNum')} + } + avatarUrl={item.logo} + avatar={ +
+ {item.name[0]} +
+ } + extra={ handleChange(item)} />} + > + +
{item.description}
+ + + + + +
+ ))} +
+ } + + +
+ ); +}); + +export default ModelListDetail; \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/ModelSquareDetail.tsx b/web/src/views/ModelManagement/components/ModelSquareDetail.tsx new file mode 100644 index 00000000..d7a5f807 --- /dev/null +++ b/web/src/views/ModelManagement/components/ModelSquareDetail.tsx @@ -0,0 +1,95 @@ +import { useState, useImperativeHandle, forwardRef } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Button, Space, App, Flex } from 'antd' +import { UsergroupAddOutlined } from '@ant-design/icons'; + +import type { ModelPlaza, ModelPlazaItem, ModelSquareDetailRef } from '../types'; +import RbDrawer from '@/components/RbDrawer'; +import { getModelPlaza, addModelPlaza } from '@/api/models' +import RbCard from '@/components/RbCard/Card' +import Tag from '@/components/Tag'; +import PageEmpty from '@/components/Empty/PageEmpty'; + +interface ModelSquareDetailProps { + refresh: () => void; + handleEdit: (vo: ModelPlazaItem) => void; +} +const ModelSquareDetail = forwardRef(({ refresh, handleEdit }, ref) => { + const { t } = useTranslation(); + const { message } = App.useApp() + const [model, setModel] = useState({} as ModelPlaza) + const [open, setOpen] = useState(false); + + const [list, setList] = useState([]) + + const handleOpen = (vo: ModelPlaza) => { + setModel(vo) + setOpen(true) + getList(vo) + } + const handleClose = () => { + setOpen(false) + refresh() + } + const getList = (vo: ModelPlaza) => { + getModelPlaza({ provider: vo.provider }) + .then(res => { + const response = res as ModelPlaza[] + setList(response.length > 0 ? response[0].models : []) + }) + } + const handleAdd = (item: ModelPlazaItem) => { + addModelPlaza(item.id) + .then(() => { + message.success(`${item.name}${t('modelNew.addSuccess')}`) + getList(model) + }) + } + + useImperativeHandle(ref, () => ({ + handleOpen, + })); + + return ( + {model.provider} {t('modelNew.modelList')} ({list.length}{t('modelNew.item')})} + open={open} + onClose={handleClose} + > + {list.length === 0 + ? + :
+ {list.map(item => ( + + {item.name[0]} +
+ } + > + {t(`modelNew.${item.type}`)} +
{item.description}
+ {item.tags.map((tag, tagIndex) => {tag})} + + + {item.add_count} + + {!item.is_official && } + {item.is_added + ? + : + } + + + + ))} + + } +
+ ); +}); + +export default ModelSquareDetail; \ No newline at end of file diff --git a/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx new file mode 100644 index 00000000..334badc8 --- /dev/null +++ b/web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx @@ -0,0 +1,122 @@ +import { forwardRef, useImperativeHandle, useState } from 'react'; +import { Form, Input, App, Button } from 'antd'; +import { useTranslation } from 'react-i18next'; +import type { ModelListItem, MultiKeyForm, MultiKeyConfigModalRef, MultiKeyConfigModalProps } from '../types'; +import RbModal from '@/components/RbModal' +import { addModelApiKey, deleteModelApiKey, getModelInfo } from '@/api/models' + +const MultiKeyConfigModal = forwardRef(({ refresh }, ref) => { + const { t } = useTranslation(); + const { message } = App.useApp(); + const [visible, setVisible] = useState(false); + const [model, setModel] = useState({} as ModelListItem); + const [form] = Form.useForm(); + const [loading, setLoading] = useState(false) + + const handleClose = () => { + setModel({} as ModelListItem); + refresh?.() + + form.resetFields(); + setLoading(false) + setVisible(false); + }; + + const handleOpen = (vo: ModelListItem) => { + setVisible(true); + getData(vo) + }; + + const getData = (vo: ModelListItem) => { + if (!vo.id) return + + getModelInfo(vo?.id) + .then(res => { + setModel(res as ModelListItem) + }) + } + const handleSave = () => { + form + .validateFields() + .then((values) => { + setLoading(true) + addModelApiKey(model.id, { + ...values, + model_config_id: model.id, + model_name: model.name, + provider: model.provider, + }).then(() => { + message.success(t('common.saveSuccess')) + form.resetFields(); + getData(model) + }) + .catch(() => { + setLoading(false) + }); + }) + .catch((err) => { + console.log('err', err) + }); + } + const handleDelete = (api_key_id: string) => { + deleteModelApiKey(api_key_id) + .then(() => { + message.success(t('common.deleteSuccess')) + getData(model) + }) + } + + useImperativeHandle(ref, () => ({ + handleOpen, + })); + + return ( + + {model.api_keys && model.api_keys.length > 0 && ( +
+ {model.api_keys.map((key) => ( +
+
+
{key.api_key}
+
{key.api_base}
+
+ +
+ ))} +
+ )} +
+ + + + + + + + + + + +
+
+ ); +}); + +export default MultiKeyConfigModal; \ No newline at end of file diff --git a/web/src/views/ModelManagement/index.tsx b/web/src/views/ModelManagement/index.tsx index 930a18e6..35f4c887 100644 --- a/web/src/views/ModelManagement/index.tsx +++ b/web/src/views/ModelManagement/index.tsx @@ -1,99 +1,123 @@ import { useState, useRef, type FC } from 'react'; -import { Row, Col, Button } from 'antd' +import { Button, Flex, Space, type SegmentedProps } from 'antd' import { useTranslation } from 'react-i18next'; -import clsx from 'clsx'; -import ConfigModal from './components/ConfigModal' -import type { Model, DescriptionItem, ConfigModalRef } from './types' -import RbCard from '@/components/RbCard/Card' +import GroupModelModal from './components/GroupModelModal' +import type { ModelListItem, GroupModelModalRef, CustomModelModalRef, ModelPlazaItem, BaseRef } from './types' import SearchInput from '@/components/SearchInput' -import PageScrollList, { type PageScrollListRef } from '@/components/PageScrollList' -import { getModelListUrl } from '@/api/models' -import { formatDateTime } from '@/utils/format'; +import PageTabs from '@/components/PageTabs' +import GroupModel from './Group' +import ModelList from './List' +import ModelSquare from './Square' +import CustomModelModal from './components/CustomModelModal' +import CustomSelect from '@/components/CustomSelect' +import { modelTypeUrl, modelProviderUrl } from '@/api/models' +const tabKeys = ['group', 'list', 'square'] const ModelManagement: FC = () => { const { t } = useTranslation(); + const [activeTab, setActiveTab] = useState('group'); const [query, setQuery] = useState({}) - const configModalRef = useRef(null) - const scrollListRef = useRef(null) + const configModalRef = useRef(null) + const customModelModalRef = useRef(null) + const groupRef = useRef(null) + const squareRef = useRef(null) - const formatData = (data: Model) => { - return [ - { - key: 'type', - label: t(`model.type`), - children: data.type || '-', - }, - { - key: 'provider', - label: t(`model.provider`), - children: data.api_keys[0].provider || '-', - }, - { - key: 'is_active', - label: t(`model.status`), - children: data.is_active ? t(`common.statusEnabled`) : t(`common.statusDisabled`), - }, - { - key: 'created', - label: t(`model.created`), - children: data.created_at ? formatDateTime(data.created_at, 'YYYY-MM-DD HH:mm:ss') : '-', - }, - ] + const formatTabItems = () => { + return tabKeys.map(value => ({ + value, + label: t(`modelNew.${value}`), + })) + } + const handleChangeTab = (value: SegmentedProps['value']) => { + setActiveTab(value as string); + setQuery({}) } - const handleEdit = (model?: Model) => { - configModalRef?.current?.handleOpen(model) + const handleEdit = (vo?: ModelListItem | ModelPlazaItem) => { + switch(activeTab) { + case 'group': + configModalRef?.current?.handleOpen(vo as ModelListItem) + break + case 'square': + customModelModalRef?.current?.handleOpen(vo as ModelPlazaItem) + break + } + } + const handleRefresh = () => { + switch (activeTab) { + case 'group': + groupRef.current?.getList() + break + case 'square': + squareRef.current?.getList() + break + } } const handleSearch = (value?: string) => { setQuery({ search: value }) } + const handleTypeChange = (value: string) => { + setQuery(pre => ({ ...pre, type: value })) + } + const handleProviderChange = (value: string) => { + setQuery(pre => ({ ...pre, provider: value })) + } return ( -
- - - + + + + + {activeTab === 'list' ? <> + items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))} + onChange={handleTypeChange} + className="rb:w-30" + allowClear={true} + placeholder={t('modelNew.type')} + /> + items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))} + onChange={handleProviderChange} + className="rb:w-30" + allowClear={true} + placeholder={t('modelNew.provider')} + /> + + : - - - - - + className="rb:w-70!" + />} + {activeTab === 'group' && } + {activeTab === 'square' && } + + - ( - - {formatData(item)?.map((description: DescriptionItem) => ( -
- {(description.label as string)} - {(description.children as string)} -
- ))} - -
- )} - /> - - + {activeTab === 'group' && } + {activeTab === 'list' && } + {activeTab === 'square' && } +
+ scrollListRef?.current?.refresh()} + refresh={handleRefresh} /> - + + ) } diff --git a/web/src/views/ModelManagement/types.ts b/web/src/views/ModelManagement/types.ts index 215e0d9f..1967f393 100644 --- a/web/src/views/ModelManagement/types.ts +++ b/web/src/views/ModelManagement/types.ts @@ -1,70 +1,139 @@ -// 模型表单数据类型 -export interface ModelFormData extends ApiKey { - name: string; - type: string; - api_keys: ApiKey; -} +export interface Query { + type?: string; + provider?: string; + is_active?: boolean; + is_public?: boolean; + is_composite?: boolean; + search?: string; + pagesize?: number; + page?: number; +} export interface DescriptionItem { key: string; label: string; children: string; } +export interface CompositeModelForm { + logo?: any; + name: string; + type: string; + description: string; + api_key_ids: ModelApiKey[] | string[]; +} +export interface GroupModelModalRef { + handleOpen: (model?: ModelListItem) => void; +} +export interface GroupModelModalProps { + refresh?: () => void; +} +export interface ModelListDetailRef { + handleOpen: (vo: ProviderModelItem) => void; +} -// 模型类型定义 -export interface Model { + +export interface ModelApiKey { + model_name: string; + description: string | null; + provider: string; + api_key: string; + api_base: string; + config: any; + is_active: boolean; + priority: string; + id: string; + usage_count: string; + last_used_at: number; + created_at: number; + updated_at: number; + model_config_ids: string[]; +} +export interface ModelListItem { + model_name?: string; + model_config_ids: string[]; + name: string; + type: string; + logo: string; + description: string; + provider: string; + config: any; + is_active: boolean; + is_public: boolean; + id: string; + created_at: number; + updated_at: number; + api_keys: ModelApiKey[] +} +export interface ProviderModelItem { + provider: string; + logo?: string; + tags: string[]; + models: ModelListItem[]; +} +export interface KeyConfigModalForm { + provider: string; + api_key: string; + api_base: string; +} +export interface KeyConfigModalRef { + handleOpen: (vo: ProviderModelItem) => void; +} +export interface KeyConfigModalProps { + refresh?: () => void; +} +export interface MultiKeyForm { + model_config_id?: string; + model_name: string; + provider: string; + api_key: string; + api_base: string; +} + +export interface MultiKeyConfigModalRef { + handleOpen: (vo: ModelListItem, provider?: string) => void; +} +export interface MultiKeyConfigModalProps { + refresh?: () => void; +} + + +export interface ModelPlaza { + provider: string; + models: ModelPlazaItem[]; +} +export interface ModelPlazaItem { id: string; name: string; type: string; - description?: string; - config: Record; - is_active: boolean; - is_public: boolean; - created_at: string | number; - updated_at: string | number; - api_keys: ApiKey[]; - - // provider: string; - // temperature: number, - // topP: number, - // status: string; - // vectorDimension: number; - // batchSize: number; - // truncateStrategy: string; - // created: string; - // updatedAt: string; - // descriptionItems?: Record[]; - // basicParameters?: string; - // normalization?: string; - // maxInputLength?: number; - // encodingFormat?: string; - // enablePooling?: boolean; - // poolingStrategy?: string; - // apiKey?: string; - // apiEndpoint?: string; - // timeout?: number; - // autoRetry?: boolean; - // retryCount?: number; -} -interface ApiKey { - model_name?: string; provider: string; - api_key?: string; - api_base?: string; - config?: Record; - is_active?: boolean; - priority?: string; - id: string; - model_config_id?: string; - usage_count?: string; - last_used_at?: string | null; - created_at?: string; - updated_at?: string; + logo: string; + description: string; + is_deprecated: boolean; + is_official: boolean; + tags: string[]; + add_count: number; + is_added: boolean; } -// 定义组件暴露的方法接口 -export interface ConfigModalRef { - handleOpen: (model?: Model) => void; +export interface ModelSquareDetailRef { + handleOpen: (vo: ModelPlaza) => void; } -export interface ConfigModalProps { +export interface CustomModelForm { + name: string; + type: string; + provider: string; + logo?: any; + description: string; + is_official: boolean; + tags: string[]; +} +export interface CustomModelModalRef { + handleOpen: (vo?: ModelPlazaItem) => void; +} +export interface CustomModelModalProps { refresh?: () => void; +} + + +export interface BaseRef { + getList: () => void; } \ No newline at end of file diff --git a/web/src/views/SelfReflectionEngine/index.tsx b/web/src/views/SelfReflectionEngine/index.tsx index 784f066c..30117bed 100644 --- a/web/src/views/SelfReflectionEngine/index.tsx +++ b/web/src/views/SelfReflectionEngine/index.tsx @@ -24,7 +24,7 @@ const configList = [ key: 'reflection_model_id', type: 'customSelect', url: getModelListUrl, - params: { type: 'chat,llm', page: 1, pagesize: 100 }, // chat,llm + params: { type: 'chat,llm', page: 1, pagesize: 100, is_active: true }, // chat,llm }, // 迭代周期 { diff --git a/web/src/views/SpaceConfig/index.tsx b/web/src/views/SpaceConfig/index.tsx index ad99e220..25490e91 100644 --- a/web/src/views/SpaceConfig/index.tsx +++ b/web/src/views/SpaceConfig/index.tsx @@ -66,7 +66,7 @@ const SpaceConfig: FC = () => { > { > { > (({ const [form] = Form.useForm(); const [loading, setLoading] = useState(false) const [editVo, setEditVo] = useState(null) - const [modelList, setModelList] = useState([]) + const [modelList, setModelList] = useState([]) const values = Form.useWatch([], form); @@ -80,9 +80,9 @@ const SpaceModal = forwardRef(({ }, []) const getModels = () => { - getModelList({ type: 'llm,chat', pagesize: 100, page: 1 }) + getModelList({ type: 'llm,chat', pagesize: 100, page: 1, is_active: true }) .then(res => { - const response = res as { items: Model[] } + const response = res as { items: ModelListItem[] } setModelList(response.items) }) } @@ -134,7 +134,7 @@ const SpaceModal = forwardRef(({ > (({ >