sounar commited on
Commit
8eaf273
·
verified ·
1 Parent(s): a986796

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -31
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
- from PIL import Image
5
  import gradio as gr
 
6
  import base64
7
  import io
8
 
@@ -17,8 +17,8 @@ bnb_config = BitsAndBytesConfig(
17
  bnb_4bit_compute_dtype=torch.float16
18
  )
19
 
20
- # Load model with revision pinning
21
- model = AutoModel.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
24
  device_map="auto",
@@ -37,38 +37,40 @@ tokenizer = AutoTokenizer.from_pretrained(
37
 
38
  def analyze_input(image_data, question):
39
  try:
40
- # Prepare the prompt
 
 
 
 
 
 
41
  if image_data is not None:
42
- prompt = f"Given the medical image and the question: {question}\nPlease provide a detailed analysis."
43
- else:
44
- prompt = f"Medical question: {question}\nAnswer: "
 
 
 
 
45
 
46
  # Tokenize input
47
- tokenized = tokenizer(prompt, return_tensors="pt")
48
- input_ids = tokenized.input_ids.to(model.device)
49
-
50
- # Calculate target size (for generation length)
51
- tgt_size = input_ids.size(1) + 256 # original length + max new tokens
52
-
53
- # Prepare model inputs
54
- model_inputs = {
55
- "input_ids": input_ids,
56
- "pixel_values": None, # Set to None for text-only queries
57
- "tgt_sizes": [tgt_size] # Add target size for generation
58
- }
59
-
60
  # Generate response
61
  outputs = model.generate(
62
- model_inputs=model_inputs,
 
 
 
 
 
 
63
  )
64
-
65
- # Decode and clean up response
66
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
67
-
68
- # Remove the prompt from the response
69
- if prompt in response:
70
- response = response[len(prompt):].strip()
71
-
72
  return {
73
  "status": "success",
74
  "response": response
@@ -88,7 +90,7 @@ demo = gr.Interface(
88
  ],
89
  outputs=gr.JSON(label="Analysis"),
90
  title="Medical Query Analysis",
91
- description="Ask medical questions. For now, please focus on text-based queries without images.",
92
  flagging_mode="never"
93
  )
94
 
@@ -97,4 +99,4 @@ demo.launch(
97
  share=True,
98
  server_name="0.0.0.0",
99
  server_port=7860
100
- )
 
1
  import os
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
4
  import gradio as gr
5
+ from PIL import Image
6
  import base64
7
  import io
8
 
 
17
  bnb_4bit_compute_dtype=torch.float16
18
  )
19
 
20
+ # Load model for causal language modeling
21
+ model = AutoModelForCausalLM.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
24
  device_map="auto",
 
37
 
38
  def analyze_input(image_data, question):
39
  try:
40
+ if not question.strip():
41
+ return {
42
+ "status": "error",
43
+ "message": "Question is required."
44
+ }
45
+
46
+ # Handle the input image (if any)
47
  if image_data is not None:
48
+ return {
49
+ "status": "error",
50
+ "message": "Image support is not implemented yet."
51
+ }
52
+
53
+ # Prepare prompt for text-only input
54
+ prompt = f"Medical question: {question}\nAnswer: "
55
 
56
  # Tokenize input
57
+ inputs = tokenizer(prompt, return_tensors="pt")
58
+ input_ids = inputs.input_ids.to(model.device)
59
+
 
 
 
 
 
 
 
 
 
 
60
  # Generate response
61
  outputs = model.generate(
62
+ input_ids=input_ids,
63
+ max_length=256, # Limit the length of the generated text
64
+ eos_token_id=tokenizer.eos_token_id, # Ensure generation stops correctly
65
+ pad_token_id=tokenizer.pad_token_id,
66
+ temperature=0.7, # Control randomness
67
+ top_p=0.9, # Nucleus sampling
68
+ top_k=50 # Top-k sampling
69
  )
70
+
71
+ # Decode response
72
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+
 
 
 
 
74
  return {
75
  "status": "success",
76
  "response": response
 
90
  ],
91
  outputs=gr.JSON(label="Analysis"),
92
  title="Medical Query Analysis",
93
+ description="Ask medical questions. For now, please focus on text-based queries.",
94
  flagging_mode="never"
95
  )
96
 
 
99
  share=True,
100
  server_name="0.0.0.0",
101
  server_port=7860
102
+ )