æLtorio commited on
Commit
1d6cff4
1 Parent(s): ca9f0b9

add decriptions

Browse files
Files changed (1) hide show
  1. app.py +69 -20
app.py CHANGED
@@ -1,41 +1,90 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, Idefics3ForConditionalGeneration, image_utils
3
  import torch
 
 
4
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
5
- print(f"Using device: {device}")
6
- model_id="eltorio/IDEFICS3_ROCO"
7
- # model = AutoModelForImageTextToText.from_pretrained(model_id).to(device)
8
- base_model_path="HuggingFaceM4/Idefics3-8B-Llama3" #or change to local path
 
 
 
9
  processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
 
 
10
  model = Idefics3ForConditionalGeneration.from_pretrained(
11
- base_model_path, torch_dtype=torch.bfloat16
12
- ).to(device)
13
 
14
- model.load_adapter(model_id,device_map="auto")
 
15
 
 
16
  def infere(image):
 
 
 
 
 
 
 
 
 
 
 
17
  messages = [
18
  {
19
- "role": "system",
20
- "content": [
21
- {"type": "text", "text": "You are a valuable medical doctor and you are looking at an image of your patient."},
22
- ]
 
 
 
 
 
 
 
23
  },
24
- {
25
- "role": "user",
26
- "content": [
27
- {"type": "image"},
28
- {"type": "text", "text": "What do we see in this image?"},
29
- ]
30
- },
31
  ]
 
 
32
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
 
 
33
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
34
- # print(f"inputs: {inputs}")
 
35
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
36
  generated_ids = model.generate(**inputs, max_new_tokens=100)
 
 
37
  generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
 
38
  return generated_texts
39
 
40
- radiotest = gr.Interface(fn=infere, inputs="image", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  radiotest.launch(share=True)
 
1
+ # Copyright 2024 Ronan Le Meillat
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Import necessary libraries
7
  import gradio as gr
8
  from transformers import AutoProcessor, Idefics3ForConditionalGeneration, image_utils
9
  import torch
10
+
11
+ # Determine the device (GPU or CPU) to run the model on
12
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
13
+ print(f"Using device: {device}") # Log the device being used
14
+
15
+ # Define the model ID and base model path
16
+ model_id = "eltorio/IDEFICS3_ROCO"
17
+ base_model_path = "HuggingFaceM4/Idefics3-8B-Llama3" # or change to local path
18
+
19
+ # Initialize the processor from the base model path
20
  processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
21
+
22
+ # Initialize the model from the base model path and set the torch dtype to bfloat16
23
  model = Idefics3ForConditionalGeneration.from_pretrained(
24
+ base_model_path, torch_dtype=torch.bfloat16
25
+ ).to(device) # Move the model to the specified device
26
 
27
+ # Load the adapter from the model ID and automatically map it to the device
28
+ model.load_adapter(model_id, device_map="auto")
29
 
30
+ # Define a function to infer a description from an image
31
  def infere(image):
32
+ """
33
+ Generate a description of a medical image.
34
+
35
+ Args:
36
+ - image (PIL Image): The medical image to describe.
37
+
38
+ Returns:
39
+ - generated_texts (List[str]): A list containing the generated description.
40
+ """
41
+
42
+ # Define a chat template for the model to respond to
43
  messages = [
44
  {
45
+ "role": "system",
46
+ "content": [
47
+ {"type": "text", "text": "You are a valuable medical doctor and you are looking at an image of your patient."},
48
+ ]
49
+ },
50
+ {
51
+ "role": "user",
52
+ "content": [
53
+ {"type": "image"},
54
+ {"type": "text", "text": "What do we see in this image?"},
55
+ ]
56
  },
 
 
 
 
 
 
 
57
  ]
58
+
59
+ # Apply the chat template and add a generation prompt
60
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
61
+
62
+ # Preprocess the input image and text
63
  inputs = processor(text=prompt, images=[image], return_tensors="pt")
64
+
65
+ # Move the inputs to the specified device
66
  inputs = {k: v.to(device) for k, v in inputs.items()}
67
+
68
+ # Generate a description with the model
69
  generated_ids = model.generate(**inputs, max_new_tokens=100)
70
+
71
+ # Decode the generated IDs into text
72
  generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
73
+
74
  return generated_texts
75
 
76
+ # Define the title, description, and device description for the Gradio interface
77
+ title = f"<a href='https://huggingface.co/eltorio/IDEFICS3_ROCO'>IDEFICS3_ROCO</a>: Medical Image to Text <b>running on {device}</b>"
78
+ desc = "This model generates a description of a medical image."
79
+
80
+ device_desc = f"This model is running on {device} 🚀." if device == torch.device('cuda') else f"🐢 This model is running on {device} it will be very (very) slow. If you can donate some GPU time it will be usable 🐢. <a href='https://huggingface.co/eltorio/IDEFICS3_ROCO/discussions'>Please contact us.</a>"
81
+
82
+ # Define the long description for the Gradio interface
83
+ long_desc = f"This model is based on the <a href='https://huggingface.co/eltorio/IDEFICS3_ROCO'>IDEFICS3_ROCO model</a>, which is a multimodal model that can generate text from images. It has been fine-tuned on <a href='https://huggingface.co/datasets/eltorio/ROCO-radiology'>eltorio/ROCO-radiology</a>&nbsp;a dataset of medical images and can generate descriptions of medical images. Try uploading an image of a medical image and see what the model generates!<br><b>{device_desc}</b><br> 2024 - Ronan Le Meillat"
84
+
85
+ # Create a Gradio interface with the infere function and specified title and descriptions
86
+ radiotest = gr.Interface(fn=infere, inputs="image", outputs="text", title=title,
87
+ description=desc, article=long_desc)
88
+
89
+ # Launch the Gradio interface and share it
90
  radiotest.launch(share=True)