test_space / app.py
anmoldograpsl's picture
Update app.py
10271df verified
raw
history blame
1.99 kB
import os
import base64
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType
from huggingface_hub import login
# Step 1: Log in to Hugging Face
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# Step 2: Load the private model and processor
model_name = "anushettypsl/paligemma_vqav2" # Replace with the actual model link
processor = AutoProcessor.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(model_name)
# Step 3: Set up PEFT configuration (if needed)
lora_config = LoraConfig(
r=16, # Rank
lora_alpha=32, # Scaling factor
lora_dropout=0.1, # Dropout
task_type=TaskType.CAUSAL_LM, # Adjust according to your model's task
)
# Step 4: Get the PEFT model
peft_model = get_peft_model(base_model, lora_config)
# Step 5: Define the prediction function
def predict(image_base64, prompt):
# Decode the base64 image
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
# Process the image
inputs = processor( text=prompt,images=image, return_tensors="pt")
# Generate output using the model
with torch.no_grad():
output = peft_model.generate(**inputs)
# Decode the output to text
generated_text = processor.decode(output[0], skip_special_tokens=True)
return generated_text
# Step 6: Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Image (Base64)", placeholder="Enter base64 encoded image here...", lines=10), # Base64 input for image
gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") # Prompt input
],
outputs="text", # Text output
title="Image and Prompt to Text Model",
description="Enter a base64 encoded image and a prompt to generate a descriptive text."
)
# Step 7: Launch the Gradio app
interface.launch()