import base64 from io import BytesIO import torch from fastapi import FastAPI, Query from PIL import Image from qwen_vl_utils import process_vision_info from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration app = FastAPI() checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct" min_pixels = 256 * 28 * 28 max_pixels = 1280 * 28 * 28 processor = AutoProcessor.from_pretrained( checkpoint, min_pixels=min_pixels, max_pixels=max_pixels ) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( checkpoint, torch_dtype=torch.bfloat16, device_map="auto", # attn_implementation="flash_attention_2", ) @app.get("/") def read_root(): return {"message": "API is live. Use the /predict endpoint."} def encode_image(image_path, max_size=(800, 800), quality=85): """ Converts an image from a local file path to a Base64-encoded string with optimized size. Args: image_path (str): The path to the image file. max_size (tuple): The maximum width and height of the resized image. quality (int): The compression quality (1-100, higher means better quality but bigger size). Returns: str: Base64-encoded representation of the optimized image. """ try: with Image.open(image_path) as img: # Convert to RGB (avoid issues with PNG transparency) img = img.convert("RGB") # Resize while maintaining aspect ratio img.thumbnail(max_size, Image.LANCZOS) # Save to buffer with compression buffer = BytesIO() img.save( buffer, format="JPEG", quality=quality ) # Save as JPEG to reduce size return base64.b64encode(buffer.getvalue()).decode("utf-8") except Exception as e: print(f"❌ Error encoding image {image_path}: {e}") return None @app.get("/predict") def predict(image_url: str = Query(...), prompt: str = Query(...)): image = encode_image(image_url) messages = [ { "role": "system", "content": "You are a helpful assistant with vision abilities.", }, { "role": "user", "content": [ {"type": "image", "image": f"data:image;base64,{image}"}, {"type": "text", "text": prompt}, ], }, ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(model.device) with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=128) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_texts = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return {"response": output_texts[0]}