fix added
Browse files- SkinGPT.py +7 -2
- app.py +0 -7
SkinGPT.py
CHANGED
@@ -66,9 +66,13 @@ class Blip2QFormer(nn.Module):
|
|
66 |
msg = self.load_state_dict(state_dict, strict=False)
|
67 |
|
68 |
def forward(self, visual_features):
|
|
|
|
|
|
|
|
|
69 |
# Project visual features
|
70 |
-
|
71 |
-
|
72 |
# visual_embeds = self.vision_proj(visual_features.float())
|
73 |
visual_attention_mask = torch.ones(
|
74 |
visual_embeds.size()[:-1],
|
@@ -215,6 +219,7 @@ class SkinGPT4(nn.Module):
|
|
215 |
x = blk(x)
|
216 |
x = self.vit.norm(x)
|
217 |
vit_features = self.ln_vision(x)
|
|
|
218 |
|
219 |
# Q-Former forward pass
|
220 |
with torch.no_grad():
|
|
|
66 |
msg = self.load_state_dict(state_dict, strict=False)
|
67 |
|
68 |
def forward(self, visual_features):
|
69 |
+
|
70 |
+
print(
|
71 |
+
f"Visual features stats - min: {visual_features.min().item():.4f}, max: {visual_features.max().item():.4f}")
|
72 |
+
|
73 |
# Project visual features
|
74 |
+
visual_embeds = self.vision_proj(visual_features.float())
|
75 |
+
print(f"Projected embeds stats - min: {visual_embeds.min().item():.4f}, max: {visual_embeds.max().item():.4f}")
|
76 |
# visual_embeds = self.vision_proj(visual_features.float())
|
77 |
visual_attention_mask = torch.ones(
|
78 |
visual_embeds.size()[:-1],
|
|
|
219 |
x = blk(x)
|
220 |
x = self.vit.norm(x)
|
221 |
vit_features = self.ln_vision(x)
|
222 |
+
print(f"vit features (first 5): {vit_features[0, 0, :5]}")
|
223 |
|
224 |
# Q-Former forward pass
|
225 |
with torch.no_grad():
|
app.py
CHANGED
@@ -6,12 +6,6 @@ torch.manual_seed(42)
|
|
6 |
random.seed(42)
|
7 |
np.random.seed(42)
|
8 |
|
9 |
-
torch.backends.cudnn.deterministic = True
|
10 |
-
torch.backends.cudnn.benchmark = False
|
11 |
-
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
torch.use_deterministic_algorithms(True)
|
14 |
-
torch.cuda.manual_seed_all(42)
|
15 |
|
16 |
import streamlit as st
|
17 |
import io
|
@@ -49,7 +43,6 @@ warnings.filterwarnings("ignore")
|
|
49 |
|
50 |
# @st.cache_resource
|
51 |
def get_classifier():
|
52 |
-
torch.use_deterministic_algorithms(True)
|
53 |
classifier = SkinGPTClassifier()
|
54 |
for module in [classifier.model.vit,
|
55 |
classifier.model.q_former,
|
|
|
6 |
random.seed(42)
|
7 |
np.random.seed(42)
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
import streamlit as st
|
11 |
import io
|
|
|
43 |
|
44 |
# @st.cache_resource
|
45 |
def get_classifier():
|
|
|
46 |
classifier = SkinGPTClassifier()
|
47 |
for module in [classifier.model.vit,
|
48 |
classifier.model.q_former,
|