ZDPLI commited on
Commit
981adbd
·
verified ·
1 Parent(s): a420201

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -17
app.py CHANGED
@@ -8,6 +8,7 @@ from transformers import ViTForImageClassification
8
  from torch import nn
9
  from torch.cuda.amp import autocast
10
  import os
 
11
 
12
  # Global configuration
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -19,7 +20,7 @@ label_mapping = {
19
  1: "Меланоцитарный невус",
20
  2: "Базальноклеточная карцинома",
21
  3: "Актинический кератоз",
22
- 4: "Доброкачественная кератоза",
23
  5: "Дерматофиброма",
24
  6: "Сосудистые поражения"
25
  }
@@ -88,7 +89,9 @@ class ModelHandler:
88
  return {"error": "Модели не загружены"}
89
 
90
  inputs = transform_image(image)
91
- with autocast():
 
 
92
  outputs = self.efficientnet(inputs)
93
  probs = torch.nn.functional.softmax(outputs, dim=1)
94
 
@@ -100,8 +103,9 @@ class ModelHandler:
100
  return {"error": "Модели не загружены"}
101
 
102
  inputs = transform_image(image)
103
- with autocast():
104
- outputs = self.deit(inputs).logits
 
105
  probs = torch.nn.functional.softmax(outputs, dim=1)
106
 
107
  return self._format_predictions(probs)
@@ -112,23 +116,22 @@ class ModelHandler:
112
  return {"error": "Модели не загружены"}
113
 
114
  inputs = transform_image(image)
115
- with autocast():
 
116
  eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
117
- deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1)
118
  ensemble_probs = (eff_probs + deit_probs) / 2
119
 
120
  return self._format_predictions(ensemble_probs)
121
 
122
- def _format_predictions(self, probs):
123
- top5_probs, top5_indices = torch.topk(probs, 5)
124
- result = {}
125
- for i in range(5):
126
- idx = top5_indices[0][i].item()
127
- label = label_mapping.get(idx, f"Класс {idx}")
128
- # return raw prob, not percent:
129
- result[label] = float(top5_probs[0][i].item())
130
- return result
131
-
132
 
133
  # Initialize model handler
134
  model_handler = ModelHandler()
@@ -184,4 +187,4 @@ def create_interface():
184
  if __name__ == "__main__":
185
  interface = create_interface()
186
  print("🚀 Запуск интерфейса...")
187
- interface.launch(share=True)
 
8
  from torch import nn
9
  from torch.cuda.amp import autocast
10
  import os
11
+ from contextlib import nullcontext
12
 
13
  # Global configuration
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
20
  1: "Меланоцитарный невус",
21
  2: "Базальноклеточная карцинома",
22
  3: "Актинический кератоз",
23
+ 4: "Доброкачественная кератома",
24
  5: "Дерматофиброма",
25
  6: "Сосудистые поражения"
26
  }
 
89
  return {"error": "Модели не загружены"}
90
 
91
  inputs = transform_image(image)
92
+ # Handle autocast based on device
93
+ ctx = autocast() if device.type == 'cuda' else nullcontext()
94
+ with ctx:
95
  outputs = self.efficientnet(inputs)
96
  probs = torch.nn.functional.softmax(outputs, dim=1)
97
 
 
103
  return {"error": "Модели не загружены"}
104
 
105
  inputs = transform_image(image)
106
+ ctx = autocast() if device.type == 'cuda' else nullcontext()
107
+ with ctx:
108
+ outputs = self.deit(pixel_values=inputs).logits # Corrected parameter
109
  probs = torch.nn.functional.softmax(outputs, dim=1)
110
 
111
  return self._format_predictions(probs)
 
116
  return {"error": "Модели не загружены"}
117
 
118
  inputs = transform_image(image)
119
+ ctx = autocast() if device.type == 'cuda' else nullcontext()
120
+ with ctx:
121
  eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
122
+ deit_probs = torch.nn.functional.softmax(self.deit(pixel_values=inputs).logits, dim=1)
123
  ensemble_probs = (eff_probs + deit_probs) / 2
124
 
125
  return self._format_predictions(ensemble_probs)
126
 
127
+ def _format_predictions(self, probs): # Corrected indentation
128
+ top5_probs, top5_indices = torch.topk(probs, 5)
129
+ result = {}
130
+ for i in range(5):
131
+ idx = top5_indices[0][i].item()
132
+ label = label_mapping.get(idx, f"Класс {idx}")
133
+ result[label] = float(top5_probs[0][i].item())
134
+ return result
 
 
135
 
136
  # Initialize model handler
137
  model_handler = ModelHandler()
 
187
  if __name__ == "__main__":
188
  interface = create_interface()
189
  print("🚀 Запуск интерфейса...")
190
+ interface.launch(server_port=7860) # Explicitly set port if needed