Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import pandas as pd
|
6 |
+
import os
|
7 |
+
|
8 |
+
# 1. Конфигурация моделей
|
9 |
+
MODEL_CONFIG = {
|
10 |
+
"image": {
|
11 |
+
"Пневмония": {
|
12 |
+
"processor": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
|
13 |
+
"model": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
|
14 |
+
"type": "image_classification"
|
15 |
+
},
|
16 |
+
"Опухоль мозга": {
|
17 |
+
"processor": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
|
18 |
+
"model": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
|
19 |
+
"type": "image_classification"
|
20 |
+
},
|
21 |
+
"Диабетическая ретинопатия": {
|
22 |
+
"processor": "Kontawat/vit-diabetic-retinopathy-classification",
|
23 |
+
"model": "Kontawat/vit-diabetic-retinopathy-classification",
|
24 |
+
"type": "image_classification"
|
25 |
+
}
|
26 |
+
},
|
27 |
+
"text": {
|
28 |
+
"NER (Bio_ClinicalBERT)": {
|
29 |
+
"model": "emilyalsentzer/Bio_ClinicalBERT",
|
30 |
+
"type": "ner"
|
31 |
+
},
|
32 |
+
"NER (BioBERT)": {
|
33 |
+
"model": "dmis-lab/biobert-v1.1",
|
34 |
+
"type": "ner"
|
35 |
+
}
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
# 2. Загрузка моделей
|
40 |
+
def load_model_and_processor(analysis_type, model_name):
|
41 |
+
if analysis_type == "Изображение":
|
42 |
+
config = MODEL_CONFIG["image"].get(model_name)
|
43 |
+
if not config:
|
44 |
+
return None, None
|
45 |
+
processor = ViTImageProcessor.from_pretrained(config["processor"])
|
46 |
+
model = ViTForImageClassification.from_pretrained(config["model"])
|
47 |
+
return processor, model
|
48 |
+
elif analysis_type == "Текст":
|
49 |
+
config = MODEL_CONFIG["text"].get(model_name)
|
50 |
+
if not config:
|
51 |
+
return None, None
|
52 |
+
nlp = pipeline(config["type"], model=config["model"], tokenizer=config["model"])
|
53 |
+
return None, nlp
|
54 |
+
return None, None
|
55 |
+
|
56 |
+
# 3. Функция для классификации изображений
|
57 |
+
def classify_image(image, model_name):
|
58 |
+
processor, model = load_model_and_processor("Изображение", model_name)
|
59 |
+
if not model or not processor:
|
60 |
+
return "Ошибка: Модель или процессор не найдены."
|
61 |
+
|
62 |
+
try:
|
63 |
+
inputs = processor(images=image, return_tensors="pt")
|
64 |
+
with torch.no_grad():
|
65 |
+
outputs = model(**inputs)
|
66 |
+
logits = outputs.logits
|
67 |
+
predicted_class_idx = logits.argmax(-1).item()
|
68 |
+
predicted_class = model.config.id2label[predicted_class_idx]
|
69 |
+
return f"Результат: {predicted_class}"
|
70 |
+
except Exception as e:
|
71 |
+
return f"Ошибка при обработке изображения: {str(e)}"
|
72 |
+
|
73 |
+
# 4. Функция для обработки текста (NER)
|
74 |
+
def extract_entities(text, model_name):
|
75 |
+
_, nlp = load_model_and_processor("Текст", model_name)
|
76 |
+
if not nlp:
|
77 |
+
return "Ошибка: Модель не найдена."
|
78 |
+
|
79 |
+
try:
|
80 |
+
ner_results = nlp(text)
|
81 |
+
entities = []
|
82 |
+
current_entity = ""
|
83 |
+
current_label = ""
|
84 |
+
|
85 |
+
for result in ner_results:
|
86 |
+
word = result['word']
|
87 |
+
entity = result['entity']
|
88 |
+
if entity.startswith('B-'):
|
89 |
+
if current_entity:
|
90 |
+
entities.append((current_entity, current_label))
|
91 |
+
current_entity = word
|
92 |
+
current_label = entity[2:]
|
93 |
+
elif entity.startswith('I-') and current_label == entity[2:]:
|
94 |
+
current_entity += " " + word
|
95 |
+
else:
|
96 |
+
if current_entity:
|
97 |
+
entities.append((current_entity, current_label))
|
98 |
+
current_entity = ""
|
99 |
+
current_label = ""
|
100 |
+
|
101 |
+
if current_entity:
|
102 |
+
entities.append((current_entity, current_label))
|
103 |
+
|
104 |
+
return "\n".join([f"{entity[0]}: {entity[1]}" for entity in entities]) if entities else "Сущности не найдены."
|
105 |
+
except Exception as e:
|
106 |
+
return f"Ошибка при обработке текста: {str(e)}"
|
107 |
+
|
108 |
+
# 5. Функция для обработки CSV-файла
|
109 |
+
def process_csv(file, model_name):
|
110 |
+
try:
|
111 |
+
df = pd.read_csv(file.name)
|
112 |
+
if not all(col in df.columns for col in ['id', 'text', 'entities']):
|
113 |
+
return "Ошибка: CSV должен содержать колонки id, text, entities"
|
114 |
+
|
115 |
+
results = []
|
116 |
+
for _, row in df.iterrows():
|
117 |
+
text = row['text']
|
118 |
+
true_entities = row['entities']
|
119 |
+
predicted_entities = extract_entities(text, model_name)
|
120 |
+
results.append({
|
121 |
+
"ID": row['id'],
|
122 |
+
"Текст": text,
|
123 |
+
"Ожидаемые сущности": true_entities,
|
124 |
+
"Предсказанные сущности": predicted_entities
|
125 |
+
})
|
126 |
+
|
127 |
+
results_df = pd.DataFrame(results)
|
128 |
+
output_file = "ner_results.csv"
|
129 |
+
results_df.to_csv(output_file, index=False)
|
130 |
+
return results_df.to_string(), output_file
|
131 |
+
except Exception as e:
|
132 |
+
return f"Ошибка при обработке CSV: {str(e)}", None
|
133 |
+
|
134 |
+
# 6. Gradio интерфейс
|
135 |
+
with gr.Blocks(fill_height=True) as demo:
|
136 |
+
with gr.Sidebar():
|
137 |
+
gr.Markdown("# Медицинский анализ")
|
138 |
+
gr.Markdown("Универсальное приложение для анализа медицинских изображений и текстов. Выберите тип анализа и модель.")
|
139 |
+
|
140 |
+
with gr.Row():
|
141 |
+
with gr.Column():
|
142 |
+
analysis_type = gr.Dropdown(
|
143 |
+
choices=["Изображение", "Текст"],
|
144 |
+
label="Тип анализа",
|
145 |
+
value="Изображение"
|
146 |
+
)
|
147 |
+
model_name = gr.Dropdown(
|
148 |
+
choices=list(MODEL_CONFIG["image"].keys()),
|
149 |
+
label="Выберите модель",
|
150 |
+
value="Пневмония"
|
151 |
+
)
|
152 |
+
image_input = gr.Image(type="pil", label="Загрузите изображение (для анализа изображений)")
|
153 |
+
text_input = gr.Textbox(label="Введите текст (для анализа текста)", visible=False)
|
154 |
+
csv_input = gr.File(label="Загрузите CSV-файл (для анализа текста)", visible=False)
|
155 |
+
analyze_button = gr.Button("Анализировать")
|
156 |
+
|
157 |
+
with gr.Column():
|
158 |
+
output = gr.Textbox(label="Результат")
|
159 |
+
csv_output = gr.File(label="Результаты обработки CSV")
|
160 |
+
|
161 |
+
# Динамическое обновление моделей и видимости входов
|
162 |
+
def update_inputs(analysis_type):
|
163 |
+
model_choices = list(MODEL_CONFIG[analysis_type.lower()].keys())
|
164 |
+
image_visible = analysis_type == "Изображение"
|
165 |
+
text_visible = analysis_type == "Текст"
|
166 |
+
csv_visible = analysis_type == "Текст"
|
167 |
+
return (
|
168 |
+
gr.update(choices=model_choices, value=model_choices[0]),
|
169 |
+
gr.update(visible=image_visible),
|
170 |
+
gr.update(visible=text_visible),
|
171 |
+
gr.update(visible=csv_visible)
|
172 |
+
)
|
173 |
+
|
174 |
+
analysis_type.change(
|
175 |
+
fn=update_inputs,
|
176 |
+
inputs=analysis_type,
|
177 |
+
outputs=[model_name, image_input, text_input, csv_input]
|
178 |
+
)
|
179 |
+
|
180 |
+
# Обработка нажатия кнопки
|
181 |
+
def analyze(analysis_type, model_name, image, text, csv_file):
|
182 |
+
if analysis_type == "Изображение" and image:
|
183 |
+
return classify_image(image, model_name), None
|
184 |
+
elif analysis_type == "Текст" and text:
|
185 |
+
return extract_entities(text, model_name), None
|
186 |
+
elif analysis_type == "Текст" and csv_file:
|
187 |
+
return process_csv(csv_file, model_name)
|
188 |
+
return "Ошибка: Загрузите данные и выберите тип анализа.", None
|
189 |
+
|
190 |
+
analyze_button.click(
|
191 |
+
fn=analyze,
|
192 |
+
inputs=[analysis_type, model_name, image_input, text_input, csv_input],
|
193 |
+
outputs=[output, csv_output]
|
194 |
+
)
|
195 |
+
|
196 |
+
# 7. Запуск приложения
|
197 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|