tigrica007 commited on
Commit
cab9826
·
verified ·
1 Parent(s): 51751d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViTImageProcessor, ViTForImageClassification, pipeline
3
+ from PIL import Image
4
+ import torch
5
+ import pandas as pd
6
+ import os
7
+
8
+ # 1. Конфигурация моделей
9
+ MODEL_CONFIG = {
10
+ "image": {
11
+ "Пневмония": {
12
+ "processor": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
13
+ "model": "nickmuchi/vit-finetuned-chest-xray-pneumonia",
14
+ "type": "image_classification"
15
+ },
16
+ "Опухоль мозга": {
17
+ "processor": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
18
+ "model": "DunnBC22/vit-base-patch16-224-in21k_brain_tumor_diagnosis",
19
+ "type": "image_classification"
20
+ },
21
+ "Диабетическая ретинопатия": {
22
+ "processor": "Kontawat/vit-diabetic-retinopathy-classification",
23
+ "model": "Kontawat/vit-diabetic-retinopathy-classification",
24
+ "type": "image_classification"
25
+ }
26
+ },
27
+ "text": {
28
+ "NER (Bio_ClinicalBERT)": {
29
+ "model": "emilyalsentzer/Bio_ClinicalBERT",
30
+ "type": "ner"
31
+ },
32
+ "NER (BioBERT)": {
33
+ "model": "dmis-lab/biobert-v1.1",
34
+ "type": "ner"
35
+ }
36
+ }
37
+ }
38
+
39
+ # 2. Загрузка моделей
40
+ def load_model_and_processor(analysis_type, model_name):
41
+ if analysis_type == "Изображение":
42
+ config = MODEL_CONFIG["image"].get(model_name)
43
+ if not config:
44
+ return None, None
45
+ processor = ViTImageProcessor.from_pretrained(config["processor"])
46
+ model = ViTForImageClassification.from_pretrained(config["model"])
47
+ return processor, model
48
+ elif analysis_type == "Текст":
49
+ config = MODEL_CONFIG["text"].get(model_name)
50
+ if not config:
51
+ return None, None
52
+ nlp = pipeline(config["type"], model=config["model"], tokenizer=config["model"])
53
+ return None, nlp
54
+ return None, None
55
+
56
+ # 3. Функция для классификации изображений
57
+ def classify_image(image, model_name):
58
+ processor, model = load_model_and_processor("Изображение", model_name)
59
+ if not model or not processor:
60
+ return "Ошибка: Модель или процессор не найдены."
61
+
62
+ try:
63
+ inputs = processor(images=image, return_tensors="pt")
64
+ with torch.no_grad():
65
+ outputs = model(**inputs)
66
+ logits = outputs.logits
67
+ predicted_class_idx = logits.argmax(-1).item()
68
+ predicted_class = model.config.id2label[predicted_class_idx]
69
+ return f"Результат: {predicted_class}"
70
+ except Exception as e:
71
+ return f"Ошибка при обработке изображения: {str(e)}"
72
+
73
+ # 4. Функция для обработки текста (NER)
74
+ def extract_entities(text, model_name):
75
+ _, nlp = load_model_and_processor("Текст", model_name)
76
+ if not nlp:
77
+ return "Ошибка: Модель не найдена."
78
+
79
+ try:
80
+ ner_results = nlp(text)
81
+ entities = []
82
+ current_entity = ""
83
+ current_label = ""
84
+
85
+ for result in ner_results:
86
+ word = result['word']
87
+ entity = result['entity']
88
+ if entity.startswith('B-'):
89
+ if current_entity:
90
+ entities.append((current_entity, current_label))
91
+ current_entity = word
92
+ current_label = entity[2:]
93
+ elif entity.startswith('I-') and current_label == entity[2:]:
94
+ current_entity += " " + word
95
+ else:
96
+ if current_entity:
97
+ entities.append((current_entity, current_label))
98
+ current_entity = ""
99
+ current_label = ""
100
+
101
+ if current_entity:
102
+ entities.append((current_entity, current_label))
103
+
104
+ return "\n".join([f"{entity[0]}: {entity[1]}" for entity in entities]) if entities else "Сущности не найдены."
105
+ except Exception as e:
106
+ return f"Ошибка при обработке текста: {str(e)}"
107
+
108
+ # 5. Функция для обработки CSV-файла
109
+ def process_csv(file, model_name):
110
+ try:
111
+ df = pd.read_csv(file.name)
112
+ if not all(col in df.columns for col in ['id', 'text', 'entities']):
113
+ return "Ошибка: CSV должен содержать колонки id, text, entities"
114
+
115
+ results = []
116
+ for _, row in df.iterrows():
117
+ text = row['text']
118
+ true_entities = row['entities']
119
+ predicted_entities = extract_entities(text, model_name)
120
+ results.append({
121
+ "ID": row['id'],
122
+ "Текст": text,
123
+ "Ожидаемые сущности": true_entities,
124
+ "Предсказанные сущности": predicted_entities
125
+ })
126
+
127
+ results_df = pd.DataFrame(results)
128
+ output_file = "ner_results.csv"
129
+ results_df.to_csv(output_file, index=False)
130
+ return results_df.to_string(), output_file
131
+ except Exception as e:
132
+ return f"Ошибка при обработке CSV: {str(e)}", None
133
+
134
+ # 6. Gradio интерфейс
135
+ with gr.Blocks(fill_height=True) as demo:
136
+ with gr.Sidebar():
137
+ gr.Markdown("# Медицинский анализ")
138
+ gr.Markdown("Универсальное приложение для анализа медицинских изображений и текстов. Выберите тип анализа и модель.")
139
+
140
+ with gr.Row():
141
+ with gr.Column():
142
+ analysis_type = gr.Dropdown(
143
+ choices=["Изображение", "Текст"],
144
+ label="Тип анализа",
145
+ value="Изображение"
146
+ )
147
+ model_name = gr.Dropdown(
148
+ choices=list(MODEL_CONFIG["image"].keys()),
149
+ label="Выберите модель",
150
+ value="Пневмония"
151
+ )
152
+ image_input = gr.Image(type="pil", label="Загрузите изображение (для анализа изображений)")
153
+ text_input = gr.Textbox(label="Введите текст (для анализа текста)", visible=False)
154
+ csv_input = gr.File(label="Загрузите CSV-файл (для анализа текста)", visible=False)
155
+ analyze_button = gr.Button("Анализировать")
156
+
157
+ with gr.Column():
158
+ output = gr.Textbox(label="Результат")
159
+ csv_output = gr.File(label="Результаты обработки CSV")
160
+
161
+ # Динамическое обновление моделей и видимости входов
162
+ def update_inputs(analysis_type):
163
+ model_choices = list(MODEL_CONFIG[analysis_type.lower()].keys())
164
+ image_visible = analysis_type == "Изображение"
165
+ text_visible = analysis_type == "Текст"
166
+ csv_visible = analysis_type == "Текст"
167
+ return (
168
+ gr.update(choices=model_choices, value=model_choices[0]),
169
+ gr.update(visible=image_visible),
170
+ gr.update(visible=text_visible),
171
+ gr.update(visible=csv_visible)
172
+ )
173
+
174
+ analysis_type.change(
175
+ fn=update_inputs,
176
+ inputs=analysis_type,
177
+ outputs=[model_name, image_input, text_input, csv_input]
178
+ )
179
+
180
+ # Обработка нажатия кнопки
181
+ def analyze(analysis_type, model_name, image, text, csv_file):
182
+ if analysis_type == "Изображение" and image:
183
+ return classify_image(image, model_name), None
184
+ elif analysis_type == "Текст" and text:
185
+ return extract_entities(text, model_name), None
186
+ elif analysis_type == "Текст" and csv_file:
187
+ return process_csv(csv_file, model_name)
188
+ return "Ошибка: Загрузите данные и выберите тип анализа.", None
189
+
190
+ analyze_button.click(
191
+ fn=analyze,
192
+ inputs=[analysis_type, model_name, image_input, text_input, csv_input],
193
+ outputs=[output, csv_output]
194
+ )
195
+
196
+ # 7. Запуск приложения
197
+ demo.launch(server_name="0.0.0.0", server_port=7860)