KYAGABA commited on
Commit
747ba73
·
1 Parent(s): 3f5a760
Files changed (6) hide show
  1. Dockerfile +16 -0
  2. Phronesis +0 -1
  3. README.md +13 -0
  4. app.py +121 -0
  5. model.py +56 -0
  6. requirements.txt +16 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy the current directory contents into the container at /app
8
+ COPY . /app
9
+
10
+ # Install the dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Expose port 8000
14
+ EXPOSE 7860
15
+ # CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
Phronesis DELETED
@@ -1 +0,0 @@
1
- Subproject commit 9c5facf38621ad02be7be79226d794f6c2f14dee
 
 
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Phronesis
3
+ emoji: 🌖
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.4.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: 'REPORT GEN AND CLASSIFICATION MODEL '
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #app.py
2
+ import os
3
+ import io
4
+ import uvicorn
5
+
6
+ import torch
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+ from torchvision import models, transforms
11
+ from PIL import Image
12
+ import numpy as np
13
+ from huggingface_hub import hf_hub_download
14
+ import pydicom
15
+ import gc
16
+ from model import CombinedModel, ImageToTextProjector
17
+
18
+ from fastapi import FastAPI, Request
19
+
20
+ app = FastAPI()
21
+
22
+ @app.get("/")
23
+ async def root(request: Request):
24
+ return {"message": "Welcome to Phronesis"}
25
+
26
+
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ def dicom_to_png(dicom_data):
30
+ try:
31
+ dicom_file = pydicom.dcmread(dicom_data)
32
+ if not hasattr(dicom_file, 'PixelData'):
33
+ raise HTTPException(status_code=400, detail="No pixel data in DICOM file.")
34
+
35
+ pixel_array = dicom_file.pixel_array.astype(np.float32)
36
+ pixel_array = ((pixel_array - pixel_array.min()) / (pixel_array.ptp())) * 255.0
37
+ pixel_array = pixel_array.astype(np.uint8)
38
+
39
+ img = Image.fromarray(pixel_array).convert("L")
40
+ return img
41
+ except Exception as e:
42
+ raise HTTPException(status_code=500, detail=f"Error converting DICOM to PNG: {e}")
43
+
44
+ # Set up secure model initialization
45
+ HF_TOKEN = os.getenv('HF_TOKEN')
46
+ if not HF_TOKEN:
47
+ raise ValueError("Missing Hugging Face token in environment variables.")
48
+
49
+ try:
50
+ report_generator_tokenizer = AutoTokenizer.from_pretrained(
51
+ "KYAGABA/combined-multimodal-model",
52
+ token=HF_TOKEN if HF_TOKEN else None
53
+ )
54
+ video_model = models.video.r3d_18(weights="KINETICS400_V1")
55
+ video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
56
+ report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
57
+ projector = ImageToTextProjector(512, report_generator.config.d_model)
58
+ num_classes = 4
59
+ combined_model = CombinedModel(video_model, report_generator, num_classes, projector, report_generator_tokenizer)
60
+ model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
61
+ state_dict = torch.load(model_file, map_location=device)
62
+ combined_model.load_state_dict(state_dict)
63
+ combined_model.eval()
64
+ except Exception as e:
65
+ raise SystemExit(f"Error loading models: {e}")
66
+
67
+ image_transform = transforms.Compose([
68
+ transforms.Resize((112, 112)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])
71
+ ])
72
+
73
+ class_names = ["acute", "normal", "chronic", "lacunar"]
74
+
75
+ @app.post("/predict/")
76
+ async def predict(files: list[UploadFile]):
77
+ print(f"Received {len(files)} files")
78
+ n_frames = 16
79
+ images = []
80
+
81
+ for file in files:
82
+ ext = file.filename.split('.')[-1].lower()
83
+ try:
84
+ if ext in ['dcm', 'ima']:
85
+ dicom_img = dicom_to_png(await file.read())
86
+ images.append(dicom_img.convert("RGB"))
87
+ elif ext in ['png', 'jpeg', 'jpg']:
88
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
89
+ images.append(img)
90
+ else:
91
+ raise HTTPException(status_code=400, detail="Unsupported file type.")
92
+ except Exception as e:
93
+ raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")
94
+
95
+ if not images:
96
+ return JSONResponse(content={"error": "No valid images provided."}, status_code=400)
97
+
98
+ if len(images) >= n_frames:
99
+ images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
100
+ else:
101
+ images_sampled = images + [images[-1]] * (n_frames - len(images))
102
+
103
+ image_tensors = [image_transform(img) for img in images_sampled]
104
+ images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0).to(device)
105
+
106
+ with torch.no_grad():
107
+ class_outputs, generated_report, _ = combined_model(images_tensor)
108
+ predicted_class = torch.argmax(class_outputs, dim=1).item()
109
+ predicted_class_name = class_names[predicted_class]
110
+
111
+ gc.collect()
112
+ if torch.cuda.is_available():
113
+ torch.cuda.empty_cache()
114
+
115
+ return {
116
+ "predicted_class": predicted_class_name,
117
+ "generated_report": generated_report[0] if generated_report else "No report generated."
118
+ }
119
+
120
+ if __name__ == "__main__":
121
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModelForSeq2SeqLM
6
+
7
+ class ImageToTextProjector(nn.Module):
8
+ def __init__(self, image_embedding_dim, text_embedding_dim):
9
+ super(ImageToTextProjector, self).__init__()
10
+ self.fc = nn.Linear(image_embedding_dim, text_embedding_dim)
11
+ self.activation = nn.ReLU()
12
+ self.dropout = nn.Dropout(p=0.5)
13
+
14
+ def forward(self, x):
15
+ x = self.fc(x)
16
+ x = self.activation(x)
17
+ x = self.dropout(x)
18
+ return x
19
+
20
+ class CombinedModel(nn.Module):
21
+ def __init__(self, video_model, report_generator, num_classes, projector, tokenizer):
22
+ super(CombinedModel, self).__init__()
23
+ self.video_model = video_model
24
+ self.report_generator = report_generator
25
+ self.classifier = nn.Linear(512, num_classes)
26
+ self.projector = projector
27
+ self.dropout = nn.Dropout(p=0.5)
28
+ self.tokenizer = tokenizer # Store tokenizer
29
+
30
+ def forward(self, images, labels=None):
31
+ video_embeddings = self.video_model(images)
32
+ video_embeddings = self.dropout(video_embeddings)
33
+ class_outputs = self.classifier(video_embeddings)
34
+ projected_embeddings = self.projector(video_embeddings)
35
+ encoder_inputs = projected_embeddings.unsqueeze(1)
36
+
37
+ if labels is not None:
38
+ outputs = self.report_generator(
39
+ inputs_embeds=encoder_inputs,
40
+ labels=labels
41
+ )
42
+ gen_loss = outputs.loss
43
+ generated_report = None
44
+ else:
45
+ generated_report_ids = self.report_generator.generate(
46
+ inputs_embeds=encoder_inputs,
47
+ max_length=512,
48
+ num_beams=4,
49
+ early_stopping=True
50
+ )
51
+ generated_report = self.tokenizer.batch_decode(
52
+ generated_report_ids, skip_special_tokens=True
53
+ )
54
+ gen_loss = None
55
+
56
+ return class_outputs, generated_report, gen_loss
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ transformers==4.44.2
5
+ gradio==5.0
6
+ numpy==1.26.2
7
+ Pillow==10.0.1
8
+ fastapi
9
+ # Additional dependencies
10
+ huggingface_hub==0.25.1 # Compatible with both transformers and gradio
11
+ torchmetrics==1.5.1
12
+ nltk==3.8.1
13
+ scikit-learn==1.3.0
14
+ tqdm==4.66.1
15
+ sentencepiece==0.1.99
16
+ pydicom==2.4.1