Spaces:
Running
Running
import base64 | |
from io import BytesIO | |
import torch | |
from fastapi import Body, FastAPI, File, HTTPException, Query, UploadFile | |
from PIL import Image | |
from pydantic import BaseModel | |
from qwen_vl_utils import process_vision_info | |
from transformers import ( | |
AutoProcessor, | |
Qwen2_5_VLForConditionalGeneration, | |
Qwen2VLForConditionalGeneration, | |
) | |
app = FastAPI() | |
# Define request model | |
class PredictRequest(BaseModel): | |
image_base64: str | |
prompt: str | |
# checkpoint = "Qwen/Qwen2-VL-2B-Instruct" | |
# min_pixels = 256 * 28 * 28 | |
# max_pixels = 1280 * 28 * 28 | |
# processor = AutoProcessor.from_pretrained( | |
# checkpoint, min_pixels=min_pixels, max_pixels=max_pixels | |
# ) | |
# model = Qwen2VLForConditionalGeneration.from_pretrained( | |
# checkpoint, | |
# torch_dtype=torch.bfloat16, | |
# device_map="auto", | |
# # attn_implementation="flash_attention_2", | |
# ) | |
checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct" | |
min_pixels = 256 * 28 * 28 | |
max_pixels = 1280 * 28 * 28 | |
processor = AutoProcessor.from_pretrained( | |
checkpoint, min_pixels=min_pixels, max_pixels=max_pixels | |
) | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
checkpoint, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
# attn_implementation="flash_attention_2", | |
) | |
def read_root(): | |
return {"message": "API is live. Use the /predict endpoint."} | |
# def encode_image(image_path, max_size=(800, 800), quality=85): | |
# """ | |
# Converts an image from a local file path to a Base64-encoded string with optimized size. | |
# Args: | |
# image_path (str): The path to the image file. | |
# max_size (tuple): The maximum width and height of the resized image. | |
# quality (int): The compression quality (1-100, higher means better quality but bigger size). | |
# Returns: | |
# str: Base64-encoded representation of the optimized image. | |
# """ | |
# try: | |
# with Image.open(image_path) as img: | |
# # Convert to RGB (avoid issues with PNG transparency) | |
# img = img.convert("RGB") | |
# # Resize while maintaining aspect ratio | |
# img.thumbnail(max_size, Image.LANCZOS) | |
# # Save to buffer with compression | |
# buffer = BytesIO() | |
# img.save( | |
# buffer, format="JPEG", quality=quality | |
# ) # Save as JPEG to reduce size | |
# return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
# except Exception as e: | |
# print(f"❌ Error encoding image {image_path}: {e}") | |
# return None | |
def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85): | |
""" | |
Converts an image from file data to a Base64-encoded string with optimized size. | |
""" | |
try: | |
with Image.open(image_data) as img: | |
img = img.convert("RGB") | |
img.thumbnail(max_size, Image.LANCZOS) | |
buffer = BytesIO() | |
img.save(buffer, format="JPEG", quality=quality) | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error encoding image: {e}") | |
async def upload_and_encode_image(file: UploadFile = File(...)): | |
""" | |
Endpoint to upload an image file and return its Base64-encoded representation. | |
""" | |
try: | |
image_data = BytesIO(await file.read()) | |
encoded_string = encode_image(image_data) | |
return {"filename": file.filename, "encoded_image": encoded_string} | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Invalid file: {e}") | |
def predict(data: PredictRequest): | |
""" | |
Generates a description for an image using the Qwen-2-VL model. | |
Args: | |
data (PredictRequest): The request containing encoded images and a prompt. | |
Returns: | |
dict: The generated description of the image(s). | |
""" | |
# Ensure image_base64 is a list (even if a single image is provided) | |
image_list = ( | |
data.image_base64 | |
if isinstance(data.image_base64, list) | |
else [data.image_base64] | |
) | |
# Create the input message structure with multiple images | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": f"data:image;base64,{image}"} | |
for image in image_list | |
] | |
+ [{"type": "text", "text": data.prompt}], | |
} | |
] | |
# Prepare inputs for the model | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(model.device) | |
# Generate the output | |
generated_ids = model.generate(**inputs, max_new_tokens=2056) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] | |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
) | |
return {"response": output_text[0] if output_text else "No description generated."} | |
# @app.get("/predict") | |
# def predict(image_url: str = Query(...), prompt: str = Query(...)): | |
# image = encode_image(image_url) | |
# messages = [ | |
# { | |
# "role": "system", | |
# "content": "You are a helpful assistant with vision abilities.", | |
# }, | |
# { | |
# "role": "user", | |
# "content": [ | |
# {"type": "image", "image": f"data:image;base64,{image}"}, | |
# {"type": "text", "text": prompt}, | |
# ], | |
# }, | |
# ] | |
# text = processor.apply_chat_template( | |
# messages, tokenize=False, add_generation_prompt=True | |
# ) | |
# image_inputs, video_inputs = process_vision_info(messages) | |
# inputs = processor( | |
# text=[text], | |
# images=image_inputs, | |
# videos=video_inputs, | |
# padding=True, | |
# return_tensors="pt", | |
# ).to(model.device) | |
# with torch.no_grad(): | |
# generated_ids = model.generate(**inputs, max_new_tokens=128) | |
# generated_ids_trimmed = [ | |
# out_ids[len(in_ids) :] | |
# for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
# ] | |
# output_texts = processor.batch_decode( | |
# generated_ids_trimmed, | |
# skip_special_tokens=True, | |
# clean_up_tokenization_spaces=False, | |
# ) | |
# return {"response": output_texts[0]} | |