nafisneehal commited on
Commit
f981774
·
verified ·
1 Parent(s): 1d783db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -3,12 +3,18 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
5
 
6
- # Initialize model
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
- model_name = "linjc16/Panacea-7B-Chat"
9
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
 
 
 
 
 
 
10
  model.to(device)
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
  # Default values for system and user input
14
  test_instruction_string = """
@@ -53,6 +59,12 @@ PRIOR CONCURRENT THERAPY:
53
  * No prior radiotherapy to \> 30% of the bone marrow or more than standard adjuvant pelvic radiotherapy for rectal cancer <Conditions:>Lung Cancer, Unspecified Adult Solid Tumor, Protocol Specific, <Interventions:>indocyanine green, lidocaine, vinorelbine ditartrate, high performance liquid chromatography, intracellular fluorescence polarization analysis, liquid chromatography, mass spectrometry, pharmacological study <StudyType:>INTERVENTIONAL <PrimaryOutcomes:>Area Under the Curve, Number of Participants With Grade 3 and 4 Toxicities <OverallStatus:>COMPLETED
54
  """
55
 
 
 
 
 
 
 
56
  @spaces.GPU
57
  def generate_response(system_instruction, user_input):
58
  # Format the prompt using the messages structure
@@ -67,15 +79,18 @@ def generate_response(system_instruction, user_input):
67
  with torch.no_grad():
68
  generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
69
  # Extract only the bot's response, omitting the prompt
70
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].split("\n")[-2:]
71
 
72
  return response
73
 
74
  # Gradio interface setup
75
  with gr.Blocks() as demo:
76
- gr.Markdown("# Clinical Trial Chatbot")
77
 
78
  with gr.Row():
 
 
 
79
  # Left sidebar for inputs
80
  with gr.Column():
81
  system_instruction = gr.Textbox(
@@ -96,7 +111,8 @@ with gr.Blocks() as demo:
96
  label="Bot Response", interactive=False, placeholder="Response will appear here."
97
  )
98
 
99
- # Link submit button to the generate_response function
 
100
  submit_btn.click(generate_response, [system_instruction, user_input], response_display)
101
 
102
  # Launch the app
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
5
 
6
+ # Initialize device
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ # Load model names from an external file
10
+ with open("models.txt", "r") as f:
11
+ model_list = [line.strip() for line in f.readlines()]
12
+
13
+ # Load default model
14
+ current_model_name = model_list[0]
15
+ model = AutoModelForCausalLM.from_pretrained(current_model_name, torch_dtype=torch.float16)
16
  model.to(device)
17
+ tokenizer = AutoTokenizer.from_pretrained(current_model_name)
18
 
19
  # Default values for system and user input
20
  test_instruction_string = """
 
59
  * No prior radiotherapy to \> 30% of the bone marrow or more than standard adjuvant pelvic radiotherapy for rectal cancer <Conditions:>Lung Cancer, Unspecified Adult Solid Tumor, Protocol Specific, <Interventions:>indocyanine green, lidocaine, vinorelbine ditartrate, high performance liquid chromatography, intracellular fluorescence polarization analysis, liquid chromatography, mass spectrometry, pharmacological study <StudyType:>INTERVENTIONAL <PrimaryOutcomes:>Area Under the Curve, Number of Participants With Grade 3 and 4 Toxicities <OverallStatus:>COMPLETED
60
  """
61
 
62
+ def load_model(model_name):
63
+ global model, tokenizer
64
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
65
+ model.to(device)
66
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
67
+
68
  @spaces.GPU
69
  def generate_response(system_instruction, user_input):
70
  # Format the prompt using the messages structure
 
79
  with torch.no_grad():
80
  generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
81
  # Extract only the bot's response, omitting the prompt
82
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].split("\n")[-1].strip()
83
 
84
  return response
85
 
86
  # Gradio interface setup
87
  with gr.Blocks() as demo:
88
+ gr.Markdown("# Clinical Trial Chatbot with Model Selection")
89
 
90
  with gr.Row():
91
+ # Dropdown for selecting model
92
+ model_dropdown = gr.Dropdown(choices=model_list, value=current_model_name, label="Select Model")
93
+
94
  # Left sidebar for inputs
95
  with gr.Column():
96
  system_instruction = gr.Textbox(
 
111
  label="Bot Response", interactive=False, placeholder="Response will appear here."
112
  )
113
 
114
+ # Link model selection and submit button to functions
115
+ model_dropdown.change(lambda m: load_model(m), inputs=model_dropdown, outputs=[])
116
  submit_btn.click(generate_response, [system_instruction, user_input], response_display)
117
 
118
  # Launch the app