anmoldograpsl commited on
Commit
b84c0d1
·
verified ·
1 Parent(s): f89d22f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -29
app.py CHANGED
@@ -1,42 +1,59 @@
1
- from huggingface_hub import login
2
  import os
 
 
 
 
 
3
  from peft import PeftModel, PeftConfig
4
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
5
  from PIL import Image
6
- import requests
7
- import torch
8
- import io
9
- import base64
10
- import cv2
11
-
12
- access_token = os.environ["HF_TOKEN"]
13
  login(token=access_token)
14
-
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  dtype = torch.bfloat16
17
-
 
18
  config = PeftConfig.from_pretrained("anushettypsl/paligemma_vqav2")
19
- # base_model = AutoModelForCausalLM.from_pretrained("google/paligemma-3b-pt-448")
20
  base_model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-448")
21
  model = PeftModel.from_pretrained(base_model, "anushettypsl/paligemma_vqav2", device_map=device)
22
  processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-448", device_map=device)
 
23
  model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- image = cv2.imread('/content/15_BC_G2_6358_40x_2_jpg.rf.97595fa4965f66ad45be8fd055331933.jpg')
26
-
27
- # Convert the image to base64 encoding
28
- image_bytes = cv2.imencode('.jpg', image)[1]
29
- base64_string = base64.b64encode(image_bytes).decode('utf-8')
30
-
31
- input_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
32
-
33
- model_inputs = processor(
34
- text=input_text, images=input_image, return_tensors="pt").to(device)
35
- input_len = model_inputs["input_ids"].shape[-1]
36
- model.to(device)
37
- with torch.inference_mode():
38
- generation = model.generate(
39
- **model_inputs, max_new_tokens=100, do_sample=False)
40
- generation = generation[0][input_len:]
41
- decoded = processor.decode(generation, skip_special_tokens=True)
42
- print(decoded)
 
 
1
  import os
2
+ import base64
3
+ import io
4
+ import cv2
5
+ import torch
6
+ import gradio as gr
7
  from peft import PeftModel, PeftConfig
8
  from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
9
  from PIL import Image
10
+ from huggingface_hub import login
11
+
12
+ # Step 1: Log in to Hugging Face
13
+ access_token = os.environ["HF_TOKEN"] # Ensure your Hugging Face token is stored in an environment variable
 
 
 
14
  login(token=access_token)
15
+
16
+ # Step 2: Setup device and load model
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  dtype = torch.bfloat16
19
+
20
+ # Load configuration and model
21
  config = PeftConfig.from_pretrained("anushettypsl/paligemma_vqav2")
 
22
  base_model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-448")
23
  model = PeftModel.from_pretrained(base_model, "anushettypsl/paligemma_vqav2", device_map=device)
24
  processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-448", device_map=device)
25
+
26
  model.to(device)
27
+
28
+ # Step 3: Define prediction function
29
+ def predict(input_image, input_text):
30
+ # Convert the uploaded image to RGB format
31
+ input_image = input_image.convert('RGB')
32
+
33
+ # Prepare the model inputs
34
+ model_inputs = processor(text=input_text, images=input_image, return_tensors="pt").to(device)
35
+
36
+ # Perform inference
37
+ with torch.inference_mode():
38
+ generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
39
+
40
+ # Decode the output
41
+ decoded_output = processor.decode(generation[0], skip_special_tokens=True)
42
+ return decoded_output
43
+
44
+ # Step 4: Create the Gradio interface
45
+ interface = gr.Interface(
46
+ fn=predict,
47
+ inputs=[
48
+ gr.Image(type="pil", label="Upload Image"), # Image input
49
+ gr.Textbox(label="Input Prompt", placeholder="Enter your prompt here...") # Text input
50
+ ],
51
+ outputs="text", # Text output
52
+ title="Image and Prompt to Text Model",
53
+ description="Upload an image and provide a prompt to generate a descriptive text."
54
+ )
55
+
56
+ # Step 5: Launch the Gradio app
57
+ interface.launch()
58
+ has context menu
59