File size: 4,485 Bytes
747ba73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#app.py
import os
import io
import uvicorn

import torch
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torchvision import models, transforms
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_download
import pydicom
import gc
from model import CombinedModel, ImageToTextProjector

from fastapi import FastAPI, Request

app = FastAPI()

@app.get("/")
async def root(request: Request):
    return {"message": "Welcome to Phronesis"}


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def dicom_to_png(dicom_data):
    try:
        dicom_file = pydicom.dcmread(dicom_data)
        if not hasattr(dicom_file, 'PixelData'):
            raise HTTPException(status_code=400, detail="No pixel data in DICOM file.")
        
        pixel_array = dicom_file.pixel_array.astype(np.float32)
        pixel_array = ((pixel_array - pixel_array.min()) / (pixel_array.ptp())) * 255.0
        pixel_array = pixel_array.astype(np.uint8)

        img = Image.fromarray(pixel_array).convert("L")
        return img
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error converting DICOM to PNG: {e}")

# Set up secure model initialization
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
    raise ValueError("Missing Hugging Face token in environment variables.")

try:
    report_generator_tokenizer = AutoTokenizer.from_pretrained(
        "KYAGABA/combined-multimodal-model",
        token=HF_TOKEN if HF_TOKEN else None
    )
    video_model = models.video.r3d_18(weights="KINETICS400_V1")
    video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
    report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
    projector = ImageToTextProjector(512, report_generator.config.d_model)
    num_classes = 4
    combined_model = CombinedModel(video_model, report_generator, num_classes, projector, report_generator_tokenizer)
    model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
    state_dict = torch.load(model_file, map_location=device)
    combined_model.load_state_dict(state_dict)
    combined_model.eval()
except Exception as e:
    raise SystemExit(f"Error loading models: {e}")

image_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])
])

class_names = ["acute", "normal", "chronic", "lacunar"]

@app.post("/predict/")
async def predict(files: list[UploadFile]):
    print(f"Received {len(files)} files")
    n_frames = 16
    images = []

    for file in files:
        ext = file.filename.split('.')[-1].lower()
        try:
            if ext in ['dcm', 'ima']:
                dicom_img = dicom_to_png(await file.read())
                images.append(dicom_img.convert("RGB"))
            elif ext in ['png', 'jpeg', 'jpg']:
                img = Image.open(io.BytesIO(await file.read())).convert("RGB")
                images.append(img)
            else:
                raise HTTPException(status_code=400, detail="Unsupported file type.")
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")

    if not images:
        return JSONResponse(content={"error": "No valid images provided."}, status_code=400)

    if len(images) >= n_frames:
        images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
    else:
        images_sampled = images + [images[-1]] * (n_frames - len(images))

    image_tensors = [image_transform(img) for img in images_sampled]
    images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0).to(device)

    with torch.no_grad():
        class_outputs, generated_report, _ = combined_model(images_tensor)
        predicted_class = torch.argmax(class_outputs, dim=1).item()
        predicted_class_name = class_names[predicted_class]

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return {
        "predicted_class": predicted_class_name,
        "generated_report": generated_report[0] if generated_report else "No report generated."
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))