File size: 1,733 Bytes
40b886d 7cb8b0c a534797 40b886d 7cb8b0c 087c7d2 5e8b7be 40b886d a534797 087c7d2 a534797 7cb8b0c 087c7d2 40b886d 087c7d2 7cb8b0c 087c7d2 1322782 40b886d 7cb8b0c 40b886d 7cb8b0c 40b886d 7cb8b0c 087c7d2 40b886d 087c7d2 7cb8b0c 50cc010 087c7d2 7cb8b0c 087c7d2 40b886d 087c7d2 7cb8b0c 087c7d2 1f2775d 087c7d2 7cb8b0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
from safetensors.torch import load_file
# Load the Hugging Face API token from environment variable
token = os.getenv("HUGGINGFACE_API_TOKEN")
if not token:
raise ValueError("HUGGINGFACE_API_TOKEN is not set. Please add it in the Secrets section of your Space.")
# Configure device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the tokenizer and model using the token
model_repo = "Grandediw/lora_model"
tokenizer = AutoTokenizer.from_pretrained(model_repo, token=token)
base_model = AutoModel.from_pretrained(model_repo, token=token)
# Load the LoRA adapter weights
lora_weights_path = "adapter_model.safetensors"
lora_weights = load_file(lora_weights_path)
# Apply LoRA weights to the base model
for name, param in base_model.named_parameters():
if name in lora_weights:
param.data += lora_weights[name].to(device, dtype=param.dtype)
# Move the model to the device
base_model = base_model.to(device)
# Define the inference function
def infer(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = base_model(**inputs)
# Placeholder return, modify based on your specific model task
return outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy()
# Gradio interface Update
with gr.Blocks() as demo:
gr.Markdown("## LoRA Model Inference")
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
generate_button = gr.Button("Generate")
output = gr.Textbox(label="Output")
generate_button.click(fn=infer, inputs=[prompt], outputs=[output])
if __name__ == "__main__":
demo.launch()
|