sounar commited on
Commit
2bdc9ef
·
verified ·
1 Parent(s): 8327db6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -86
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
4
  import gradio as gr
5
  from PIL import Image
6
  from torchvision.transforms import ToTensor
@@ -16,102 +16,49 @@ bnb_config = BitsAndBytesConfig(
16
  bnb_4bit_compute_dtype=torch.float16
17
  )
18
 
19
- # Model name
20
- model_name = "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1"
21
-
22
- # Initialize tokenizer
23
- tokenizer = AutoTokenizer.from_pretrained(
24
- model_name,
25
- trust_remote_code=True,
26
- token=api_token
27
- )
28
-
29
- # Set up tokenizer with default tokens
30
- default_tokens = {
31
- "pad_token": "[PAD]",
32
- "eos_token": "</s>",
33
- "bos_token": "<s>",
34
- "unk_token": "<unk>",
35
- }
36
-
37
- for token_name, token_value in default_tokens.items():
38
- if getattr(tokenizer, token_name) is None:
39
- setattr(tokenizer, token_name, token_value)
40
- token_id_name = f"{token_name}_id"
41
- if getattr(tokenizer, token_id_name) is None:
42
- token_id = tokenizer.convert_tokens_to_ids(token_value)
43
- setattr(tokenizer, token_id_name, token_id)
44
-
45
- # Create generation config
46
- generation_config = GenerationConfig(
47
- pad_token_id=tokenizer.pad_token_id,
48
- eos_token_id=tokenizer.eos_token_id,
49
- bos_token_id=tokenizer.bos_token_id,
50
- max_new_tokens=256,
51
- )
52
-
53
- # Load the model
54
- model = AutoModelForCausalLM.from_pretrained(
55
- model_name,
56
  quantization_config=bnb_config,
57
  device_map="auto",
58
  torch_dtype=torch.float16,
59
  trust_remote_code=True,
60
- token=api_token,
61
- generation_config=generation_config
62
  )
63
 
64
- # Ensure model configs are set
65
- model.config.pad_token_id = tokenizer.pad_token_id
66
- model.config.eos_token_id = tokenizer.eos_token_id
67
- model.config.bos_token_id = tokenizer.bos_token_id
68
-
69
- # Preprocess image
70
- def preprocess_image(image):
71
- transform = ToTensor()
72
- return transform(image).unsqueeze(0).to(model.device)
73
 
74
- # Handle queries
75
  def analyze_input(image, question):
76
  try:
77
- # Debug print
78
- print(f"Tokenizer config:")
79
- print(f"EOS token: {tokenizer.eos_token} (id: {tokenizer.eos_token_id})")
80
- print(f"PAD token: {tokenizer.pad_token} (id: {tokenizer.pad_token_id})")
81
- print(f"BOS token: {tokenizer.bos_token} (id: {tokenizer.bos_token_id})")
82
-
83
- # Process the image if provided
84
- pixel_values = None
85
  if image is not None:
 
86
  image = image.convert('RGB')
87
- pixel_values = preprocess_image(image)
88
-
89
- # Tokenize the question
90
- inputs = tokenizer(
91
- question,
92
- return_tensors="pt",
93
- padding=True,
94
- truncation=True,
95
- max_length=512
96
- ).to(model.device)
97
 
98
- # Add image if provided
99
- if pixel_values is not None:
100
- inputs['pixel_values'] = pixel_values
101
 
102
- # Generate response
103
- outputs = model.generate(
104
- **inputs,
105
- generation_config=generation_config,
106
- max_new_tokens=256,
107
- do_sample=True,
108
- temperature=0.7,
109
- top_p=0.9,
110
  )
111
 
112
- # Decode response
113
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
114
- return {"status": "success", "response": response}
 
 
 
 
115
 
116
  except Exception as e:
117
  import traceback
@@ -123,12 +70,16 @@ def analyze_input(image, question):
123
  demo = gr.Interface(
124
  fn=analyze_input,
125
  inputs=[
126
- gr.Image(type="pil", label="Upload Medical Image (Optional)"),
127
- gr.Textbox(label="Medical Question")
 
 
 
 
128
  ],
129
  outputs=gr.JSON(label="Analysis"),
130
- title="ContactDoctor Medical Assistant",
131
- description="Upload a medical image and/or enter a question to receive detailed AI-powered responses."
132
  )
133
 
134
  # Launch the Gradio app
 
1
  import os
2
  import torch
3
+ from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
5
  from PIL import Image
6
  from torchvision.transforms import ToTensor
 
16
  bnb_4bit_compute_dtype=torch.float16
17
  )
18
 
19
+ # Initialize model and tokenizer
20
+ model = AutoModel.from_pretrained(
21
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  quantization_config=bnb_config,
23
  device_map="auto",
24
  torch_dtype=torch.float16,
25
  trust_remote_code=True,
26
+ attn_implementation="flash_attention_2",
27
+ token=api_token
28
  )
29
 
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
32
+ trust_remote_code=True,
33
+ token=api_token
34
+ )
 
 
 
 
35
 
 
36
  def analyze_input(image, question):
37
  try:
 
 
 
 
 
 
 
 
38
  if image is not None:
39
+ # Convert to RGB if image is provided
40
  image = image.convert('RGB')
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Prepare messages in the format expected by the model
43
+ msgs = [{'role': 'user', 'content': [image, question]}]
 
44
 
45
+ # Generate response using the chat method
46
+ response_stream = model.chat(
47
+ image=image,
48
+ msgs=msgs,
49
+ tokenizer=tokenizer,
50
+ sampling=True,
51
+ temperature=0.95,
52
+ stream=True
53
  )
54
 
55
+ # Collect the streamed response
56
+ generated_text = ""
57
+ for new_text in response_stream:
58
+ generated_text += new_text
59
+ print(new_text, flush=True, end='')
60
+
61
+ return {"status": "success", "response": generated_text}
62
 
63
  except Exception as e:
64
  import traceback
 
70
  demo = gr.Interface(
71
  fn=analyze_input,
72
  inputs=[
73
+ gr.Image(type="pil", label="Upload Medical Image"),
74
+ gr.Textbox(
75
+ label="Medical Question",
76
+ placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?",
77
+ value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?"
78
+ )
79
  ],
80
  outputs=gr.JSON(label="Analysis"),
81
+ title="Medical Image Analysis Assistant",
82
+ description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses."
83
  )
84
 
85
  # Launch the Gradio app