sounar commited on
Commit
8e90fc6
·
verified ·
1 Parent(s): acfc179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -24
app.py CHANGED
@@ -1,33 +1,91 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
 
 
 
 
3
 
4
- # Load the model
5
- model_name = "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- def generate_response(input_text):
10
- # Tokenize input text
11
- inputs = tokenizer(input_text, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
12
- # Generate response
13
- outputs = model.generate(inputs["input_ids"], max_length=150, temperature=0.7)
14
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
- return response
16
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- from flask import Flask, request, jsonify
19
- from predict import generate_response # import from the predict file
 
 
 
20
 
21
- app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- @app.route("/predict", methods=["POST"])
24
- def predict():
25
- data = request.get_json()
26
- input_text = data.get("text")
27
- if not input_text:
28
- return jsonify({"error": "No input text provided"}), 400
29
- response = generate_response(input_text)
30
- return jsonify({"response": response})
 
 
 
 
 
 
 
31
 
 
32
  if __name__ == "__main__":
33
- app.run(port=5000)
 
 
 
 
 
1
+ import os
2
  import torch
3
+ from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from torchvision.transforms import ToTensor
7
 
8
+ # Get API token from environment variable
9
+ api_token = os.getenv("HF_TOKEN").strip()
 
 
10
 
11
+ # Quantization configuration
12
+ bnb_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_use_double_quant=True,
16
+ bnb_4bit_compute_dtype=torch.float16
17
+ )
18
 
19
+ # Initialize model and tokenizer
20
+ model = AutoModel.from_pretrained(
21
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
22
+ quantization_config=bnb_config,
23
+ device_map="auto",
24
+ torch_dtype=torch.float16,
25
+ trust_remote_code=True,
26
+ attn_implementation="flash_attention_2",
27
+ token=api_token
28
+ )
29
 
30
+ tokenizer = AutoTokenizer.from_pretrained(
31
+ "ContactDoctor/Bio-Medical-MultiModal-Llama-3-8B-V1",
32
+ trust_remote_code=True,
33
+ token=api_token
34
+ )
35
 
36
+ def analyze_input(image, question):
37
+ try:
38
+ if image is not None:
39
+ # Convert to RGB if image is provided
40
+ image = image.convert('RGB')
41
+
42
+ # Prepare messages in the format expected by the model
43
+ msgs = [{'role': 'user', 'content': [image, question]}]
44
+
45
+ # Generate response using the chat method
46
+ response_stream = model.chat(
47
+ image=image,
48
+ msgs=msgs,
49
+ tokenizer=tokenizer,
50
+ sampling=True,
51
+ temperature=0.95,
52
+ stream=True
53
+ )
54
+
55
+ # Collect the streamed response
56
+ generated_text = ""
57
+ for new_text in response_stream:
58
+ generated_text += new_text
59
+ print(new_text, flush=True, end='')
60
+
61
+ return {"status": "success", "response": generated_text}
62
+
63
+ except Exception as e:
64
+ import traceback
65
+ error_trace = traceback.format_exc()
66
+ print(f"Error occurred: {error_trace}")
67
+ return {"status": "error", "message": str(e)}
68
 
69
+ # Create Gradio interface
70
+ demo = gr.Interface(
71
+ fn=analyze_input,
72
+ inputs=[
73
+ gr.Image(type="pil", label="Upload Medical Image"),
74
+ gr.Textbox(
75
+ label="Medical Question",
76
+ placeholder="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?",
77
+ value="Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present)?"
78
+ )
79
+ ],
80
+ outputs=gr.JSON(label="Analysis"),
81
+ title="Medical Image Analysis Assistant",
82
+ description="Upload a medical image and ask questions about it. The AI will analyze the image and provide detailed responses."
83
+ )
84
 
85
+ # Launch the Gradio app
86
  if __name__ == "__main__":
87
+ demo.launch(
88
+ share=True,
89
+ server_name="0.0.0.0",
90
+ server_port=7860
91
+ )