jatingocodeo commited on
Commit
d98b4df
·
verified ·
1 Parent(s): 43fe11a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -51
app.py CHANGED
@@ -3,44 +3,9 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel, PeftConfig
5
  from PIL import Image
6
- import requests
7
- from io import BytesIO
8
  import torchvision.datasets as datasets
9
- import numpy as np
10
 
11
- # Load SigLIP for image embeddings
12
- from model.siglip import SigLIPModel
13
-
14
- def get_cifar_examples():
15
- # Load CIFAR10 test set
16
- cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)
17
-
18
- # CIFAR10 classes
19
- classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
20
- 'dog', 'frog', 'horse', 'ship', 'truck']
21
-
22
- # Get one example from each class
23
- examples = []
24
- used_classes = set()
25
-
26
- for idx in range(len(cifar10_test)):
27
- img, label = cifar10_test[idx]
28
- if classes[label] not in used_classes:
29
- # Save the image temporarily
30
- img_path = f"examples/{classes[label]}_example.jpg"
31
- img.save(img_path)
32
- examples.append(img_path)
33
- used_classes.add(classes[label])
34
-
35
- if len(used_classes) == 10: # We have one example from each class
36
- break
37
-
38
- return examples
39
-
40
- def load_models():
41
- # Load SigLIP model
42
- siglip = SigLIPModel()
43
-
44
  # Load base Phi model
45
  base_model = AutoModelForCausalLM.from_pretrained(
46
  "microsoft/Phi-3-mini-4k-instruct",
@@ -58,19 +23,16 @@ def load_models():
58
 
59
  tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm")
60
 
61
- return siglip, model, tokenizer
62
 
63
- def generate_description(image, siglip, model, tokenizer):
64
  # Convert image to RGB if needed
65
  if image.mode != "RGB":
66
  image = image.convert("RGB")
67
 
68
- # Resize image to match SigLIP's expected size
69
  image = image.resize((32, 32))
70
 
71
- # Get image embedding from SigLIP
72
- image_embedding = siglip.encode_image(image)
73
-
74
  # Prepare prompt
75
  prompt = """Below is an image. Please describe it in detail.
76
 
@@ -88,9 +50,8 @@ Description: """
88
 
89
  # Generate description
90
  with torch.no_grad():
91
- outputs = model(
92
  **inputs,
93
- image_embeddings=image_embedding.unsqueeze(0),
94
  max_new_tokens=100,
95
  temperature=0.7,
96
  do_sample=True,
@@ -101,16 +62,37 @@ Description: """
101
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
102
  return generated_text.split("Description: ")[-1].strip()
103
 
104
- # Load models
105
- print("Loading models...")
106
- siglip, model, tokenizer = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # Create Gradio interface
109
  def process_image(image):
110
- description = generate_description(image, siglip, model, tokenizer)
111
- return description
112
 
113
- # Get CIFAR10 examples
114
  examples = get_cifar_examples()
115
 
116
  # Define interface
@@ -121,7 +103,7 @@ iface = gr.Interface(
121
  title="Image Description Generator",
122
  description="""Upload an image and get a detailed description generated by our fine-tuned VLM model.
123
  Below are sample images from CIFAR10 dataset that you can try.""",
124
- examples=[[ex] for ex in examples] # Format examples for Gradio
125
  )
126
 
127
  # Launch the interface
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import PeftModel, PeftConfig
5
  from PIL import Image
 
 
6
  import torchvision.datasets as datasets
 
7
 
8
+ def load_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Load base Phi model
10
  base_model = AutoModelForCausalLM.from_pretrained(
11
  "microsoft/Phi-3-mini-4k-instruct",
 
23
 
24
  tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm")
25
 
26
+ return model, tokenizer
27
 
28
+ def generate_description(image, model, tokenizer):
29
  # Convert image to RGB if needed
30
  if image.mode != "RGB":
31
  image = image.convert("RGB")
32
 
33
+ # Resize image to match training size
34
  image = image.resize((32, 32))
35
 
 
 
 
36
  # Prepare prompt
37
  prompt = """Below is an image. Please describe it in detail.
38
 
 
50
 
51
  # Generate description
52
  with torch.no_grad():
53
+ outputs = model.generate(
54
  **inputs,
 
55
  max_new_tokens=100,
56
  temperature=0.7,
57
  do_sample=True,
 
62
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
  return generated_text.split("Description: ")[-1].strip()
64
 
65
+ # Load model
66
+ print("Loading model...")
67
+ model, tokenizer = load_model()
68
+
69
+ # Get CIFAR10 examples
70
+ def get_cifar_examples():
71
+ cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)
72
+ classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
73
+ 'dog', 'frog', 'horse', 'ship', 'truck']
74
+
75
+ examples = []
76
+ used_classes = set()
77
+
78
+ for idx in range(len(cifar10_test)):
79
+ img, label = cifar10_test[idx]
80
+ if classes[label] not in used_classes:
81
+ img_path = f"examples/{classes[label]}_example.jpg"
82
+ img.save(img_path)
83
+ examples.append(img_path)
84
+ used_classes.add(classes[label])
85
+
86
+ if len(used_classes) == 10:
87
+ break
88
+
89
+ return examples
90
 
91
  # Create Gradio interface
92
  def process_image(image):
93
+ return generate_description(image, model, tokenizer)
 
94
 
95
+ # Get examples
96
  examples = get_cifar_examples()
97
 
98
  # Define interface
 
103
  title="Image Description Generator",
104
  description="""Upload an image and get a detailed description generated by our fine-tuned VLM model.
105
  Below are sample images from CIFAR10 dataset that you can try.""",
106
+ examples=[[ex] for ex in examples]
107
  )
108
 
109
  # Launch the interface