KeerthiVM commited on
Commit
975c276
·
1 Parent(s): 7e973b1
Files changed (2) hide show
  1. SkinGPT.py +7 -2
  2. app.py +0 -7
SkinGPT.py CHANGED
@@ -66,9 +66,13 @@ class Blip2QFormer(nn.Module):
66
  msg = self.load_state_dict(state_dict, strict=False)
67
 
68
  def forward(self, visual_features):
 
 
 
 
69
  # Project visual features
70
- with autocast(enabled=False):
71
- visual_embeds = self.vision_proj(visual_features.float())
72
  # visual_embeds = self.vision_proj(visual_features.float())
73
  visual_attention_mask = torch.ones(
74
  visual_embeds.size()[:-1],
@@ -215,6 +219,7 @@ class SkinGPT4(nn.Module):
215
  x = blk(x)
216
  x = self.vit.norm(x)
217
  vit_features = self.ln_vision(x)
 
218
 
219
  # Q-Former forward pass
220
  with torch.no_grad():
 
66
  msg = self.load_state_dict(state_dict, strict=False)
67
 
68
  def forward(self, visual_features):
69
+
70
+ print(
71
+ f"Visual features stats - min: {visual_features.min().item():.4f}, max: {visual_features.max().item():.4f}")
72
+
73
  # Project visual features
74
+ visual_embeds = self.vision_proj(visual_features.float())
75
+ print(f"Projected embeds stats - min: {visual_embeds.min().item():.4f}, max: {visual_embeds.max().item():.4f}")
76
  # visual_embeds = self.vision_proj(visual_features.float())
77
  visual_attention_mask = torch.ones(
78
  visual_embeds.size()[:-1],
 
219
  x = blk(x)
220
  x = self.vit.norm(x)
221
  vit_features = self.ln_vision(x)
222
+ print(f"vit features (first 5): {vit_features[0, 0, :5]}")
223
 
224
  # Q-Former forward pass
225
  with torch.no_grad():
app.py CHANGED
@@ -6,12 +6,6 @@ torch.manual_seed(42)
6
  random.seed(42)
7
  np.random.seed(42)
8
 
9
- torch.backends.cudnn.deterministic = True
10
- torch.backends.cudnn.benchmark = False
11
-
12
- if torch.cuda.is_available():
13
- torch.use_deterministic_algorithms(True)
14
- torch.cuda.manual_seed_all(42)
15
 
16
  import streamlit as st
17
  import io
@@ -49,7 +43,6 @@ warnings.filterwarnings("ignore")
49
 
50
  # @st.cache_resource
51
  def get_classifier():
52
- torch.use_deterministic_algorithms(True)
53
  classifier = SkinGPTClassifier()
54
  for module in [classifier.model.vit,
55
  classifier.model.q_former,
 
6
  random.seed(42)
7
  np.random.seed(42)
8
 
 
 
 
 
 
 
9
 
10
  import streamlit as st
11
  import io
 
43
 
44
  # @st.cache_resource
45
  def get_classifier():
 
46
  classifier = SkinGPTClassifier()
47
  for module in [classifier.model.vit,
48
  classifier.model.q_former,