sachin
commited on
Commit
·
75bbaa5
1
Parent(s):
badf26d
add-ocr
Browse files- src/server/main.py +97 -1
src/server/main.py
CHANGED
@@ -9,7 +9,7 @@ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, Uploa
|
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
from PIL import Image
|
12 |
-
from pydantic import BaseModel, field_validator
|
13 |
from pydantic_settings import BaseSettings
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
@@ -26,6 +26,10 @@ from starlette.responses import StreamingResponse
|
|
26 |
from logging_config import logger
|
27 |
from tts_config import SPEED, ResponseFormat, config as tts_config
|
28 |
import torchaudio
|
|
|
|
|
|
|
|
|
29 |
|
30 |
# Device setup
|
31 |
if torch.cuda.is_available():
|
@@ -296,6 +300,14 @@ class SynthesizeRequest(BaseModel):
|
|
296 |
class KannadaSynthesizeRequest(BaseModel):
|
297 |
text: str
|
298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
# TTS Functions
|
300 |
def load_audio_from_url(url: str):
|
301 |
response = requests.get(url)
|
@@ -762,6 +774,90 @@ async def visual_query(
|
|
762 |
logger.error(f"Error processing request: {str(e)}")
|
763 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
764 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
765 |
@app.post("/v1/chat_v2", response_model=ChatResponse)
|
766 |
@limiter.limit(settings.chat_rate_limit)
|
767 |
async def chat_v2(
|
|
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
11 |
from PIL import Image
|
12 |
+
from pydantic import BaseModel, field_validator, Field
|
13 |
from pydantic_settings import BaseSettings
|
14 |
from slowapi import Limiter
|
15 |
from slowapi.util import get_remote_address
|
|
|
26 |
from logging_config import logger
|
27 |
from tts_config import SPEED, ResponseFormat, config as tts_config
|
28 |
import torchaudio
|
29 |
+
import base64
|
30 |
+
from io import BytesIO
|
31 |
+
from pypdf import PdfReader
|
32 |
+
from olmocr.data.renderpdf import render_pdf_to_base64png
|
33 |
|
34 |
# Device setup
|
35 |
if torch.cuda.is_available():
|
|
|
300 |
class KannadaSynthesizeRequest(BaseModel):
|
301 |
text: str
|
302 |
|
303 |
+
class ExtractTextRequest(BaseModel):
|
304 |
+
page_number: int = Field(
|
305 |
+
default=1,
|
306 |
+
description="The page number to extract text from (1-based indexing). Must be a positive integer.",
|
307 |
+
ge=1,
|
308 |
+
example=1
|
309 |
+
)
|
310 |
+
|
311 |
# TTS Functions
|
312 |
def load_audio_from_url(url: str):
|
313 |
response = requests.get(url)
|
|
|
774 |
logger.error(f"Error processing request: {str(e)}")
|
775 |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
776 |
|
777 |
+
@app.post(
|
778 |
+
"/v1/extract-text-visual-query/",
|
779 |
+
response_model=dict,
|
780 |
+
summary="Extract text from a PDF page using visual query",
|
781 |
+
description=(
|
782 |
+
"Extracts text from a specific page of a PDF file by rendering it as an image and processing it with the internal vision query model. "
|
783 |
+
"The query 'describe the image' is used to generate a description of the page content."
|
784 |
+
),
|
785 |
+
response_description="A JSON object containing the extracted text from the specified page."
|
786 |
+
)
|
787 |
+
async def extract_text_visual_query(
|
788 |
+
file: UploadFile = File(..., description="The PDF file to process. Must be a valid PDF."),
|
789 |
+
page_number: int = Body(
|
790 |
+
default=1,
|
791 |
+
embed=True,
|
792 |
+
description=ExtractTextRequest.model_fields["page_number"].description,
|
793 |
+
ge=1,
|
794 |
+
example=1
|
795 |
+
)
|
796 |
+
):
|
797 |
+
"""
|
798 |
+
Extract text from a specific page of a PDF file using the internal vision query model.
|
799 |
+
|
800 |
+
Args:
|
801 |
+
file (UploadFile): The PDF file to process.
|
802 |
+
page_number (int): The page number to extract text from (1-based indexing). Defaults to 1.
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
JSONResponse: A dictionary containing:
|
806 |
+
- page_content: The extracted text from the specified page via the vision query model.
|
807 |
+
|
808 |
+
Raises:
|
809 |
+
HTTPException: If the file is not a PDF, the page number is invalid, or processing fails.
|
810 |
+
|
811 |
+
Example:
|
812 |
+
```json
|
813 |
+
{"page_content": "Here’s a summary of the page in one sentence:\\n\\nThe page displays..."}
|
814 |
+
```
|
815 |
+
"""
|
816 |
+
try:
|
817 |
+
# Validate file type
|
818 |
+
if not file.filename.lower().endswith(".pdf"):
|
819 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
820 |
+
|
821 |
+
# Save the uploaded PDF to a temporary file
|
822 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
|
823 |
+
temp_file.write(await file.read())
|
824 |
+
temp_file_path = temp_file.name
|
825 |
+
|
826 |
+
# Render the specified page to an image
|
827 |
+
try:
|
828 |
+
image_base64 = render_pdf_to_base64png(
|
829 |
+
temp_file_path, page_number, target_longest_image_dim=1024
|
830 |
+
)
|
831 |
+
except Exception as e:
|
832 |
+
raise HTTPException(status_code=500, detail=f"Failed to render PDF page: {str(e)}")
|
833 |
+
|
834 |
+
# Decode base64 image to PIL Image
|
835 |
+
try:
|
836 |
+
image_bytes = base64.b64decode(image_base64)
|
837 |
+
image = Image.open(BytesIO(image_bytes))
|
838 |
+
except Exception as e:
|
839 |
+
raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}")
|
840 |
+
|
841 |
+
# Process image with vision query
|
842 |
+
try:
|
843 |
+
page_content = await llm_manager.vision_query(image, "describe the image")
|
844 |
+
except Exception as e:
|
845 |
+
raise HTTPException(status_code=500, detail=f"Vision query processing failed: {str(e)}")
|
846 |
+
|
847 |
+
# Clean up temporary file
|
848 |
+
os.remove(temp_file_path)
|
849 |
+
|
850 |
+
return JSONResponse(content={"page_content": page_content})
|
851 |
+
|
852 |
+
except Exception as e:
|
853 |
+
# Clean up in case of error
|
854 |
+
if 'temp_file_path' in locals():
|
855 |
+
try:
|
856 |
+
os.remove(temp_file_path)
|
857 |
+
except:
|
858 |
+
pass
|
859 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
860 |
+
|
861 |
@app.post("/v1/chat_v2", response_model=ChatResponse)
|
862 |
@limiter.limit(settings.chat_rate_limit)
|
863 |
async def chat_v2(
|