lab2-2024 / app.py
Grandediw
Update Interface
50cc010
raw
history blame
1.73 kB
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
from safetensors.torch import load_file
# Load the Hugging Face API token from environment variable
token = os.getenv("HUGGINGFACE_API_TOKEN")
if not token:
raise ValueError("HUGGINGFACE_API_TOKEN is not set. Please add it in the Secrets section of your Space.")
# Configure device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the tokenizer and model using the token
model_repo = "Grandediw/lora_model"
tokenizer = AutoTokenizer.from_pretrained(model_repo, token=token)
base_model = AutoModel.from_pretrained(model_repo, token=token)
# Load the LoRA adapter weights
lora_weights_path = "adapter_model.safetensors"
lora_weights = load_file(lora_weights_path)
# Apply LoRA weights to the base model
for name, param in base_model.named_parameters():
if name in lora_weights:
param.data += lora_weights[name].to(device, dtype=param.dtype)
# Move the model to the device
base_model = base_model.to(device)
# Define the inference function
def infer(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = base_model(**inputs)
# Placeholder return, modify based on your specific model task
return outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy()
# Gradio interface Update
with gr.Blocks() as demo:
gr.Markdown("## LoRA Model Inference")
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
generate_button = gr.Button("Generate")
output = gr.Textbox(label="Output")
generate_button.click(fn=infer, inputs=[prompt], outputs=[output])
if __name__ == "__main__":
demo.launch()