AliArshad commited on
Commit
fe36943
·
verified ·
1 Parent(s): 70df9b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer
3
+ from PIL import Image
4
+ import torch
5
+ import json
6
+
7
+ class SinogramAnalysisSystem:
8
+ def __init__(self):
9
+ print("Initializing system...")
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Load analysis models
13
+ print("Loading tumor detection models...")
14
+ self.tumor_classifier = AutoModelForImageClassification.from_pretrained(
15
+ "SIATCN/vit_tumor_classifier"
16
+ ).to(self.device)
17
+ self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
18
+
19
+ self.size_classifier = AutoModelForImageClassification.from_pretrained(
20
+ "SIATCN/vit_tumor_radius_detection_finetuned"
21
+ ).to(self.device)
22
+ self.size_processor = AutoImageProcessor.from_pretrained(
23
+ "SIATCN/vit_tumor_radius_detection_finetuned"
24
+ )
25
+
26
+ # Load Hymba model
27
+ print("Loading Hymba model...")
28
+ repo_name = "nvidia/Hymba-1.5B-Base"
29
+ self.tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
30
+ self.llm = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
31
+ self.llm = self.llm.to(self.device).to(torch.bfloat16)
32
+
33
+ print("System ready!")
34
+
35
+ def process_sinogram(self, image):
36
+ if isinstance(image, str):
37
+ image = Image.open(image)
38
+ if image.mode != 'RGB':
39
+ image = image.convert('RGB')
40
+ return image.resize((224, 224))
41
+
42
+ @torch.no_grad()
43
+ def analyze_sinogram(self, processed_image):
44
+ # Detect tumor
45
+ inputs = self.tumor_processor(processed_image, return_tensors="pt").to(self.device)
46
+ outputs = self.tumor_classifier(**inputs)
47
+ tumor_present = outputs.logits.softmax(dim=-1)[0].cpu()
48
+ has_tumor = tumor_present[1] > tumor_present[0]
49
+
50
+ # Assess size
51
+ size_inputs = self.size_processor(processed_image, return_tensors="pt").to(self.device)
52
+ size_outputs = self.size_classifier(**size_inputs)
53
+ size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
54
+ sizes = ["no-tumor", "0.5", "1.0", "1.5"]
55
+ tumor_size = sizes[size_pred.argmax().item()]
56
+
57
+ return has_tumor, tumor_size
58
+
59
+ def generate_report(self, tumor_present, tumor_size):
60
+ prompt = f"""As a medical professional, provide a brief analysis of these sinogram findings:
61
+
62
+ Findings:
63
+ - Tumor Detection: {'Positive' if tumor_present else 'Negative'}
64
+ - Tumor Size: {tumor_size} cm
65
+
66
+ Please provide:
67
+ 1. Brief interpretation
68
+ 2. Clinical recommendations
69
+ 3. Follow-up plan"""
70
+
71
+ # Generate response using Hymba
72
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
73
+ outputs = self.llm.generate(
74
+ **inputs,
75
+ max_length=512,
76
+ do_sample=True,
77
+ temperature=0.7,
78
+ use_cache=True
79
+ )
80
+
81
+ response = self.tokenizer.decode(
82
+ outputs[0][inputs['input_ids'].shape[1]:],
83
+ skip_special_tokens=True
84
+ )
85
+
86
+ return response.strip()
87
+
88
+ def analyze_image(self, image):
89
+ try:
90
+ # Process sinogram
91
+ processed = self.process_sinogram(image)
92
+ tumor_present, tumor_size = self.analyze_sinogram(processed)
93
+
94
+ # Generate medical report
95
+ report = self.generate_report(tumor_present, tumor_size)
96
+
97
+ # Format results
98
+ return f"""
99
+ SINOGRAM ANALYSIS:
100
+ • Tumor Detection: {'Positive' if tumor_present else 'Negative'}
101
+ • Size Assessment: {tumor_size} cm
102
+
103
+ MEDICAL REPORT:
104
+ {report}
105
+ """
106
+ except Exception as e:
107
+ return f"Error during analysis: {str(e)}"
108
+
109
+ def create_interface():
110
+ system = SinogramAnalysisSystem()
111
+
112
+ iface = gr.Interface(
113
+ fn=system.analyze_image,
114
+ inputs=[
115
+ gr.Image(type="pil", label="Upload Sinogram")
116
+ ],
117
+ outputs=[
118
+ gr.Textbox(label="Analysis Results", lines=15)
119
+ ],
120
+ title="Sinogram Analysis System",
121
+ description="Upload a sinogram for tumor detection and medical assessment."
122
+ )
123
+
124
+ return iface
125
+
126
+ if __name__ == "__main__":
127
+ interface = create_interface()
128
+ interface.launch(debug=True, share=True)