Spaces:
Sleeping
Sleeping
File size: 1,991 Bytes
3605bde 036580a 10271df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import base64
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType
from huggingface_hub import login
# Step 1: Log in to Hugging Face
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# Step 2: Load the private model and processor
model_name = "anushettypsl/paligemma_vqav2" # Replace with the actual model link
processor = AutoProcessor.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(model_name)
# Step 3: Set up PEFT configuration (if needed)
lora_config = LoraConfig(
r=16, # Rank
lora_alpha=32, # Scaling factor
lora_dropout=0.1, # Dropout
task_type=TaskType.CAUSAL_LM, # Adjust according to your model's task
)
# Step 4: Get the PEFT model
peft_model = get_peft_model(base_model, lora_config)
# Step 5: Define the prediction function
def predict(image_base64, prompt):
# Decode the base64 image
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
# Process the image
inputs = processor( text=prompt,images=image, return_tensors="pt")
# Generate output using the model
with torch.no_grad():
output = peft_model.generate(**inputs)
# Decode the output to text
generated_text = processor.decode(output[0], skip_special_tokens=True)
return generated_text
# Step 6: Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Image (Base64)", placeholder="Enter base64 encoded image here...", lines=10), # Base64 input for image
gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") # Prompt input
],
outputs="text", # Text output
title="Image and Prompt to Text Model",
description="Enter a base64 encoded image and a prompt to generate a descriptive text."
)
# Step 7: Launch the Gradio app
interface.launch() |