darpanaswal commited on
Commit
889c466
·
verified ·
1 Parent(s): 49270db

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +152 -0
  2. model.pth +3 -0
  3. requirements.txt +58 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image, ImageEnhance
7
+ from torchvision import models
8
+ import torch.nn as nn
9
+ import matplotlib.pyplot as plt
10
+ import torch.nn.functional as F
11
+ import ssl
12
+ import certifi
13
+ import os
14
+
15
+ ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
16
+
17
+ # Set device
18
+ device = "cpu"
19
+
20
+ # Number of classes
21
+ num_classes = 6
22
+
23
+ # Load the pre-trained ResNet model
24
+ model = models.resnet152(pretrained=True)
25
+ for param in model.parameters():
26
+ param.requires_grad = False # Freeze feature extractor
27
+
28
+ # Modify the classifier for 6 classes with an additional hidden layer
29
+ model.fc = nn.Sequential(
30
+ nn.Linear(model.fc.in_features, 512),
31
+ nn.ReLU(),
32
+ nn.Linear(512, num_classes)
33
+ )
34
+
35
+ # Load trained weights
36
+ model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
37
+ model.eval()
38
+
39
+ # Class labels
40
+ class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse']
41
+
42
+ # Image transformation function
43
+ def transform_image(image):
44
+ """Preprocess the input image."""
45
+ transform = transforms.Compose([
46
+ transforms.Resize((32, 32)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
49
+ ])
50
+ img_tensor = transform(image).unsqueeze(0).to(device)
51
+ return img_tensor
52
+
53
+ # Apply feature filters
54
+ def apply_filters(image, brightness, contrast, hue):
55
+ """Adjust Brightness, Contrast, and Hue of the input image."""
56
+ image = image.convert("RGB") # Ensure RGB mode
57
+
58
+ # Adjust brightness
59
+ enhancer = ImageEnhance.Brightness(image)
60
+ image = enhancer.enhance(brightness)
61
+
62
+ # Adjust contrast
63
+ enhancer = ImageEnhance.Contrast(image)
64
+ image = enhancer.enhance(contrast)
65
+
66
+ # Adjust hue (convert to HSV, modify, and convert back)
67
+ image = np.array(image)
68
+ hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
69
+ hsv_image[..., 0] = (hsv_image[..., 0] + hue * 180) % 180 # Adjust hue
70
+ image = cv2.cvtColor(hsv_image.astype(np.uint8), cv2.COLOR_HSV2RGB)
71
+
72
+ return Image.fromarray(image)
73
+
74
+ # Superimposition function
75
+ def superimpose_images(base_image, overlay_image, alpha):
76
+ """Superimpose overlay_image onto base_image with a given alpha blend."""
77
+ if overlay_image is None:
78
+ return base_image # No overlay, return base image as is
79
+
80
+ # Resize overlay image to match base image
81
+ overlay_image = overlay_image.resize(base_image.size)
82
+
83
+ # Convert to numpy arrays
84
+ base_array = np.array(base_image).astype(float)
85
+ overlay_array = np.array(overlay_image).astype(float)
86
+
87
+ # Blend images
88
+ blended_array = (1 - alpha) * base_array + alpha * overlay_array
89
+ blended_array = np.clip(blended_array, 0, 255).astype(np.uint8)
90
+
91
+ return Image.fromarray(blended_array)
92
+
93
+ # Prediction function
94
+ def predict(image, brightness, contrast, hue, overlay_image, alpha):
95
+ """Apply filters, superimpose, classify image, and visualize results."""
96
+ if image is None:
97
+ return None, None, None
98
+
99
+ # Apply feature filters
100
+ processed_image = apply_filters(image, brightness, contrast, hue)
101
+
102
+ # Superimpose overlay image
103
+ final_image = superimpose_images(processed_image, overlay_image, alpha)
104
+
105
+ # Convert PIL Image to Tensor
106
+ image_tensor = transform_image(final_image)
107
+
108
+ with torch.no_grad():
109
+ output = model(image_tensor)
110
+ probabilities = F.softmax(output, dim=1).cpu().numpy()[0]
111
+
112
+ # Generate Bar Chart
113
+ with plt.xkcd():
114
+ fig, ax = plt.subplots(figsize=(5, 3))
115
+ ax.bar(class_labels, probabilities, color='skyblue')
116
+ ax.set_ylabel("Probability")
117
+ ax.set_title("Class Probabilities")
118
+ ax.set_ylim([0, 1])
119
+ for i, v in enumerate(probabilities):
120
+ ax.text(i, v + 0.02, f"{v:.2f}", ha='center', fontsize=10)
121
+
122
+ return final_image, fig
123
+
124
+ # Gradio Interface
125
+ with gr.Blocks() as interface:
126
+ gr.Markdown("<h2 style='text-align: center;'>Image Classifier with Superimposition & Adjustable Filters</h2>")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ image_input = gr.Image(type="pil", label="Upload Base Image")
131
+ overlay_input = gr.Image(type="pil", label="Upload Overlay Image (Optional)")
132
+ brightness = gr.Slider(0.5, 2.0, value=1.0, label="Brightness")
133
+ contrast = gr.Slider(0.5, 2.0, value=1.0, label="Contrast")
134
+ hue = gr.Slider(-0.5, 0.5, value=0.0, label="Hue")
135
+ alpha = gr.Slider(0.0, 1.0, value=0.5, label="Overlay Weight (Alpha)")
136
+
137
+ with gr.Column():
138
+ processed_image = gr.Image(label="Final Processed Image")
139
+ bar_chart = gr.Plot(label="Class Probabilities")
140
+
141
+ inputs = [image_input, brightness, contrast, hue, overlay_input, alpha]
142
+ outputs = [processed_image, bar_chart]
143
+
144
+ # Event listeners for real-time updates
145
+ image_input.change(predict, inputs=inputs, outputs=outputs)
146
+ overlay_input.change(predict, inputs=inputs, outputs=outputs)
147
+ brightness.change(predict, inputs=inputs, outputs=outputs)
148
+ contrast.change(predict, inputs=inputs, outputs=outputs)
149
+ hue.change(predict, inputs=inputs, outputs=outputs)
150
+ alpha.change(predict, inputs=inputs, outputs=outputs)
151
+
152
+ interface.launch()
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76febe50dff559d23c30d9ef496bbb3fa3c5cbe8e75b7086660efd2b1addb09a
3
+ size 237644690
requirements.txt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.8.0
4
+ certifi==2025.1.31
5
+ charset-normalizer==3.4.1
6
+ click==8.1.8
7
+ fastapi==0.115.8
8
+ ffmpy==0.5.0
9
+ filelock==3.17.0
10
+ fsspec==2025.2.0
11
+ gradio==5.15.0
12
+ gradio_client==1.7.0
13
+ h11==0.14.0
14
+ httpcore==1.0.7
15
+ httpx==0.28.1
16
+ huggingface-hub==0.28.1
17
+ idna==3.10
18
+ Jinja2==3.1.5
19
+ markdown-it-py==3.0.0
20
+ MarkupSafe==2.1.5
21
+ mdurl==0.1.2
22
+ mpmath==1.3.0
23
+ networkx==3.4.2
24
+ numpy==2.2.2
25
+ opencv-python==4.11.0.86
26
+ orjson==3.10.15
27
+ packaging==24.2
28
+ pandas==2.2.3
29
+ pillow==11.1.0
30
+ pydantic==2.10.6
31
+ pydantic_core==2.27.2
32
+ pydub==0.25.1
33
+ Pygments==2.19.1
34
+ python-dateutil==2.9.0.post0
35
+ python-multipart==0.0.20
36
+ pytz==2025.1
37
+ PyYAML==6.0.2
38
+ requests==2.32.3
39
+ rich==13.9.4
40
+ ruff==0.9.5
41
+ safehttpx==0.1.6
42
+ semantic-version==2.10.0
43
+ setuptools==75.8.0
44
+ shellingham==1.5.4
45
+ six==1.17.0
46
+ sniffio==1.3.1
47
+ starlette==0.45.3
48
+ sympy==1.13.1
49
+ tomlkit==0.13.2
50
+ torch==2.6.0
51
+ torchvision==0.21.0
52
+ tqdm==4.67.1
53
+ typer==0.15.1
54
+ typing_extensions==4.12.2
55
+ tzdata==2025.1
56
+ urllib3==2.3.0
57
+ uvicorn==0.34.0
58
+ websockets==14.2