sounar commited on
Commit
e8eeeb2
·
verified ·
1 Parent(s): 9da0a3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -37
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  from PIL import Image
5
  import gradio as gr
6
  import base64
@@ -17,8 +17,8 @@ bnb_config = BitsAndBytesConfig(
17
  bnb_4bit_compute_dtype=torch.float16
18
  )
19
 
20
- # Load model with revision pinning
21
- model = AutoModel.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
24
  device_map="auto",
@@ -37,45 +37,32 @@ tokenizer = AutoTokenizer.from_pretrained(
37
 
38
  def analyze_input(image_data, question):
39
  try:
40
- # Handle base64 image if provided
41
- if isinstance(image_data, str) and image_data.startswith('data:image'):
42
- base64_data = image_data.split(',')[1]
43
- image_bytes = base64.b64decode(base64_data)
44
- image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
45
- # Handle direct image input
46
- elif image_data is not None:
47
- image = Image.fromarray(image_data).convert('RGB')
48
  else:
49
- image = None
50
 
51
- # Process with or without image
52
- if image is not None:
53
- # Prepare inputs for multimodal generation
54
- model_inputs = {
55
- "input_ids": tokenizer(question, return_tensors="pt").input_ids.to(model.device),
56
- "images": [image]
57
- }
58
- else:
59
- # Prepare inputs for text-only generation
60
- model_inputs = {
61
- "input_ids": tokenizer(question, return_tensors="pt").input_ids.to(model.device)
62
- }
63
-
64
- # Generate response with proper inputs
65
- generation_config = {
66
- "max_new_tokens": 256,
67
- "do_sample": True,
68
- "temperature": 0.7,
69
- "top_p": 0.9,
70
- }
71
 
 
72
  outputs = model.generate(
73
- model_inputs=model_inputs,
74
- **generation_config
 
 
 
 
75
  )
76
 
 
77
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
78
 
 
 
 
 
79
  return {
80
  "status": "success",
81
  "response": response
@@ -90,12 +77,12 @@ def analyze_input(image_data, question):
90
  demo = gr.Interface(
91
  fn=analyze_input,
92
  inputs=[
93
- gr.Image(type="numpy", label="Medical Image"),
94
  gr.Textbox(label="Question", placeholder="Enter your medical query...")
95
  ],
96
  outputs=gr.JSON(label="Analysis"),
97
- title="Bio-Medical MultiModal Analysis",
98
- description="Ask questions with or without an image",
99
  flagging_mode="never"
100
  )
101
 
 
1
  import os
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from PIL import Image
5
  import gradio as gr
6
  import base64
 
17
  bnb_4bit_compute_dtype=torch.float16
18
  )
19
 
20
+ # Load model with revision pinning - using CausalLM for text generation
21
+ model = AutoModelForCausalLM.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
24
  device_map="auto",
 
37
 
38
  def analyze_input(image_data, question):
39
  try:
40
+ # Prepare the prompt
41
+ if image_data is not None:
42
+ prompt = f"Given the medical image and the question: {question}\nPlease provide a detailed analysis."
 
 
 
 
 
43
  else:
44
+ prompt = f"Medical question: {question}\nAnswer: "
45
 
46
+ # Tokenize input
47
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Generate response
50
  outputs = model.generate(
51
+ **inputs,
52
+ max_new_tokens=256,
53
+ do_sample=True,
54
+ temperature=0.7,
55
+ top_p=0.9,
56
+ pad_token_id=tokenizer.eos_token_id
57
  )
58
 
59
+ # Decode and clean up response
60
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
 
62
+ # Remove the prompt from the response
63
+ if prompt in response:
64
+ response = response[len(prompt):].strip()
65
+
66
  return {
67
  "status": "success",
68
  "response": response
 
77
  demo = gr.Interface(
78
  fn=analyze_input,
79
  inputs=[
80
+ gr.Image(type="numpy", label="Medical Image (Optional)"),
81
  gr.Textbox(label="Question", placeholder="Enter your medical query...")
82
  ],
83
  outputs=gr.JSON(label="Analysis"),
84
+ title="Medical Query Analysis",
85
+ description="Ask medical questions with or without images. For general medical queries, no image is needed.",
86
  flagging_mode="never"
87
  )
88