From db46c186aac6c33f1b1dc198776d89dd3c6c15ec Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Fri, 6 Feb 2026 12:18:40 +0800 Subject: [PATCH] [ADD]Three party synchronization 1. Three party web website data access - Web site synchronization Building a knowledge base by crawling web page data in batches through web crawlers Web site synchronization utilizes crawler technology, which can automatically capture all websites under the same domain name through a single entry website. Currently, it supports up to 200 subpages. For compliance and security reasons, only static site crawling is supported, mainly used for quickly building knowledge bases on various document sites. 2. Feishu Knowledge Base By configuring Feishu document permissions, a knowledge base can be built using Feishu documents, and the documents will not undergo secondary storage 3. Language Bird Knowledge Base You can configure the permissions of the language bird document to build a knowledge base using the language bird document, and the document will not undergo secondary storage --- api/app/celery_app.py | 1 + api/app/controllers/knowledge_controller.py | 101 +++- api/app/core/rag/crawler/__init__.py | 0 api/app/core/rag/crawler/__main__.py | 89 +++ api/app/core/rag/crawler/content_extractor.py | 233 ++++++++ api/app/core/rag/crawler/http_fetcher.py | 302 ++++++++++ api/app/core/rag/crawler/models.py | 52 ++ api/app/core/rag/crawler/rate_limiter.py | 57 ++ api/app/core/rag/crawler/robots_parser.py | 118 ++++ api/app/core/rag/crawler/url_normalizer.py | 171 ++++++ api/app/core/rag/crawler/web_crawler.py | 215 +++++++ api/app/core/rag/integrations/__init__.py | 1 + .../core/rag/integrations/feishu/__init__.py | 1 + .../core/rag/integrations/feishu/__main__.py | 84 +++ .../core/rag/integrations/feishu/client.py | 452 +++++++++++++++ .../rag/integrations/feishu/exceptions.py | 46 ++ .../core/rag/integrations/feishu/models.py | 17 + api/app/core/rag/integrations/feishu/retry.py | 137 +++++ .../core/rag/integrations/yuque/__init__.py | 1 + .../core/rag/integrations/yuque/__main__.py | 77 +++ api/app/core/rag/integrations/yuque/client.py | 544 ++++++++++++++++++ .../core/rag/integrations/yuque/exceptions.py | 46 ++ api/app/core/rag/integrations/yuque/models.py | 42 ++ api/app/core/rag/integrations/yuque/retry.py | 134 +++++ api/app/models/file_model.py | 1 + api/app/models/knowledge_model.py | 11 + api/app/schemas/file_schema.py | 3 + api/app/tasks.py | 483 ++++++++++++++++ api/pyproject.toml | 2 + api/requirements.txt | 2 + 30 files changed, 3422 insertions(+), 1 deletion(-) create mode 100644 api/app/core/rag/crawler/__init__.py create mode 100644 api/app/core/rag/crawler/__main__.py create mode 100644 api/app/core/rag/crawler/content_extractor.py create mode 100644 api/app/core/rag/crawler/http_fetcher.py create mode 100644 api/app/core/rag/crawler/models.py create mode 100644 api/app/core/rag/crawler/rate_limiter.py create mode 100644 api/app/core/rag/crawler/robots_parser.py create mode 100644 api/app/core/rag/crawler/url_normalizer.py create mode 100644 api/app/core/rag/crawler/web_crawler.py create mode 100644 api/app/core/rag/integrations/__init__.py create mode 100644 api/app/core/rag/integrations/feishu/__init__.py create mode 100644 api/app/core/rag/integrations/feishu/__main__.py create mode 100644 api/app/core/rag/integrations/feishu/client.py create mode 100644 api/app/core/rag/integrations/feishu/exceptions.py create mode 100644 api/app/core/rag/integrations/feishu/models.py create mode 100644 api/app/core/rag/integrations/feishu/retry.py create mode 100644 api/app/core/rag/integrations/yuque/__init__.py create mode 100644 api/app/core/rag/integrations/yuque/__main__.py create mode 100644 api/app/core/rag/integrations/yuque/client.py create mode 100644 api/app/core/rag/integrations/yuque/exceptions.py create mode 100644 api/app/core/rag/integrations/yuque/models.py create mode 100644 api/app/core/rag/integrations/yuque/retry.py diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 002547f6..db78a368 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -76,6 +76,7 @@ celery_app.conf.update( # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, + 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, diff --git a/api/app/controllers/knowledge_controller.py b/api/app/controllers/knowledge_controller.py index 901208ba..01f89a3d 100644 --- a/api/app/controllers/knowledge_controller.py +++ b/api/app/controllers/knowledge_controller.py @@ -9,13 +9,16 @@ from sqlalchemy import or_ from sqlalchemy.orm import Session from app.celery_app import celery_app +from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger from app.core.rag.common import settings +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.yuque.client import YuqueAPIClient from app.core.rag.llm.chat_model import Base from app.core.rag.nlp import rag_tokenizer, search from app.core.rag.prompts.generator import graph_entity_types from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.response_utils import success +from app.core.response_utils import success, fail from app.db import get_db from app.dependencies import get_current_user from app.models import knowledge_model @@ -484,3 +487,99 @@ async def rebuild_knowledge_graph( except Exception as e: api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}") raise + + +@router.get("/check/yuque/auth", response_model=ApiResponse) +async def check_yuque_auth( + yuque_user_id: str, + yuque_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check yuque auth info + """ + api_logger.info(f"check yuque auth info, username: {current_user.username}") + + try: + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + async with api_client as client: + repos = await client.get_user_repos() + if repos: + return success(data=repos, msg="Successfully auth yuque info") + return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth yuque info failed: {str(e)}") + raise + + +@router.get("/check/feishu/auth", response_model=ApiResponse) +async def check_yuque_auth( + feishu_app_id: str, + feishu_app_secret: str, + feishu_folder_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check feishu auth info + """ + api_logger.info(f"check feishu auth info, username: {current_user.username}") + + try: + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + if files: + return success(data=files, msg="Successfully auth feishu info") + return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth feishu info failed: {str(e)}") + raise + + +@router.post("/{knowledge_id}/sync", response_model=ApiResponse) +async def sync_knowledge( + knowledge_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + sync knowledge base information based on knowledge_id + """ + api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}") + + try: + # 1. Query knowledge base information from the database + api_logger.debug(f"Query knowledge base: {knowledge_id}") + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user) + if not db_knowledge: + api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The knowledge base does not exist or access is denied" + ) + + # 2. sync knowledge + # from app.tasks import sync_knowledge_for_kb + # sync_knowledge_for_kb(kb_id) + task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id]) + result = { + "task_id": task.id + } + return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}") + raise diff --git a/api/app/core/rag/crawler/__init__.py b/api/app/core/rag/crawler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/rag/crawler/__main__.py b/api/app/core/rag/crawler/__main__.py new file mode 100644 index 00000000..51a6870f --- /dev/null +++ b/api/app/core/rag/crawler/__main__.py @@ -0,0 +1,89 @@ +"""Command-line interface for web crawler.""" + +import argparse +import logging +import sys +from app.core.rag.crawler.web_crawler import WebCrawler + + +def setup_logging(verbose: bool = False): + """Set up logging configuration.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] + ) + + +def main(entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0"): + """Main entry point for the crawler.""" + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + + # Crawl and collect documents + documents = [] + try: + for doc in crawler.crawl(): + print(f"\n{'=' * 80}") + print(f"URL: {doc.url}") + print(f"Title: {doc.title}") + print(f"Content Length: {doc.content_length} characters") + print(f"Word Count: {doc.metadata.get('word_count', 0)} words") + print(f"{'=' * 80}\n") + + documents.append({ + 'url': doc.url, + 'title': doc.title, + 'content': doc.content, + 'content_length': doc.content_length, + 'crawl_timestamp': doc.crawl_timestamp.isoformat(), + 'http_status': doc.http_status, + 'metadata': doc.metadata + }) + + except KeyboardInterrupt: + print("\n\nCrawl interrupted by user.") + + except Exception as e: + print(f"\n\nError during crawl: {e}") + sys.exit(1) + + # Get summary + summary = crawler.get_summary() + print(f"\n{'=' * 80}") + print("CRAWL SUMMARY") + print(f"{'=' * 80}") + print(f"Total Pages Processed: {summary.total_pages_processed}") + print(f"Total Errors: {summary.total_errors}") + print(f"Total Skipped: {summary.total_skipped}") + print(f"Total URLs Discovered: {summary.total_urls_discovered}") + print(f"Duration: {summary.duration_seconds:.2f} seconds") + print(f"documents: {documents}") + + if summary.error_breakdown: + print(f"\nError Breakdown:") + for error_type, count in summary.error_breakdown.items(): + print(f" {error_type}: {count}") + + +if __name__ == '__main__': + entry_url = "https://www.xxx.com" + max_pages = 20 + delay_seconds = 1.0 + timeout_seconds = 10 + user_agent = "KnowledgeBaseCrawler/1.0" + + main(entry_url, max_pages, delay_seconds, timeout_seconds, user_agent) diff --git a/api/app/core/rag/crawler/content_extractor.py b/api/app/core/rag/crawler/content_extractor.py new file mode 100644 index 00000000..69dca53c --- /dev/null +++ b/api/app/core/rag/crawler/content_extractor.py @@ -0,0 +1,233 @@ +"""Content extractor for web crawler.""" + +from bs4 import BeautifulSoup +import re +import logging + +from app.core.rag.crawler.models import ExtractedContent + +logger = logging.getLogger(__name__) + + +class ContentExtractor: + """Extract clean, readable text from HTML pages.""" + + # Tags to remove completely + REMOVE_TAGS = ['script', 'style', 'nav', 'header', 'footer', 'aside'] + + # Tags that typically contain main content + MAIN_CONTENT_TAGS = ['article', 'main'] + + # Content extraction tags + CONTENT_TAGS = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'td', 'th', 'section'] + + def is_static_content(self, html: str) -> bool: + """ + Determine if the HTML represents static content. + + Detects JavaScript-rendered content by checking for minimal body + with heavy script tag presence. + + Args: + html: Raw HTML string + + Returns: + bool: True if static, False if JavaScript-rendered + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Count script tags + script_tags = soup.find_all('script') + script_count = len(script_tags) + + # Get body content (excluding scripts and styles) + body = soup.find('body') + if not body: + return False + + # Remove scripts and styles temporarily for text check + for tag in body.find_all(['script', 'style']): + tag.decompose() + + # Get text content + text = body.get_text(strip=True) + text_length = len(text) + + # If there's very little text but many scripts, likely JS-rendered + if script_count > 5 and text_length < 200: + logger.warning("Detected JavaScript-rendered content (many scripts, little text)") + return False + + # If there's no meaningful text, likely JS-rendered + if text_length < 50: + logger.warning("Detected JavaScript-rendered content (minimal text)") + return False + + return True + + except Exception as e: + logger.error(f"Error checking if content is static: {e}") + return True # Assume static on error + + def extract(self, html: str, url: str) -> ExtractedContent: + """ + Extract clean text content from HTML. + + Args: + html: Raw HTML string + url: Source URL (for context) + + Returns: + ExtractedContent: Contains title, text, metadata + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Check if content is static + is_static = self.is_static_content(html) + + # Extract title + title = self._extract_title(soup) + + # Remove unwanted tags + for tag_name in self.REMOVE_TAGS: + for tag in soup.find_all(tag_name): + tag.decompose() + + # Extract main content + text = self._extract_main_content(soup) + + # Normalize whitespace + text = self._normalize_whitespace(text) + + # Count words + word_count = len(text.split()) + + logger.info(f"Extracted {word_count} words from {url}") + + return ExtractedContent( + title=title, + text=text, + is_static=is_static, + word_count=word_count, + metadata={'url': url} + ) + + except Exception as e: + logger.error(f"Error extracting content from {url}: {e}") + return ExtractedContent( + title=url, + text="", + is_static=False, + word_count=0, + metadata={'url': url, 'error': str(e)} + ) + + def _extract_title(self, soup: BeautifulSoup) -> str: + """ + Extract title from HTML. + + Tries tag first, then first <h1>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Page title + """ + # Try <title> tag + title_tag = soup.find('title') + if title_tag and title_tag.string: + return title_tag.string.strip() + + # Try first <h1> + h1_tag = soup.find('h1') + if h1_tag: + return h1_tag.get_text(strip=True) + + # Default to empty string + return "" + + def _extract_main_content(self, soup: BeautifulSoup) -> str: + """ + Extract main content from HTML. + + Prioritizes semantic HTML5 elements like <article> and <main>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Extracted text content + """ + # Try to find main content area + main_content = None + + # Priority 1: <article> or <main> tags + for tag_name in self.MAIN_CONTENT_TAGS: + main_content = soup.find(tag_name) + if main_content: + logger.debug(f"Found main content in <{tag_name}> tag") + break + + # Priority 2: div with role="main" + if not main_content: + main_content = soup.find('div', role='main') + if main_content: + logger.debug("Found main content in div[role='main']") + + # Priority 3: Common class/id patterns + if not main_content: + for pattern in ['content', 'main', 'article', 'post']: + main_content = soup.find(['div', 'section'], class_=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with class pattern '{pattern}'") + break + + main_content = soup.find(['div', 'section'], id=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with id pattern '{pattern}'") + break + + # Fallback: use body + if not main_content: + main_content = soup.find('body') + logger.debug("Using <body> as main content (no specific content area found)") + + # Extract text from content tags + if main_content: + text_parts = [] + for tag in main_content.find_all(self.CONTENT_TAGS): + text = tag.get_text(strip=True) + if text: + text_parts.append(text) + + return '\n'.join(text_parts) + + return "" + + def _normalize_whitespace(self, text: str) -> str: + """ + Normalize whitespace in text. + + - Collapse multiple spaces to single space + - Reduce excessive newlines to maximum 2 + - Strip leading/trailing whitespace + + Args: + text: Text to normalize + + Returns: + str: Normalized text + """ + # Collapse multiple spaces to single space + text = re.sub(r' +', ' ', text) + + # Reduce excessive newlines to maximum 2 + text = re.sub(r'\n{3,}', '\n\n', text) + + # Strip leading/trailing whitespace + text = text.strip() + + return text diff --git a/api/app/core/rag/crawler/http_fetcher.py b/api/app/core/rag/crawler/http_fetcher.py new file mode 100644 index 00000000..b3a08098 --- /dev/null +++ b/api/app/core/rag/crawler/http_fetcher.py @@ -0,0 +1,302 @@ +"""HTTP fetcher for web crawler.""" + +import requests +import time +import logging +import re +from typing import Optional, Dict + + +from app.core.rag.crawler.models import FetchResult + +logger = logging.getLogger(__name__) + + +class HTTPFetcher: + """Handle HTTP requests with retries, error handling, and response validation.""" + + def __init__( + self, + timeout: int = 10, + max_retries: int = 3, + user_agent: str = "KnowledgeBaseCrawler/1.0" + ): + """ + Initialize HTTP fetcher. + + Args: + timeout: Request timeout in seconds + max_retries: Maximum number of retry attempts + user_agent: User-Agent header value + """ + self.timeout = timeout + self.max_retries = max_retries + self.user_agent = user_agent + + # Create session for connection pooling + self.session = requests.Session() + self.session.headers.update({ + 'User-Agent': user_agent + }) + + def fetch(self, url: str) -> FetchResult: + """ + Fetch a URL with retry logic and error handling. + + Args: + url: URL to fetch + + Returns: + FetchResult: Contains status_code, content, headers, error info + """ + last_error = None + + for attempt in range(self.max_retries): + try: + # Calculate backoff delay for retries + if attempt > 0: + backoff_delay = 2 ** (attempt - 1) # 1s, 2s, 4s + logger.info(f"Retry attempt {attempt + 1}/{self.max_retries} for {url} after {backoff_delay}s") + time.sleep(backoff_delay) + + # Make HTTP request + response = self.session.get( + url, + timeout=self.timeout, + allow_redirects=True + ) + + # Handle different status codes + if response.status_code == 429: + # Too Many Requests - backoff and retry + logger.warning(f"429 Too Many Requests for {url}, backing off") + if attempt < self.max_retries - 1: + continue + + if response.status_code == 503: + # Service Unavailable - pause and retry + logger.warning(f"503 Service Unavailable for {url}") + if attempt < self.max_retries - 1: + time.sleep(5) # Longer pause for 503 + continue + + # Success or client error (don't retry 4xx except 429) + if 200 <= response.status_code < 300: + logger.info(f"Successfully fetched {url} (status: {response.status_code})") + + # Get correctly encoded content + content = self._get_decoded_content(response) + + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=content, + headers=dict(response.headers), + error=None, + success=True + ) + elif response.status_code == 404: + logger.info(f"404 Not Found: {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error="Not Found", + success=False + ) + elif 400 <= response.status_code < 500: + logger.warning(f"Client error {response.status_code} for {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error=f"Client error: {response.status_code}", + success=False + ) + elif 500 <= response.status_code < 600: + logger.error(f"Server error {response.status_code} for {url}") + last_error = f"Server error: {response.status_code}" + if attempt < self.max_retries - 1: + continue + return FetchResult( + url=url, + final_url=url, + status_code=response.status_code, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.Timeout: + last_error = "Request timeout" + logger.warning(f"Timeout fetching {url} (attempt {attempt + 1}/{self.max_retries})") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.SSLError as e: + last_error = f"SSL/TLS error: {str(e)}" + logger.error(f"SSL/TLS error for {url}: {e}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.ConnectionError as e: + last_error = f"Connection error: {str(e)}" + logger.warning(f"Connection error for {url} (attempt {attempt + 1}/{self.max_retries}): {e}") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.RequestException as e: + last_error = f"Request error: {str(e)}" + logger.error(f"Request error for {url}: {e}") + if attempt >= self.max_retries - 1: + break + continue + + # All retries exhausted + logger.error(f"Failed to fetch {url} after {self.max_retries} attempts: {last_error}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error or "Unknown error", + success=False + ) + + def _get_decoded_content(self, response) -> str: + """ + Get correctly decoded content from response. + + Handles encoding detection and fallback strategies: + 1. Try encoding from HTML meta tags + 2. Try response.encoding (from Content-Type header or detected) + 3. Try UTF-8 + 4. Try common encodings (GB2312, GBK for Chinese, etc.) + 5. Fall back to latin-1 with error replacement + + Args: + response: requests.Response object + + Returns: + str: Decoded content + """ + # Try to detect encoding from HTML meta tags + meta_encoding = self._detect_encoding_from_meta(response.content) + if meta_encoding: + try: + content = response.content.decode(meta_encoding) + logger.info(f"Successfully decoded with meta tag encoding: {meta_encoding}") + return content + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with meta encoding {meta_encoding}: {e}") + + # Try response.encoding (from Content-Type header or detected by requests) + if response.encoding and response.encoding.lower() != 'iso-8859-1': + # Note: requests defaults to ISO-8859-1 if no charset in Content-Type, + # so we skip it here and try UTF-8 first + try: + return response.text + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with detected encoding {response.encoding}: {e}") + + # Try UTF-8 first (most common) + try: + return response.content.decode('utf-8') + except UnicodeDecodeError: + logger.debug("UTF-8 decoding failed, trying other encodings") + + # Try common encodings for different languages + encodings_to_try = [ + 'gbk', # Chinese (Simplified) + 'gb2312', # Chinese (Simplified, older) + 'gb18030', # Chinese (Simplified, extended) + 'big5', # Chinese (Traditional) + 'shift_jis', # Japanese + 'euc-jp', # Japanese + 'euc-kr', # Korean + 'iso-8859-1', # Western European + 'windows-1252', # Windows Western European + 'windows-1251', # Cyrillic + ] + + for encoding in encodings_to_try: + try: + content = response.content.decode(encoding) + logger.info(f"Successfully decoded with {encoding}") + return content + except (UnicodeDecodeError, LookupError): + continue + + # Last resort: use latin-1 with error replacement + logger.warning("All encoding attempts failed, using latin-1 with error replacement") + return response.content.decode('latin-1', errors='replace') + + def _detect_encoding_from_meta(self, content: bytes) -> Optional[str]: + """ + Detect encoding from HTML meta tags. + + Looks for: + - <meta charset="..."> + - <meta http-equiv="Content-Type" content="...; charset=..."> + + Args: + content: Raw response content (bytes) + + Returns: + Optional[str]: Detected encoding or None + """ + try: + # Only check first 2KB for performance + head = content[:2048] + + # Try to decode as ASCII/Latin-1 to search for meta tags + try: + head_str = head.decode('ascii', errors='ignore') + except: + head_str = head.decode('latin-1', errors='ignore') + + # Look for <meta charset="..."> + charset_match = re.search( + r'<meta[^>]+charset=["\']?([a-zA-Z0-9_-]+)', + head_str, + re.IGNORECASE + ) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in meta tag: {encoding}") + return encoding + + # Look for <meta http-equiv="Content-Type" content="...; charset=..."> + content_type_match = re.search( + r'<meta[^>]+http-equiv=["\']?content-type["\']?[^>]+content=["\']([^"\']+)', + head_str, + re.IGNORECASE + ) + if content_type_match: + content_value = content_type_match.group(1) + charset_match = re.search(r'charset=([a-zA-Z0-9_-]+)', content_value, re.IGNORECASE) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in Content-Type meta: {encoding}") + return encoding + + except Exception as e: + logger.debug(f"Error detecting encoding from meta tags: {e}") + + return None diff --git a/api/app/core/rag/crawler/models.py b/api/app/core/rag/crawler/models.py new file mode 100644 index 00000000..5d10963c --- /dev/null +++ b/api/app/core/rag/crawler/models.py @@ -0,0 +1,52 @@ +"""Data models for web crawler.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, Any, Optional + + +@dataclass +class CrawledDocument: + """Represents a successfully processed web page with extracted content.""" + url: str + title: str + content: str + content_length: int + crawl_timestamp: datetime + http_status: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FetchResult: + """Represents the result of an HTTP fetch operation.""" + url: str + final_url: str + status_code: int + content: Optional[str] + headers: Dict[str, str] + error: Optional[str] + success: bool + + +@dataclass +class ExtractedContent: + """Represents content extracted from HTML.""" + title: str + text: str + is_static: bool + word_count: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CrawlSummary: + """Represents statistics from a completed crawl.""" + total_pages_processed: int + total_errors: int + total_skipped: int + total_urls_discovered: int + start_time: datetime + end_time: datetime + duration_seconds: float + error_breakdown: Dict[str, int] = field(default_factory=dict) diff --git a/api/app/core/rag/crawler/rate_limiter.py b/api/app/core/rag/crawler/rate_limiter.py new file mode 100644 index 00000000..e00fad36 --- /dev/null +++ b/api/app/core/rag/crawler/rate_limiter.py @@ -0,0 +1,57 @@ +"""Rate limiter for web crawler.""" + +import time +import logging + +logger = logging.getLogger(__name__) + + +class RateLimiter: + """Enforce delays between requests to be polite to servers.""" + + def __init__(self, delay_seconds: float = 1.0): + """ + Initialize rate limiter. + + Args: + delay_seconds: Minimum delay between requests + """ + self.delay_seconds = delay_seconds + self.last_request_time = 0.0 + self.max_delay = 60.0 # Cap maximum delay at 60 seconds + + def wait(self): + """ + Block until enough time has passed since last request. + Respects the configured delay. + """ + current_time = time.time() + elapsed = current_time - self.last_request_time + + if elapsed < self.delay_seconds: + sleep_time = self.delay_seconds - elapsed + logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds") + time.sleep(sleep_time) + + self.last_request_time = time.time() + + def set_delay(self, delay_seconds: float): + """ + Update the delay (useful for respecting Crawl-delay from robots.txt). + + Args: + delay_seconds: New delay in seconds + """ + self.delay_seconds = min(delay_seconds, self.max_delay) + logger.info(f"Rate limiter delay updated to {self.delay_seconds} seconds") + + def backoff(self, multiplier: float = 2.0): + """ + Increase delay exponentially for backoff scenarios (429, 503 responses). + + Args: + multiplier: Factor to multiply current delay by + """ + old_delay = self.delay_seconds + self.delay_seconds = min(self.delay_seconds * multiplier, self.max_delay) + logger.warning(f"Rate limiter backing off: {old_delay:.2f}s -> {self.delay_seconds:.2f}s") diff --git a/api/app/core/rag/crawler/robots_parser.py b/api/app/core/rag/crawler/robots_parser.py new file mode 100644 index 00000000..882bc9c8 --- /dev/null +++ b/api/app/core/rag/crawler/robots_parser.py @@ -0,0 +1,118 @@ +"""Robots.txt parser for web crawler.""" + +from urllib.robotparser import RobotFileParser +from urllib.parse import urlparse, urljoin +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +class RobotsParser: + """Parse and check robots.txt compliance for URLs.""" + + def __init__(self, user_agent: str, timeout: int = 10): + """ + Initialize robots.txt parser. + + Args: + user_agent: User agent string to check permissions for + timeout: Timeout for fetching robots.txt + """ + self.user_agent = user_agent + self.timeout = timeout + self._parsers = {} # Cache parsers by domain + + def _get_robots_url(self, url: str) -> str: + """ + Get the robots.txt URL for a given URL. + + Args: + url: URL to get robots.txt for + + Returns: + str: robots.txt URL + """ + parsed = urlparse(url) + robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt" + return robots_url + + def _get_parser(self, url: str) -> RobotFileParser: + """ + Get or create a RobotFileParser for the domain. + + Args: + url: URL to get parser for + + Returns: + RobotFileParser: Parser for the domain + """ + robots_url = self._get_robots_url(url) + + # Return cached parser if available + if robots_url in self._parsers: + return self._parsers[robots_url] + + # Create new parser + parser = RobotFileParser() + parser.set_url(robots_url) + + try: + # Fetch and parse robots.txt + parser.read() + logger.info(f"Successfully fetched robots.txt from {robots_url}") + except Exception as e: + # If robots.txt cannot be fetched, assume all URLs are allowed + logger.warning(f"Could not fetch robots.txt from {robots_url}: {e}. Assuming all URLs allowed.") + # Create a permissive parser + parser = RobotFileParser() + parser.parse([]) # Empty robots.txt allows everything + + # Cache the parser + self._parsers[robots_url] = parser + return parser + + def can_fetch(self, url: str) -> bool: + """ + Check if the given URL can be fetched according to robots.txt. + + Args: + url: URL to check + + Returns: + bool: True if allowed, False if disallowed + """ + try: + parser = self._get_parser(url) + allowed = parser.can_fetch(self.user_agent, url) + + if not allowed: + logger.info(f"URL disallowed by robots.txt: {url}") + + return allowed + except Exception as e: + logger.error(f"Error checking robots.txt for {url}: {e}") + # On error, assume allowed + return True + + def get_crawl_delay(self, url: str) -> Optional[float]: + """ + Get the Crawl-delay directive from robots.txt if present. + + Args: + url: URL to get crawl delay for + + Returns: + Optional[float]: Delay in seconds, or None if not specified + """ + try: + parser = self._get_parser(url) + delay = parser.crawl_delay(self.user_agent) + + if delay is not None: + logger.info(f"Crawl-delay from robots.txt: {delay} seconds") + + return delay + except Exception as e: + logger.error(f"Error getting crawl delay for {url}: {e}") + return None diff --git a/api/app/core/rag/crawler/url_normalizer.py b/api/app/core/rag/crawler/url_normalizer.py new file mode 100644 index 00000000..7762a9d5 --- /dev/null +++ b/api/app/core/rag/crawler/url_normalizer.py @@ -0,0 +1,171 @@ +"""URL normalization and validation for web crawler.""" + +from typing import Optional, List +from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, urljoin +from bs4 import BeautifulSoup + + +class URLNormalizer: + """Normalize and validate URLs for deduplication and domain checking.""" + + # Common tracking parameters to remove + TRACKING_PARAMS = { + 'utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content', + 'fbclid', 'gclid', 'msclkid', '_ga', 'mc_cid', 'mc_eid' + } + + def __init__(self, base_domain: str): + """ + Initialize URL normalizer with base domain. + + Args: + base_domain: The domain to use for same-domain checks + """ + parsed = urlparse(base_domain) + self.base_domain = parsed.netloc.lower() # example.com:8000 + self.base_scheme = parsed.scheme or 'https' # https + + def normalize(self, url: str) -> Optional[str]: + """ + Normalize a URL for deduplication. + + Normalization rules: + 1. Convert domain to lowercase + 2. Remove fragments (#section) + 3. Remove default ports (80 for http, 443 for https) + 4. Remove trailing slashes (except for root) + 5. Sort query parameters alphabetically + 6. Remove common tracking parameters + + Args: + url: URL to normalize + + Returns: + Optional[str]: Normalized URL, or None if invalid + """ + try: + parsed = urlparse(url) + + # Validate scheme + if parsed.scheme not in ('http', 'https'): + return None + + # Normalize domain to lowercase + netloc = parsed.netloc.lower() + + # Remove default ports + if ':' in netloc: + host, port = netloc.rsplit(':', 1) + if (parsed.scheme == 'http' and port == '80') or \ + (parsed.scheme == 'https' and port == '443'): + netloc = host + + # Normalize path + path = parsed.path + # Remove trailing slash except for root + if path != '/' and path.endswith('/'): + path = path.rstrip('/') + # Ensure path starts with / + if not path: + path = '/' + + # Process query parameters + query = '' + if parsed.query: + # Parse query parameters + params = parse_qs(parsed.query, keep_blank_values=True) + # Remove tracking parameters + filtered_params = { + k: v for k, v in params.items() + if k not in self.TRACKING_PARAMS + } + # Sort parameters alphabetically + if filtered_params: + sorted_params = sorted(filtered_params.items()) + query = urlencode(sorted_params, doseq=True) + + # Reconstruct URL without fragment + normalized = urlunparse(( + parsed.scheme, + netloc, + path, + parsed.params, + query, + '' # Remove fragment + )) + + return normalized + + except Exception: + return None + + def is_same_domain(self, url: str) -> bool: + """ + Check if URL belongs to the same domain as base_domain. + + Args: + url: URL to check + + Returns: + bool: True if same domain, False otherwise + """ + try: + parsed = urlparse(url) + domain = parsed.netloc.lower() + + # Remove port if present + if ':' in domain: + domain = domain.split(':')[0] + + # Check if domains match + return domain == self.base_domain or domain == self.base_domain.split(':')[0] + + except Exception: + return False + + def extract_links(self, html: str, base_url: str) -> List[str]: + """ + Extract and normalize all links from HTML. + + Args: + html: HTML content + base_url: Base URL for resolving relative links + + Returns: + List[str]: List of normalized absolute URLs + """ + links = [] + + try: + soup = BeautifulSoup(html, 'lxml') + + # Find all anchor tags + for anchor in soup.find_all('a', href=True): + href = anchor['href'] + + # Skip empty hrefs + if not href or href.strip() == '': + continue + + # Skip javascript: and mailto: links + if href.startswith(('javascript:', 'mailto:', 'tel:')): + continue + + normalized_url = None + # Check if href starts with http/https (absolute URL) + if href.startswith(('http://', 'https://')): + if self.is_same_domain(href): + normalized_url = self.normalize(href) + else: + # Convert relative URL to absolute + absolute_url = urljoin(base_url, href) + # Normalize the URL + normalized_url = self.normalize(absolute_url) + + if normalized_url: + links.append(normalized_url) + + except Exception: + pass + + return links diff --git a/api/app/core/rag/crawler/web_crawler.py b/api/app/core/rag/crawler/web_crawler.py new file mode 100644 index 00000000..3afa09b2 --- /dev/null +++ b/api/app/core/rag/crawler/web_crawler.py @@ -0,0 +1,215 @@ +"""Main web crawler orchestrator.""" + +from collections import deque +from datetime import datetime +from typing import Iterator, Optional, List, Set +from urllib.parse import urlparse +import logging + +from app.core.rag.crawler.url_normalizer import URLNormalizer +from app.core.rag.crawler.robots_parser import RobotsParser +from app.core.rag.crawler.rate_limiter import RateLimiter +from app.core.rag.crawler.http_fetcher import HTTPFetcher +from app.core.rag.crawler.content_extractor import ContentExtractor +from app.core.rag.crawler.models import CrawledDocument, CrawlSummary + +logger = logging.getLogger(__name__) + + +class WebCrawler: + """Main orchestrator for web crawling.""" + + def __init__( + self, + entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0", + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + content_extractor: Optional[ContentExtractor] = None + ): + """ + Initialize the web crawler. + + Args: + entry_url: Starting URL for the crawl + max_pages: Maximum number of pages to crawl (default: 200) + delay_seconds: Delay between requests in seconds (default: 1.0) + timeout_seconds: HTTP request timeout (default: 10) + user_agent: User-Agent header string + include_patterns: List of regex patterns for URLs to include + exclude_patterns: List of regex patterns for URLs to exclude + content_extractor: Custom content extractor (optional) + """ + # Validate entry URL + parsed = urlparse(entry_url) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"Invalid entry URL: {entry_url}") + + self.entry_url = entry_url + self.max_pages = max_pages + self.user_agent = user_agent + + # Extract domain from entry URL + self.domain = parsed.netloc + + # Initialize components + self.url_normalizer = URLNormalizer(entry_url) + self.robots_parser = RobotsParser(user_agent, timeout_seconds) + self.rate_limiter = RateLimiter(delay_seconds) + self.http_fetcher = HTTPFetcher(timeout_seconds, max_retries=3, user_agent=user_agent) + self.content_extractor = content_extractor or ContentExtractor() + + # State management + self.url_queue: deque = deque() + self.visited_urls: Set[str] = set() + self.pages_processed = 0 + + # Statistics + self.stats = { + 'success': 0, + 'errors': 0, + 'skipped': 0, + 'urls_discovered': 0, + 'error_breakdown': {} + } + self.start_time: Optional[datetime] = None + self.end_time: Optional[datetime] = None + + def crawl(self) -> Iterator[CrawledDocument]: + """ + Execute the crawl and yield documents as they are processed. + + Yields: + CrawledDocument: Structured document with extracted content + """ + logger.info(f"Starting crawl from {self.entry_url} (max_pages: {self.max_pages})") + self.start_time = datetime.now() + + # Add entry URL to queue + normalized_entry = self.url_normalizer.normalize(self.entry_url) + if normalized_entry: + self.url_queue.append(normalized_entry) + self.stats['urls_discovered'] += 1 + + # Check robots.txt and update rate limiter if needed + crawl_delay = self.robots_parser.get_crawl_delay(self.entry_url) + if crawl_delay: + self.rate_limiter.set_delay(crawl_delay) + + # Main crawl loop + while self.url_queue and self.pages_processed < self.max_pages: + url = self.url_queue.popleft() + + # Skip if already visited + if url in self.visited_urls: + continue + + # Mark as visited + self.visited_urls.add(url) + + # Check robots.txt permission + if not self.robots_parser.can_fetch(url): + logger.info(f"Skipping {url} (disallowed by robots.txt)") + self.stats['skipped'] += 1 + continue + + # Apply rate limiting + self.rate_limiter.wait() + + # Fetch URL + logger.info(f"Fetching {url} ({self.pages_processed + 1}/{self.max_pages})") + fetch_result = self.http_fetcher.fetch(url) + + # Handle fetch errors + if not fetch_result.success: + self._record_error(fetch_result.error or "Unknown error") + continue + + # Check Content-Type + content_type = fetch_result.headers.get('Content-Type', '').lower() + if not any(substring in content_type for substring in ['text/html', 'application/xhtml+xml']): + logger.warning(f"Skipping {url} (Content-Type: {content_type})") + self.stats['skipped'] += 1 + continue + + # Extract content + try: + extracted = self.content_extractor.extract(fetch_result.content, url) + + # Check if static content + if not extracted.is_static: + logger.warning(f"Skipping {url} (JavaScript-rendered content)") + self.stats['skipped'] += 1 + continue + + # Create document + document = CrawledDocument( + url=url, + title=extracted.title, + content=extracted.text, + content_length=len(extracted.text), + crawl_timestamp=datetime.now(), + http_status=fetch_result.status_code, + metadata={ + 'word_count': extracted.word_count, + 'final_url': fetch_result.final_url + } + ) + + # Update statistics + self.pages_processed += 1 + self.stats['success'] += 1 + + # Extract and queue links + links = self.url_normalizer.extract_links(fetch_result.content, url) + for link in links: + if link not in self.visited_urls and self.url_normalizer.is_same_domain(link): + if link not in self.url_queue: + self.url_queue.append(link) + self.stats['urls_discovered'] += 1 + + # Yield document + yield document + + except Exception as e: + logger.error(f"Error processing {url}: {e}") + self._record_error(f"Processing error: {str(e)}") + continue + + self.end_time = datetime.now() + logger.info(f"Crawl completed. Processed {self.pages_processed} pages.") + + def get_summary(self) -> CrawlSummary: + """ + Get summary statistics after crawl completion. + + Returns: + CrawlSummary: Statistics including success/error/skip counts + """ + if not self.start_time: + self.start_time = datetime.now() + if not self.end_time: + self.end_time = datetime.now() + + duration = (self.end_time - self.start_time).total_seconds() + + return CrawlSummary( + total_pages_processed=self.stats['success'], + total_errors=self.stats['errors'], + total_skipped=self.stats['skipped'], + total_urls_discovered=self.stats['urls_discovered'], + start_time=self.start_time, + end_time=self.end_time, + duration_seconds=duration, + error_breakdown=self.stats['error_breakdown'] + ) + + def _record_error(self, error: str): + """Record an error in statistics.""" + self.stats['errors'] += 1 + error_type = error.split(':')[0] if ':' in error else error + self.stats['error_breakdown'][error_type] = \ + self.stats['error_breakdown'].get(error_type, 0) + 1 diff --git a/api/app/core/rag/integrations/__init__.py b/api/app/core/rag/integrations/__init__.py new file mode 100644 index 00000000..c1c43854 --- /dev/null +++ b/api/app/core/rag/integrations/__init__.py @@ -0,0 +1 @@ +"""Integrations package for external services.""" diff --git a/api/app/core/rag/integrations/feishu/__init__.py b/api/app/core/rag/integrations/feishu/__init__.py new file mode 100644 index 00000000..d989b816 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__init__.py @@ -0,0 +1 @@ +"""Feishu integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/feishu/__main__.py b/api/app/core/rag/integrations/feishu/__main__.py new file mode 100644 index 00000000..79d5a48e --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__main__.py @@ -0,0 +1,84 @@ +"""Command-line interface for feishu integration.""" + +import asyncio +import sys +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo + + +def main(feishu_app_id: str, # Feishu application ID + feishu_app_secret: str, # Feishu application secret + feishu_folder_token: str, # Feishu Folder Token + save_dir: str, # save file directory + feishu_api_base_url: str = "https://open.feishu.cn/open-apis", # Feishu API base URL + timeout: int = 30, # Request timeout in seconds + max_retries: int = 3, # Maximum number of retries + recursive: bool = True # recursive: Whether to sync subfolders recursively, + ): + """Main entry point for the feishuAPIClient.""" + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret, + api_base_url=feishu_api_base_url, + timeout=timeout, + max_retries=max_retries + ) + + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + if recursive: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + else: + all_files = [] + page_token = None + while True: + files_page, page_token = await client.list_folder_files( + feishu_folder_token, page_token + ) + all_files.extend(files_page) + if not page_token: + break + files = all_files + return files + files = asyncio.run(async_get_files(api_client,feishu_folder_token)) + + # Filter out folders, only sync documents + # documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file", "slides"]] + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + + try: + for doc in documents: + print(f"\n{'=' * 80}") + print(f"token: {doc.token}") + print(f"name: {doc.name}") + print(f"type: {doc.type}") + print(f"created_time: {doc.created_time}") + print(f"modified_time: {doc.modified_time}") + print(f"owner_id: {doc.owner_id}") + print(f"url: {doc.url}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == '__main__': + feishu_app_id = "" + feishu_app_secret = "" + feishu_folder_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(feishu_app_id, feishu_app_secret, feishu_folder_token, save_dir) diff --git a/api/app/core/rag/integrations/feishu/client.py b/api/app/core/rag/integrations/feishu/client.py new file mode 100644 index 00000000..0a3c4ea8 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/client.py @@ -0,0 +1,452 @@ +"""Feishu API client for document operations.""" + +import asyncio +import os +import re +from typing import Optional, Tuple, List +from datetime import datetime, timedelta +import httpx +from cachetools import TTLCache +import urllib.parse + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuAPIError, + FeishuNotFoundError, + FeishuPermissionError, + FeishuRateLimitError, + FeishuNetworkError, +) +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.feishu.retry import with_retry + + +class FeishuAPIClient: + """Feishu API client for document synchronization.""" + + def __init__( + self, + app_id: str, + app_secret: str, + api_base_url: str = "https://open.feishu.cn/open-apis", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Feishu API client. + + Args: + app_id: Feishu application ID + app_secret: Feishu application secret + api_base_url: Feishu API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.app_id = app_id + self.app_secret = app_secret + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + self._token_cache: TTLCache = TTLCache(maxsize=1, ttl=7200 - 300) # 2 hours - 5 minutes + self._token_lock = asyncio.Lock() + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={"Content-Type": "application/json"} + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + async def get_tenant_access_token(self) -> str: + """ + Get tenant access token with caching. + + Returns: + Access token string + + Raises: + FeishuAuthError: If authentication fails + """ + # Check cache first + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Use lock to prevent concurrent token requests + async with self._token_lock: + # Double-check cache after acquiring lock + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Request new token + try: + if not self._http_client: + raise FeishuAuthError("HTTP client not initialized") + + response = await self._http_client.post( + "/auth/v3/tenant_access_token/internal", + json={ + "app_id": self.app_id, + "app_secret": self.app_secret + } + ) + + data = response.json() + + if data.get("code") != 0: + error_msg = data.get("msg", "Unknown error") + raise FeishuAuthError( + f"Authentication failed: {error_msg}", + error_code=str(data.get("code")), + details=data + ) + + token = data.get("tenant_access_token") + if not token: + raise FeishuAuthError("No access token in response") + + # Cache the token + self._token_cache["access_token"] = token + + return token + + except httpx.HTTPError as e: + raise FeishuAuthError(f"HTTP error during authentication: {str(e)}") + except Exception as e: + if isinstance(e, FeishuAuthError): + raise + raise FeishuAuthError(f"Unexpected error during authentication: {str(e)}") + + @with_retry + async def list_folder_files( + self, + folder_token: str, + page_token: Optional[str] = None + ) -> Tuple[List[FileInfo], Optional[str]]: + """ + Get list of files in a folder with pagination support. + + Args: + folder_token: Folder token + page_token: Page token for pagination + + Returns: + Tuple of (list of FileInfo, next page token) + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If folder not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Build request parameters + params = {"page_size": 200, "folder_token": folder_token} + if page_token: + params["page_token"] = page_token + + # Make API request + response = await self._http_client.get( + f"/drive/v1/files", + params=params, + headers={"Authorization": f"Bearer {token}"} + ) + + data = response.json() + # print(f"get files: {data}") + + # Handle errors + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + + if error_code == 404 or error_code == 230005: + raise FeishuNotFoundError( + f"Folder not found: {error_msg}", + error_code=str(error_code), + details=data + ) + elif error_code == 403 or error_code == 230003: + raise FeishuPermissionError( + f"Permission denied: {error_msg}", + error_code=str(error_code), + details=data + ) + else: + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + # Parse response + files_data = data.get("data", {}).get("files", []) + next_page_token = data.get("data", {}).get("next_page_token", None) + + # Convert to FileInfo objects + files = [] + for file_data in files_data: + try: + file_info = FileInfo( + token=file_data.get("token", ""), + name=file_data.get("name", ""), + type=file_data.get("type", ""), + created_time=datetime.fromtimestamp(int(file_data.get("created_time", 0))), + modified_time=datetime.fromtimestamp(int(file_data.get("modified_time", 0))), + owner_id=file_data.get("owner_id", ""), + url=file_data.get("url", "") + ) + files.append(file_info) + except (ValueError, TypeError) as e: + # Skip invalid file entries + continue + + return files, next_page_token + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def list_all_folder_files( + self, + folder_token: str, + recursive: bool = True + ) -> List[FileInfo]: + """ + Get all files in a folder, handling pagination automatically. + + Args: + folder_token: Folder token + recursive: Whether to recursively get files from subfolders + + Returns: + List of all FileInfo objects + + Raises: + FeishuAPIError: If API call fails + """ + all_files = [] + page_token = None + + # Get all files with pagination + while True: + files, page_token = await self.list_folder_files(folder_token, page_token) + all_files.extend(files) + + if not page_token: + break + + # Recursively get files from subfolders if requested + if recursive: + subfolders = [f for f in all_files if f.type == "folder"] + for subfolder in subfolders: + try: + subfolder_files = await self.list_all_folder_files( + subfolder.token, + recursive=True + ) + all_files.extend(subfolder_files) + except Exception: + # Continue with other folders if one fails + continue + + return all_files + + @with_retry + async def download_document( + self, + document: FileInfo, + save_dir: str + ) -> str: + """ + download document content. + + Args: + document: Document FileInfo + save_dir: save dir + + Returns: + file_full_path + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If document not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Different API endpoints for different document types + if document.type == "doc" or document.type == "docx" or document.type == "sheet" or document.type == "bitable": + return await self._export_file(document, token, save_dir) + elif document.type == "file" or document.type == "slides": + return await self._download_file(document, token, save_dir) + else: + raise FeishuAPIError(f"Unsupported document type: {document.type}") + + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def _export_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """export file for feishu online file type.""" + try: + # 1.创建导出任务 + file_extension = "pdf" + match document.type: + case "doc": + file_extension = "doc" + case "docx": + file_extension = "docx" + case "sheet": + file_extension = "xlsx" + case "bitable": + file_extension = "xlsx" + case _: + file_extension = "pdf" + response = await self._http_client.post( + "/drive/v1/export_tasks", + json={ + "file_extension": file_extension, + "token": document.token, + "type": document.type + }, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"1.创建导出任务: {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + ticket = data.get("data", {}).get("ticket", None) + if not ticket: + raise FeishuAuthError("No ticket in response") + + # 2.轮序查询导出任务结果 + max_retries = 10 # 最大轮询次数 + poll_interval = 2 # 每次轮询间隔时间(秒) + file_token = None + for attempt in range(max_retries): + # 查询导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/{ticket}", + params={"token": document.token}, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"2. 尝试查询导出任务结果 (第{attempt + 1}次): {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data, + ) + + # 检查导出任务结果 + file_token = data.get("data", {}).get("result", {}).get("file_token", None) + if file_token: + # 如果导出任务成功生成 file_token,则退出轮询 + break + + # 如果结果还没准备好,等待一段时间再进行下一次轮询 + await asyncio.sleep(poll_interval) + + if not file_token: + raise FeishuAPIError("Export task did not complete within the allowed time") + + # 3.下载导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/file/{file_token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + print(f'3.下载导出任务: {response.headers.get("Content-Disposition")}') + + file_full_path = os.path.join(save_dir, document.name + "." + file_extension) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") + + async def _download_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """download file for file type.""" + try: + response = await self._http_client.get( + f"/drive/v1/files/{document.token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + + filename_header = response.headers.get("Content-Disposition") + + # 最终的文件名(初始化为 None) + filename = None + if filename_header: + # 优先解析 filename* 格式 + match = re.search(r"filename\*=([^']*)''([^;]+)", filename_header) + if match: + # 使用 `filename*` 提取(已编码) + encoding = match.group(1) # 编码部分(如 UTF-8) + encoded_filename = match.group(2) # 文件名部分 + filename = urllib.parse.unquote(encoded_filename) # 解码 URL 编码的文件名 + + # 如果 `filename*` 不存在,回退到解析 `filename` + if not filename: + match = re.search(r'filename="([^"]+)"', filename_header) + if match: + filename = match.group(1) + # 如果文件名仍为 None,则使用默认文件名 + if not filename: + filename = f"{document.name}.pdf" + # 确保文件名合法,替换非法字符 + filename = re.sub(r'[\/:*?"<>|]', '_', filename) + + file_full_path = os.path.join(save_dir, filename) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") diff --git a/api/app/core/rag/integrations/feishu/exceptions.py b/api/app/core/rag/integrations/feishu/exceptions.py new file mode 100644 index 00000000..26e42a07 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Feishu integration.""" + + +class FeishuError(Exception): + """Base exception for all Feishu-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class FeishuAuthError(FeishuError): + """Authentication error with Feishu API.""" + pass + + +class FeishuAPIError(FeishuError): + """General API error from Feishu.""" + pass + + +class FeishuNotFoundError(FeishuError): + """Resource not found error (404).""" + pass + + +class FeishuPermissionError(FeishuError): + """Permission denied error (403).""" + pass + + +class FeishuRateLimitError(FeishuError): + """Rate limit exceeded error (429).""" + pass + + +class FeishuNetworkError(FeishuError): + """Network-related error (timeout, connection failure).""" + pass + + +class FeishuDataError(FeishuError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/feishu/models.py b/api/app/core/rag/integrations/feishu/models.py new file mode 100644 index 00000000..b194afc1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/models.py @@ -0,0 +1,17 @@ +"""Data models for Feishu integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, Any, List, Optional + + +@dataclass +class FileInfo: + """File information from Feishu.""" + token: str + name: str + type: str # doc/docx/sheet/bitable/file/slides/folder + created_time: datetime + modified_time: datetime + owner_id: str + url: str diff --git a/api/app/core/rag/integrations/feishu/retry.py b/api/app/core/rag/integrations/feishu/retry.py new file mode 100644 index 00000000..c1d9aff1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/retry.py @@ -0,0 +1,137 @@ +"""Retry strategy for Feishu API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuRateLimitError, + FeishuNetworkError, + FeishuDataError, + FeishuAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + FeishuNetworkError, + FeishuRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for FeishuAPIError with specific codes + if isinstance(error, FeishuAPIError): + if error.error_code: + # Rate limit error codes + if error.error_code in ["99991400", "99991401"]: + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/core/rag/integrations/yuque/__init__.py b/api/app/core/rag/integrations/yuque/__init__.py new file mode 100644 index 00000000..dc4f2a17 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__init__.py @@ -0,0 +1 @@ +"""Yuque integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/yuque/__main__.py b/api/app/core/rag/integrations/yuque/__main__.py new file mode 100644 index 00000000..3b87bbcd --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__main__.py @@ -0,0 +1,77 @@ +"""Main entry point for Yuque integration testing.""" + +import asyncio +import sys +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo + + +def main(yuque_user_id: str, # yuque User ID + yuque_token: str, # yuque Token + save_dir: str, # save file directory + ): + """Main entry point for the YuqueAPIClient.""" + # Create feishuAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + files = asyncio.run(async_get_files(api_client)) + + try: + for doc in files: + print(f"\n{'=' * 80}") + print(f"id: {doc.id}") + print(f"type: {doc.type}") + print(f"slug: {doc.slug}") + print(f"title: {doc.title}") + print(f"book_id: {doc.book_id}") + # print(f"format: {doc.format}") + # print(f"body: {doc.body}") + # print(f"body_draft: {doc.body_draft}") + # print(f"body_html: {doc.body_html}") + print(f"public: {doc.public}") + print(f"status: {doc.status}") + print(f"created_at: {doc.created_at}") + print(f"updated_at: {doc.updated_at}") + print(f"published_at: {doc.published_at}") + print(f"word_count: {doc.word_count}") + print(f"cover: {doc.cover}") + print(f"description: {doc.description}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == "__main__": + yuque_user_id = "" + yuque_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(yuque_user_id, yuque_token, save_dir) diff --git a/api/app/core/rag/integrations/yuque/client.py b/api/app/core/rag/integrations/yuque/client.py new file mode 100644 index 00000000..444d9d31 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/client.py @@ -0,0 +1,544 @@ +"""Yuque API client for document operations.""" + +import os +import re +from typing import Optional, List +from datetime import datetime, timedelta +import httpx +import urllib.parse +import json +from openpyxl import Workbook +from openpyxl.styles import Font, Alignment, PatternFill +from openpyxl.utils import get_column_letter +import zlib + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuqueAPIError, + YuqueNotFoundError, + YuquePermissionError, + YuqueRateLimitError, + YuqueNetworkError, +) +from app.core.rag.integrations.yuque.models import YuqueDocInfo, YuqueRepoInfo +from app.core.rag.integrations.yuque.retry import with_retry + + +class YuqueAPIClient: + """Yuque API client for document synchronization.""" + + def __init__( + self, + user_id: str, + token: str, + api_base_url: str = "https://www.yuque.com/api/v2", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Yuque API client. + + Args: + user_id: Yuque user ID or login name + token: Yuque personal access token + api_base_url: Yuque API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.user_id = user_id + self.token = token + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={ + "Content-Type": "application/json", + "X-Auth-Token": self.token, + "User-Agent": "Yuque-Integration-Client" + } + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + def _handle_api_error(self, response: httpx.Response): + """Handle API error responses.""" + try: + data = response.json() + except Exception: + data = {} + + status_code = response.status_code + error_msg = data.get("message", "Unknown error") + + # Rate limit errors + if status_code == 429: + raise YuqueRateLimitError( + f"Rate limit exceeded: {error_msg}", + error_code=str(status_code), + details=data + ) + # Not found errors + elif status_code == 404: + raise YuqueNotFoundError( + f"Resource not found: {error_msg}", + error_code=str(status_code), + details=data + ) + # Permission errors + elif status_code == 403: + raise YuquePermissionError( + f"Permission denied: {error_msg}", + error_code=str(status_code), + details=data + ) + # Authentication errors + elif status_code == 401: + raise YuqueAuthError( + f"Authentication failed: {error_msg}", + error_code=str(status_code), + details=data + ) + # Generic API error + else: + raise YuqueAPIError( + f"API error: {error_msg}", + error_code=str(status_code), + details=data + ) + + @with_retry + async def get_user_repos(self) -> List[YuqueRepoInfo]: + """ + Get all repositories (知识库) for the user. + + Returns: + List of YuqueRepoInfo objects + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/users/{self.user_id}/repos") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + repos_data = data.get("data", []) + + repos = [] + for repo_data in repos_data: + try: + repo = YuqueRepoInfo( + id=repo_data.get("id"), + type=repo_data.get("type", ""), + name=repo_data.get("name", ""), + namespace=repo_data.get("namespace", ""), + slug=repo_data.get("slug", ""), + description=repo_data.get("description"), + public=repo_data.get("public", 0), + items_count=repo_data.get("items_count", 0), + created_at=datetime.fromisoformat(repo_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(repo_data.get("updated_at", "").replace("Z", "+00:00")) + ) + repos.append(repo) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid repo entries + continue + + return repos + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueAuthError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_repo_docs(self, book_id: int) -> List[YuqueDocInfo]: + """ + Get all documents in a repository. + + Args: + book_id: repository id + + Returns: + List of YuqueDocInfo objects (without body content) + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/repos/{book_id}/docs") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + docs_data = data.get("data", []) + + docs = [] + for doc_data in docs_data: + try: + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=None, # Body not included in list API + body_draft=None, + body_html=None, + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + docs.append(doc) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid doc entries + continue + + return docs + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_doc_detail(self, id: int) -> YuqueDocInfo: + """ + Get detailed document information including content. + + Args: + id: document ID + + Returns: + YuqueDocInfo object with full content + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get( + f"/repos/docs/{id}", + params={"raw": 1} # Get raw markdown content + ) + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + doc_data = data.get("data", {}) + + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=doc_data.get("body", ""), + body_draft=doc_data.get("body_draft"), + body_html=doc_data.get("body_html"), + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + + return doc + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + async def download_document( + self, + doc: YuqueDocInfo, + save_dir: str + ) -> str: + """ + Download document content to local file. + + Args: + doc: Document info (can be without body) + save_dir: Directory to save the file + + Returns: + Full path to the saved file + + Raises: + YuqueAPIError: If download fails + """ + try: + # Get full document content if not already loaded + if not doc.body: + doc = await self.get_doc_detail(doc.id) + + # Sanitize filename + filename = re.sub(r'[\/:*?"<>|]', '_', doc.title) + + # Determine file extension based on format + content = doc.body or "" + if doc.format == "markdown": + file_extension = "md" + elif doc.format == "lake": + file_extension = "md" # Save lake format as markdown + elif doc.format == "html": + file_extension = "html" + elif doc.format == "lakesheet": + file_extension = "xlsx" + + body_data = json.loads(doc.body) + sheet_data = body_data.get("sheet", "") + try: + sheet_raw = zlib.decompress(bytes(sheet_data, 'latin-1')) + except Exception as e: + print(f"Error decompressing sheet data: {e}") + raise ValueError("Invalid or unsupported sheet data format.") + try: + sheet_text = sheet_raw.decode("utf-8") # 假设是 UTF-8 编码 + except UnicodeDecodeError: + sheet_text = sheet_raw.decode("gbk") # 如果 UTF-8 解码失败,尝试 GBK + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + self.generate_excel_from_sheet(sheet_text, file_full_path) + return file_full_path + else: + file_extension = "txt" + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + # Remove existing file if it exists + if os.path.exists(file_full_path): + os.remove(file_full_path) + + # Write content to file + with open(file_full_path, "w", encoding="utf-8") as file: + file.write(content) + + return file_full_path + + except Exception as e: + if isinstance(e, YuqueAPIError): + raise + raise YuqueAPIError(f"Unexpected error during file download: {str(e)}") + + def generate_excel_from_sheet(self, sheet_text: str, save_path: str): + """ + 将解析的 sheet_text 数据转换为 Excel 文件。 + + Args: + sheet_text (str): JSON 格式的 sheet 数据。 + save_path (str): Excel 文件的保存路径。 + """ + try: + # 解析 JSON 数据 + sheets = json.loads(sheet_text) + + if not isinstance(sheets, list): + raise ValueError("sheet_text must be a JSON array of sheets.") + + # 创建一个新的 Excel 工作簿 + workbook = Workbook() + + for sheet_index, sheet_data in enumerate(sheets): + sheet_name = sheet_data.get("name", f"Sheet{sheet_index + 1}") + row_data = sheet_data.get("data", {}) + merge_cells = sheet_data.get("mergeCells", {}) + rows_styles = sheet_data.get("rows", []) + cols_styles = sheet_data.get("columns", []) + + # 创建 Sheet + if sheet_index == 0: + worksheet = workbook.active + worksheet.title = sheet_name + else: + worksheet = workbook.create_sheet(title=sheet_name) + + # 设置列宽 + for col_index, col_style in enumerate(cols_styles): + col_width = col_style.get("size", 82.125) / 7.0 + col_letter = get_column_letter(col_index + 1) # Excel 列从1开始 + worksheet.column_dimensions[col_letter].width = col_width + + # 设置行高 + for row_index, row_style in enumerate(rows_styles): + row_height = row_style.get("size", 24) / 1.5 + worksheet.row_dimensions[row_index + 1].height = row_height + + # 写入单元格数据 + for r_index, row in row_data.items(): + for c_index, cell in row.items(): + # 防御性检查:确保行号和列号都是有效的整数 + try: + row_number = int(r_index) + 1 + col_number = int(c_index) + 1 + except ValueError: + print(f"Invalid row or column index: r_index={r_index}, c_index={c_index}") + continue + + if col_number < 1 or col_number > 16384: # Excel 最大列数支持到 XFD,即 16384 列 + print(f"Invalid column index: c_index={c_index}") + continue + + cell_obj = worksheet.cell(row=row_number, column=col_number) + + # 处理值和公式 + cell_value = cell.get("value", "") + if isinstance(cell_value, dict): + # 检查是否为公式 + if cell_value.get("class") == "formula" and "formula" in cell_value: + cell_obj.value = f"={cell_value['formula']}" # 写入公式 + else: + cell_obj.value = cell_value.get("value", "") # 写入值 + else: + cell_obj.value = cell_value # 写入简单值 + + # 应用样式 + style = cell.get("style", {}) + self.apply_cell_style(cell_obj, style) + + # 合并单元格 + for key, merge_def in merge_cells.items(): + start_row = merge_def["row"] + 1 + start_col = merge_def["col"] + 1 + end_row = start_row + merge_def["rowCount"] - 1 + end_col = start_col + merge_def["colCount"] - 1 + worksheet.merge_cells( + start_row=start_row, start_column=start_col, end_row=end_row, end_column=end_col + ) + + # 保存 Excel 文件 + workbook.save(save_path) + print(f"Excel file successfully saved to: {save_path}") + + except Exception as e: + print(f"Error generating Excel file: {e}") + + + def apply_cell_style(self, cell, style): + """ + 应用单元格样式,包括字体、对齐、背景颜色等。 + + Args: + cell: openpyxl 的单元格对象。 + style: 字典格式的样式信息。 + """ + # 定义允许的对齐值 + allowed_horizontal_alignments = {"general", "left", "center", "centerContinuous", "right", "fill", "justify", + "distributed"} + allowed_vertical_alignments = {"top", "center", "justify", "distributed", "bottom"} + + # 处理字体 + font = Font( + size=style.get("fontSize", 11), + bold=style.get("fontWeight", False), + italic=style.get("fontStyle", "normal") == "italic", + underline="single" if style.get("underline", False) else None, + color=self.convert_color_to_hex(style.get("color", "#000000")), + ) + cell.font = font + + # 处理对齐方式 + horizontal_alignment = style.get("hAlign", "left") + vertical_alignment = style.get("vAlign", "top") + + # 如果对齐值无效,则使用默认值 + if horizontal_alignment not in allowed_horizontal_alignments: + horizontal_alignment = "left" + if vertical_alignment not in allowed_vertical_alignments: + vertical_alignment = "top" + + alignment = Alignment( + horizontal=horizontal_alignment, + vertical=vertical_alignment, + wrap_text=style.get("overflow") == "wrap", + ) + cell.alignment = alignment + + # 处理背景颜色 + background_color = style.get("backColor", None) + if background_color: + hex_color = self.convert_color_to_hex(background_color) + if hex_color: + cell.fill = PatternFill( + start_color=hex_color, + end_color=hex_color, + fill_type="solid" + ) + + def convert_color_to_hex(self, color): + """ + 将颜色从 `rgba(...)` 或 `rgb(...)` 转换为 aRGB 十六进制格式。 + + Args: + color (str): 原始颜色字符串,如 `rgba(255,255,0,1.00)` 或 `#FFFFFF`。 + + Returns: + str: 转换后的颜色字符串(符合 openpyxl 的格式),例如 `FFFF0000`。 + """ + try: + if not color: + return None + + # 如果是 `#RRGGBB` 或 `#AARRGGBB` 格式,直接返回 + if color.startswith("#"): + return color.lstrip("#").upper() + + # 如果是 `rgb(...)` 格式,例如 `rgb(255,255,0)` + if color.startswith("rgb("): + rgb_values = color.strip("rgb()").split(",") + red, green, blue = [int(v) for v in rgb_values] + return f"FF{red:02X}{green:02X}{blue:02X}" + + # 如果是 `rgba(...)` 格式,例如 `rgba(255,255,0,1.00)` + if color.startswith("rgba("): + rgba_values = color.strip("rgba()").split(",") + red, green, blue = [int(v) for v in rgba_values[:3]] + alpha = float(rgba_values[3]) + alpha_hex = int(alpha * 255) # 将透明度转换为 [00, FF] + return f"{alpha_hex:02X}{red:02X}{green:02X}{blue:02X}" + + # 返回默认颜色 + return None + except Exception as e: + print(f"Error parsing color '{color}': {e}") + return None diff --git a/api/app/core/rag/integrations/yuque/exceptions.py b/api/app/core/rag/integrations/yuque/exceptions.py new file mode 100644 index 00000000..e862323c --- /dev/null +++ b/api/app/core/rag/integrations/yuque/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Yuque integration.""" + + +class YuqueError(Exception): + """Base exception for all Yuque-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class YuqueAuthError(YuqueError): + """Authentication error with Yuque API.""" + pass + + +class YuqueAPIError(YuqueError): + """General API error from Yuque.""" + pass + + +class YuqueNotFoundError(YuqueError): + """Resource not found error (404).""" + pass + + +class YuquePermissionError(YuqueError): + """Permission denied error (403).""" + pass + + +class YuqueRateLimitError(YuqueError): + """Rate limit exceeded error (429).""" + pass + + +class YuqueNetworkError(YuqueError): + """Network-related error (timeout, connection failure).""" + pass + + +class YuqueDataError(YuqueError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/yuque/models.py b/api/app/core/rag/integrations/yuque/models.py new file mode 100644 index 00000000..6230aa69 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/models.py @@ -0,0 +1,42 @@ +"""Data models for Yuque integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class YuqueRepoInfo: + """Repository (知识库) information from Yuque.""" + id: int # 知识库 ID + type: str # 类型 (Book:文档, Design:图集, Sheet:表格, Resource:资源) + name: str # 名称 + namespace: str # 完整路径: user/repo format + slug: str # 路径 + description: Optional[str] # 简介 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + items_count: int # 文档数量 + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + + +@dataclass +class YuqueDocInfo: + """Document information from Yuque.""" + id: int # 文档 ID + type: str # 文档类型 (Doc:普通文档, Sheet:表格, Thread:话题, Board:图集, Table:数据表) + slug: str # 路径 + title: str # 标题 + book_id: int # 归属知识库 ID + format: str # 内容格式 (markdown:Markdown 格式, lake:语雀 Lake 格式, html:HTML 标准格式, lakesheet:语雀表格) + body: Optional[str] # 正文原始内容 + body_draft: Optional[str] # 正文草稿内容 + body_html: Optional[str] # 正文 HTML 标准格式内容 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + status: int # 状态 (0:草稿, 1:发布) + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + published_at: Optional[datetime] # 发布时间 + word_count: int # 内容字数 + cover: Optional[str] # 封面 + description: Optional[str] # 摘要 diff --git a/api/app/core/rag/integrations/yuque/retry.py b/api/app/core/rag/integrations/yuque/retry.py new file mode 100644 index 00000000..a68d6b47 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/retry.py @@ -0,0 +1,134 @@ +"""Retry strategy for Yuque API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueRateLimitError, + YuqueNetworkError, + YuqueDataError, + YuqueAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + YuqueNetworkError, + YuqueRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for YuqueRateLimitError + if isinstance(error, YuqueRateLimitError): + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/models/file_model.py b/api/app/models/file_model.py index 842e3dc8..44a7d613 100644 --- a/api/app/models/file_model.py +++ b/api/app/models/file_model.py @@ -14,4 +14,5 @@ class File(Base): file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /") file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf") file_size = Column(Integer, default=0, comment="file size(byte)") + file_url = Column(String, index=True, nullable=True, comment="file comes from a website url") created_at = Column(DateTime, default=datetime.datetime.now) \ No newline at end of file diff --git a/api/app/models/knowledge_model.py b/api/app/models/knowledge_model.py index 8f0909d3..fbebe1b4 100644 --- a/api/app/models/knowledge_model.py +++ b/api/app/models/knowledge_model.py @@ -57,6 +57,17 @@ class Knowledge(Base): parser_id = Column(String, index=True, default="naive", comment="default parser ID") parser_config = Column(JSON, nullable=False, default={ + "entry_url": "https://ai.redbearai.com", + "max_pages": 20, + "delay_seconds": 1.0, + "timeout_seconds": 10, + "user_agent": "KnowledgeBaseCrawler/1.0", + "yuque_user_id": "User ID", + "yuque_token": "Token", + "feishu_app_id": "App ID", + "feishu_app_secret": "App Secret", + "feishu_folder_token": "Folder Token", + "sync_cron": "30 7 * * 1-5", "layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n", diff --git a/api/app/schemas/file_schema.py b/api/app/schemas/file_schema.py index 00f1a148..7245671a 100644 --- a/api/app/schemas/file_schema.py +++ b/api/app/schemas/file_schema.py @@ -10,6 +10,8 @@ class FileBase(BaseModel): file_name: str file_ext: str file_size: int + file_url: str | None = None + created_at: datetime.datetime | None = None class FileCreate(FileBase): @@ -26,6 +28,7 @@ class FileUpdate(BaseModel): file_name: str | None = Field(None) file_ext: str | None = Field(None) file_size: str | None = Field(None) + file_url: str | None = Field(None) class File(FileBase): diff --git a/api/app/tasks.py b/api/app/tasks.py index a46a3a7b..29b0e485 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -7,6 +7,8 @@ import uuid from uuid import UUID from datetime import datetime, timezone from math import ceil +from pathlib import Path +import shutil from typing import Any, Dict, List, Optional import redis @@ -16,8 +18,13 @@ import trio # Import a unified Celery instance from app.celery_app import celery_app from app.core.config import settings +from app.core.rag.crawler.web_crawler import WebCrawler from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo from app.core.rag.llm.chat_model import Base from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.embedding_model import OpenAIEmbed @@ -29,7 +36,9 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ) from app.db import get_db, get_db_context from app.models.document_model import Document +from app.models.file_model import File from app.models.knowledge_model import Knowledge +from app.schemas import file_schema, document_schema from app.services.memory_agent_service import MemoryAgentService @@ -382,6 +391,480 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): db.close() +@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") +def sync_knowledge_for_kb(kb_id: uuid.UUID): + """ + sync knowledge document and Document parsing, vectorization, and storage + """ + db = next(get_db()) # Manually call the generator + db_knowledge = None + try: + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + # 1. get vector_service + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + # 2. sync data + match db_knowledge.type: + case "Web": # Crawl webpages in batches through a web crawler + entry_url = db_knowledge.parser_config.get("entry_url", "") + max_pages = db_knowledge.parser_config.get("max_pages", 20) + delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0) + timeout_seconds = db_knowledge.parser_config.get("timeout_seconds", 10) + user_agent = db_knowledge.parser_config.get("user_agent", "KnowledgeBaseCrawler/1.0") + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + try: + # 初始化存储已爬取 URLs 的集合 + file_urls = set() + # crawl entry_url by yield + for crawled_document in crawler.crawl(): + file_urls.add(crawled_document.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == crawled_document.url).first() + if db_file: + if db_file.file_size == crawled_document.content_length: # same + continue + else: # --update + if crawled_document.content_length: + # 1. update file + db_file.file_name = f"{crawled_document.title}.txt" + db_file.file_ext=".txt" + db_file.file_size=crawled_document.content_length + db.commit() + db.refresh(db_file) + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + if crawled_document.content_length: + # 1. upload file + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=f"{crawled_document.title}.txt", + file_ext=".txt", + file_size=crawled_document.content_length, + file_url=crawled_document.url, + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # Save file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during crawl: {e}") + case "Third-party": # Integration of knowledge bases from three parties + yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") + feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") + if yuque_user_id: # Yuque Knowledge Base + yuque_token = db_knowledge.parser_config.get("yuque_token", "") + # Create yuqueAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + try: + # 初始化存储获取语雀 URLs 的集合 + file_urls = set() + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + + files = asyncio.run(async_get_files(api_client)) + for doc in files: + file_urls.add(doc.slug) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.slug).first() + if db_file: + if db_file.created_at == doc.updated_at: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.updated_at + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.slug, + created_at=doc.updated_at + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + if feishu_app_id: # Feishu Knowledge Base + feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") + feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + try: + # 初始化存储获取飞书 URLs 的集合 + file_urls = set() + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + return files + files = asyncio.run(async_get_files(api_client, feishu_folder_token)) + # Filter out folders, only sync documents + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + for doc in documents: + file_urls.add(doc.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.url).first() + if db_file: + if db_file.created_at == doc.modified_time: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.modified_time + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.url, + created_at = doc.modified_time + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + case _: # General + print(f"General: No synchronization needed\n") + + + result = f"sync knowledge '{db_knowledge.name}' processed successfully." + return result + except Exception as e: + if 'db_knowledge' in locals(): + print(f"Failed to sync knowledge:{str(e)}\n") + result = f"sync knowledge '{db_knowledge.name}' failed." + return result + finally: + db.close() + + @celery_app.task(name="app.core.memory.agent.read_message", bind=True) def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]: diff --git a/api/pyproject.toml b/api/pyproject.toml index 6d23a3b9..66b1a295 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -141,6 +141,8 @@ dependencies = [ "flower>=2.0.1", "aiofiles>=23.0.0", "owlready2>=0.46", + "lxml>=4.9.0", + "httpx>=0.28.0", ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 6cdae2d1..144c0db2 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -134,3 +134,5 @@ xlrd==2.0.2 oss2>=2.18.0 boto3>=1.28.0 aiofiles>=23.0.0 +lxml>=4.9.0 +httpx>=0.28.0