sounar commited on
Commit
9698346
·
verified ·
1 Parent(s): 8eaf273

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -53
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import os
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
5
- from PIL import Image
6
- import base64
7
- import io
8
 
9
  # Get API token from environment variable
10
  api_token = os.getenv("HF_TOKEN").strip()
@@ -17,84 +14,64 @@ bnb_config = BitsAndBytesConfig(
17
  bnb_4bit_compute_dtype=torch.float16
18
  )
19
 
20
- # Load model for causal language modeling
21
- model = AutoModelForCausalLM.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
24
  device_map="auto",
25
  torch_dtype=torch.float16,
26
  trust_remote_code=True,
27
- token=api_token,
28
- revision="main"
29
  )
30
 
31
  tokenizer = AutoTokenizer.from_pretrained(
32
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
33
  trust_remote_code=True,
34
- token=api_token,
35
- revision="main"
36
  )
37
 
38
- def analyze_input(image_data, question):
39
  try:
40
- if not question.strip():
41
- return {
42
- "status": "error",
43
- "message": "Question is required."
 
 
 
 
 
44
  }
45
-
46
- # Handle the input image (if any)
47
- if image_data is not None:
48
- return {
49
- "status": "error",
50
- "message": "Image support is not implemented yet."
51
  }
52
-
53
- # Prepare prompt for text-only input
54
- prompt = f"Medical question: {question}\nAnswer: "
55
 
56
- # Tokenize input
57
- inputs = tokenizer(prompt, return_tensors="pt")
58
- input_ids = inputs.input_ids.to(model.device)
59
 
60
- # Generate response
61
- outputs = model.generate(
62
- input_ids=input_ids,
63
- max_length=256, # Limit the length of the generated text
64
- eos_token_id=tokenizer.eos_token_id, # Ensure generation stops correctly
65
- pad_token_id=tokenizer.pad_token_id,
66
- temperature=0.7, # Control randomness
67
- top_p=0.9, # Nucleus sampling
68
- top_k=50 # Top-k sampling
69
- )
70
-
71
- # Decode response
72
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
73
 
74
- return {
75
- "status": "success",
76
- "response": response
77
- }
78
  except Exception as e:
79
- return {
80
- "status": "error",
81
- "message": str(e)
82
- }
83
 
84
  # Create Gradio interface
85
  demo = gr.Interface(
86
  fn=analyze_input,
87
  inputs=[
88
- gr.Image(type="numpy", label="Medical Image (Optional)"),
89
- gr.Textbox(label="Question", placeholder="Enter your medical query...")
90
  ],
91
  outputs=gr.JSON(label="Analysis"),
92
- title="Medical Query Analysis",
93
- description="Ask medical questions. For now, please focus on text-based queries.",
94
- flagging_mode="never"
95
  )
96
 
97
- # Launch the interface
98
  demo.launch(
99
  share=True,
100
  server_name="0.0.0.0",
 
1
  import os
2
  import torch
3
+ from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  import gradio as gr
 
 
 
5
 
6
  # Get API token from environment variable
7
  api_token = os.getenv("HF_TOKEN").strip()
 
14
  bnb_4bit_compute_dtype=torch.float16
15
  )
16
 
17
+ # Load the model and tokenizer
18
+ model = AutoModel.from_pretrained(
19
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
20
  quantization_config=bnb_config,
21
  device_map="auto",
22
  torch_dtype=torch.float16,
23
  trust_remote_code=True,
24
+ token=api_token
 
25
  )
26
 
27
  tokenizer = AutoTokenizer.from_pretrained(
28
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
29
  trust_remote_code=True,
30
+ token=api_token
 
31
  )
32
 
33
+ def analyze_input(image, question):
34
  try:
35
+ # Prepare inputs
36
+ if image:
37
+ prompt = f"Given the medical image and question: {question}\nPlease provide a detailed analysis."
38
+ # Convert image to RGB
39
+ image = image.convert('RGB')
40
+ # Custom model_inputs for multimodal generation
41
+ model_inputs = {
42
+ "input_ids": tokenizer(prompt, return_tensors="pt").input_ids.to(model.device),
43
+ "images": [image]
44
  }
45
+ else:
46
+ prompt = f"Medical question: {question}\nAnswer:"
47
+ model_inputs = {
48
+ "input_ids": tokenizer(prompt, return_tensors="pt").input_ids.to(model.device),
49
+ "images": None
 
50
  }
 
 
 
51
 
52
+ # Generate response using model's custom method
53
+ outputs = model.generate(model_inputs=model_inputs, max_new_tokens=256)
 
54
 
55
+ # Decode and clean response
 
 
 
 
 
 
 
 
 
 
 
56
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
+ return {"status": "success", "response": response}
58
 
 
 
 
 
59
  except Exception as e:
60
+ return {"status": "error", "message": str(e)}
 
 
 
61
 
62
  # Create Gradio interface
63
  demo = gr.Interface(
64
  fn=analyze_input,
65
  inputs=[
66
+ gr.Image(type="pil", label="Upload Medical Image (Optional)"),
67
+ gr.Textbox(label="Medical Question")
68
  ],
69
  outputs=gr.JSON(label="Analysis"),
70
+ title="ContactDoctor Medical Assistant",
71
+ description="Upload a medical image and/or enter a question to receive detailed AI-powered responses."
 
72
  )
73
 
74
+ # Launch the Gradio app
75
  demo.launch(
76
  share=True,
77
  server_name="0.0.0.0",