Grandediw commited on
Commit
087c7d2
·
verified ·
1 Parent(s): 40b886d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -39
app.py CHANGED
@@ -4,21 +4,21 @@ import torch
4
  from transformers import AutoTokenizer, AutoModel
5
  from safetensors.torch import load_file
6
 
7
- # Load the Hugging Face API token
8
  token = os.getenv("HUGGINGFACE_API_TOKEN")
9
  if not token:
10
  raise ValueError("HUGGINGFACE_API_TOKEN is not set. Please add it in the Secrets section of your Space.")
11
 
12
- # Configure device and data type
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- # Load the tokenizer and model
16
  model_repo = "Grandediw/lora_model"
17
- tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=True)
18
- base_model = AutoModel.from_pretrained(model_repo, use_auth_token=True)
19
 
20
- # Load LoRA adapter weights
21
- lora_weights_path = "adapter_model.safetensors" # Ensure this file is present in the same directory
22
  lora_weights = load_file(lora_weights_path)
23
 
24
  # Apply LoRA weights to the base model
@@ -29,45 +29,24 @@ for name, param in base_model.named_parameters():
29
  # Move the model to the device
30
  base_model = base_model.to(device)
31
 
32
- # Inference function
33
- def infer(prompt, negative_prompt=None):
34
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
35
  outputs = base_model(**inputs)
36
- return outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy() # Placeholder return
 
37
 
38
- # Gradio Interface
39
- css = """
40
- #interface-container {
41
- margin: 0 auto;
42
- max-width: 700px;
43
- padding: 15px;
44
- border-radius: 10px;
45
- background-color: #f9f9f9;
46
- box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
47
- }
48
- #header {
49
- text-align: center;
50
- font-size: 1.5em;
51
- font-weight: bold;
52
- margin-bottom: 20px;
53
- color: #333;
54
- }
55
- """
56
 
57
- with gr.Blocks(css=css) as demo:
58
- with gr.Box(elem_id="interface-container"):
59
- gr.Markdown("<div id='header'>LoRA Model Inference</div>")
60
-
61
- # Input for prompt and run button
62
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
63
- run_button = gr.Button("Generate Output", variant="primary")
64
 
65
- # Display output
66
- output = gr.Textbox(label="Output")
67
 
68
- # Connect button with inference
69
- run_button.click(fn=infer, inputs=[prompt], outputs=[output])
70
 
71
- # Launch the app
72
  if __name__ == "__main__":
73
  demo.launch()
 
4
  from transformers import AutoTokenizer, AutoModel
5
  from safetensors.torch import load_file
6
 
7
+ # Load the Hugging Face API token from environment variable
8
  token = os.getenv("HUGGINGFACE_API_TOKEN")
9
  if not token:
10
  raise ValueError("HUGGINGFACE_API_TOKEN is not set. Please add it in the Secrets section of your Space.")
11
 
12
+ # Configure device
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
+ # Load the tokenizer and model using the token
16
  model_repo = "Grandediw/lora_model"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, token=token)
18
+ base_model = AutoModel.from_pretrained(model_repo, token=token)
19
 
20
+ # Load the LoRA adapter weights
21
+ lora_weights_path = "adapter_model.safetensors" # Ensure this file exists in the same directory
22
  lora_weights = load_file(lora_weights_path)
23
 
24
  # Apply LoRA weights to the base model
 
29
  # Move the model to the device
30
  base_model = base_model.to(device)
31
 
32
+ # Define the inference function
33
+ def infer(prompt):
34
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
35
  outputs = base_model(**inputs)
36
+ # Placeholder return, modify based on your specific model task
37
+ return outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy()
38
 
39
+ # Gradio interface
40
+ with gr.Blocks() as demo:
41
+ gr.Markdown("## LoRA Model Inference")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ with gr.Row():
 
 
 
 
44
  prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
45
+ generate_button = gr.Button("Generate")
46
 
47
+ output = gr.Textbox(label="Output")
 
48
 
49
+ generate_button.click(fn=infer, inputs=[prompt], outputs=[output])
 
50
 
 
51
  if __name__ == "__main__":
52
  demo.launch()