Factory-POC / app.py
Do0rMaMu's picture
Update app.py
d6ad1d3 verified
from fastapi import FastAPI, UploadFile, File, Form
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
import io
import os
from typing import Union
# Patch to remove flash-attn dependency
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
"""Work around for flash-attn imports."""
if not str(filename).endswith("/modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
if "flash_attn" in imports:
imports.remove("flash_attn")
return imports
device = "cuda" if torch.cuda.is_available() else "cpu"
# Apply the patch
from unittest.mock import patch
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained("numberPlate_model_2", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("numberPlate_model_2", trust_remote_code=True)
# Initialize FastAPI
app = FastAPI()
def process_image(image, task_token):
inputs = processor(text=task_token, images=image, return_tensors="pt", padding=True).to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=256,
num_beams=2,
do_sample=False
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_result = processor.post_process_generation(generated_text, task=task_token, image_size=(image.width, image.height))
return parsed_result
@app.post("/process-image/")
async def process_image_endpoint(file: UploadFile = File(...), task_token: str = Form(" ")):
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
result = process_image(image, task_token)
return result