Spaces:
Sleeping
Sleeping
File size: 1,950 Bytes
177e69b 20ed9e7 177e69b ce47c87 20ed9e7 177e69b 20ed9e7 ce47c87 177e69b ce47c87 177e69b ce47c87 177e69b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import io
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the model and processor
try:
model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded")
model.load_adapter('blip-cpu-model')
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
raise RuntimeError(f"Failed to load the model or processor: {str(e)}")
@app.post("/generate-caption/")
async def generate_caption(file: UploadFile = File(...)):
try:
image = Image.open(io.BytesIO(await file.read()))
except UnidentifiedImageError:
# Raise a 400 error if the file is not a valid image
raise HTTPException(status_code=400, detail="Uploaded file is not a valid image.")
except Exception as e:
# Catch any other unexpected errors related to image processing
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while processing the image: {str(e)}")
try:
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
with torch.no_grad():
caption_ids = model.generate(**inputs, max_length=128)
caption = processor.decode(caption_ids[0], skip_special_tokens=True)
return {"caption": caption}
except Exception as e:
# Catch any errors during the caption generation process
raise HTTPException(status_code=500, detail=f"An error occurred while generating the caption: {str(e)}") |