anmoldograpsl commited on
Commit
10271df
·
verified ·
1 Parent(s): ca0ce8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -40
app.py CHANGED
@@ -1,42 +1,60 @@
1
- from huggingface_hub import login
2
  import os
3
- from peft import PeftModel, PeftConfig
4
- from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
5
- from PIL import Image
6
- import requests
7
- import torch
8
- import io
9
  import base64
10
- import cv2
11
-
12
- access_token = os.environ["HF_TOKEN"]
13
- login(token=access_token)
14
-
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- dtype = torch.bfloat16
17
-
18
- config = PeftConfig.from_pretrained("anushettypsl/paligemma_vqav2")
19
- # base_model = AutoModelForCausalLM.from_pretrained("google/paligemma-3b-pt-448")
20
- base_model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-448")
21
- model = PeftModel.from_pretrained(base_model, "anushettypsl/paligemma_vqav2", device_map=device)
22
- processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-448", device_map=device)
23
- model.to(device)
24
-
25
- image = cv2.imread('/content/15_BC_G2_6358_40x_2_jpg.rf.97595fa4965f66ad45be8fd055331933.jpg')
26
-
27
- # Convert the image to base64 encoding
28
- image_bytes = cv2.imencode('.jpg', image)[1]
29
- base64_string = base64.b64encode(image_bytes).decode('utf-8')
30
-
31
- input_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
32
-
33
- model_inputs = processor(
34
- text=input_text, images=input_image, return_tensors="pt").to(device)
35
- input_len = model_inputs["input_ids"].shape[-1]
36
- model.to(device)
37
- with torch.inference_mode():
38
- generation = model.generate(
39
- **model_inputs, max_new_tokens=100, do_sample=False)
40
- generation = generation[0][input_len:]
41
- decoded = processor.decode(generation, skip_special_tokens=True)
42
- print(decoded)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
 
2
  import base64
3
+ import torch
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ from peft import get_peft_model, LoraConfig, TaskType
8
+ from huggingface_hub import login
9
+
10
+ # Step 1: Log in to Hugging Face
11
+ hf_token = os.getenv("HF_TOKEN")
12
+ login(token=hf_token)
13
+
14
+ # Step 2: Load the private model and processor
15
+ model_name = "anushettypsl/paligemma_vqav2" # Replace with the actual model link
16
+ processor = AutoProcessor.from_pretrained(model_name)
17
+ base_model = AutoModelForCausalLM.from_pretrained(model_name)
18
+
19
+ # Step 3: Set up PEFT configuration (if needed)
20
+ lora_config = LoraConfig(
21
+ r=16, # Rank
22
+ lora_alpha=32, # Scaling factor
23
+ lora_dropout=0.1, # Dropout
24
+ task_type=TaskType.CAUSAL_LM, # Adjust according to your model's task
25
+ )
26
+
27
+ # Step 4: Get the PEFT model
28
+ peft_model = get_peft_model(base_model, lora_config)
29
+
30
+ # Step 5: Define the prediction function
31
+ def predict(image_base64, prompt):
32
+ # Decode the base64 image
33
+ image_data = base64.b64decode(image_base64)
34
+ image = Image.open(io.BytesIO(image_data))
35
+
36
+ # Process the image
37
+ inputs = processor( text=prompt,images=image, return_tensors="pt")
38
+
39
+ # Generate output using the model
40
+ with torch.no_grad():
41
+ output = peft_model.generate(**inputs)
42
+
43
+ # Decode the output to text
44
+ generated_text = processor.decode(output[0], skip_special_tokens=True)
45
+ return generated_text
46
+
47
+ # Step 6: Create the Gradio interface
48
+ interface = gr.Interface(
49
+ fn=predict,
50
+ inputs=[
51
+ gr.Textbox(label="Image (Base64)", placeholder="Enter base64 encoded image here...", lines=10), # Base64 input for image
52
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") # Prompt input
53
+ ],
54
+ outputs="text", # Text output
55
+ title="Image and Prompt to Text Model",
56
+ description="Enter a base64 encoded image and a prompt to generate a descriptive text."
57
+ )
58
+
59
+ # Step 7: Launch the Gradio app
60
+ interface.launch()