sounar commited on
Commit
2974476
·
verified ·
1 Parent(s): 4f5fa66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -28
app.py CHANGED
@@ -1,10 +1,12 @@
1
- import gradio as gr
2
  import torch
3
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  from PIL import Image
5
- import os
 
 
6
 
7
- # Get API token from environment variables
8
  api_token = os.getenv("HF_TOKEN").strip()
9
 
10
  # Quantization configuration
@@ -15,7 +17,7 @@ bnb_config = BitsAndBytesConfig(
15
  bnb_4bit_compute_dtype=torch.float16
16
  )
17
 
18
- # Load the model and tokenizer
19
  model = AutoModel.from_pretrained(
20
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
21
  quantization_config=bnb_config,
@@ -31,42 +33,54 @@ tokenizer = AutoTokenizer.from_pretrained(
31
  token=api_token
32
  )
33
 
34
- # Function to handle inputs
35
- def process_query(image, question):
36
  try:
37
- if image:
38
- # Process image and text
39
- image = image.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
40
  inputs = model.prepare_inputs_for_generation(
41
  input_ids=tokenizer(question, return_tensors="pt").input_ids,
42
  images=[image]
43
  )
44
- outputs = model.generate(**inputs, max_new_tokens=256)
45
  else:
46
- # Process text-only
47
  inputs = tokenizer(question, return_tensors="pt")
48
- outputs = model.generate(**inputs, max_new_tokens=256)
49
 
50
- # Decode response
51
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
- return response
53
-
 
 
 
54
  except Exception as e:
55
- return f"Error: {str(e)}"
 
 
 
56
 
57
- # Define Gradio interface
58
- interface = gr.Interface(
59
- fn=process_query,
60
  inputs=[
61
- gr.Image(type="pil", label="Upload an Image (Optional)"),
62
- gr.Textbox(label="Enter a Question")
63
  ],
64
- outputs="text",
65
- title="ContactDoctor Multimodal Medical Assistant",
66
- description="Provide an image and/or question to get AI-powered medical advice.",
67
- enable_api=True # Enable API for external calls
68
  )
69
 
70
- # Launch the app
71
- if __name__ == "__main__":
72
- interface.launch()
 
1
+ import os
2
  import torch
3
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  from PIL import Image
5
+ import gradio as gr
6
+ import base64
7
+ import io
8
 
9
+ # Get API token from environment variable
10
  api_token = os.getenv("HF_TOKEN").strip()
11
 
12
  # Quantization configuration
 
17
  bnb_4bit_compute_dtype=torch.float16
18
  )
19
 
20
+ # Load model
21
  model = AutoModel.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
 
33
  token=api_token
34
  )
35
 
36
+ def analyze_input(image_data=None, question=""):
 
37
  try:
38
+ # Handle base64 image if provided
39
+ if isinstance(image_data, str) and image_data.startswith('data:image'):
40
+ # Extract base64 data after the comma
41
+ base64_data = image_data.split(',')[1]
42
+ image_bytes = base64.b64decode(base64_data)
43
+ image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
44
+ # Handle direct image input
45
+ elif image_data is not None:
46
+ image = Image.fromarray(image_data).convert('RGB')
47
+ else:
48
+ image = None
49
+
50
+ # Process with or without image
51
+ if image is not None:
52
  inputs = model.prepare_inputs_for_generation(
53
  input_ids=tokenizer(question, return_tensors="pt").input_ids,
54
  images=[image]
55
  )
 
56
  else:
 
57
  inputs = tokenizer(question, return_tensors="pt")
 
58
 
59
+ outputs = model.generate(**inputs, max_new_tokens=256)
60
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+
62
+ return {
63
+ "status": "success",
64
+ "response": response
65
+ }
66
  except Exception as e:
67
+ return {
68
+ "status": "error",
69
+ "message": str(e)
70
+ }
71
 
72
+ # Create Gradio interface
73
+ demo = gr.Interface(
74
+ fn=analyze_input,
75
  inputs=[
76
+ gr.Image(type="numpy", label="Medical Image (Optional)", optional=True),
77
+ gr.Textbox(label="Question", placeholder="Enter your medical query...")
78
  ],
79
+ outputs=gr.JSON(label="Analysis"),
80
+ title="Bio-Medical MultiModal Analysis",
81
+ description="Ask questions with or without an image",
82
+ allow_flagging="never",
83
  )
84
 
85
+ # Launch with API access enabled
86
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860, enable_queue=True)