Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends | |
from typing import Optional | |
from fastapi.responses import StreamingResponse, JSONResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from huggingface_hub import InferenceClient | |
from pydantic import BaseModel, ConfigDict | |
import os | |
from base64 import b64encode | |
from io import BytesIO | |
from PIL import Image, ImageEnhance | |
import logging | |
import pytesseract | |
import time | |
# Set Tesseract CMD path for Windows | |
#pytesseract.pytesseract.tesseract_cmd = r"F:\Python-files\tesseract\tesseract.exe" | |
app = FastAPI() | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Default model | |
DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" | |
# Initialize Jinja2 templates | |
templates = Jinja2Templates(directory="templates") | |
class TextRequest(BaseModel): | |
model_config = ConfigDict(protected_namespaces=()) | |
query: str | |
stream: bool = False | |
model_name: Optional[str] = None | |
class ImageTextRequest(BaseModel): | |
model_config = ConfigDict(protected_namespaces=()) | |
query: str | |
stream: bool = False | |
model_name: Optional[str] = None | |
def as_form( | |
cls, | |
query: str = Form(...), | |
stream: bool = Form(False), | |
model_name: Optional[str] = Form(None), | |
image: UploadFile = File(...) # Make image required for i2t2t | |
): | |
return cls( | |
query=query, | |
stream=stream, | |
model_name=model_name | |
), image | |
def get_client(model_name: Optional[str] = None): | |
"""Get inference client for specified model or default model""" | |
try: | |
model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL | |
return InferenceClient( | |
model=model_path | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Error initializing model {model_path}: {str(e)}" | |
) | |
def generate_text_response(query: str, model_name: Optional[str] = None): | |
messages = [{ | |
"role": "user", | |
"content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}" | |
}] | |
try: | |
client = get_client(model_name) | |
for message in client.chat_completion( | |
messages, | |
max_tokens=2048, | |
stream=True | |
): | |
token = message.choices[0].delta.content | |
yield token | |
except Exception as e: | |
yield f"Error generating response: {str(e)}" | |
def generate_image_text_response(query: str, image_data: str, model_name: Optional[str] = None): | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"}, | |
{"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}} | |
] | |
} | |
] | |
logger.debug(f"Messages sent to API: {messages}") | |
try: | |
client = get_client(model_name) | |
for message in client.chat_completion(messages, max_tokens=2048, stream=True): | |
logger.debug(f"Received message chunk: {message}") | |
token = message.choices[0].delta.content | |
yield token | |
except Exception as e: | |
logger.error(f"Error in generate_image_text_response: {str(e)}") | |
yield f"Error generating response: {str(e)}" | |
def preprocess_image(img): | |
"""Enhance image for better OCR results""" | |
# Convert to grayscale | |
img = img.convert('L') | |
# Enhance contrast | |
enhancer = ImageEnhance.Contrast(img) | |
img = enhancer.enhance(2.0) | |
# Enhance sharpness | |
enhancer = ImageEnhance.Sharpness(img) | |
img = enhancer.enhance(1.5) | |
return img | |
async def root(): | |
return {"message": "Welcome to FastAPI server!"} | |
async def text_to_text(request: TextRequest): | |
try: | |
if request.stream: | |
return StreamingResponse( | |
generate_text_response(request.query, request.model_name), | |
media_type="text/event-stream" | |
) | |
else: | |
response = "" | |
for chunk in generate_text_response(request.query, request.model_name): | |
response += chunk | |
return {"response": response} | |
except Exception as e: | |
logger.error(f"Error in /t2t endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def image_text_to_text(form_data: tuple[ImageTextRequest, UploadFile] = Depends(ImageTextRequest.as_form)): | |
form, image = form_data | |
try: | |
# Process image | |
contents = await image.read() | |
try: | |
logger.debug("Attempting to open image") | |
img = Image.open(BytesIO(contents)) | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
buffer = BytesIO() | |
img.save(buffer, format="PNG") | |
image_data = b64encode(buffer.getvalue()).decode('utf-8') | |
logger.debug("Image processed and encoded to base64") | |
except Exception as img_error: | |
logger.error(f"Error processing image: {str(img_error)}") | |
raise HTTPException( | |
status_code=422, | |
detail=f"Error processing image: {str(img_error)}" | |
) | |
if form.stream: | |
return StreamingResponse( | |
generate_image_text_response(form.query, image_data, form.model_name), | |
media_type="text/event-stream" | |
) | |
else: | |
response = "" | |
for chunk in generate_image_text_response(form.query, image_data, form.model_name): | |
response += chunk | |
return {"response": response} | |
except Exception as e: | |
logger.error(f"Error in /i2t2t endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def ocr_endpoint(image: UploadFile = File(...)): | |
try: | |
# Read and process the image | |
contents = await image.read() | |
img = Image.open(BytesIO(contents)) | |
# Preprocess the image | |
img = preprocess_image(img) | |
# Perform OCR with timeout and retries | |
max_retries = 3 | |
text = "" | |
for attempt in range(max_retries): | |
try: | |
text = pytesseract.image_to_string( | |
img, | |
timeout=30, # 30 second timeout | |
config='--oem 3 --psm 6' | |
) | |
break | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error extracting text: {str(e)}" | |
) | |
time.sleep(1) # Wait before retry | |
return {"text": text} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error processing image: {str(e)}" | |
) | |
async def api_guide(): | |
html_content = ''' | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>API Documentation</title> | |
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/tailwind.min.css" rel="stylesheet"> | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/themes/prism-tomorrow.min.css"> | |
<style> | |
.copy-button { | |
position: absolute; | |
top: 8px; | |
right: 8px; | |
padding: 4px 8px; | |
background: #2d3748; | |
border: 1px solid #4a5568; | |
border-radius: 4px; | |
color: #cbd5e0; | |
font-size: 12px; | |
cursor: pointer; | |
transition: all 0.2s; | |
} | |
.copy-button:hover { | |
background: #4a5568; | |
} | |
.code-block { | |
position: relative; | |
margin: 1rem 0; | |
} | |
.endpoint-card { | |
background: #1a202c; | |
border-radius: 8px; | |
margin-bottom: 2rem; | |
padding: 1.5rem; | |
} | |
.language-tab { | |
cursor: pointer; | |
padding: 0.5rem 1rem; | |
border-radius: 4px 4px 0 0; | |
} | |
.language-tab.active { | |
background: #2d3748; | |
color: #fff; | |
} | |
</style> | |
</head> | |
<body class="bg-gray-900 text-gray-100 min-h-screen p-8"> | |
<div class="max-w-6xl mx-auto"> | |
<h1 class="text-4xl font-bold mb-8">API Documentation</h1> | |
<!-- T2T Endpoint --> | |
<div class="endpoint-card"> | |
<h2 class="text-2xl font-semibold mb-4">Text-to-Text Endpoint</h2> | |
<p class="mb-4 text-gray-400">Endpoint for general text queries</p> | |
<p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /t2t</span></p> | |
<div class="code-block"> | |
<div class="flex mb-2"> | |
<div class="language-tab active" data-lang="curl">cURL</div> | |
<div class="language-tab" data-lang="python">Python</div> | |
<div class="language-tab" data-lang="javascript">JavaScript</div> | |
<div class="language-tab" data-lang="node">Node.js</div> | |
</div> | |
<pre><code class="language-bash">curl -X POST "http://localhost:8000/t2t" \ | |
-H "Content-Type: application/json" \ | |
-d '{"query": "What is FastAPI?", "stream": false}'</code></pre> | |
<button class="copy-button">Copy</button> | |
</div> | |
</div> | |
<!-- I2T2T Endpoint --> | |
<div class="endpoint-card"> | |
<h2 class="text-2xl font-semibold mb-4">Image and Text to Text Endpoint</h2> | |
<p class="mb-4 text-gray-400">Endpoint for queries about images</p> | |
<p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /i2t2t</span></p> | |
<div class="code-block"> | |
<div class="flex mb-2"> | |
<div class="language-tab active" data-lang="curl">cURL</div> | |
<div class="language-tab" data-lang="python">Python</div> | |
<div class="language-tab" data-lang="javascript">JavaScript</div> | |
<div class="language-tab" data-lang="node">Node.js</div> | |
</div> | |
<pre><code class="language-bash">curl -X POST "http://localhost:8000/i2t2t" \ | |
-F "query=Describe this image" \ | |
-F "stream=false" \ | |
-F "image=@/path/to/your/image.jpg"</code></pre> | |
<button class="copy-button">Copy</button> | |
</div> | |
</div> | |
<!-- TES Endpoint --> | |
<div class="endpoint-card"> | |
<h2 class="text-2xl font-semibold mb-4">OCR Endpoint</h2> | |
<p class="mb-4 text-gray-400">Extract text from images using OCR</p> | |
<p class="mb-2 text-gray-300"><span class="font-mono bg-gray-800 px-2 py-1 rounded">POST /tes</span></p> | |
<div class="code-block"> | |
<div class="flex mb-2"> | |
<div class="language-tab active" data-lang="curl">cURL</div> | |
<div class="language-tab" data-lang="python">Python</div> | |
<div class="language-tab" data-lang="javascript">JavaScript</div> | |
<div class="language-tab" data-lang="node">Node.js</div> | |
</div> | |
<pre><code class="language-bash">curl -X POST "http://localhost:8000/tes" \ | |
-F "image=@/path/to/your/image.jpg"</code></pre> | |
<button class="copy-button">Copy</button> | |
</div> | |
</div> | |
</div> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/prism.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-python.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-javascript.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.24.1/components/prism-bash.min.js"></script> | |
<script> | |
const codeExamples = { | |
't2t': { | |
'curl': `curl -X POST "http://localhost:8000/t2t" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{"query": "What is FastAPI?", "stream": false}'`, | |
'python': `import requests | |
url = "http://localhost:8000/t2t" | |
payload = { | |
"query": "What is FastAPI?", | |
"stream": False | |
} | |
response = requests.post(url, json=payload) | |
print(response.json())`, | |
'javascript': `// Using fetch | |
fetch("http://localhost:8000/t2t", { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
}, | |
body: JSON.stringify({ | |
query: "What is FastAPI?", | |
stream: false | |
}) | |
}) | |
.then(response => response.json()) | |
.then(data => console.log(data));`, | |
'node': `const axios = require('axios'); | |
async function makeRequest() { | |
try { | |
const response = await axios.post('http://localhost:8000/t2t', { | |
query: "What is FastAPI?", | |
stream: false | |
}); | |
console.log(response.data); | |
} catch (error) { | |
console.error(error); | |
} | |
} | |
makeRequest();` | |
}, | |
'i2t2t': { | |
'curl': `curl -X POST "http://localhost:8000/i2t2t" \\ | |
-F "query=Describe this image" \\ | |
-F "stream=false" \\ | |
-F "image=@/path/to/your/image.jpg"`, | |
'python': `import requests | |
url = "http://localhost:8000/i2t2t" | |
files = { | |
'image': ('image.jpg', open('path/to/image.jpg', 'rb')), | |
} | |
data = { | |
'query': 'Describe this image', | |
'stream': 'false' | |
} | |
response = requests.post(url, files=files, data=data) | |
print(response.json())`, | |
'javascript': `const formData = new FormData(); | |
formData.append('image', imageFile); | |
formData.append('query', 'Describe this image'); | |
formData.append('stream', 'false'); | |
fetch("http://localhost:8000/i2t2t", { | |
method: "POST", | |
body: formData | |
}) | |
.then(response => response.json()) | |
.then(data => console.log(data));`, | |
'node': `const axios = require('axios'); | |
const FormData = require('form-data'); | |
const fs = require('fs'); | |
async function makeRequest() { | |
try { | |
const formData = new FormData(); | |
formData.append('image', fs.createReadStream('path/to/image.jpg')); | |
formData.append('query', 'Describe this image'); | |
formData.append('stream', 'false'); | |
const response = await axios.post('http://localhost:8000/i2t2t', formData, { | |
headers: formData.getHeaders() | |
}); | |
console.log(response.data); | |
} catch (error) { | |
console.error(error); | |
} | |
} | |
makeRequest();` | |
}, | |
'tes': { | |
'curl': `curl -X POST "http://localhost:8000/tes" \\ | |
-F "image=@/path/to/your/image.jpg"`, | |
'python': `import requests | |
url = "http://localhost:8000/tes" | |
files = { | |
'image': ('image.jpg', open('path/to/image.jpg', 'rb')) | |
} | |
response = requests.post(url, files=files) | |
print(response.json())`, | |
'javascript': `const formData = new FormData(); | |
formData.append('image', imageFile); | |
fetch("http://localhost:8000/tes", { | |
method: "POST", | |
body: formData | |
}) | |
.then(response => response.json()) | |
.then(data => console.log(data));`, | |
'node': `const axios = require('axios'); | |
const FormData = require('form-data'); | |
const fs = require('fs'); | |
async function makeRequest() { | |
try { | |
const formData = new FormData(); | |
formData.append('image', fs.createReadStream('path/to/image.jpg')); | |
const response = await axios.post('http://localhost:8000/tes', formData, { | |
headers: formData.getHeaders() | |
}); | |
console.log(response.data); | |
} catch (error) { | |
console.error(error); | |
} | |
} | |
makeRequest();` | |
} | |
}; | |
// Handle language tab switching | |
document.querySelectorAll('.language-tab').forEach(tab => { | |
tab.addEventListener('click', () => { | |
const lang = tab.dataset.lang; | |
const codeBlock = tab.closest('.endpoint-card'); | |
const endpoint = codeBlock.querySelector('h2').textContent.toLowerCase().includes('ocr') ? 'tes' : | |
codeBlock.querySelector('h2').textContent.toLowerCase().includes('image') ? 'i2t2t' : 't2t'; | |
// Update active tab | |
codeBlock.querySelectorAll('.language-tab').forEach(t => t.classList.remove('active')); | |
tab.classList.add('active'); | |
// Update code content | |
const code = codeBlock.querySelector('code'); | |
code.textContent = codeExamples[endpoint][lang]; | |
code.className = `language-${lang === 'curl' ? 'bash' : lang}`; | |
Prism.highlightElement(code); | |
}); | |
}); | |
// Handle copy buttons | |
document.querySelectorAll('.copy-button').forEach(button => { | |
button.addEventListener('click', () => { | |
const code = button.previousElementSibling.textContent; | |
navigator.clipboard.writeText(code); | |
// Show feedback | |
const originalText = button.textContent; | |
button.textContent = 'Copied!'; | |
setTimeout(() => { | |
button.textContent = originalText; | |
}, 2000); | |
}); | |
}); | |
</script> | |
</body> | |
</html> | |
''' | |
return HTMLResponse(content=html_content) | |