File size: 1,733 Bytes
40b886d
7cb8b0c
a534797
40b886d
 
7cb8b0c
087c7d2
5e8b7be
 
40b886d
a534797
087c7d2
a534797
7cb8b0c
087c7d2
40b886d
087c7d2
 
7cb8b0c
087c7d2
1322782
40b886d
7cb8b0c
40b886d
 
 
 
7cb8b0c
40b886d
 
7cb8b0c
087c7d2
 
40b886d
 
087c7d2
 
7cb8b0c
50cc010
087c7d2
 
7cb8b0c
087c7d2
40b886d
087c7d2
7cb8b0c
087c7d2
1f2775d
087c7d2
7cb8b0c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
from safetensors.torch import load_file

# Load the Hugging Face API token from environment variable
token = os.getenv("HUGGINGFACE_API_TOKEN")
if not token:
    raise ValueError("HUGGINGFACE_API_TOKEN is not set. Please add it in the Secrets section of your Space.")

# Configure device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the tokenizer and model using the token
model_repo = "Grandediw/lora_model"
tokenizer = AutoTokenizer.from_pretrained(model_repo, token=token)
base_model = AutoModel.from_pretrained(model_repo, token=token)

# Load the LoRA adapter weights
lora_weights_path = "adapter_model.safetensors"  
lora_weights = load_file(lora_weights_path)

# Apply LoRA weights to the base model
for name, param in base_model.named_parameters():
    if name in lora_weights:
        param.data += lora_weights[name].to(device, dtype=param.dtype)

# Move the model to the device
base_model = base_model.to(device)

# Define the inference function
def infer(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = base_model(**inputs)
    # Placeholder return, modify based on your specific model task
    return outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy()

# Gradio interface Update
with gr.Blocks() as demo:
    gr.Markdown("## LoRA Model Inference")

    with gr.Row():
        prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
        generate_button = gr.Button("Generate")

    output = gr.Textbox(label="Output")

    generate_button.click(fn=infer, inputs=[prompt], outputs=[output])

if __name__ == "__main__":
    demo.launch()