Schach-Notation / caption.py
Chesscorner's picture
Rename app.py to caption.py
02530de verified
raw
history blame
1.33 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
from PIL import Image
from caption import predict_step
with gr.Blocks() as demo:
image = gr.Image(type='pil', label='Image')
label = gr.Text(label='Generated Caption')
image.upload(
predict_step,
[image],
[label]
)
model = AutoModelForCausalLM.from_pretrained("Chesscorner/git-chess-v3")
processor = AutoProcessor.from_pretrained("Chesscorner/git-chess-v3")
# Set up device and move model to it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Enable mixed precision if on GPU
use_fp16 = device.type == "cuda"
if use_fp16:
model.half()
# Set generation parameters
gen_kwargs = {'max_length': 100, 'num_beams': 2} # Adjust num_beams if needed
# Prediction function
def predict_step(image):
# Preprocess the image
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
# Generate predictions with no_grad for efficiency
with torch.no_grad():
output_ids = model.generate(pixel_values=pixel_values, **gen_kwargs)
# Decode predictions
preds = processor.batch_decode(output_ids, skip_special_tokens=True)
return preds[0].strip()
if __name__ == '__main__':
demo.launch(share=True)