File size: 2,562 Bytes
0faaa54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import torch
import config
from utils import (
    load_dataset,
    get_model_instance,
    load_checkpoint,
    can_load_checkpoint,
    normalize_text,
)
from PIL import Image
import torchvision.transforms as transforms

# Define device
DEVICE = 'cpu'

# Define image transformations (adjust based on training setup)
TRANSFORMS = transforms.Compose([
    transforms.Resize((224, 224)),  # Replace with your model's expected input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def load_model():
    """
    Loads the model with the vocabulary and checkpoint.
    """
    print("Loading dataset and vocabulary...")
    dataset = load_dataset()  # Load dataset to access vocabulary
    vocabulary = dataset.vocab  # Assuming 'vocab' is an attribute of the dataset

    print("Initializing the model...")
    model = get_model_instance(vocabulary)  # Initialize the model

    if can_load_checkpoint():
        print("Loading checkpoint...")
        load_checkpoint(model)
    else:
        print("No checkpoint found, starting with untrained model.")

    model.eval()  # Set the model to evaluation mode
    print("Model is ready for inference.")
    return model


def preprocess_image(image_path):
    """
    Preprocess the input image for the model.
    """
    print(f"Preprocessing image: {image_path}")
    image = Image.open(image_path).convert("RGB")  # Ensure RGB format
    image = TRANSFORMS(image).unsqueeze(0)  # Add batch dimension
    return image.to(DEVICE)


def generate_report(model, image_path):
    """
    Generates a report for a given image using the model.
    """
    image = preprocess_image(image_path)

    print("Generating report...")
    with torch.no_grad():
        # Assuming the model has a 'generate_caption' method
        output = model.generate_caption(image, max_length=25)
        report = " ".join(output)

    print(f"Generated report: {report}")
    return report


if __name__ == "__main__":
    # Path to the checkpoint file
    CHECKPOINT_PATH = config.CHECKPOINT_FILE  # Ensure config.CHECKPOINT_FILE is correctly set

    # Path to the input image
    IMAGE_PATH = "./dataset/images/CXR1178_IM-0121-1001.png"  # Replace with your image path

    # Load the model
    model = load_model()

    # Ensure the image exists before inference
    if os.path.exists(IMAGE_PATH):
        report = generate_report(model, IMAGE_PATH)
        print("Final Report:", report)
    else:
        print(f"Image not found at path: {IMAGE_PATH}")