diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 002547f6..db78a368 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -76,6 +76,7 @@ celery_app.conf.update( # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, + 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, diff --git a/api/app/controllers/knowledge_controller.py b/api/app/controllers/knowledge_controller.py index 901208ba..01f89a3d 100644 --- a/api/app/controllers/knowledge_controller.py +++ b/api/app/controllers/knowledge_controller.py @@ -9,13 +9,16 @@ from sqlalchemy import or_ from sqlalchemy.orm import Session from app.celery_app import celery_app +from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger from app.core.rag.common import settings +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.yuque.client import YuqueAPIClient from app.core.rag.llm.chat_model import Base from app.core.rag.nlp import rag_tokenizer, search from app.core.rag.prompts.generator import graph_entity_types from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory -from app.core.response_utils import success +from app.core.response_utils import success, fail from app.db import get_db from app.dependencies import get_current_user from app.models import knowledge_model @@ -484,3 +487,99 @@ async def rebuild_knowledge_graph( except Exception as e: api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}") raise + + +@router.get("/check/yuque/auth", response_model=ApiResponse) +async def check_yuque_auth( + yuque_user_id: str, + yuque_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check yuque auth info + """ + api_logger.info(f"check yuque auth info, username: {current_user.username}") + + try: + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + async with api_client as client: + repos = await client.get_user_repos() + if repos: + return success(data=repos, msg="Successfully auth yuque info") + return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth yuque info failed: {str(e)}") + raise + + +@router.get("/check/feishu/auth", response_model=ApiResponse) +async def check_yuque_auth( + feishu_app_id: str, + feishu_app_secret: str, + feishu_folder_token: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + check feishu auth info + """ + api_logger.info(f"check feishu auth info, username: {current_user.username}") + + try: + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + if files: + return success(data=files, msg="Successfully auth feishu info") + return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"auth feishu info failed: {str(e)}") + raise + + +@router.post("/{knowledge_id}/sync", response_model=ApiResponse) +async def sync_knowledge( + knowledge_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + sync knowledge base information based on knowledge_id + """ + api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}") + + try: + # 1. Query knowledge base information from the database + api_logger.debug(f"Query knowledge base: {knowledge_id}") + db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user) + if not db_knowledge: + api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The knowledge base does not exist or access is denied" + ) + + # 2. sync knowledge + # from app.tasks import sync_knowledge_for_kb + # sync_knowledge_for_kb(kb_id) + task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id]) + result = { + "task_id": task.id + } + return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.") + except HTTPException: + raise + except Exception as e: + api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}") + raise diff --git a/api/app/core/rag/crawler/__init__.py b/api/app/core/rag/crawler/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/rag/crawler/__main__.py b/api/app/core/rag/crawler/__main__.py new file mode 100644 index 00000000..51a6870f --- /dev/null +++ b/api/app/core/rag/crawler/__main__.py @@ -0,0 +1,89 @@ +"""Command-line interface for web crawler.""" + +import argparse +import logging +import sys +from app.core.rag.crawler.web_crawler import WebCrawler + + +def setup_logging(verbose: bool = False): + """Set up logging configuration.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] + ) + + +def main(entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0"): + """Main entry point for the crawler.""" + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + + # Crawl and collect documents + documents = [] + try: + for doc in crawler.crawl(): + print(f"\n{'=' * 80}") + print(f"URL: {doc.url}") + print(f"Title: {doc.title}") + print(f"Content Length: {doc.content_length} characters") + print(f"Word Count: {doc.metadata.get('word_count', 0)} words") + print(f"{'=' * 80}\n") + + documents.append({ + 'url': doc.url, + 'title': doc.title, + 'content': doc.content, + 'content_length': doc.content_length, + 'crawl_timestamp': doc.crawl_timestamp.isoformat(), + 'http_status': doc.http_status, + 'metadata': doc.metadata + }) + + except KeyboardInterrupt: + print("\n\nCrawl interrupted by user.") + + except Exception as e: + print(f"\n\nError during crawl: {e}") + sys.exit(1) + + # Get summary + summary = crawler.get_summary() + print(f"\n{'=' * 80}") + print("CRAWL SUMMARY") + print(f"{'=' * 80}") + print(f"Total Pages Processed: {summary.total_pages_processed}") + print(f"Total Errors: {summary.total_errors}") + print(f"Total Skipped: {summary.total_skipped}") + print(f"Total URLs Discovered: {summary.total_urls_discovered}") + print(f"Duration: {summary.duration_seconds:.2f} seconds") + print(f"documents: {documents}") + + if summary.error_breakdown: + print(f"\nError Breakdown:") + for error_type, count in summary.error_breakdown.items(): + print(f" {error_type}: {count}") + + +if __name__ == '__main__': + entry_url = "https://www.xxx.com" + max_pages = 20 + delay_seconds = 1.0 + timeout_seconds = 10 + user_agent = "KnowledgeBaseCrawler/1.0" + + main(entry_url, max_pages, delay_seconds, timeout_seconds, user_agent) diff --git a/api/app/core/rag/crawler/content_extractor.py b/api/app/core/rag/crawler/content_extractor.py new file mode 100644 index 00000000..69dca53c --- /dev/null +++ b/api/app/core/rag/crawler/content_extractor.py @@ -0,0 +1,233 @@ +"""Content extractor for web crawler.""" + +from bs4 import BeautifulSoup +import re +import logging + +from app.core.rag.crawler.models import ExtractedContent + +logger = logging.getLogger(__name__) + + +class ContentExtractor: + """Extract clean, readable text from HTML pages.""" + + # Tags to remove completely + REMOVE_TAGS = ['script', 'style', 'nav', 'header', 'footer', 'aside'] + + # Tags that typically contain main content + MAIN_CONTENT_TAGS = ['article', 'main'] + + # Content extraction tags + CONTENT_TAGS = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'td', 'th', 'section'] + + def is_static_content(self, html: str) -> bool: + """ + Determine if the HTML represents static content. + + Detects JavaScript-rendered content by checking for minimal body + with heavy script tag presence. + + Args: + html: Raw HTML string + + Returns: + bool: True if static, False if JavaScript-rendered + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Count script tags + script_tags = soup.find_all('script') + script_count = len(script_tags) + + # Get body content (excluding scripts and styles) + body = soup.find('body') + if not body: + return False + + # Remove scripts and styles temporarily for text check + for tag in body.find_all(['script', 'style']): + tag.decompose() + + # Get text content + text = body.get_text(strip=True) + text_length = len(text) + + # If there's very little text but many scripts, likely JS-rendered + if script_count > 5 and text_length < 200: + logger.warning("Detected JavaScript-rendered content (many scripts, little text)") + return False + + # If there's no meaningful text, likely JS-rendered + if text_length < 50: + logger.warning("Detected JavaScript-rendered content (minimal text)") + return False + + return True + + except Exception as e: + logger.error(f"Error checking if content is static: {e}") + return True # Assume static on error + + def extract(self, html: str, url: str) -> ExtractedContent: + """ + Extract clean text content from HTML. + + Args: + html: Raw HTML string + url: Source URL (for context) + + Returns: + ExtractedContent: Contains title, text, metadata + """ + try: + soup = BeautifulSoup(html, 'lxml') + + # Check if content is static + is_static = self.is_static_content(html) + + # Extract title + title = self._extract_title(soup) + + # Remove unwanted tags + for tag_name in self.REMOVE_TAGS: + for tag in soup.find_all(tag_name): + tag.decompose() + + # Extract main content + text = self._extract_main_content(soup) + + # Normalize whitespace + text = self._normalize_whitespace(text) + + # Count words + word_count = len(text.split()) + + logger.info(f"Extracted {word_count} words from {url}") + + return ExtractedContent( + title=title, + text=text, + is_static=is_static, + word_count=word_count, + metadata={'url': url} + ) + + except Exception as e: + logger.error(f"Error extracting content from {url}: {e}") + return ExtractedContent( + title=url, + text="", + is_static=False, + word_count=0, + metadata={'url': url, 'error': str(e)} + ) + + def _extract_title(self, soup: BeautifulSoup) -> str: + """ + Extract title from HTML. + + Tries tag first, then first <h1>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Page title + """ + # Try <title> tag + title_tag = soup.find('title') + if title_tag and title_tag.string: + return title_tag.string.strip() + + # Try first <h1> + h1_tag = soup.find('h1') + if h1_tag: + return h1_tag.get_text(strip=True) + + # Default to empty string + return "" + + def _extract_main_content(self, soup: BeautifulSoup) -> str: + """ + Extract main content from HTML. + + Prioritizes semantic HTML5 elements like <article> and <main>. + + Args: + soup: BeautifulSoup object + + Returns: + str: Extracted text content + """ + # Try to find main content area + main_content = None + + # Priority 1: <article> or <main> tags + for tag_name in self.MAIN_CONTENT_TAGS: + main_content = soup.find(tag_name) + if main_content: + logger.debug(f"Found main content in <{tag_name}> tag") + break + + # Priority 2: div with role="main" + if not main_content: + main_content = soup.find('div', role='main') + if main_content: + logger.debug("Found main content in div[role='main']") + + # Priority 3: Common class/id patterns + if not main_content: + for pattern in ['content', 'main', 'article', 'post']: + main_content = soup.find(['div', 'section'], class_=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with class pattern '{pattern}'") + break + + main_content = soup.find(['div', 'section'], id=re.compile(pattern, re.I)) + if main_content: + logger.debug(f"Found main content with id pattern '{pattern}'") + break + + # Fallback: use body + if not main_content: + main_content = soup.find('body') + logger.debug("Using <body> as main content (no specific content area found)") + + # Extract text from content tags + if main_content: + text_parts = [] + for tag in main_content.find_all(self.CONTENT_TAGS): + text = tag.get_text(strip=True) + if text: + text_parts.append(text) + + return '\n'.join(text_parts) + + return "" + + def _normalize_whitespace(self, text: str) -> str: + """ + Normalize whitespace in text. + + - Collapse multiple spaces to single space + - Reduce excessive newlines to maximum 2 + - Strip leading/trailing whitespace + + Args: + text: Text to normalize + + Returns: + str: Normalized text + """ + # Collapse multiple spaces to single space + text = re.sub(r' +', ' ', text) + + # Reduce excessive newlines to maximum 2 + text = re.sub(r'\n{3,}', '\n\n', text) + + # Strip leading/trailing whitespace + text = text.strip() + + return text diff --git a/api/app/core/rag/crawler/http_fetcher.py b/api/app/core/rag/crawler/http_fetcher.py new file mode 100644 index 00000000..b3a08098 --- /dev/null +++ b/api/app/core/rag/crawler/http_fetcher.py @@ -0,0 +1,302 @@ +"""HTTP fetcher for web crawler.""" + +import requests +import time +import logging +import re +from typing import Optional, Dict + + +from app.core.rag.crawler.models import FetchResult + +logger = logging.getLogger(__name__) + + +class HTTPFetcher: + """Handle HTTP requests with retries, error handling, and response validation.""" + + def __init__( + self, + timeout: int = 10, + max_retries: int = 3, + user_agent: str = "KnowledgeBaseCrawler/1.0" + ): + """ + Initialize HTTP fetcher. + + Args: + timeout: Request timeout in seconds + max_retries: Maximum number of retry attempts + user_agent: User-Agent header value + """ + self.timeout = timeout + self.max_retries = max_retries + self.user_agent = user_agent + + # Create session for connection pooling + self.session = requests.Session() + self.session.headers.update({ + 'User-Agent': user_agent + }) + + def fetch(self, url: str) -> FetchResult: + """ + Fetch a URL with retry logic and error handling. + + Args: + url: URL to fetch + + Returns: + FetchResult: Contains status_code, content, headers, error info + """ + last_error = None + + for attempt in range(self.max_retries): + try: + # Calculate backoff delay for retries + if attempt > 0: + backoff_delay = 2 ** (attempt - 1) # 1s, 2s, 4s + logger.info(f"Retry attempt {attempt + 1}/{self.max_retries} for {url} after {backoff_delay}s") + time.sleep(backoff_delay) + + # Make HTTP request + response = self.session.get( + url, + timeout=self.timeout, + allow_redirects=True + ) + + # Handle different status codes + if response.status_code == 429: + # Too Many Requests - backoff and retry + logger.warning(f"429 Too Many Requests for {url}, backing off") + if attempt < self.max_retries - 1: + continue + + if response.status_code == 503: + # Service Unavailable - pause and retry + logger.warning(f"503 Service Unavailable for {url}") + if attempt < self.max_retries - 1: + time.sleep(5) # Longer pause for 503 + continue + + # Success or client error (don't retry 4xx except 429) + if 200 <= response.status_code < 300: + logger.info(f"Successfully fetched {url} (status: {response.status_code})") + + # Get correctly encoded content + content = self._get_decoded_content(response) + + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=content, + headers=dict(response.headers), + error=None, + success=True + ) + elif response.status_code == 404: + logger.info(f"404 Not Found: {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error="Not Found", + success=False + ) + elif 400 <= response.status_code < 500: + logger.warning(f"Client error {response.status_code} for {url}") + return FetchResult( + url=url, + final_url=response.url, + status_code=response.status_code, + content=None, + headers=dict(response.headers), + error=f"Client error: {response.status_code}", + success=False + ) + elif 500 <= response.status_code < 600: + logger.error(f"Server error {response.status_code} for {url}") + last_error = f"Server error: {response.status_code}" + if attempt < self.max_retries - 1: + continue + return FetchResult( + url=url, + final_url=url, + status_code=response.status_code, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.Timeout: + last_error = "Request timeout" + logger.warning(f"Timeout fetching {url} (attempt {attempt + 1}/{self.max_retries})") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.SSLError as e: + last_error = f"SSL/TLS error: {str(e)}" + logger.error(f"SSL/TLS error for {url}: {e}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error, + success=False + ) + + except requests.exceptions.ConnectionError as e: + last_error = f"Connection error: {str(e)}" + logger.warning(f"Connection error for {url} (attempt {attempt + 1}/{self.max_retries}): {e}") + if attempt >= self.max_retries - 1: + break + continue + + except requests.exceptions.RequestException as e: + last_error = f"Request error: {str(e)}" + logger.error(f"Request error for {url}: {e}") + if attempt >= self.max_retries - 1: + break + continue + + # All retries exhausted + logger.error(f"Failed to fetch {url} after {self.max_retries} attempts: {last_error}") + return FetchResult( + url=url, + final_url=url, + status_code=0, + content=None, + headers={}, + error=last_error or "Unknown error", + success=False + ) + + def _get_decoded_content(self, response) -> str: + """ + Get correctly decoded content from response. + + Handles encoding detection and fallback strategies: + 1. Try encoding from HTML meta tags + 2. Try response.encoding (from Content-Type header or detected) + 3. Try UTF-8 + 4. Try common encodings (GB2312, GBK for Chinese, etc.) + 5. Fall back to latin-1 with error replacement + + Args: + response: requests.Response object + + Returns: + str: Decoded content + """ + # Try to detect encoding from HTML meta tags + meta_encoding = self._detect_encoding_from_meta(response.content) + if meta_encoding: + try: + content = response.content.decode(meta_encoding) + logger.info(f"Successfully decoded with meta tag encoding: {meta_encoding}") + return content + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with meta encoding {meta_encoding}: {e}") + + # Try response.encoding (from Content-Type header or detected by requests) + if response.encoding and response.encoding.lower() != 'iso-8859-1': + # Note: requests defaults to ISO-8859-1 if no charset in Content-Type, + # so we skip it here and try UTF-8 first + try: + return response.text + except (UnicodeDecodeError, LookupError) as e: + logger.warning(f"Failed to decode with detected encoding {response.encoding}: {e}") + + # Try UTF-8 first (most common) + try: + return response.content.decode('utf-8') + except UnicodeDecodeError: + logger.debug("UTF-8 decoding failed, trying other encodings") + + # Try common encodings for different languages + encodings_to_try = [ + 'gbk', # Chinese (Simplified) + 'gb2312', # Chinese (Simplified, older) + 'gb18030', # Chinese (Simplified, extended) + 'big5', # Chinese (Traditional) + 'shift_jis', # Japanese + 'euc-jp', # Japanese + 'euc-kr', # Korean + 'iso-8859-1', # Western European + 'windows-1252', # Windows Western European + 'windows-1251', # Cyrillic + ] + + for encoding in encodings_to_try: + try: + content = response.content.decode(encoding) + logger.info(f"Successfully decoded with {encoding}") + return content + except (UnicodeDecodeError, LookupError): + continue + + # Last resort: use latin-1 with error replacement + logger.warning("All encoding attempts failed, using latin-1 with error replacement") + return response.content.decode('latin-1', errors='replace') + + def _detect_encoding_from_meta(self, content: bytes) -> Optional[str]: + """ + Detect encoding from HTML meta tags. + + Looks for: + - <meta charset="..."> + - <meta http-equiv="Content-Type" content="...; charset=..."> + + Args: + content: Raw response content (bytes) + + Returns: + Optional[str]: Detected encoding or None + """ + try: + # Only check first 2KB for performance + head = content[:2048] + + # Try to decode as ASCII/Latin-1 to search for meta tags + try: + head_str = head.decode('ascii', errors='ignore') + except: + head_str = head.decode('latin-1', errors='ignore') + + # Look for <meta charset="..."> + charset_match = re.search( + r'<meta[^>]+charset=["\']?([a-zA-Z0-9_-]+)', + head_str, + re.IGNORECASE + ) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in meta tag: {encoding}") + return encoding + + # Look for <meta http-equiv="Content-Type" content="...; charset=..."> + content_type_match = re.search( + r'<meta[^>]+http-equiv=["\']?content-type["\']?[^>]+content=["\']([^"\']+)', + head_str, + re.IGNORECASE + ) + if content_type_match: + content_value = content_type_match.group(1) + charset_match = re.search(r'charset=([a-zA-Z0-9_-]+)', content_value, re.IGNORECASE) + if charset_match: + encoding = charset_match.group(1).lower() + logger.debug(f"Found charset in Content-Type meta: {encoding}") + return encoding + + except Exception as e: + logger.debug(f"Error detecting encoding from meta tags: {e}") + + return None diff --git a/api/app/core/rag/crawler/models.py b/api/app/core/rag/crawler/models.py new file mode 100644 index 00000000..5d10963c --- /dev/null +++ b/api/app/core/rag/crawler/models.py @@ -0,0 +1,52 @@ +"""Data models for web crawler.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, Any, Optional + + +@dataclass +class CrawledDocument: + """Represents a successfully processed web page with extracted content.""" + url: str + title: str + content: str + content_length: int + crawl_timestamp: datetime + http_status: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FetchResult: + """Represents the result of an HTTP fetch operation.""" + url: str + final_url: str + status_code: int + content: Optional[str] + headers: Dict[str, str] + error: Optional[str] + success: bool + + +@dataclass +class ExtractedContent: + """Represents content extracted from HTML.""" + title: str + text: str + is_static: bool + word_count: int + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CrawlSummary: + """Represents statistics from a completed crawl.""" + total_pages_processed: int + total_errors: int + total_skipped: int + total_urls_discovered: int + start_time: datetime + end_time: datetime + duration_seconds: float + error_breakdown: Dict[str, int] = field(default_factory=dict) diff --git a/api/app/core/rag/crawler/rate_limiter.py b/api/app/core/rag/crawler/rate_limiter.py new file mode 100644 index 00000000..e00fad36 --- /dev/null +++ b/api/app/core/rag/crawler/rate_limiter.py @@ -0,0 +1,57 @@ +"""Rate limiter for web crawler.""" + +import time +import logging + +logger = logging.getLogger(__name__) + + +class RateLimiter: + """Enforce delays between requests to be polite to servers.""" + + def __init__(self, delay_seconds: float = 1.0): + """ + Initialize rate limiter. + + Args: + delay_seconds: Minimum delay between requests + """ + self.delay_seconds = delay_seconds + self.last_request_time = 0.0 + self.max_delay = 60.0 # Cap maximum delay at 60 seconds + + def wait(self): + """ + Block until enough time has passed since last request. + Respects the configured delay. + """ + current_time = time.time() + elapsed = current_time - self.last_request_time + + if elapsed < self.delay_seconds: + sleep_time = self.delay_seconds - elapsed + logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds") + time.sleep(sleep_time) + + self.last_request_time = time.time() + + def set_delay(self, delay_seconds: float): + """ + Update the delay (useful for respecting Crawl-delay from robots.txt). + + Args: + delay_seconds: New delay in seconds + """ + self.delay_seconds = min(delay_seconds, self.max_delay) + logger.info(f"Rate limiter delay updated to {self.delay_seconds} seconds") + + def backoff(self, multiplier: float = 2.0): + """ + Increase delay exponentially for backoff scenarios (429, 503 responses). + + Args: + multiplier: Factor to multiply current delay by + """ + old_delay = self.delay_seconds + self.delay_seconds = min(self.delay_seconds * multiplier, self.max_delay) + logger.warning(f"Rate limiter backing off: {old_delay:.2f}s -> {self.delay_seconds:.2f}s") diff --git a/api/app/core/rag/crawler/robots_parser.py b/api/app/core/rag/crawler/robots_parser.py new file mode 100644 index 00000000..882bc9c8 --- /dev/null +++ b/api/app/core/rag/crawler/robots_parser.py @@ -0,0 +1,118 @@ +"""Robots.txt parser for web crawler.""" + +from urllib.robotparser import RobotFileParser +from urllib.parse import urlparse, urljoin +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +class RobotsParser: + """Parse and check robots.txt compliance for URLs.""" + + def __init__(self, user_agent: str, timeout: int = 10): + """ + Initialize robots.txt parser. + + Args: + user_agent: User agent string to check permissions for + timeout: Timeout for fetching robots.txt + """ + self.user_agent = user_agent + self.timeout = timeout + self._parsers = {} # Cache parsers by domain + + def _get_robots_url(self, url: str) -> str: + """ + Get the robots.txt URL for a given URL. + + Args: + url: URL to get robots.txt for + + Returns: + str: robots.txt URL + """ + parsed = urlparse(url) + robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt" + return robots_url + + def _get_parser(self, url: str) -> RobotFileParser: + """ + Get or create a RobotFileParser for the domain. + + Args: + url: URL to get parser for + + Returns: + RobotFileParser: Parser for the domain + """ + robots_url = self._get_robots_url(url) + + # Return cached parser if available + if robots_url in self._parsers: + return self._parsers[robots_url] + + # Create new parser + parser = RobotFileParser() + parser.set_url(robots_url) + + try: + # Fetch and parse robots.txt + parser.read() + logger.info(f"Successfully fetched robots.txt from {robots_url}") + except Exception as e: + # If robots.txt cannot be fetched, assume all URLs are allowed + logger.warning(f"Could not fetch robots.txt from {robots_url}: {e}. Assuming all URLs allowed.") + # Create a permissive parser + parser = RobotFileParser() + parser.parse([]) # Empty robots.txt allows everything + + # Cache the parser + self._parsers[robots_url] = parser + return parser + + def can_fetch(self, url: str) -> bool: + """ + Check if the given URL can be fetched according to robots.txt. + + Args: + url: URL to check + + Returns: + bool: True if allowed, False if disallowed + """ + try: + parser = self._get_parser(url) + allowed = parser.can_fetch(self.user_agent, url) + + if not allowed: + logger.info(f"URL disallowed by robots.txt: {url}") + + return allowed + except Exception as e: + logger.error(f"Error checking robots.txt for {url}: {e}") + # On error, assume allowed + return True + + def get_crawl_delay(self, url: str) -> Optional[float]: + """ + Get the Crawl-delay directive from robots.txt if present. + + Args: + url: URL to get crawl delay for + + Returns: + Optional[float]: Delay in seconds, or None if not specified + """ + try: + parser = self._get_parser(url) + delay = parser.crawl_delay(self.user_agent) + + if delay is not None: + logger.info(f"Crawl-delay from robots.txt: {delay} seconds") + + return delay + except Exception as e: + logger.error(f"Error getting crawl delay for {url}: {e}") + return None diff --git a/api/app/core/rag/crawler/url_normalizer.py b/api/app/core/rag/crawler/url_normalizer.py new file mode 100644 index 00000000..7762a9d5 --- /dev/null +++ b/api/app/core/rag/crawler/url_normalizer.py @@ -0,0 +1,171 @@ +"""URL normalization and validation for web crawler.""" + +from typing import Optional, List +from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, urljoin +from bs4 import BeautifulSoup + + +class URLNormalizer: + """Normalize and validate URLs for deduplication and domain checking.""" + + # Common tracking parameters to remove + TRACKING_PARAMS = { + 'utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content', + 'fbclid', 'gclid', 'msclkid', '_ga', 'mc_cid', 'mc_eid' + } + + def __init__(self, base_domain: str): + """ + Initialize URL normalizer with base domain. + + Args: + base_domain: The domain to use for same-domain checks + """ + parsed = urlparse(base_domain) + self.base_domain = parsed.netloc.lower() # example.com:8000 + self.base_scheme = parsed.scheme or 'https' # https + + def normalize(self, url: str) -> Optional[str]: + """ + Normalize a URL for deduplication. + + Normalization rules: + 1. Convert domain to lowercase + 2. Remove fragments (#section) + 3. Remove default ports (80 for http, 443 for https) + 4. Remove trailing slashes (except for root) + 5. Sort query parameters alphabetically + 6. Remove common tracking parameters + + Args: + url: URL to normalize + + Returns: + Optional[str]: Normalized URL, or None if invalid + """ + try: + parsed = urlparse(url) + + # Validate scheme + if parsed.scheme not in ('http', 'https'): + return None + + # Normalize domain to lowercase + netloc = parsed.netloc.lower() + + # Remove default ports + if ':' in netloc: + host, port = netloc.rsplit(':', 1) + if (parsed.scheme == 'http' and port == '80') or \ + (parsed.scheme == 'https' and port == '443'): + netloc = host + + # Normalize path + path = parsed.path + # Remove trailing slash except for root + if path != '/' and path.endswith('/'): + path = path.rstrip('/') + # Ensure path starts with / + if not path: + path = '/' + + # Process query parameters + query = '' + if parsed.query: + # Parse query parameters + params = parse_qs(parsed.query, keep_blank_values=True) + # Remove tracking parameters + filtered_params = { + k: v for k, v in params.items() + if k not in self.TRACKING_PARAMS + } + # Sort parameters alphabetically + if filtered_params: + sorted_params = sorted(filtered_params.items()) + query = urlencode(sorted_params, doseq=True) + + # Reconstruct URL without fragment + normalized = urlunparse(( + parsed.scheme, + netloc, + path, + parsed.params, + query, + '' # Remove fragment + )) + + return normalized + + except Exception: + return None + + def is_same_domain(self, url: str) -> bool: + """ + Check if URL belongs to the same domain as base_domain. + + Args: + url: URL to check + + Returns: + bool: True if same domain, False otherwise + """ + try: + parsed = urlparse(url) + domain = parsed.netloc.lower() + + # Remove port if present + if ':' in domain: + domain = domain.split(':')[0] + + # Check if domains match + return domain == self.base_domain or domain == self.base_domain.split(':')[0] + + except Exception: + return False + + def extract_links(self, html: str, base_url: str) -> List[str]: + """ + Extract and normalize all links from HTML. + + Args: + html: HTML content + base_url: Base URL for resolving relative links + + Returns: + List[str]: List of normalized absolute URLs + """ + links = [] + + try: + soup = BeautifulSoup(html, 'lxml') + + # Find all anchor tags + for anchor in soup.find_all('a', href=True): + href = anchor['href'] + + # Skip empty hrefs + if not href or href.strip() == '': + continue + + # Skip javascript: and mailto: links + if href.startswith(('javascript:', 'mailto:', 'tel:')): + continue + + normalized_url = None + # Check if href starts with http/https (absolute URL) + if href.startswith(('http://', 'https://')): + if self.is_same_domain(href): + normalized_url = self.normalize(href) + else: + # Convert relative URL to absolute + absolute_url = urljoin(base_url, href) + # Normalize the URL + normalized_url = self.normalize(absolute_url) + + if normalized_url: + links.append(normalized_url) + + except Exception: + pass + + return links diff --git a/api/app/core/rag/crawler/web_crawler.py b/api/app/core/rag/crawler/web_crawler.py new file mode 100644 index 00000000..3afa09b2 --- /dev/null +++ b/api/app/core/rag/crawler/web_crawler.py @@ -0,0 +1,215 @@ +"""Main web crawler orchestrator.""" + +from collections import deque +from datetime import datetime +from typing import Iterator, Optional, List, Set +from urllib.parse import urlparse +import logging + +from app.core.rag.crawler.url_normalizer import URLNormalizer +from app.core.rag.crawler.robots_parser import RobotsParser +from app.core.rag.crawler.rate_limiter import RateLimiter +from app.core.rag.crawler.http_fetcher import HTTPFetcher +from app.core.rag.crawler.content_extractor import ContentExtractor +from app.core.rag.crawler.models import CrawledDocument, CrawlSummary + +logger = logging.getLogger(__name__) + + +class WebCrawler: + """Main orchestrator for web crawling.""" + + def __init__( + self, + entry_url: str, + max_pages: int = 200, + delay_seconds: float = 1.0, + timeout_seconds: int = 10, + user_agent: str = "KnowledgeBaseCrawler/1.0", + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + content_extractor: Optional[ContentExtractor] = None + ): + """ + Initialize the web crawler. + + Args: + entry_url: Starting URL for the crawl + max_pages: Maximum number of pages to crawl (default: 200) + delay_seconds: Delay between requests in seconds (default: 1.0) + timeout_seconds: HTTP request timeout (default: 10) + user_agent: User-Agent header string + include_patterns: List of regex patterns for URLs to include + exclude_patterns: List of regex patterns for URLs to exclude + content_extractor: Custom content extractor (optional) + """ + # Validate entry URL + parsed = urlparse(entry_url) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"Invalid entry URL: {entry_url}") + + self.entry_url = entry_url + self.max_pages = max_pages + self.user_agent = user_agent + + # Extract domain from entry URL + self.domain = parsed.netloc + + # Initialize components + self.url_normalizer = URLNormalizer(entry_url) + self.robots_parser = RobotsParser(user_agent, timeout_seconds) + self.rate_limiter = RateLimiter(delay_seconds) + self.http_fetcher = HTTPFetcher(timeout_seconds, max_retries=3, user_agent=user_agent) + self.content_extractor = content_extractor or ContentExtractor() + + # State management + self.url_queue: deque = deque() + self.visited_urls: Set[str] = set() + self.pages_processed = 0 + + # Statistics + self.stats = { + 'success': 0, + 'errors': 0, + 'skipped': 0, + 'urls_discovered': 0, + 'error_breakdown': {} + } + self.start_time: Optional[datetime] = None + self.end_time: Optional[datetime] = None + + def crawl(self) -> Iterator[CrawledDocument]: + """ + Execute the crawl and yield documents as they are processed. + + Yields: + CrawledDocument: Structured document with extracted content + """ + logger.info(f"Starting crawl from {self.entry_url} (max_pages: {self.max_pages})") + self.start_time = datetime.now() + + # Add entry URL to queue + normalized_entry = self.url_normalizer.normalize(self.entry_url) + if normalized_entry: + self.url_queue.append(normalized_entry) + self.stats['urls_discovered'] += 1 + + # Check robots.txt and update rate limiter if needed + crawl_delay = self.robots_parser.get_crawl_delay(self.entry_url) + if crawl_delay: + self.rate_limiter.set_delay(crawl_delay) + + # Main crawl loop + while self.url_queue and self.pages_processed < self.max_pages: + url = self.url_queue.popleft() + + # Skip if already visited + if url in self.visited_urls: + continue + + # Mark as visited + self.visited_urls.add(url) + + # Check robots.txt permission + if not self.robots_parser.can_fetch(url): + logger.info(f"Skipping {url} (disallowed by robots.txt)") + self.stats['skipped'] += 1 + continue + + # Apply rate limiting + self.rate_limiter.wait() + + # Fetch URL + logger.info(f"Fetching {url} ({self.pages_processed + 1}/{self.max_pages})") + fetch_result = self.http_fetcher.fetch(url) + + # Handle fetch errors + if not fetch_result.success: + self._record_error(fetch_result.error or "Unknown error") + continue + + # Check Content-Type + content_type = fetch_result.headers.get('Content-Type', '').lower() + if not any(substring in content_type for substring in ['text/html', 'application/xhtml+xml']): + logger.warning(f"Skipping {url} (Content-Type: {content_type})") + self.stats['skipped'] += 1 + continue + + # Extract content + try: + extracted = self.content_extractor.extract(fetch_result.content, url) + + # Check if static content + if not extracted.is_static: + logger.warning(f"Skipping {url} (JavaScript-rendered content)") + self.stats['skipped'] += 1 + continue + + # Create document + document = CrawledDocument( + url=url, + title=extracted.title, + content=extracted.text, + content_length=len(extracted.text), + crawl_timestamp=datetime.now(), + http_status=fetch_result.status_code, + metadata={ + 'word_count': extracted.word_count, + 'final_url': fetch_result.final_url + } + ) + + # Update statistics + self.pages_processed += 1 + self.stats['success'] += 1 + + # Extract and queue links + links = self.url_normalizer.extract_links(fetch_result.content, url) + for link in links: + if link not in self.visited_urls and self.url_normalizer.is_same_domain(link): + if link not in self.url_queue: + self.url_queue.append(link) + self.stats['urls_discovered'] += 1 + + # Yield document + yield document + + except Exception as e: + logger.error(f"Error processing {url}: {e}") + self._record_error(f"Processing error: {str(e)}") + continue + + self.end_time = datetime.now() + logger.info(f"Crawl completed. Processed {self.pages_processed} pages.") + + def get_summary(self) -> CrawlSummary: + """ + Get summary statistics after crawl completion. + + Returns: + CrawlSummary: Statistics including success/error/skip counts + """ + if not self.start_time: + self.start_time = datetime.now() + if not self.end_time: + self.end_time = datetime.now() + + duration = (self.end_time - self.start_time).total_seconds() + + return CrawlSummary( + total_pages_processed=self.stats['success'], + total_errors=self.stats['errors'], + total_skipped=self.stats['skipped'], + total_urls_discovered=self.stats['urls_discovered'], + start_time=self.start_time, + end_time=self.end_time, + duration_seconds=duration, + error_breakdown=self.stats['error_breakdown'] + ) + + def _record_error(self, error: str): + """Record an error in statistics.""" + self.stats['errors'] += 1 + error_type = error.split(':')[0] if ':' in error else error + self.stats['error_breakdown'][error_type] = \ + self.stats['error_breakdown'].get(error_type, 0) + 1 diff --git a/api/app/core/rag/integrations/__init__.py b/api/app/core/rag/integrations/__init__.py new file mode 100644 index 00000000..c1c43854 --- /dev/null +++ b/api/app/core/rag/integrations/__init__.py @@ -0,0 +1 @@ +"""Integrations package for external services.""" diff --git a/api/app/core/rag/integrations/feishu/__init__.py b/api/app/core/rag/integrations/feishu/__init__.py new file mode 100644 index 00000000..d989b816 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__init__.py @@ -0,0 +1 @@ +"""Feishu integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/feishu/__main__.py b/api/app/core/rag/integrations/feishu/__main__.py new file mode 100644 index 00000000..79d5a48e --- /dev/null +++ b/api/app/core/rag/integrations/feishu/__main__.py @@ -0,0 +1,84 @@ +"""Command-line interface for feishu integration.""" + +import asyncio +import sys +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo + + +def main(feishu_app_id: str, # Feishu application ID + feishu_app_secret: str, # Feishu application secret + feishu_folder_token: str, # Feishu Folder Token + save_dir: str, # save file directory + feishu_api_base_url: str = "https://open.feishu.cn/open-apis", # Feishu API base URL + timeout: int = 30, # Request timeout in seconds + max_retries: int = 3, # Maximum number of retries + recursive: bool = True # recursive: Whether to sync subfolders recursively, + ): + """Main entry point for the feishuAPIClient.""" + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret, + api_base_url=feishu_api_base_url, + timeout=timeout, + max_retries=max_retries + ) + + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + if recursive: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + else: + all_files = [] + page_token = None + while True: + files_page, page_token = await client.list_folder_files( + feishu_folder_token, page_token + ) + all_files.extend(files_page) + if not page_token: + break + files = all_files + return files + files = asyncio.run(async_get_files(api_client,feishu_folder_token)) + + # Filter out folders, only sync documents + # documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file", "slides"]] + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + + try: + for doc in documents: + print(f"\n{'=' * 80}") + print(f"token: {doc.token}") + print(f"name: {doc.name}") + print(f"type: {doc.type}") + print(f"created_time: {doc.created_time}") + print(f"modified_time: {doc.modified_time}") + print(f"owner_id: {doc.owner_id}") + print(f"url: {doc.url}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == '__main__': + feishu_app_id = "" + feishu_app_secret = "" + feishu_folder_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(feishu_app_id, feishu_app_secret, feishu_folder_token, save_dir) diff --git a/api/app/core/rag/integrations/feishu/client.py b/api/app/core/rag/integrations/feishu/client.py new file mode 100644 index 00000000..0a3c4ea8 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/client.py @@ -0,0 +1,452 @@ +"""Feishu API client for document operations.""" + +import asyncio +import os +import re +from typing import Optional, Tuple, List +from datetime import datetime, timedelta +import httpx +from cachetools import TTLCache +import urllib.parse + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuAPIError, + FeishuNotFoundError, + FeishuPermissionError, + FeishuRateLimitError, + FeishuNetworkError, +) +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.feishu.retry import with_retry + + +class FeishuAPIClient: + """Feishu API client for document synchronization.""" + + def __init__( + self, + app_id: str, + app_secret: str, + api_base_url: str = "https://open.feishu.cn/open-apis", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Feishu API client. + + Args: + app_id: Feishu application ID + app_secret: Feishu application secret + api_base_url: Feishu API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.app_id = app_id + self.app_secret = app_secret + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + self._token_cache: TTLCache = TTLCache(maxsize=1, ttl=7200 - 300) # 2 hours - 5 minutes + self._token_lock = asyncio.Lock() + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={"Content-Type": "application/json"} + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + async def get_tenant_access_token(self) -> str: + """ + Get tenant access token with caching. + + Returns: + Access token string + + Raises: + FeishuAuthError: If authentication fails + """ + # Check cache first + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Use lock to prevent concurrent token requests + async with self._token_lock: + # Double-check cache after acquiring lock + cached_token = self._token_cache.get("access_token") + if cached_token: + return cached_token + + # Request new token + try: + if not self._http_client: + raise FeishuAuthError("HTTP client not initialized") + + response = await self._http_client.post( + "/auth/v3/tenant_access_token/internal", + json={ + "app_id": self.app_id, + "app_secret": self.app_secret + } + ) + + data = response.json() + + if data.get("code") != 0: + error_msg = data.get("msg", "Unknown error") + raise FeishuAuthError( + f"Authentication failed: {error_msg}", + error_code=str(data.get("code")), + details=data + ) + + token = data.get("tenant_access_token") + if not token: + raise FeishuAuthError("No access token in response") + + # Cache the token + self._token_cache["access_token"] = token + + return token + + except httpx.HTTPError as e: + raise FeishuAuthError(f"HTTP error during authentication: {str(e)}") + except Exception as e: + if isinstance(e, FeishuAuthError): + raise + raise FeishuAuthError(f"Unexpected error during authentication: {str(e)}") + + @with_retry + async def list_folder_files( + self, + folder_token: str, + page_token: Optional[str] = None + ) -> Tuple[List[FileInfo], Optional[str]]: + """ + Get list of files in a folder with pagination support. + + Args: + folder_token: Folder token + page_token: Page token for pagination + + Returns: + Tuple of (list of FileInfo, next page token) + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If folder not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Build request parameters + params = {"page_size": 200, "folder_token": folder_token} + if page_token: + params["page_token"] = page_token + + # Make API request + response = await self._http_client.get( + f"/drive/v1/files", + params=params, + headers={"Authorization": f"Bearer {token}"} + ) + + data = response.json() + # print(f"get files: {data}") + + # Handle errors + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + + if error_code == 404 or error_code == 230005: + raise FeishuNotFoundError( + f"Folder not found: {error_msg}", + error_code=str(error_code), + details=data + ) + elif error_code == 403 or error_code == 230003: + raise FeishuPermissionError( + f"Permission denied: {error_msg}", + error_code=str(error_code), + details=data + ) + else: + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + # Parse response + files_data = data.get("data", {}).get("files", []) + next_page_token = data.get("data", {}).get("next_page_token", None) + + # Convert to FileInfo objects + files = [] + for file_data in files_data: + try: + file_info = FileInfo( + token=file_data.get("token", ""), + name=file_data.get("name", ""), + type=file_data.get("type", ""), + created_time=datetime.fromtimestamp(int(file_data.get("created_time", 0))), + modified_time=datetime.fromtimestamp(int(file_data.get("modified_time", 0))), + owner_id=file_data.get("owner_id", ""), + url=file_data.get("url", "") + ) + files.append(file_info) + except (ValueError, TypeError) as e: + # Skip invalid file entries + continue + + return files, next_page_token + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def list_all_folder_files( + self, + folder_token: str, + recursive: bool = True + ) -> List[FileInfo]: + """ + Get all files in a folder, handling pagination automatically. + + Args: + folder_token: Folder token + recursive: Whether to recursively get files from subfolders + + Returns: + List of all FileInfo objects + + Raises: + FeishuAPIError: If API call fails + """ + all_files = [] + page_token = None + + # Get all files with pagination + while True: + files, page_token = await self.list_folder_files(folder_token, page_token) + all_files.extend(files) + + if not page_token: + break + + # Recursively get files from subfolders if requested + if recursive: + subfolders = [f for f in all_files if f.type == "folder"] + for subfolder in subfolders: + try: + subfolder_files = await self.list_all_folder_files( + subfolder.token, + recursive=True + ) + all_files.extend(subfolder_files) + except Exception: + # Continue with other folders if one fails + continue + + return all_files + + @with_retry + async def download_document( + self, + document: FileInfo, + save_dir: str + ) -> str: + """ + download document content. + + Args: + document: Document FileInfo + save_dir: save dir + + Returns: + file_full_path + + Raises: + FeishuAPIError: If API call fails + FeishuNotFoundError: If document not found + FeishuPermissionError: If permission denied + """ + try: + token = await self.get_tenant_access_token() + + if not self._http_client: + raise FeishuAPIError("HTTP client not initialized") + + # Different API endpoints for different document types + if document.type == "doc" or document.type == "docx" or document.type == "sheet" or document.type == "bitable": + return await self._export_file(document, token, save_dir) + elif document.type == "file" or document.type == "slides": + return await self._download_file(document, token, save_dir) + else: + raise FeishuAPIError(f"Unsupported document type: {document.type}") + + except Exception as e: + if isinstance(e, (FeishuAPIError, FeishuNotFoundError, FeishuPermissionError)): + raise + raise FeishuAPIError(f"Unexpected error: {str(e)}") + + async def _export_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """export file for feishu online file type.""" + try: + # 1.创建导出任务 + file_extension = "pdf" + match document.type: + case "doc": + file_extension = "doc" + case "docx": + file_extension = "docx" + case "sheet": + file_extension = "xlsx" + case "bitable": + file_extension = "xlsx" + case _: + file_extension = "pdf" + response = await self._http_client.post( + "/drive/v1/export_tasks", + json={ + "file_extension": file_extension, + "token": document.token, + "type": document.type + }, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"1.创建导出任务: {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data + ) + + ticket = data.get("data", {}).get("ticket", None) + if not ticket: + raise FeishuAuthError("No ticket in response") + + # 2.轮序查询导出任务结果 + max_retries = 10 # 最大轮询次数 + poll_interval = 2 # 每次轮询间隔时间(秒) + file_token = None + for attempt in range(max_retries): + # 查询导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/{ticket}", + params={"token": document.token}, + headers={"Authorization": f"Bearer {access_token}"} + ) + data = response.json() + print(f"2. 尝试查询导出任务结果 (第{attempt + 1}次): {data}") + + if data.get("code") != 0: + error_code = data.get("code") + error_msg = data.get("msg", "Unknown error") + raise FeishuAPIError( + f"API error: {error_msg}", + error_code=str(error_code), + details=data, + ) + + # 检查导出任务结果 + file_token = data.get("data", {}).get("result", {}).get("file_token", None) + if file_token: + # 如果导出任务成功生成 file_token,则退出轮询 + break + + # 如果结果还没准备好,等待一段时间再进行下一次轮询 + await asyncio.sleep(poll_interval) + + if not file_token: + raise FeishuAPIError("Export task did not complete within the allowed time") + + # 3.下载导出任务 + response = await self._http_client.get( + f"/drive/v1/export_tasks/file/{file_token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + print(f'3.下载导出任务: {response.headers.get("Content-Disposition")}') + + file_full_path = os.path.join(save_dir, document.name + "." + file_extension) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") + + async def _download_file(self, document: FileInfo, access_token: str, save_dir: str) -> str: + """download file for file type.""" + try: + response = await self._http_client.get( + f"/drive/v1/files/{document.token}/download", + headers={"Authorization": f"Bearer {access_token}"} + ) + response.raise_for_status() + + filename_header = response.headers.get("Content-Disposition") + + # 最终的文件名(初始化为 None) + filename = None + if filename_header: + # 优先解析 filename* 格式 + match = re.search(r"filename\*=([^']*)''([^;]+)", filename_header) + if match: + # 使用 `filename*` 提取(已编码) + encoding = match.group(1) # 编码部分(如 UTF-8) + encoded_filename = match.group(2) # 文件名部分 + filename = urllib.parse.unquote(encoded_filename) # 解码 URL 编码的文件名 + + # 如果 `filename*` 不存在,回退到解析 `filename` + if not filename: + match = re.search(r'filename="([^"]+)"', filename_header) + if match: + filename = match.group(1) + # 如果文件名仍为 None,则使用默认文件名 + if not filename: + filename = f"{document.name}.pdf" + # 确保文件名合法,替换非法字符 + filename = re.sub(r'[\/:*?"<>|]', '_', filename) + + file_full_path = os.path.join(save_dir, filename) + if os.path.exists(file_full_path): + os.remove(file_full_path) # Delete a single file + with open(file_full_path, "wb") as file: + file.write(response.content) + + return file_full_path + + except httpx.HTTPError as e: + raise FeishuAPIError(f"HTTP error: {str(e)}") + except Exception as e: + raise FeishuAPIError(f"Unexpected error during file download: {str(e)}") diff --git a/api/app/core/rag/integrations/feishu/exceptions.py b/api/app/core/rag/integrations/feishu/exceptions.py new file mode 100644 index 00000000..26e42a07 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Feishu integration.""" + + +class FeishuError(Exception): + """Base exception for all Feishu-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class FeishuAuthError(FeishuError): + """Authentication error with Feishu API.""" + pass + + +class FeishuAPIError(FeishuError): + """General API error from Feishu.""" + pass + + +class FeishuNotFoundError(FeishuError): + """Resource not found error (404).""" + pass + + +class FeishuPermissionError(FeishuError): + """Permission denied error (403).""" + pass + + +class FeishuRateLimitError(FeishuError): + """Rate limit exceeded error (429).""" + pass + + +class FeishuNetworkError(FeishuError): + """Network-related error (timeout, connection failure).""" + pass + + +class FeishuDataError(FeishuError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/feishu/models.py b/api/app/core/rag/integrations/feishu/models.py new file mode 100644 index 00000000..b194afc1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/models.py @@ -0,0 +1,17 @@ +"""Data models for Feishu integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, Any, List, Optional + + +@dataclass +class FileInfo: + """File information from Feishu.""" + token: str + name: str + type: str # doc/docx/sheet/bitable/file/slides/folder + created_time: datetime + modified_time: datetime + owner_id: str + url: str diff --git a/api/app/core/rag/integrations/feishu/retry.py b/api/app/core/rag/integrations/feishu/retry.py new file mode 100644 index 00000000..c1d9aff1 --- /dev/null +++ b/api/app/core/rag/integrations/feishu/retry.py @@ -0,0 +1,137 @@ +"""Retry strategy for Feishu API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.feishu.exceptions import ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuRateLimitError, + FeishuNetworkError, + FeishuDataError, + FeishuAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + FeishuNetworkError, + FeishuRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + FeishuAuthError, + FeishuPermissionError, + FeishuNotFoundError, + FeishuDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for FeishuAPIError with specific codes + if isinstance(error, FeishuAPIError): + if error.error_code: + # Rate limit error codes + if error.error_code in ["99991400", "99991401"]: + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/core/rag/integrations/yuque/__init__.py b/api/app/core/rag/integrations/yuque/__init__.py new file mode 100644 index 00000000..dc4f2a17 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__init__.py @@ -0,0 +1 @@ +"""Yuque integration module for document synchronization.""" diff --git a/api/app/core/rag/integrations/yuque/__main__.py b/api/app/core/rag/integrations/yuque/__main__.py new file mode 100644 index 00000000..3b87bbcd --- /dev/null +++ b/api/app/core/rag/integrations/yuque/__main__.py @@ -0,0 +1,77 @@ +"""Main entry point for Yuque integration testing.""" + +import asyncio +import sys +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo + + +def main(yuque_user_id: str, # yuque User ID + yuque_token: str, # yuque Token + save_dir: str, # save file directory + ): + """Main entry point for the YuqueAPIClient.""" + # Create feishuAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + files = asyncio.run(async_get_files(api_client)) + + try: + for doc in files: + print(f"\n{'=' * 80}") + print(f"id: {doc.id}") + print(f"type: {doc.type}") + print(f"slug: {doc.slug}") + print(f"title: {doc.title}") + print(f"book_id: {doc.book_id}") + # print(f"format: {doc.format}") + # print(f"body: {doc.body}") + # print(f"body_draft: {doc.body_draft}") + # print(f"body_html: {doc.body_html}") + print(f"public: {doc.public}") + print(f"status: {doc.status}") + print(f"created_at: {doc.created_at}") + print(f"updated_at: {doc.updated_at}") + print(f"published_at: {doc.published_at}") + print(f"word_count: {doc.word_count}") + print(f"cover: {doc.cover}") + print(f"description: {doc.description}") + print(f"{'=' * 80}\n") + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + print(file_path) + + except KeyboardInterrupt: + print("\n\nfeishu integration interrupted by user.") + + except Exception as e: + print(f"\n\nError during feishu integration: {e}") + sys.exit(1) + + +if __name__ == "__main__": + yuque_user_id = "" + yuque_token = "" + save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/" + main(yuque_user_id, yuque_token, save_dir) diff --git a/api/app/core/rag/integrations/yuque/client.py b/api/app/core/rag/integrations/yuque/client.py new file mode 100644 index 00000000..444d9d31 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/client.py @@ -0,0 +1,544 @@ +"""Yuque API client for document operations.""" + +import os +import re +from typing import Optional, List +from datetime import datetime, timedelta +import httpx +import urllib.parse +import json +from openpyxl import Workbook +from openpyxl.styles import Font, Alignment, PatternFill +from openpyxl.utils import get_column_letter +import zlib + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuqueAPIError, + YuqueNotFoundError, + YuquePermissionError, + YuqueRateLimitError, + YuqueNetworkError, +) +from app.core.rag.integrations.yuque.models import YuqueDocInfo, YuqueRepoInfo +from app.core.rag.integrations.yuque.retry import with_retry + + +class YuqueAPIClient: + """Yuque API client for document synchronization.""" + + def __init__( + self, + user_id: str, + token: str, + api_base_url: str = "https://www.yuque.com/api/v2", + timeout: int = 30, + max_retries: int = 3 + ): + """ + Initialize Yuque API client. + + Args: + user_id: Yuque user ID or login name + token: Yuque personal access token + api_base_url: Yuque API base URL + timeout: Request timeout in seconds + max_retries: Maximum number of retries + """ + self.user_id = user_id + self.token = token + self.api_base_url = api_base_url + self.timeout = timeout + self.max_retries = max_retries + self._http_client: Optional[httpx.AsyncClient] = None + + async def __aenter__(self): + """Async context manager entry.""" + self._http_client = httpx.AsyncClient( + base_url=self.api_base_url, + timeout=self.timeout, + headers={ + "Content-Type": "application/json", + "X-Auth-Token": self.token, + "User-Agent": "Yuque-Integration-Client" + } + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if self._http_client: + await self._http_client.aclose() + + def _handle_api_error(self, response: httpx.Response): + """Handle API error responses.""" + try: + data = response.json() + except Exception: + data = {} + + status_code = response.status_code + error_msg = data.get("message", "Unknown error") + + # Rate limit errors + if status_code == 429: + raise YuqueRateLimitError( + f"Rate limit exceeded: {error_msg}", + error_code=str(status_code), + details=data + ) + # Not found errors + elif status_code == 404: + raise YuqueNotFoundError( + f"Resource not found: {error_msg}", + error_code=str(status_code), + details=data + ) + # Permission errors + elif status_code == 403: + raise YuquePermissionError( + f"Permission denied: {error_msg}", + error_code=str(status_code), + details=data + ) + # Authentication errors + elif status_code == 401: + raise YuqueAuthError( + f"Authentication failed: {error_msg}", + error_code=str(status_code), + details=data + ) + # Generic API error + else: + raise YuqueAPIError( + f"API error: {error_msg}", + error_code=str(status_code), + details=data + ) + + @with_retry + async def get_user_repos(self) -> List[YuqueRepoInfo]: + """ + Get all repositories (知识库) for the user. + + Returns: + List of YuqueRepoInfo objects + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/users/{self.user_id}/repos") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + repos_data = data.get("data", []) + + repos = [] + for repo_data in repos_data: + try: + repo = YuqueRepoInfo( + id=repo_data.get("id"), + type=repo_data.get("type", ""), + name=repo_data.get("name", ""), + namespace=repo_data.get("namespace", ""), + slug=repo_data.get("slug", ""), + description=repo_data.get("description"), + public=repo_data.get("public", 0), + items_count=repo_data.get("items_count", 0), + created_at=datetime.fromisoformat(repo_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(repo_data.get("updated_at", "").replace("Z", "+00:00")) + ) + repos.append(repo) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid repo entries + continue + + return repos + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueAuthError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_repo_docs(self, book_id: int) -> List[YuqueDocInfo]: + """ + Get all documents in a repository. + + Args: + book_id: repository id + + Returns: + List of YuqueDocInfo objects (without body content) + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get(f"/repos/{book_id}/docs") + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + docs_data = data.get("data", []) + + docs = [] + for doc_data in docs_data: + try: + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=None, # Body not included in list API + body_draft=None, + body_html=None, + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + docs.append(doc) + except (ValueError, TypeError, KeyError) as e: + # Skip invalid doc entries + continue + + return docs + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + @with_retry + async def get_doc_detail(self, id: int) -> YuqueDocInfo: + """ + Get detailed document information including content. + + Args: + id: document ID + + Returns: + YuqueDocInfo object with full content + + Raises: + YuqueAPIError: If API call fails + """ + try: + if not self._http_client: + raise YuqueAPIError("HTTP client not initialized") + + response = await self._http_client.get( + f"/repos/docs/{id}", + params={"raw": 1} # Get raw markdown content + ) + + if response.status_code != 200: + self._handle_api_error(response) + + data = response.json() + doc_data = data.get("data", {}) + + published_at = doc_data.get("published_at") + doc = YuqueDocInfo( + id=doc_data.get("id"), + type=doc_data.get("type", ""), + slug=doc_data.get("slug", ""), + title=doc_data.get("title", ""), + book_id=doc_data.get("book_id"), + format=doc_data.get("format", "markdown"), + body=doc_data.get("body", ""), + body_draft=doc_data.get("body_draft"), + body_html=doc_data.get("body_html"), + public=doc_data.get("public", 0), + status=doc_data.get("status", 0), + created_at=datetime.fromisoformat(doc_data.get("created_at", "").replace("Z", "+00:00")), + updated_at=datetime.fromisoformat(doc_data.get("updated_at", "").replace("Z", "+00:00")), + published_at=datetime.fromisoformat(published_at.replace("Z", "+00:00")) if published_at else None, + word_count=doc_data.get("word_count", 0), + cover=doc_data.get("cover"), + description=doc_data.get("description") + ) + + return doc + + except httpx.HTTPError as e: + raise YuqueAPIError(f"HTTP error: {str(e)}") + except Exception as e: + if isinstance(e, (YuqueAPIError, YuqueNotFoundError)): + raise + raise YuqueAPIError(f"Unexpected error: {str(e)}") + + async def download_document( + self, + doc: YuqueDocInfo, + save_dir: str + ) -> str: + """ + Download document content to local file. + + Args: + doc: Document info (can be without body) + save_dir: Directory to save the file + + Returns: + Full path to the saved file + + Raises: + YuqueAPIError: If download fails + """ + try: + # Get full document content if not already loaded + if not doc.body: + doc = await self.get_doc_detail(doc.id) + + # Sanitize filename + filename = re.sub(r'[\/:*?"<>|]', '_', doc.title) + + # Determine file extension based on format + content = doc.body or "" + if doc.format == "markdown": + file_extension = "md" + elif doc.format == "lake": + file_extension = "md" # Save lake format as markdown + elif doc.format == "html": + file_extension = "html" + elif doc.format == "lakesheet": + file_extension = "xlsx" + + body_data = json.loads(doc.body) + sheet_data = body_data.get("sheet", "") + try: + sheet_raw = zlib.decompress(bytes(sheet_data, 'latin-1')) + except Exception as e: + print(f"Error decompressing sheet data: {e}") + raise ValueError("Invalid or unsupported sheet data format.") + try: + sheet_text = sheet_raw.decode("utf-8") # 假设是 UTF-8 编码 + except UnicodeDecodeError: + sheet_text = sheet_raw.decode("gbk") # 如果 UTF-8 解码失败,尝试 GBK + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + self.generate_excel_from_sheet(sheet_text, file_full_path) + return file_full_path + else: + file_extension = "txt" + + file_full_path = os.path.join(save_dir, f"{filename}.{file_extension}") + # Remove existing file if it exists + if os.path.exists(file_full_path): + os.remove(file_full_path) + + # Write content to file + with open(file_full_path, "w", encoding="utf-8") as file: + file.write(content) + + return file_full_path + + except Exception as e: + if isinstance(e, YuqueAPIError): + raise + raise YuqueAPIError(f"Unexpected error during file download: {str(e)}") + + def generate_excel_from_sheet(self, sheet_text: str, save_path: str): + """ + 将解析的 sheet_text 数据转换为 Excel 文件。 + + Args: + sheet_text (str): JSON 格式的 sheet 数据。 + save_path (str): Excel 文件的保存路径。 + """ + try: + # 解析 JSON 数据 + sheets = json.loads(sheet_text) + + if not isinstance(sheets, list): + raise ValueError("sheet_text must be a JSON array of sheets.") + + # 创建一个新的 Excel 工作簿 + workbook = Workbook() + + for sheet_index, sheet_data in enumerate(sheets): + sheet_name = sheet_data.get("name", f"Sheet{sheet_index + 1}") + row_data = sheet_data.get("data", {}) + merge_cells = sheet_data.get("mergeCells", {}) + rows_styles = sheet_data.get("rows", []) + cols_styles = sheet_data.get("columns", []) + + # 创建 Sheet + if sheet_index == 0: + worksheet = workbook.active + worksheet.title = sheet_name + else: + worksheet = workbook.create_sheet(title=sheet_name) + + # 设置列宽 + for col_index, col_style in enumerate(cols_styles): + col_width = col_style.get("size", 82.125) / 7.0 + col_letter = get_column_letter(col_index + 1) # Excel 列从1开始 + worksheet.column_dimensions[col_letter].width = col_width + + # 设置行高 + for row_index, row_style in enumerate(rows_styles): + row_height = row_style.get("size", 24) / 1.5 + worksheet.row_dimensions[row_index + 1].height = row_height + + # 写入单元格数据 + for r_index, row in row_data.items(): + for c_index, cell in row.items(): + # 防御性检查:确保行号和列号都是有效的整数 + try: + row_number = int(r_index) + 1 + col_number = int(c_index) + 1 + except ValueError: + print(f"Invalid row or column index: r_index={r_index}, c_index={c_index}") + continue + + if col_number < 1 or col_number > 16384: # Excel 最大列数支持到 XFD,即 16384 列 + print(f"Invalid column index: c_index={c_index}") + continue + + cell_obj = worksheet.cell(row=row_number, column=col_number) + + # 处理值和公式 + cell_value = cell.get("value", "") + if isinstance(cell_value, dict): + # 检查是否为公式 + if cell_value.get("class") == "formula" and "formula" in cell_value: + cell_obj.value = f"={cell_value['formula']}" # 写入公式 + else: + cell_obj.value = cell_value.get("value", "") # 写入值 + else: + cell_obj.value = cell_value # 写入简单值 + + # 应用样式 + style = cell.get("style", {}) + self.apply_cell_style(cell_obj, style) + + # 合并单元格 + for key, merge_def in merge_cells.items(): + start_row = merge_def["row"] + 1 + start_col = merge_def["col"] + 1 + end_row = start_row + merge_def["rowCount"] - 1 + end_col = start_col + merge_def["colCount"] - 1 + worksheet.merge_cells( + start_row=start_row, start_column=start_col, end_row=end_row, end_column=end_col + ) + + # 保存 Excel 文件 + workbook.save(save_path) + print(f"Excel file successfully saved to: {save_path}") + + except Exception as e: + print(f"Error generating Excel file: {e}") + + + def apply_cell_style(self, cell, style): + """ + 应用单元格样式,包括字体、对齐、背景颜色等。 + + Args: + cell: openpyxl 的单元格对象。 + style: 字典格式的样式信息。 + """ + # 定义允许的对齐值 + allowed_horizontal_alignments = {"general", "left", "center", "centerContinuous", "right", "fill", "justify", + "distributed"} + allowed_vertical_alignments = {"top", "center", "justify", "distributed", "bottom"} + + # 处理字体 + font = Font( + size=style.get("fontSize", 11), + bold=style.get("fontWeight", False), + italic=style.get("fontStyle", "normal") == "italic", + underline="single" if style.get("underline", False) else None, + color=self.convert_color_to_hex(style.get("color", "#000000")), + ) + cell.font = font + + # 处理对齐方式 + horizontal_alignment = style.get("hAlign", "left") + vertical_alignment = style.get("vAlign", "top") + + # 如果对齐值无效,则使用默认值 + if horizontal_alignment not in allowed_horizontal_alignments: + horizontal_alignment = "left" + if vertical_alignment not in allowed_vertical_alignments: + vertical_alignment = "top" + + alignment = Alignment( + horizontal=horizontal_alignment, + vertical=vertical_alignment, + wrap_text=style.get("overflow") == "wrap", + ) + cell.alignment = alignment + + # 处理背景颜色 + background_color = style.get("backColor", None) + if background_color: + hex_color = self.convert_color_to_hex(background_color) + if hex_color: + cell.fill = PatternFill( + start_color=hex_color, + end_color=hex_color, + fill_type="solid" + ) + + def convert_color_to_hex(self, color): + """ + 将颜色从 `rgba(...)` 或 `rgb(...)` 转换为 aRGB 十六进制格式。 + + Args: + color (str): 原始颜色字符串,如 `rgba(255,255,0,1.00)` 或 `#FFFFFF`。 + + Returns: + str: 转换后的颜色字符串(符合 openpyxl 的格式),例如 `FFFF0000`。 + """ + try: + if not color: + return None + + # 如果是 `#RRGGBB` 或 `#AARRGGBB` 格式,直接返回 + if color.startswith("#"): + return color.lstrip("#").upper() + + # 如果是 `rgb(...)` 格式,例如 `rgb(255,255,0)` + if color.startswith("rgb("): + rgb_values = color.strip("rgb()").split(",") + red, green, blue = [int(v) for v in rgb_values] + return f"FF{red:02X}{green:02X}{blue:02X}" + + # 如果是 `rgba(...)` 格式,例如 `rgba(255,255,0,1.00)` + if color.startswith("rgba("): + rgba_values = color.strip("rgba()").split(",") + red, green, blue = [int(v) for v in rgba_values[:3]] + alpha = float(rgba_values[3]) + alpha_hex = int(alpha * 255) # 将透明度转换为 [00, FF] + return f"{alpha_hex:02X}{red:02X}{green:02X}{blue:02X}" + + # 返回默认颜色 + return None + except Exception as e: + print(f"Error parsing color '{color}': {e}") + return None diff --git a/api/app/core/rag/integrations/yuque/exceptions.py b/api/app/core/rag/integrations/yuque/exceptions.py new file mode 100644 index 00000000..e862323c --- /dev/null +++ b/api/app/core/rag/integrations/yuque/exceptions.py @@ -0,0 +1,46 @@ +"""Exception classes for Yuque integration.""" + + +class YuqueError(Exception): + """Base exception for all Yuque-related errors.""" + + def __init__(self, message: str, error_code: str = None, details: dict = None): + super().__init__(message) + self.message = message + self.error_code = error_code + self.details = details or {} + + +class YuqueAuthError(YuqueError): + """Authentication error with Yuque API.""" + pass + + +class YuqueAPIError(YuqueError): + """General API error from Yuque.""" + pass + + +class YuqueNotFoundError(YuqueError): + """Resource not found error (404).""" + pass + + +class YuquePermissionError(YuqueError): + """Permission denied error (403).""" + pass + + +class YuqueRateLimitError(YuqueError): + """Rate limit exceeded error (429).""" + pass + + +class YuqueNetworkError(YuqueError): + """Network-related error (timeout, connection failure).""" + pass + + +class YuqueDataError(YuqueError): + """Data parsing or validation error.""" + pass diff --git a/api/app/core/rag/integrations/yuque/models.py b/api/app/core/rag/integrations/yuque/models.py new file mode 100644 index 00000000..6230aa69 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/models.py @@ -0,0 +1,42 @@ +"""Data models for Yuque integration.""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class YuqueRepoInfo: + """Repository (知识库) information from Yuque.""" + id: int # 知识库 ID + type: str # 类型 (Book:文档, Design:图集, Sheet:表格, Resource:资源) + name: str # 名称 + namespace: str # 完整路径: user/repo format + slug: str # 路径 + description: Optional[str] # 简介 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + items_count: int # 文档数量 + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + + +@dataclass +class YuqueDocInfo: + """Document information from Yuque.""" + id: int # 文档 ID + type: str # 文档类型 (Doc:普通文档, Sheet:表格, Thread:话题, Board:图集, Table:数据表) + slug: str # 路径 + title: str # 标题 + book_id: int # 归属知识库 ID + format: str # 内容格式 (markdown:Markdown 格式, lake:语雀 Lake 格式, html:HTML 标准格式, lakesheet:语雀表格) + body: Optional[str] # 正文原始内容 + body_draft: Optional[str] # 正文草稿内容 + body_html: Optional[str] # 正文 HTML 标准格式内容 + public: int # 公开性 (0:私密, 1:公开, 2:企业内公开) + status: int # 状态 (0:草稿, 1:发布) + created_at: datetime # 创建时间 + updated_at: datetime # 更新时间 + published_at: Optional[datetime] # 发布时间 + word_count: int # 内容字数 + cover: Optional[str] # 封面 + description: Optional[str] # 摘要 diff --git a/api/app/core/rag/integrations/yuque/retry.py b/api/app/core/rag/integrations/yuque/retry.py new file mode 100644 index 00000000..a68d6b47 --- /dev/null +++ b/api/app/core/rag/integrations/yuque/retry.py @@ -0,0 +1,134 @@ +"""Retry strategy for Yuque API calls.""" + +import asyncio +import functools +from typing import Callable, TypeVar +import httpx + +from app.core.rag.integrations.yuque.exceptions import ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueRateLimitError, + YuqueNetworkError, + YuqueDataError, + YuqueAPIError, +) + +T = TypeVar('T') + + +class RetryStrategy: + """Retry strategy for API calls.""" + + # Retryable error types + RETRYABLE_ERRORS = ( + YuqueNetworkError, + YuqueRateLimitError, + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + ) + + # Non-retryable error types + NON_RETRYABLE_ERRORS = ( + YuqueAuthError, + YuquePermissionError, + YuqueNotFoundError, + YuqueDataError, + ) + + # Retry configuration + MAX_RETRIES = 3 + BACKOFF_DELAYS = [1, 2, 4] # seconds + + @classmethod + def is_retryable(cls, error: Exception) -> bool: + """Check if an error is retryable.""" + # Check for specific retryable errors + if isinstance(error, cls.RETRYABLE_ERRORS): + return True + + # Check for non-retryable errors + if isinstance(error, cls.NON_RETRYABLE_ERRORS): + return False + + # Check for HTTP status codes + if isinstance(error, httpx.HTTPStatusError): + status_code = error.response.status_code + # Retry on 429 (rate limit), 503 (service unavailable), 502 (bad gateway) + if status_code in [429, 502, 503]: + return True + # Don't retry on 4xx errors (except 429) + if 400 <= status_code < 500: + return False + # Retry on 5xx errors + if 500 <= status_code < 600: + return True + + # Check for YuqueRateLimitError + if isinstance(error, YuqueRateLimitError): + return True + + return False + + @classmethod + async def execute_with_retry( + cls, + func: Callable[..., T], + *args, + **kwargs + ) -> T: + """ + Execute a function with retry logic. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Function result + + Raises: + Exception: The last exception if all retries fail + """ + last_exception = None + + for attempt in range(cls.MAX_RETRIES + 1): + try: + return await func(*args, **kwargs) + except Exception as e: + last_exception = e + + # Don't retry if not retryable + if not cls.is_retryable(e): + raise + + # Don't retry if this was the last attempt + if attempt >= cls.MAX_RETRIES: + raise + + # Wait before retrying + delay = cls.BACKOFF_DELAYS[attempt] if attempt < len(cls.BACKOFF_DELAYS) else cls.BACKOFF_DELAYS[-1] + await asyncio.sleep(delay) + + # Should not reach here, but raise last exception if we do + if last_exception: + raise last_exception + + +def with_retry(func: Callable[..., T]) -> Callable[..., T]: + """ + Decorator to add retry logic to async functions. + + Usage: + @with_retry + async def my_api_call(): + ... + """ + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await RetryStrategy.execute_with_retry(func, *args, **kwargs) + + return wrapper diff --git a/api/app/models/file_model.py b/api/app/models/file_model.py index 842e3dc8..44a7d613 100644 --- a/api/app/models/file_model.py +++ b/api/app/models/file_model.py @@ -14,4 +14,5 @@ class File(Base): file_name = Column(String, index=True, nullable=False, comment="file name or folder name,default folder name is /") file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf") file_size = Column(Integer, default=0, comment="file size(byte)") + file_url = Column(String, index=True, nullable=True, comment="file comes from a website url") created_at = Column(DateTime, default=datetime.datetime.now) \ No newline at end of file diff --git a/api/app/models/knowledge_model.py b/api/app/models/knowledge_model.py index 8f0909d3..fbebe1b4 100644 --- a/api/app/models/knowledge_model.py +++ b/api/app/models/knowledge_model.py @@ -57,6 +57,17 @@ class Knowledge(Base): parser_id = Column(String, index=True, default="naive", comment="default parser ID") parser_config = Column(JSON, nullable=False, default={ + "entry_url": "https://ai.redbearai.com", + "max_pages": 20, + "delay_seconds": 1.0, + "timeout_seconds": 10, + "user_agent": "KnowledgeBaseCrawler/1.0", + "yuque_user_id": "User ID", + "yuque_token": "Token", + "feishu_app_id": "App ID", + "feishu_app_secret": "App Secret", + "feishu_folder_token": "Folder Token", + "sync_cron": "30 7 * * 1-5", "layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n", diff --git a/api/app/schemas/file_schema.py b/api/app/schemas/file_schema.py index 00f1a148..7245671a 100644 --- a/api/app/schemas/file_schema.py +++ b/api/app/schemas/file_schema.py @@ -10,6 +10,8 @@ class FileBase(BaseModel): file_name: str file_ext: str file_size: int + file_url: str | None = None + created_at: datetime.datetime | None = None class FileCreate(FileBase): @@ -26,6 +28,7 @@ class FileUpdate(BaseModel): file_name: str | None = Field(None) file_ext: str | None = Field(None) file_size: str | None = Field(None) + file_url: str | None = Field(None) class File(FileBase): diff --git a/api/app/tasks.py b/api/app/tasks.py index a46a3a7b..29b0e485 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -7,6 +7,8 @@ import uuid from uuid import UUID from datetime import datetime, timezone from math import ceil +from pathlib import Path +import shutil from typing import Any, Dict, List, Optional import redis @@ -16,8 +18,13 @@ import trio # Import a unified Celery instance from app.celery_app import celery_app from app.core.config import settings +from app.core.rag.crawler.web_crawler import WebCrawler from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo from app.core.rag.llm.chat_model import Base from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.embedding_model import OpenAIEmbed @@ -29,7 +36,9 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ) from app.db import get_db, get_db_context from app.models.document_model import Document +from app.models.file_model import File from app.models.knowledge_model import Knowledge +from app.schemas import file_schema, document_schema from app.services.memory_agent_service import MemoryAgentService @@ -382,6 +391,480 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): db.close() +@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") +def sync_knowledge_for_kb(kb_id: uuid.UUID): + """ + sync knowledge document and Document parsing, vectorization, and storage + """ + db = next(get_db()) # Manually call the generator + db_knowledge = None + try: + db_knowledge = db.query(Knowledge).filter(Knowledge.id == kb_id).first() + # 1. get vector_service + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + # 2. sync data + match db_knowledge.type: + case "Web": # Crawl webpages in batches through a web crawler + entry_url = db_knowledge.parser_config.get("entry_url", "") + max_pages = db_knowledge.parser_config.get("max_pages", 20) + delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0) + timeout_seconds = db_knowledge.parser_config.get("timeout_seconds", 10) + user_agent = db_knowledge.parser_config.get("user_agent", "KnowledgeBaseCrawler/1.0") + # Create crawler + crawler = WebCrawler( + entry_url=entry_url, + max_pages=max_pages, + delay_seconds=delay_seconds, + timeout_seconds=timeout_seconds, + user_agent=user_agent + ) + try: + # 初始化存储已爬取 URLs 的集合 + file_urls = set() + # crawl entry_url by yield + for crawled_document in crawler.crawl(): + file_urls.add(crawled_document.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == crawled_document.url).first() + if db_file: + if db_file.file_size == crawled_document.content_length: # same + continue + else: # --update + if crawled_document.content_length: + # 1. update file + db_file.file_name = f"{crawled_document.title}.txt" + db_file.file_ext=".txt" + db_file.file_size=crawled_document.content_length + db.commit() + db.refresh(db_file) + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + if crawled_document.content_length: + # 1. upload file + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=f"{crawled_document.title}.txt", + file_ext=".txt", + file_size=crawled_document.content_length, + file_url=crawled_document.url, + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # Save file + content_bytes = crawled_document.content.encode('utf-8') + with open(save_path, "wb") as f: + f.write(content_bytes) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during crawl: {e}") + case "Third-party": # Integration of knowledge bases from three parties + yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") + feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") + if yuque_user_id: # Yuque Knowledge Base + yuque_token = db_knowledge.parser_config.get("yuque_token", "") + # Create yuqueAPIClient + api_client = YuqueAPIClient( + user_id=yuque_user_id, + token=yuque_token + ) + try: + # 初始化存储获取语雀 URLs 的集合 + file_urls = set() + + # Get all files from all repos + async def async_get_files(api_client: YuqueAPIClient): + async with api_client as client: + print("\n=== Fetching repositories ===") + repos = await client.get_user_repos() + print(f"Found {len(repos)} repositories:") + all_files = [] + for repo in repos: + # Get documents from repository + print(f"\n=== Fetching documents from '{repo.name}' ===") + docs = await client.get_repo_docs(repo.id) + all_files.extend(docs) + return all_files + + files = asyncio.run(async_get_files(api_client)) + for doc in files: + file_urls.add(doc.slug) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.slug).first() + if db_file: + if db_file.created_at == doc.updated_at: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.updated_at + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + + # download document from Feishu FileInfo + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(doc, save_dir) + return file_path + + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.slug, + created_at=doc.updated_at + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + if feishu_app_id: # Feishu Knowledge Base + feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") + feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") + # Create feishuAPIClient + api_client = FeishuAPIClient( + app_id=feishu_app_id, + app_secret=feishu_app_secret + ) + try: + # 初始化存储获取飞书 URLs 的集合 + file_urls = set() + # Get all files from folder + async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): + async with api_client as client: + files = await client.list_all_folder_files(feishu_folder_token, recursive=True) + return files + files = asyncio.run(async_get_files(api_client, feishu_folder_token)) + # Filter out folders, only sync documents + documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] + for doc in documents: + file_urls.add(doc.url) + db_file = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url == doc.url).first() + if db_file: + if db_file.created_at == doc.modified_time: # same + continue + else: # --update + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # update db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + db_file.file_name = file_name + db_file.file_ext = file_extension.lower() + db_file.file_size = file_size + db_file.created_at = doc.modified_time + db.commit() + db.refresh(db_file) + # 2. update a document + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + db_document.file_name = db_file.file_name + db_document.file_ext = db_file.file_ext + db_document.file_size = db_file.file_size + db_document.created_at = db_file.created_at + db_document.updated_at = datetime.now() + db.commit() + db.refresh(db_document) + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + else: # --add + # 1. update file + # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async with api_client as client: + file_path = await client.download_document(document=doc, save_dir=save_dir) + return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) + # add db_file + file_name = os.path.basename(file_path) + _, file_extension = os.path.splitext(file_name) + file_size = os.path.getsize(file_path) + upload_file = file_schema.FileCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + parent_id=db_knowledge.id, + file_name=file_name, + file_ext=file_extension.lower(), + file_size=file_size, + file_url=doc.url, + created_at = doc.modified_time + ) + db_file = File(**upload_file.model_dump()) + db.add(db_file) + db.commit() + # Save file + save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") + # update file + if os.path.exists(save_path): + os.remove(save_path) # Delete a single file + shutil.copyfile(file_path, save_path) + # 2. Create a document + create_document_data = document_schema.DocumentCreate( + kb_id=db_knowledge.id, + created_by=db_knowledge.created_by, + file_id=db_file.id, + file_name=db_file.file_name, + file_ext=db_file.file_ext, + file_size=db_file.file_size, + file_meta={}, + parser_id="naive", + parser_config={ + "layout_recognize": "DeepDOC", + "chunk_token_num": 128, + "delimiter": "\n", + "auto_keywords": 0, + "auto_questions": 0, + "html4excel": "false" + } + ) + db_document = Document(**create_document_data.model_dump()) + db.add(db_document) + db.commit() + # 3. Document parsing, vectorization, and storage + parse_document(file_path=save_path, document_id=db_document.id) + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete + for db_file in db_files: + db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, + Document.file_id == db_file.id).first() + if db_document: + # 1. Delete vector index + vector_service.delete_by_metadata_field(key="document_id", + value=str(db_document.id)) + # 2. Delete document + db.delete(db_document) + # 3. Delete file + file_path = Path( + settings.FILE_PATH, + str(db_file.kb_id), + str(db_file.parent_id), + f"{db_file.id}{db_file.file_ext}" + ) + if file_path.exists(): + file_path.unlink() # Delete a single file + db.delete(db_file) + # commit transaction + db.commit() + + except Exception as e: + print(f"\n\nError during fetch feishu: {e}") + case _: # General + print(f"General: No synchronization needed\n") + + + result = f"sync knowledge '{db_knowledge.name}' processed successfully." + return result + except Exception as e: + if 'db_knowledge' in locals(): + print(f"Failed to sync knowledge:{str(e)}\n") + result = f"sync knowledge '{db_knowledge.name}' failed." + return result + finally: + db.close() + + @celery_app.task(name="app.core.memory.agent.read_message", bind=True) def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]: diff --git a/api/pyproject.toml b/api/pyproject.toml index 6d23a3b9..66b1a295 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -141,6 +141,8 @@ dependencies = [ "flower>=2.0.1", "aiofiles>=23.0.0", "owlready2>=0.46", + "lxml>=4.9.0", + "httpx>=0.28.0", ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 6cdae2d1..144c0db2 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -134,3 +134,5 @@ xlrd==2.0.2 oss2>=2.18.0 boto3>=1.28.0 aiofiles>=23.0.0 +lxml>=4.9.0 +httpx>=0.28.0