File size: 4,419 Bytes
fe36943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
import json

class SinogramAnalysisSystem:
    def __init__(self):
        print("Initializing system...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Load analysis models
        print("Loading tumor detection models...")
        self.tumor_classifier = AutoModelForImageClassification.from_pretrained(
            "SIATCN/vit_tumor_classifier"
        ).to(self.device)
        self.tumor_processor = AutoImageProcessor.from_pretrained("SIATCN/vit_tumor_classifier")
        
        self.size_classifier = AutoModelForImageClassification.from_pretrained(
            "SIATCN/vit_tumor_radius_detection_finetuned"
        ).to(self.device)
        self.size_processor = AutoImageProcessor.from_pretrained(
            "SIATCN/vit_tumor_radius_detection_finetuned"
        )
        
        # Load Hymba model
        print("Loading Hymba model...")
        repo_name = "nvidia/Hymba-1.5B-Base"
        self.tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
        self.llm = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
        self.llm = self.llm.to(self.device).to(torch.bfloat16)
        
        print("System ready!")

    def process_sinogram(self, image):
        if isinstance(image, str):
            image = Image.open(image)
        if image.mode != 'RGB':
            image = image.convert('RGB')
        return image.resize((224, 224))

    @torch.no_grad()
    def analyze_sinogram(self, processed_image):
        # Detect tumor
        inputs = self.tumor_processor(processed_image, return_tensors="pt").to(self.device)
        outputs = self.tumor_classifier(**inputs)
        tumor_present = outputs.logits.softmax(dim=-1)[0].cpu()
        has_tumor = tumor_present[1] > tumor_present[0]
        
        # Assess size
        size_inputs = self.size_processor(processed_image, return_tensors="pt").to(self.device)
        size_outputs = self.size_classifier(**size_inputs)
        size_pred = size_outputs.logits.softmax(dim=-1)[0].cpu()
        sizes = ["no-tumor", "0.5", "1.0", "1.5"]
        tumor_size = sizes[size_pred.argmax().item()]
        
        return has_tumor, tumor_size

    def generate_report(self, tumor_present, tumor_size):
        prompt = f"""As a medical professional, provide a brief analysis of these sinogram findings:

Findings:
- Tumor Detection: {'Positive' if tumor_present else 'Negative'}
- Tumor Size: {tumor_size} cm

Please provide:
1. Brief interpretation
2. Clinical recommendations
3. Follow-up plan"""

        # Generate response using Hymba
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        outputs = self.llm.generate(
            **inputs,
            max_length=512,
            do_sample=True,
            temperature=0.7,
            use_cache=True
        )
        
        response = self.tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True
        )
        
        return response.strip()

    def analyze_image(self, image):
        try:
            # Process sinogram
            processed = self.process_sinogram(image)
            tumor_present, tumor_size = self.analyze_sinogram(processed)
            
            # Generate medical report
            report = self.generate_report(tumor_present, tumor_size)
            
            # Format results
            return f"""
SINOGRAM ANALYSIS:
• Tumor Detection: {'Positive' if tumor_present else 'Negative'}
• Size Assessment: {tumor_size} cm

MEDICAL REPORT:
{report}
"""
        except Exception as e:
            return f"Error during analysis: {str(e)}"

def create_interface():
    system = SinogramAnalysisSystem()
    
    iface = gr.Interface(
        fn=system.analyze_image,
        inputs=[
            gr.Image(type="pil", label="Upload Sinogram")
        ],
        outputs=[
            gr.Textbox(label="Analysis Results", lines=15)
        ],
        title="Sinogram Analysis System",
        description="Upload a sinogram for tumor detection and medical assessment."
    )
    
    return iface

if __name__ == "__main__":
    interface = create_interface()
    interface.launch(debug=True, share=True)