Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from peft import get_peft_model, LoraConfig, TaskType | |
from huggingface_hub import login | |
# Step 1: Log in to Hugging Face | |
hf_token = os.getenv("HF_TOKEN") | |
login(token=hf_token) | |
# Step 2: Load the private model and processor | |
model_name = "anushettypsl/paligemma_vqav2" # Replace with the actual model link | |
processor = AutoProcessor.from_pretrained(model_name) | |
base_model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Step 3: Set up PEFT configuration (if needed) | |
lora_config = LoraConfig( | |
r=16, # Rank | |
lora_alpha=32, # Scaling factor | |
lora_dropout=0.1, # Dropout | |
task_type=TaskType.CAUSAL_LM, # Adjust according to your model's task | |
) | |
# Step 4: Get the PEFT model | |
peft_model = get_peft_model(base_model, lora_config) | |
# Step 5: Define the prediction function | |
def predict(image_base64, prompt): | |
# Decode the base64 image | |
image_data = base64.b64decode(image_base64) | |
image = Image.open(io.BytesIO(image_data)) | |
# Process the image | |
inputs = processor( text=prompt,images=image, return_tensors="pt") | |
# Generate output using the model | |
with torch.no_grad(): | |
output = peft_model.generate(**inputs) | |
# Decode the output to text | |
generated_text = processor.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
# Step 6: Create the Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label="Image (Base64)", placeholder="Enter base64 encoded image here...", lines=10), # Base64 input for image | |
gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") # Prompt input | |
], | |
outputs="text", # Text output | |
title="Image and Prompt to Text Model", | |
description="Enter a base64 encoded image and a prompt to generate a descriptive text." | |
) | |
# Step 7: Launch the Gradio app | |
interface.launch() |