sounar commited on
Commit
a986796
·
verified ·
1 Parent(s): c600b9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from PIL import Image
5
  import gradio as gr
6
  import base64
@@ -18,7 +18,7 @@ bnb_config = BitsAndBytesConfig(
18
  )
19
 
20
  # Load model with revision pinning
21
- model = AutoModelForCausalLM.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
24
  device_map="auto",
@@ -44,25 +44,22 @@ def analyze_input(image_data, question):
44
  prompt = f"Medical question: {question}\nAnswer: "
45
 
46
  # Tokenize input
47
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
 
 
 
 
48
 
49
  # Prepare model inputs
50
  model_inputs = {
51
  "input_ids": input_ids,
52
- "pixel_values": None # Set to None for text-only queries
 
53
  }
54
 
55
  # Generate response
56
- generation_config = {
57
- "max_new_tokens": 256,
58
- "do_sample": True,
59
- "temperature": 0.7,
60
- "top_p": 0.9,
61
- }
62
-
63
  outputs = model.generate(
64
  model_inputs=model_inputs,
65
- **generation_config
66
  )
67
 
68
  # Decode and clean up response
 
1
  import os
2
  import torch
3
+ from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
  from PIL import Image
5
  import gradio as gr
6
  import base64
 
18
  )
19
 
20
  # Load model with revision pinning
21
+ model = AutoModel.from_pretrained(
22
  "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
23
  quantization_config=bnb_config,
24
  device_map="auto",
 
44
  prompt = f"Medical question: {question}\nAnswer: "
45
 
46
  # Tokenize input
47
+ tokenized = tokenizer(prompt, return_tensors="pt")
48
+ input_ids = tokenized.input_ids.to(model.device)
49
+
50
+ # Calculate target size (for generation length)
51
+ tgt_size = input_ids.size(1) + 256 # original length + max new tokens
52
 
53
  # Prepare model inputs
54
  model_inputs = {
55
  "input_ids": input_ids,
56
+ "pixel_values": None, # Set to None for text-only queries
57
+ "tgt_sizes": [tgt_size] # Add target size for generation
58
  }
59
 
60
  # Generate response
 
 
 
 
 
 
 
61
  outputs = model.generate(
62
  model_inputs=model_inputs,
 
63
  )
64
 
65
  # Decode and clean up response