Update app.py
Browse files
app.py
CHANGED
@@ -3,12 +3,18 @@ import torch
|
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
import spaces
|
5 |
|
6 |
-
# Initialize
|
7 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
-
|
9 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
model.to(device)
|
11 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
12 |
|
13 |
# Default values for system and user input
|
14 |
test_instruction_string = """
|
@@ -53,6 +59,12 @@ PRIOR CONCURRENT THERAPY:
|
|
53 |
* No prior radiotherapy to \> 30% of the bone marrow or more than standard adjuvant pelvic radiotherapy for rectal cancer <Conditions:>Lung Cancer, Unspecified Adult Solid Tumor, Protocol Specific, <Interventions:>indocyanine green, lidocaine, vinorelbine ditartrate, high performance liquid chromatography, intracellular fluorescence polarization analysis, liquid chromatography, mass spectrometry, pharmacological study <StudyType:>INTERVENTIONAL <PrimaryOutcomes:>Area Under the Curve, Number of Participants With Grade 3 and 4 Toxicities <OverallStatus:>COMPLETED
|
54 |
"""
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
@spaces.GPU
|
57 |
def generate_response(system_instruction, user_input):
|
58 |
# Format the prompt using the messages structure
|
@@ -67,15 +79,18 @@ def generate_response(system_instruction, user_input):
|
|
67 |
with torch.no_grad():
|
68 |
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
69 |
# Extract only the bot's response, omitting the prompt
|
70 |
-
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].split("\n")[-
|
71 |
|
72 |
return response
|
73 |
|
74 |
# Gradio interface setup
|
75 |
with gr.Blocks() as demo:
|
76 |
-
gr.Markdown("# Clinical Trial Chatbot")
|
77 |
|
78 |
with gr.Row():
|
|
|
|
|
|
|
79 |
# Left sidebar for inputs
|
80 |
with gr.Column():
|
81 |
system_instruction = gr.Textbox(
|
@@ -96,7 +111,8 @@ with gr.Blocks() as demo:
|
|
96 |
label="Bot Response", interactive=False, placeholder="Response will appear here."
|
97 |
)
|
98 |
|
99 |
-
# Link submit button to
|
|
|
100 |
submit_btn.click(generate_response, [system_instruction, user_input], response_display)
|
101 |
|
102 |
# Launch the app
|
|
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
import spaces
|
5 |
|
6 |
+
# Initialize device
|
7 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
8 |
+
|
9 |
+
# Load model names from an external file
|
10 |
+
with open("models.txt", "r") as f:
|
11 |
+
model_list = [line.strip() for line in f.readlines()]
|
12 |
+
|
13 |
+
# Load default model
|
14 |
+
current_model_name = model_list[0]
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(current_model_name, torch_dtype=torch.float16)
|
16 |
model.to(device)
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(current_model_name)
|
18 |
|
19 |
# Default values for system and user input
|
20 |
test_instruction_string = """
|
|
|
59 |
* No prior radiotherapy to \> 30% of the bone marrow or more than standard adjuvant pelvic radiotherapy for rectal cancer <Conditions:>Lung Cancer, Unspecified Adult Solid Tumor, Protocol Specific, <Interventions:>indocyanine green, lidocaine, vinorelbine ditartrate, high performance liquid chromatography, intracellular fluorescence polarization analysis, liquid chromatography, mass spectrometry, pharmacological study <StudyType:>INTERVENTIONAL <PrimaryOutcomes:>Area Under the Curve, Number of Participants With Grade 3 and 4 Toxicities <OverallStatus:>COMPLETED
|
60 |
"""
|
61 |
|
62 |
+
def load_model(model_name):
|
63 |
+
global model, tokenizer
|
64 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
|
65 |
+
model.to(device)
|
66 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
67 |
+
|
68 |
@spaces.GPU
|
69 |
def generate_response(system_instruction, user_input):
|
70 |
# Format the prompt using the messages structure
|
|
|
79 |
with torch.no_grad():
|
80 |
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
81 |
# Extract only the bot's response, omitting the prompt
|
82 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].split("\n")[-1].strip()
|
83 |
|
84 |
return response
|
85 |
|
86 |
# Gradio interface setup
|
87 |
with gr.Blocks() as demo:
|
88 |
+
gr.Markdown("# Clinical Trial Chatbot with Model Selection")
|
89 |
|
90 |
with gr.Row():
|
91 |
+
# Dropdown for selecting model
|
92 |
+
model_dropdown = gr.Dropdown(choices=model_list, value=current_model_name, label="Select Model")
|
93 |
+
|
94 |
# Left sidebar for inputs
|
95 |
with gr.Column():
|
96 |
system_instruction = gr.Textbox(
|
|
|
111 |
label="Bot Response", interactive=False, placeholder="Response will appear here."
|
112 |
)
|
113 |
|
114 |
+
# Link model selection and submit button to functions
|
115 |
+
model_dropdown.change(lambda m: load_model(m), inputs=model_dropdown, outputs=[])
|
116 |
submit_btn.click(generate_response, [system_instruction, user_input], response_display)
|
117 |
|
118 |
# Launch the app
|