import osimport ioimport torchimport uvicornimport requestsfrom PIL import Imagefrom fastapi import FastAPI, HTTPExceptionfrom pydantic import BaseModelfrom typing import Optional, Dict, Any, Listimport tempfileimport fitzfrom concurrent.futures import ThreadPoolExecutorimport asyncio# Set environment variablesif torch.version.cuda == '11.8': os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda-11.8/bin/ptxas"os.environ['VLLM_USE_V1'] = '0'os.environ["CUDA_VISIBLE_DEVICES"] = '0'from config import MODEL_PATH, CROP_MODE, MAX_CONCURRENCY, NUM_WORKERSfrom vllm import LLM, SamplingParamsfrom vllm.model_executor.models.registry import ModelRegistryfrom deepseek_ocr import DeepseekOCRForCausalLMfrom process.ngram_norepeat import NoRepeatNGramLogitsProcessorfrom process.image_process import DeepseekOCRProcessor# Register modelModelRegistry.register_model("DeepseekOCRForCausalLM", DeepseekOCRForCausalLM)# Initialize modelprint("Loading model...")llm = LLM( model=MODEL_PATH, hf_overrides={"architectures": ["DeepseekOCRForCausalLM"]}, block_size=256, # Memory block size for KV cache enforce_eager=False, # Use eager mode for better performance with multimodal models trust_remote_code=True, # Allow execution of code from remote repositories max_model_len=8192, # Maximum sequence length the model can handle swap_space=0, # No swapping to CPU, keeping everything on GPU max_num_seqs=max(MAX_CONCURRENCY, 100), # Maximum number of sequences to process concurrently tensor_parallel_size=1, # Number of GPUs for tensor parallelism (1 = single GPU) gpu_memory_utilization=0.9, # Use 90% of GPU memory for model execution disable_mm_preprocessor_cache=True # Disable cache for multimodal preprocessor to avoid issues)# Configure sampling parameters# NoRepeatNGramLogitsProcessor prevents repetition in generated text by tracking n-gram patternslogits_processors = [NoRepeatNGramLogitsProcessor(ngram_size=20, window_size=50, whitelist_token_ids={128821, 128822})]sampling_params = SamplingParams( temperature=0.0, # Deterministic output (greedy decoding) max_tokens=8192, # Maximum number of tokens to generate logits_processors=logits_processors, # Apply the processor to avoid repetitive text skip_special_tokens=False, # Include special tokens in the output include_stop_str_in_output=True, # Include stop strings in the output)# Initialize FastAPI appapp = FastAPI(title="DeepSeek-OCR API", version="1.0.0")class InputData(BaseModel): """ Input data model to define what types of documents to process images: Optional list of image URLs to process pdfs: Optional list of PDF URLs to process Note: At least one of these fields must be provided in a request """ images: Optional[List[str]] = None pdfs: Optional[List[str]] = Noneclass RequestData(BaseModel): """ Main request model that defines the input data and optional prompt """ input: InputData # Add prompt as an optional field with a default value prompt: str = '<image>\nFree OCR.' # Default promptclass ResponseData(BaseModel): """ Response model that returns OCR results for each input document """ output: List[str]def download_file(url: str) -> bytes: """Download file from URL""" try: response = requests.get(url, timeout=30) response.raise_for_status() return response.content except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {str(e)}")def is_pdf_file(content: bytes) -> bool: """Check if the content is a PDF file""" return content.startswith(b'%PDF')def load_image_from_bytes(image_bytes: bytes) -> Image.Image: """Load image from bytes""" try: image = Image.open(io.BytesIO(image_bytes)) return image.convert('RGB') except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}")def pdf_to_images(pdf_bytes: bytes, dpi: int = 144) -> list: """Convert PDF to images""" try: images = [] pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf") zoom = dpi / 72.0 matrix = fitz.Matrix(zoom, zoom) for page_num in range(pdf_document.page_count): page = pdf_document[page_num] pixmap = page.get_pixmap(matrix=matrix, alpha=False) img_data = pixmap.tobytes("png") img = Image.open(io.BytesIO(img_data)) images.append(img.convert('RGB')) pdf_document.close() return images except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to convert PDF to images: {str(e)}")def process_single_image_sync(image: Image.Image, prompt: str) -> Dict: # Renamed and made sync """Process a single image (synchronous function for CPU-bound work)""" try: cache_item = { "prompt": prompt, "multi_modal_data": { "image": DeepseekOCRProcessor().tokenize_with_images( images=[image], bos=True, eos=True, cropping=CROP_MODE ) }, } return cache_item except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}")async def process_items_async(items_urls: List[str], is_pdf: bool, prompt: str) -> tuple[List[Dict], List[int]]: """ Process a list of image or PDF URLs asynchronously. Downloads files concurrently, then processes images/PDF pages in a thread pool. Returns a tuple: (batch_inputs, num_results_per_input) """ loop = asyncio.get_event_loop() # 1. Download all files concurrently download_tasks = [loop.run_in_executor(None, download_file, url) for url in items_urls] contents = await asyncio.gather(*download_tasks) # 2. Prepare arguments for processing (determine if PDF/image, count pages) processing_args = [] num_results_per_input = [] for idx, (url, content) in enumerate(zip(items_urls, contents)): if is_pdf: if not is_pdf_file(content): raise HTTPException(status_code=400, detail=f"Provided file is not a PDF: {url}") images = pdf_to_images(content) num_pages = len(images) num_results_per_input.append(num_pages) # Each page will be processed separately processing_args.extend([(img, prompt) for img in images]) else: # is image if is_pdf_file(content): # Handle case where an image URL accidentally points to a PDF images = pdf_to_images(content) num_pages = len(images) num_results_per_input.append(num_pages) processing_args.extend([(img, prompt) for img in images]) else: image = load_image_from_bytes(content) num_results_per_input.append(1) processing_args.append((image, prompt)) # 3. Process images/PDF pages in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: # Submit all processing tasks process_tasks = [ loop.run_in_executor(executor, process_single_image_sync, img, prompt) for img, prompt in processing_args ] # Wait for all to complete processed_results = await asyncio.gather(*process_tasks) return processed_results, num_results_per_inputasync def run_inference(batch_inputs: List[Dict]) -> List: """Run inference on batch inputs""" if not batch_inputs: return [] try: # Run inference on the entire batch outputs_list = llm.generate( batch_inputs, sampling_params=sampling_params ) return outputs_list except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to run inference: {str(e)}")@app.post("/ocr_batch", response_model=ResponseData)async def ocr_batch_inference(request: RequestData): """ Main OCR batch processing endpoint Accepts a list of image URLs and/or PDF URLs for OCR processing Returns a list of OCR results corresponding to each input document Supports both individual image processing and PDF-to-image conversion """ print(f"Received request data: {request}") try: input_data = request.input prompt = request.prompt # Get the prompt from the request if not input_data.images and not input_data.pdfs: raise HTTPException(status_code=400, detail="Either 'images' or 'pdfs' (or both) must be provided as lists.") all_batch_inputs = [] final_output_parts = [] # Process images if provided if input_data.images: batch_inputs_images, counts_images = await process_items_async(input_data.images, is_pdf=False, prompt=prompt) all_batch_inputs.extend(batch_inputs_images) final_output_parts.append(counts_images) # Process PDFs if provided if input_data.pdfs: batch_inputs_pdfs, counts_pdfs = await process_items_async(input_data.pdfs, is_pdf=True, prompt=prompt) all_batch_inputs.extend(batch_inputs_pdfs) final_output_parts.append(counts_pdfs) if not all_batch_inputs: raise HTTPException(status_code=400, detail="No valid images or PDF pages were processed from the input URLs.") # Run inference on the combined batch outputs_list = await run_inference(all_batch_inputs) # Reconstruct final output list based on counts final_outputs = [] output_idx = 0 # Flatten the counts list all_counts = [count for sublist in final_output_parts for count in sublist] for count in all_counts: # Get 'count' number of outputs for this input input_outputs = outputs_list[output_idx : output_idx + count] output_texts = [] for output in input_outputs: content = output.outputs[0].text if '<|end▁of▁sentence|>' in content: content = content.replace('<|end▁of▁sentence|>', '') output_texts.append(content) # Combine pages if it was a multi-page PDF input (or image treated as PDF) if count > 1: combined_text = "\n<--- Page Split --->\n".join(output_texts) final_outputs.append(combined_text) else: # Single image or single-page PDF final_outputs.append(output_texts[0] if output_texts else "") output_idx += count # Move to the next set of outputs return ResponseData(output=final_outputs) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")@app.get("/health")async def health_check(): """Health check endpoint""" return {"status": "healthy"}@app.get("/")async def root(): """Root endpoint""" return {"message": "DeepSeek-OCR API is running (Batch endpoint available at /ocr_batch)"}if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
评论