far555na commited on
Commit
ce2a0c5
·
verified ·
1 Parent(s): 0ce21a0

Update ig.py

Browse files
Files changed (1) hide show
  1. ig.py +14 -3
ig.py CHANGED
@@ -1,10 +1,21 @@
1
  from transformers import AutoProcessor, AutoModelForCausalLM
 
2
  import gradio as gr
3
  import torch
4
 
5
- # Load the processor and model
6
  processor = AutoProcessor.from_pretrained("microsoft/git-base")
7
- model = AutoModelForCausalLM.from_pretrained("./")
 
 
 
 
 
 
 
 
 
 
8
 
9
  def predict(image):
10
  try:
@@ -28,7 +39,7 @@ def predict(image):
28
  print("Error during prediction:", str(e))
29
  return "Error: " + str(e)
30
 
31
- # https://www.gradio.app/guides
32
  with gr.Blocks() as demo:
33
  image = gr.Image(type="pil")
34
  predict_btn = gr.Button("Predict", variant="primary")
 
1
  from transformers import AutoProcessor, AutoModelForCausalLM
2
+ from peft import PeftModel, PeftConfig
3
  import gradio as gr
4
  import torch
5
 
6
+ # Load the processor
7
  processor = AutoProcessor.from_pretrained("microsoft/git-base")
8
+
9
+ # Load the base model (the pre-trained model you're adapting with LoRA)
10
+ base_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
11
+
12
+ # Load the adapter configuration
13
+ adapter_config_path = "./" # Path to your adapter_config.json
14
+ adapter_model_path = "./" # Path to your adapter_model.safetensors
15
+
16
+ # Load the LoRA adapter using Peft
17
+ peft_config = PeftConfig.from_pretrained(adapter_config_path)
18
+ model = PeftModel.from_pretrained(base_model, adapter_model_path, config=peft_config)
19
 
20
  def predict(image):
21
  try:
 
39
  print("Error during prediction:", str(e))
40
  return "Error: " + str(e)
41
 
42
+ # Gradio Interface
43
  with gr.Blocks() as demo:
44
  image = gr.Image(type="pil")
45
  predict_btn = gr.Button("Predict", variant="primary")