sreyanghosh commited on
Commit
4705650
·
1 Parent(s): d5a79fc

update/changed model loader

Browse files
Files changed (2) hide show
  1. app.py +14 -8
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,19 +1,25 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from peft import PeftModel
4
  import torch
5
 
6
  # Load the model and tokenizer
7
  def load_model():
8
- base_model_name = "unsloth/llama-3.2-1b-instruct-bnb-4bit" # Replace with your base model name
9
  lora_model_name = "sreyanghosh/lora_model" # Replace with your LoRA model path
10
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
11
- model = AutoModelForCausalLM.from_pretrained(
12
- base_model_name,
13
- device_map="auto" if torch.cuda.is_available() else None,
14
- load_in_8bit=not torch.cuda.is_available(),
 
 
 
 
 
 
15
  )
16
- model = PeftModel.from_pretrained(model, lora_model_name)
17
  model.eval()
18
  return tokenizer, model
19
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import AutoPeftModelForCausalLM
4
  import torch
5
 
6
  # Load the model and tokenizer
7
  def load_model():
8
+ # base_model_name = "unsloth/llama-3.2-1b-instruct-bnb-4bit" # Replace with your base model name
9
  lora_model_name = "sreyanghosh/lora_model" # Replace with your LoRA model path
10
+ # tokenizer = AutoTokenizer.from_pretrained(base_model_name)
11
+ # model = AutoModelForCausalLM.from_pretrained(
12
+ # base_model_name,
13
+ # device_map="auto" if torch.cuda.is_available() else None,
14
+ # load_in_8bit=not torch.cuda.is_available(),
15
+ # )
16
+ # model = PeftModel.from_pretrained(model, lora_model_name)
17
+
18
+ model = AutoPeftModelForCausalLM.from_pretrained(
19
+ lora_model_name, # YOUR MODEL YOU USED FOR TRAINING
20
+ load_in_4bit = True, # False
21
  )
22
+ tokenizer = AutoTokenizer.from_pretrained(lora_model_name)
23
  model.eval()
24
  return tokenizer, model
25
 
requirements.txt CHANGED
@@ -3,4 +3,4 @@ gradio
3
  transformers
4
  peft
5
  torch
6
- bitsandbytes==0.41.1
 
3
  transformers
4
  peft
5
  torch
6
+ bitsandbytes