Ss-mol / main.py
zainimam's picture
Updated main.py
3cdb7c9 verified
raw
history blame
1.5 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
import torch
# Define the FastAPI app
app = FastAPI()
# Initialize model and processor at startup
processor = AutoProcessor.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('allenai/Molmo-7B-D-0924', trust_remote_code=True)
# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Request body structure
class GenerateRequest(BaseModel):
image_url: str
text_input: str
# API root endpoint
@app.get("/")
def root():
return {"message": "Molmo-7B-D API is up and running!"}
# Text generation endpoint
@app.post("/generate/")
def generate_text(request: GenerateRequest):
try:
# Fetch image from URL
response = requests.get(request.image_url, stream=True)
image = Image.open(response.raw)
# Preprocess inputs
inputs = processor(images=[image], text=request.text_input, return_tensors="pt").to(device)
# Generate text
output_ids = model.generate(inputs["input_ids"], max_new_tokens=200)
generated_text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"generated_text": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))