fatima02 commited on
Commit
3ce1322
·
verified ·
1 Parent(s): 8b336a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -55
app.py CHANGED
@@ -1,64 +1,259 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
8
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from tensorflow.keras.applications import ResNet50
5
+ from tensorflow.keras.applications.resnet50 import preprocess_input
6
+ from tensorflow.keras.preprocessing import image
7
+ from skimage.metrics import structural_similarity as ssim
8
+ import os
9
+ import tempfile
10
+ from PIL import Image
11
 
12
+ class ImageCharacterClassifier:
13
+ def __init__(self, similarity_threshold=0.5):
14
+ # Initialize ResNet50 model without top classification layer
15
+ self.model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
16
+ self.similarity_threshold = similarity_threshold
17
 
18
+ def load_and_preprocess_image(self, image_path, target_size=(224, 224)):
19
+ # Load and preprocess image for ResNet50
20
+ img = image.load_img(image_path, target_size=target_size)
21
+ img_array = image.img_to_array(img)
22
+ img_array = np.expand_dims(img_array, axis=0)
23
+ img_array = preprocess_input(img_array)
24
+ return img_array
25
 
26
+ def extract_features(self, image_path):
27
+ # Extract deep features using ResNet50
28
+ preprocessed_img = self.load_and_preprocess_image(image_path)
29
+ features = self.model.predict(preprocessed_img)
30
+ return features
 
 
 
 
31
 
32
+ def calculate_ssim(self, img1_path, img2_path):
33
+ # Calculate SSIM between two images
34
+ img1 = cv2.imread(img1_path)
35
+ img2 = cv2.imread(img2_path)
36
+
37
+ if img1 is None or img2 is None:
38
+ return 0.0
39
+
40
+ # Convert to grayscale if images are in color
41
+ if len(img1.shape) == 3:
42
+ img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
43
+ if len(img2.shape) == 3:
44
+ img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
45
+
46
+ # Resize images to same dimensions
47
+ img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
48
+
49
+ score = ssim(img1, img2)
50
+ return score
51
 
52
+ def process_images(reference_image, comparison_images, similarity_threshold):
53
+ try:
54
+ if reference_image is None:
55
+ return "Please upload a reference image.", []
56
+
57
+ if not comparison_images:
58
+ return "Please upload comparison images.", []
59
+
60
+ # Create temporary directory for saving uploaded files
61
+ with tempfile.TemporaryDirectory() as temp_dir:
62
+ # Initialize classifier with the threshold
63
+ classifier = ImageCharacterClassifier(similarity_threshold=similarity_threshold)
64
+
65
+ # Save reference image
66
+ ref_path = os.path.join(temp_dir, "reference.jpg")
67
+ cv2.imwrite(ref_path, cv2.cvtColor(reference_image, cv2.COLOR_RGB2BGR))
68
+
69
+ results = []
70
+ html_output = """
71
+ <div style='text-align: center; margin-bottom: 20px;'>
72
+ <h2 style='color: #2c3e50;'>Results</h2>
73
+ <p style='color: #7f8c8d;'>Reference image compared with uploaded images</p>
74
+ </div>
75
+ """
76
+
77
+ # Extract reference features once
78
+ ref_features = classifier.extract_features(ref_path)
79
+
80
+ # Process each comparison image
81
+ for i, comp_image in enumerate(comparison_images):
82
+ try:
83
+ # Save comparison image
84
+ comp_path = os.path.join(temp_dir, f"comparison_{i}.jpg")
85
+
86
+ try:
87
+ # First attempt: Try using PIL
88
+ with Image.open(comp_image.name) as img:
89
+ img = img.convert('RGB')
90
+ img_array = np.array(img)
91
+ cv2.imwrite(comp_path, cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
92
+ except Exception as e1:
93
+ print(f"PIL failed: {str(e1)}")
94
+ # Second attempt: Try using OpenCV directly
95
+ img = cv2.imread(comp_image.name)
96
+ if img is not None:
97
+ cv2.imwrite(comp_path, img)
98
+ else:
99
+ raise ValueError(f"Could not read image: {comp_image.name}")
100
+
101
+ # Calculate SSIM for structural similarity
102
+ ssim_score = classifier.calculate_ssim(ref_path, comp_path)
103
+
104
+ # Extract features for physical feature comparison
105
+ comp_features = classifier.extract_features(comp_path)
106
+
107
+ # Calculate feature differences
108
+ feature_diff = np.abs(ref_features - comp_features)
109
+ max_feature_diff = np.max(feature_diff)
110
+
111
+ # Calculate cosine similarity for overall similarity
112
+ feature_similarity = np.dot(ref_features.flatten(),
113
+ comp_features.flatten()) / (
114
+ np.linalg.norm(ref_features) * np.linalg.norm(comp_features))
115
+
116
+ # Stricter similarity criteria
117
+ is_similar = False
118
+ reason = ""
119
+
120
+ if max_feature_diff > 0.5: # Threshold for major feature differences
121
+ is_similar = False
122
+ reason = "Physical features missing or different"
123
+ elif ssim_score < 0.3: # Minimum structural similarity required
124
+ is_similar = False
125
+ reason = "Significant structural differences"
126
+ elif feature_similarity < similarity_threshold:
127
+ is_similar = False
128
+ reason = "Overall similarity too low"
129
+ else:
130
+ is_similar = True
131
+ reason = "Images are similar"
132
+
133
+ # Create HTML output with improved styling and reason
134
+ status_color = "#27ae60" if is_similar else "#c0392b" # Green or Red
135
+ status_text = "SIMILAR" if is_similar else "NOT SIMILAR"
136
+ status_icon = "✓" if is_similar else "✗"
137
+
138
+ html_output += f"""
139
+ <div style='
140
+ margin: 15px 0;
141
+ padding: 15px;
142
+ border-radius: 8px;
143
+ background-color: {status_color}1a;
144
+ border: 2px solid {status_color};
145
+ display: flex;
146
+ align-items: center;
147
+ justify-content: space-between;
148
+ '>
149
+ <div style='display: flex; align-items: center;'>
150
+ <span style='
151
+ font-size: 24px;
152
+ margin-right: 10px;
153
+ color: {status_color};
154
+ '>{status_icon}</span>
155
+ <div>
156
+ <span style='color: #2c3e50; font-weight: bold; display: block;'>
157
+ {os.path.basename(comp_image.name)}
158
+ </span>
159
+ <span style='color: {status_color}; font-size: 12px;'>
160
+ {reason}
161
+ </span>
162
+ </div>
163
+ </div>
164
+ <div style='
165
+ color: {status_color};
166
+ font-weight: bold;
167
+ font-size: 16px;
168
+ '>{status_text}</div>
169
+ </div>
170
+ """
171
+
172
+ # Read the processed image back for display
173
+ display_img = cv2.imread(comp_path)
174
+ if display_img is not None:
175
+ display_img = cv2.cvtColor(display_img, cv2.COLOR_BGR2RGB)
176
+ results.append(display_img)
177
+
178
+ except Exception as e:
179
+ print(f"Error processing {comp_image.name}: {str(e)}")
180
+ html_output += f"""
181
+ <div style='
182
+ margin: 15px 0;
183
+ padding: 15px;
184
+ border-radius: 8px;
185
+ background-color: #e74c3c1a;
186
+ border: 2px solid #e74c3c;
187
+ '>
188
+ <h3 style='color: #e74c3c; margin: 0;'>
189
+ Error processing: {os.path.basename(comp_image.name)}
190
+ </h3>
191
+ <p style='color: #e74c3c; margin: 5px 0 0 0;'>{str(e)}</p>
192
+ </div>
193
+ """
194
+
195
+ return html_output, results
196
+
197
+ except Exception as e:
198
+ print(f"Main error: {str(e)}")
199
+ return f"""
200
+ <div style='
201
+ padding: 15px;
202
+ border-radius: 8px;
203
+ background-color: #e74c3c1a;
204
+ border: 2px solid #e74c3c;
205
+ '>
206
+ <h3 style='color: #e74c3c; margin: 0;'>Error</h3>
207
+ <p style='color: #e74c3c; margin: 5px 0 0 0;'>{str(e)}</p>
208
+ </div>
209
+ """, []
210
 
211
+ # Update the interface creation
212
+ def create_interface():
213
+ with gr.Blocks() as interface:
214
+ gr.Markdown("# Image Similarity Classifier")
215
+ gr.Markdown("Upload a reference image and up to 10 comparison images to check similarity.")
216
+
217
+ with gr.Row():
218
+ with gr.Column():
219
+ reference_input = gr.Image(
220
+ label="Reference Image",
221
+ type="numpy",
222
+ image_mode="RGB"
223
+ )
224
+ comparison_input = gr.File(
225
+ label="Comparison Images (Upload up to 10)",
226
+ file_count="multiple",
227
+ file_types=["image"],
228
+ maximum=10
229
+ )
230
+ threshold_slider = gr.Slider(
231
+ minimum=0.0,
232
+ maximum=1.0,
233
+ value=0.5,
234
+ step=0.05,
235
+ label="Similarity Threshold"
236
+ )
237
+ submit_button = gr.Button("Compare Images", variant="primary")
238
+
239
+ with gr.Column():
240
+ output_html = gr.HTML(label="Results")
241
+ output_gallery = gr.Gallery(
242
+ label="Processed Images",
243
+ columns=5,
244
+ show_label=True,
245
+ height="auto"
246
+ )
247
+
248
+ submit_button.click(
249
+ fn=process_images,
250
+ inputs=[reference_input, comparison_input, threshold_slider],
251
+ outputs=[output_html, output_gallery]
252
+ )
253
+
254
+ return interface
255
 
256
+ # Launch the app
257
  if __name__ == "__main__":
258
+ interface = create_interface()
259
+ interface.launch(share=True)