sounar commited on
Commit
a0ba541
·
verified ·
1 Parent(s): 380bfc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -3,58 +3,78 @@
3
  #api_token = os.getenv("HF_TOKEN").strip()
4
 
5
  import torch
6
- from PIL import Image
7
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
8
- import gradio as gr
 
 
9
 
 
10
 
11
- # Configuration for 4-bit quantization
12
  bnb_config = BitsAndBytesConfig(
13
- load_in_4bit=True,
14
- bnb_4bit_quant_type="nf4",
15
- bnb_4bit_use_double_quant=True,
16
- bnb_4bit_compute_dtype=torch.float16,
17
  )
18
 
19
- # Load the model without flash-attn
20
  model = AutoModel.from_pretrained(
21
- "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
22
- quantization_config=bnb_config,
23
- device_map="auto",
24
- torch_dtype=torch.float16,
25
- trust_remote_code=True,
26
- attn_implementation=None, # Disable flash-attn
27
  )
28
 
29
  tokenizer = AutoTokenizer.from_pretrained(
30
- "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
31
  trust_remote_code=True
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Define the function to handle the input
36
- def process_input(image, question):
37
- image = Image.open(image).convert("RGB")
38
- msgs = [{'role': 'user', 'content': [image, question]}]
39
- res = model.chat(image=image, msgs=msgs, tokenizer=tokenizer, sampling=True, temperature=0.95, stream=True)
40
-
41
- generated_text = ""
42
- for new_text in res:
43
- generated_text += new_text
44
- return generated_text
45
-
46
- # Gradio interface
47
- iface = gr.Interface(
48
- fn=process_input,
49
- inputs=[
50
- gr.Image(type="file", label="Upload Image"),
51
- gr.Textbox(lines=2, label="Question")
52
- ],
53
- outputs=gr.Textbox(label="Generated Response"),
54
- title="BioMedical MultiModal Llama",
55
- description="Upload an image and ask a medical question."
56
- )
57
 
58
- if __name__ == "__main__":
59
- iface.launch()
60
 
 
3
  #api_token = os.getenv("HF_TOKEN").strip()
4
 
5
  import torch
6
+ from flask import Flask, request, jsonify
7
  from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
8
+ from PIL import Image
9
+ import io
10
+ import base64
11
 
12
+ app = Flask(__name__)
13
 
14
+ # Quantization configuration
15
  bnb_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_quant_type="nf4",
18
+ bnb_4bit_use_double_quant=True,
19
+ bnb_4bit_compute_dtype=torch.float16
20
  )
21
 
22
+ # Load model
23
  model = AutoModel.from_pretrained(
24
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
25
+ quantization_config=bnb_config,
26
+ device_map="auto",
27
+ torch_dtype=torch.float16,
28
+ trust_remote_code=True,
29
+ attn_implementation="flash_attention_2"
30
  )
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(
33
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
34
  trust_remote_code=True
35
  )
36
 
37
+ def decode_base64_image(base64_string):
38
+ # Decode base64 image
39
+ image_data = base64.b64decode(base64_string)
40
+ image = Image.open(io.BytesIO(image_data)).convert('RGB')
41
+ return image
42
+
43
+ @app.route('/analyze', methods=['POST'])
44
+ def analyze_input():
45
+ data = request.json
46
+ question = data.get('question', '')
47
+ base64_image = data.get('image', None)
48
+
49
+ try:
50
+ # Process with image if provided
51
+ if base64_image:
52
+ image = decode_base64_image(base64_image)
53
+ inputs = model.prepare_inputs_for_generation(
54
+ input_ids=tokenizer(question, return_tensors="pt").input_ids,
55
+ images=[image]
56
+ )
57
+ outputs = model.generate(**inputs, max_new_tokens=256)
58
+ else:
59
+ # Text-only processing
60
+ inputs = tokenizer(question, return_tensors="pt")
61
+ outputs = model.generate(**inputs, max_new_tokens=256)
62
+
63
+ # Decode response
64
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+
66
+ return jsonify({
67
+ 'status': 'success',
68
+ 'response': response
69
+ })
70
+
71
+ except Exception as e:
72
+ return jsonify({
73
+ 'status': 'error',
74
+ 'message': str(e)
75
+ }), 500
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ if __name__ == '__main__':
79
+ app.run(debug=True)
80