raccoote commited on
Commit
c17b3c4
·
verified ·
1 Parent(s): 8a0d2ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -1,24 +1,26 @@
1
  import gradio as gr
2
- from transformers import AutoModel, AutoTokenizer
 
 
3
 
4
- # Load the model and tokenizer
5
- model_name = "raccoote/angry-birds-v2" # Replace with the correct model name
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModel.from_pretrained(model_name)
 
 
 
8
 
9
  def generate_text(prompt):
10
  inputs = tokenizer(prompt, return_tensors="pt")
11
- with torch.no_grad():
12
- outputs = model(**inputs)
13
- # Process the outputs to generate text (this will vary based on your model)
14
- # Here we just return the hidden states shape as a placeholder
15
- return outputs.last_hidden_state.shape
16
 
17
  # Create the Gradio interface
18
  iface = gr.Interface(fn=generate_text,
19
  inputs="text",
20
  outputs="text",
21
- title="LLaMA 3.1 Model with LoRA Adapters",
22
  description="Enter a prompt and get the model's output.")
23
 
24
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ from peft import PeftModel
5
 
6
+ # Load the base model and tokenizer directly from the raccoote repository
7
+ model_name = "raccoote/angry-birds-v2"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ base_model = AutoModelForCausalLM.from_pretrained(model_name)
10
+
11
+ # Load the LoRA adapter from the raccoote repository
12
+ adapter_model = PeftModel.from_pretrained(base_model, model_name)
13
 
14
  def generate_text(prompt):
15
  inputs = tokenizer(prompt, return_tensors="pt")
16
+ outputs = adapter_model.generate(**inputs, max_new_tokens=50)
17
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
18
 
19
  # Create the Gradio interface
20
  iface = gr.Interface(fn=generate_text,
21
  inputs="text",
22
  outputs="text",
23
+ title="LLaMA 3.1 with LoRA Adapters",
24
  description="Enter a prompt and get the model's output.")
25
 
26
  iface.launch()