KeerthiVM commited on
Commit
167ea92
·
1 Parent(s): 3b2ce40
Files changed (3) hide show
  1. 1.jpg +0 -0
  2. 2.jpg +0 -0
  3. test.py +162 -0
1.jpg ADDED
2.jpg ADDED
test.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import matplotlib.pyplot as plt
4
+ import math
5
+ from PIL import Image
6
+ import streamlit as st
7
+ from SkinGPT import SkinGPTClassifier
8
+ import numpy as np
9
+ from torchvision import transforms
10
+ import os
11
+
12
+ class SkinGPTTester:
13
+ def __init__(self, model_path="finetuned_dermnet_version1.pth"):
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ self.classifier = SkinGPTClassifier()
16
+ self.transform = transforms.Compose([
17
+ transforms.Resize((224, 224)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
20
+ ])
21
+
22
+ def visualize_attention(self, image_path):
23
+ """Visualize attention maps from Q-Former"""
24
+ image = Image.open(image_path).convert('RGB')
25
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
26
+
27
+ with torch.no_grad():
28
+ # Get attention maps
29
+ _ = self.classifier.model.encode_image(image_tensor)
30
+ attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
31
+
32
+ # Reshape attention to image size
33
+ h = w = int(math.sqrt(attention.shape[1]))
34
+ attention = attention.reshape(h, w)
35
+
36
+ # Plot
37
+ plt.figure(figsize=(15, 5))
38
+
39
+ # Original image
40
+ plt.subplot(1, 3, 1)
41
+ plt.imshow(image)
42
+ plt.title('Original Image')
43
+ plt.axis('off')
44
+
45
+ # Attention map
46
+ plt.subplot(1, 3, 2)
47
+ plt.imshow(attention, cmap='hot')
48
+ plt.title('Attention Map')
49
+ plt.axis('off')
50
+
51
+ # Overlay
52
+ plt.subplot(1, 3, 3)
53
+ plt.imshow(image)
54
+ plt.imshow(attention, alpha=0.5, cmap='hot')
55
+ plt.title('Attention Overlay')
56
+ plt.axis('off')
57
+
58
+ plt.tight_layout()
59
+ plt.savefig('attention_visualization.png')
60
+ plt.close()
61
+
62
+ def check_feature_similarity(self, image_path1, image_path2):
63
+ """Compare embeddings of two images"""
64
+ image1 = Image.open(image_path1).convert('RGB')
65
+ image2 = Image.open(image_path2).convert('RGB')
66
+
67
+ with torch.no_grad():
68
+ # Get embeddings
69
+ emb1 = self.classifier.model.encode_image(
70
+ self.transform(image1).unsqueeze(0).to(self.device)
71
+ )
72
+ emb2 = self.classifier.model.encode_image(
73
+ self.transform(image2).unsqueeze(0).to(self.device)
74
+ )
75
+
76
+ # Calculate cosine similarity
77
+ similarity = F.cosine_similarity(emb1.mean(dim=1), emb2.mean(dim=1))
78
+
79
+ # Print statistics
80
+ print(f"\nFeature Similarity Analysis:")
81
+ print(f"Image 1: {image_path1}")
82
+ print(f"Image 2: {image_path2}")
83
+ print(f"Cosine Similarity: {similarity.item():.4f}")
84
+ print(f"Embedding shapes: {emb1.shape}, {emb2.shape}")
85
+ print(f"Embedding means: {emb1.mean().item():.4f}, {emb2.mean().item():.4f}")
86
+ print(f"Embedding stds: {emb1.std().item():.4f}, {emb2.std().item():.4f}")
87
+
88
+ return similarity.item()
89
+
90
+ def validate_response(self, image_path, diagnosis):
91
+ """Validate if diagnosis contains relevant visual features"""
92
+ image = Image.open(image_path).convert('RGB')
93
+
94
+ # Extract visual features using attention
95
+ with torch.no_grad():
96
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
97
+ attention = self.classifier.model.q_former.last_attention[0].mean(dim=0)
98
+
99
+ # Get regions with high attention
100
+ attention = attention.reshape(int(math.sqrt(attention.shape[1])), -1)
101
+ high_attention_regions = (attention > attention.mean() + attention.std()).nonzero()
102
+
103
+ print(f"\nResponse Validation:")
104
+ print(f"Image: {image_path}")
105
+ print(f"Diagnosis: {diagnosis}")
106
+ print(f"Number of high-attention regions: {len(high_attention_regions)}")
107
+
108
+ return high_attention_regions
109
+
110
+ def debug_generation(self, image_path, prompt=None):
111
+ """Debug the generation process"""
112
+ image = Image.open(image_path).convert('RGB')
113
+ image_tensor = self.transform(image).unsqueeze(0).to(self.device)
114
+
115
+ with torch.no_grad():
116
+ # Get image embeddings
117
+ image_embeds = self.classifier.model.encode_image(image_tensor)
118
+
119
+ print("\nGeneration Debug Information:")
120
+ print(f"Image embedding shape: {image_embeds.shape}")
121
+ print(f"Image embedding mean: {image_embeds.mean().item():.4f}")
122
+ print(f"Image embedding std: {image_embeds.std().item():.4f}")
123
+
124
+ # Get diagnosis
125
+ result = self.classifier.predict(image, user_input=prompt)
126
+
127
+ print(f"\nGenerated Diagnosis:")
128
+ print(result["diagnosis"])
129
+
130
+ return result
131
+
132
+ def main():
133
+ # Initialize tester
134
+ tester = SkinGPTTester()
135
+
136
+ # Test image paths
137
+ test_image = "1.jpg"
138
+ similar_image = "2.jpg"
139
+
140
+ # Run all tests
141
+ print("Running comprehensive tests...")
142
+
143
+ # 1. Visualize attention
144
+ print("\n1. Visualizing attention maps...")
145
+ tester.visualize_attention(test_image)
146
+
147
+ # 2. Check feature similarity
148
+ print("\n2. Checking feature similarity...")
149
+ similarity = tester.check_feature_similarity(test_image, similar_image)
150
+
151
+ # 3. Debug generation
152
+ print("\n3. Debugging generation process...")
153
+ result = tester.debug_generation(test_image, "Describe the skin condition in detail.")
154
+
155
+ # 4. Validate response
156
+ print("\n4. Validating response...")
157
+ high_attention_regions = tester.validate_response(test_image, result["diagnosis"])
158
+
159
+ print("\nAll tests completed!")
160
+
161
+ if __name__ == "__main__":
162
+ main()