jatingocodeo commited on
Commit
ee3c970
·
verified ·
1 Parent(s): ae3435b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -1,16 +1,22 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
 
5
  # Load model and tokenizer
6
  def load_model(model_id):
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_id,
10
  torch_dtype=torch.float16,
11
  device_map="auto",
12
  trust_remote_code=True
13
  )
 
 
 
14
  return model, tokenizer
15
 
16
  def generate_response(instruction, model, tokenizer, max_length=200, temperature=0.7, top_p=0.9):
@@ -39,10 +45,7 @@ def generate_response(instruction, model, tokenizer, max_length=200, temperature
39
  return response_parts[1].strip()
40
  return response.strip()
41
 
42
- def create_demo():
43
- # Use your uploaded model
44
- model_id = "jatingocodeo/phi2-finetuned-openassistant"
45
-
46
  # Load model and tokenizer
47
  model, tokenizer = load_model(model_id)
48
 
@@ -106,5 +109,7 @@ def create_demo():
106
  return demo
107
 
108
  if __name__ == "__main__":
109
- demo = create_demo()
 
 
110
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
 
6
  # Load model and tokenizer
7
  def load_model(model_id):
8
+ # First load the base model
9
+ base_model_id = "microsoft/phi-2"
10
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
11
+ base_model = AutoModelForCausalLM.from_pretrained(
12
+ base_model_id,
13
  torch_dtype=torch.float16,
14
  device_map="auto",
15
  trust_remote_code=True
16
  )
17
+
18
+ # Load and merge the LoRA adapter
19
+ model = PeftModel.from_pretrained(base_model, model_id)
20
  return model, tokenizer
21
 
22
  def generate_response(instruction, model, tokenizer, max_length=200, temperature=0.7, top_p=0.9):
 
45
  return response_parts[1].strip()
46
  return response.strip()
47
 
48
+ def create_demo(model_id):
 
 
 
49
  # Load model and tokenizer
50
  model, tokenizer = load_model(model_id)
51
 
 
109
  return demo
110
 
111
  if __name__ == "__main__":
112
+ # Replace with your model ID (username/model-name)
113
+ model_id = "your-username/phi2-finetuned-oasst"
114
+ demo = create_demo(model_id)
115
  demo.launch()