Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from transformers import ViTForImageClassification
|
|
8 |
from torch import nn
|
9 |
from torch.cuda.amp import autocast
|
10 |
import os
|
|
|
11 |
|
12 |
# Global configuration
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -19,7 +20,7 @@ label_mapping = {
|
|
19 |
1: "Меланоцитарный невус",
|
20 |
2: "Базальноклеточная карцинома",
|
21 |
3: "Актинический кератоз",
|
22 |
-
4: "Доброкачественная
|
23 |
5: "Дерматофиброма",
|
24 |
6: "Сосудистые поражения"
|
25 |
}
|
@@ -88,7 +89,9 @@ class ModelHandler:
|
|
88 |
return {"error": "Модели не загружены"}
|
89 |
|
90 |
inputs = transform_image(image)
|
91 |
-
|
|
|
|
|
92 |
outputs = self.efficientnet(inputs)
|
93 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
94 |
|
@@ -100,8 +103,9 @@ class ModelHandler:
|
|
100 |
return {"error": "Модели не загружены"}
|
101 |
|
102 |
inputs = transform_image(image)
|
103 |
-
|
104 |
-
|
|
|
105 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
106 |
|
107 |
return self._format_predictions(probs)
|
@@ -112,23 +116,22 @@ class ModelHandler:
|
|
112 |
return {"error": "Модели не загружены"}
|
113 |
|
114 |
inputs = transform_image(image)
|
115 |
-
|
|
|
116 |
eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
|
117 |
-
deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1)
|
118 |
ensemble_probs = (eff_probs + deit_probs) / 2
|
119 |
|
120 |
return self._format_predictions(ensemble_probs)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
return result
|
131 |
-
|
132 |
|
133 |
# Initialize model handler
|
134 |
model_handler = ModelHandler()
|
@@ -184,4 +187,4 @@ def create_interface():
|
|
184 |
if __name__ == "__main__":
|
185 |
interface = create_interface()
|
186 |
print("🚀 Запуск интерфейса...")
|
187 |
-
interface.launch(
|
|
|
8 |
from torch import nn
|
9 |
from torch.cuda.amp import autocast
|
10 |
import os
|
11 |
+
from contextlib import nullcontext
|
12 |
|
13 |
# Global configuration
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
20 |
1: "Меланоцитарный невус",
|
21 |
2: "Базальноклеточная карцинома",
|
22 |
3: "Актинический кератоз",
|
23 |
+
4: "Доброкачественная кератома",
|
24 |
5: "Дерматофиброма",
|
25 |
6: "Сосудистые поражения"
|
26 |
}
|
|
|
89 |
return {"error": "Модели не загружены"}
|
90 |
|
91 |
inputs = transform_image(image)
|
92 |
+
# Handle autocast based on device
|
93 |
+
ctx = autocast() if device.type == 'cuda' else nullcontext()
|
94 |
+
with ctx:
|
95 |
outputs = self.efficientnet(inputs)
|
96 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
97 |
|
|
|
103 |
return {"error": "Модели не загружены"}
|
104 |
|
105 |
inputs = transform_image(image)
|
106 |
+
ctx = autocast() if device.type == 'cuda' else nullcontext()
|
107 |
+
with ctx:
|
108 |
+
outputs = self.deit(pixel_values=inputs).logits # Corrected parameter
|
109 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
110 |
|
111 |
return self._format_predictions(probs)
|
|
|
116 |
return {"error": "Модели не загружены"}
|
117 |
|
118 |
inputs = transform_image(image)
|
119 |
+
ctx = autocast() if device.type == 'cuda' else nullcontext()
|
120 |
+
with ctx:
|
121 |
eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
|
122 |
+
deit_probs = torch.nn.functional.softmax(self.deit(pixel_values=inputs).logits, dim=1)
|
123 |
ensemble_probs = (eff_probs + deit_probs) / 2
|
124 |
|
125 |
return self._format_predictions(ensemble_probs)
|
126 |
|
127 |
+
def _format_predictions(self, probs): # Corrected indentation
|
128 |
+
top5_probs, top5_indices = torch.topk(probs, 5)
|
129 |
+
result = {}
|
130 |
+
for i in range(5):
|
131 |
+
idx = top5_indices[0][i].item()
|
132 |
+
label = label_mapping.get(idx, f"Класс {idx}")
|
133 |
+
result[label] = float(top5_probs[0][i].item())
|
134 |
+
return result
|
|
|
|
|
135 |
|
136 |
# Initialize model handler
|
137 |
model_handler = ModelHandler()
|
|
|
187 |
if __name__ == "__main__":
|
188 |
interface = create_interface()
|
189 |
print("🚀 Запуск интерфейса...")
|
190 |
+
interface.launch(server_port=7860) # Explicitly set port if needed
|