Erick Garcia Espinosa commited on
Commit
b73d7a5
1 Parent(s): 6b394c3

Add application file and dependencies

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -5,22 +5,22 @@ from torchvision import transforms
5
  from PIL import Image
6
  from timm import create_model
7
 
8
- # Definir el diccionario de mapeo de clases a 铆ndices
9
- class_to_idx = {'Monkeypox': 0, 'Melanoma': 1, 'Herpes': 2, 'Sarampion': 3, 'Varicela': 4}
10
 
11
- # Transformaci贸n de datos
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
14
  transforms.ToTensor(),
15
  ])
16
 
17
- # Funci贸n para cargar y preprocesar una imagen
18
  def load_image(image_path):
19
  image = Image.open(image_path).convert('RGB')
20
- image = transform(image).unsqueeze(0) # A帽adir dimensi贸n del batch
21
  return image
22
 
23
- # Cargar el modelo
24
  model_name = 'vit_base_patch16_224'
25
  pretrained = True
26
  num_classes = len(class_to_idx)
@@ -28,33 +28,29 @@ model = create_model(model_name, pretrained=pretrained, num_classes=num_classes)
28
  model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True))
29
  model.eval()
30
 
31
- # Definir la funci贸n de predicci贸n
32
- def predict_image(img):
33
- # Convertir la imagen a PIL.Image si es un numpy array
34
- if isinstance(img, np.ndarray):
35
- img = Image.fromarray(img)
36
 
37
- # Convertir la imagen a tensor y a帽adir dimensi贸n del batch
38
- img_tensor = transform(img).unsqueeze(0)
39
-
40
- # Realizar la predicci贸n
41
  with torch.no_grad():
42
  output = model(img_tensor)
43
  _, predicted = torch.max(output, 1)
44
 
45
-
46
  predicted_label = list(class_to_idx.keys())[predicted.item()]
47
 
48
  return predicted_label
49
 
50
- # Crear la interfaz de Gradio
51
  iface = gr.Interface(
52
  fn=predict_image,
53
- inputs=gr.Image(type="filepath", label="Sube una imagen"),
54
- outputs=gr.Label(label="Predicci贸n"),
55
- title="Clasificaci贸n de Im谩genes de Lesiones Cut谩neas",
56
- description="Carga una imagen de una lesi贸n cut谩nea para obtener una predicci贸n."
57
  )
58
 
59
- # Lanzar la interfaz de Gradio
60
  iface.launch()
 
5
  from PIL import Image
6
  from timm import create_model
7
 
8
+ # Define class to index mapping dictionary
9
+ class_to_idx = {'Monkeypox': 0, 'Melanoma': 1, 'Herpes': 2, 'Measles': 3, 'Chickenpox': 4}
10
 
11
+ # Data transformation
12
  transform = transforms.Compose([
13
  transforms.Resize((224, 224)),
14
  transforms.ToTensor(),
15
  ])
16
 
17
+ # Function to load and preprocess an image
18
  def load_image(image_path):
19
  image = Image.open(image_path).convert('RGB')
20
+ image = transform(image).unsqueeze(0) # Add batch dimension
21
  return image
22
 
23
+ # Load the model
24
  model_name = 'vit_base_patch16_224'
25
  pretrained = True
26
  num_classes = len(class_to_idx)
 
28
  model.load_state_dict(torch.load('ARTmodelo5ns_vit_weights_epoch6.pth', map_location='cpu', weights_only=True))
29
  model.eval()
30
 
31
+ # Define the prediction function
32
+ def predict_image(image_path):
33
+ # Load and transform the image from the file path
34
+ img_tensor = load_image(image_path)
 
35
 
36
+ # Perform the prediction
 
 
 
37
  with torch.no_grad():
38
  output = model(img_tensor)
39
  _, predicted = torch.max(output, 1)
40
 
41
+ # Get the predicted label
42
  predicted_label = list(class_to_idx.keys())[predicted.item()]
43
 
44
  return predicted_label
45
 
46
+ # Create the Gradio interface
47
  iface = gr.Interface(
48
  fn=predict_image,
49
+ inputs=gr.Image(type="filepath", label="Upload an image"),
50
+ outputs=gr.Label(label="Prediction"),
51
+ title="Skin Lesion Image Classification",
52
+ description="Upload an image of a skin lesion to get a prediction."
53
  )
54
 
55
+ # Launch the Gradio interface
56
  iface.launch()