TakiTakiTa commited on
Commit
83c9d49
·
verified ·
1 Parent(s): 973ac63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -18
app.py CHANGED
@@ -3,17 +3,16 @@ import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
- # Global variables to store the loaded model and tokenizer.
7
- model = None
8
- tokenizer = None
9
 
10
  @spaces.GPU
11
  def load_model(model_name: str):
12
  """
13
- Loads the model and tokenizer given the model name.
14
- Returns a status message.
15
  """
16
- global model, tokenizer
17
  try:
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_name,
@@ -21,20 +20,24 @@ def load_model(model_name: str):
21
  device_map="auto"
22
  )
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- return f"Model '{model_name}' loaded successfully."
 
25
  except Exception as e:
26
- return f"Failed to load model '{model_name}': {str(e)}"
27
 
28
  @spaces.GPU
29
- def generate(prompt, history):
30
  """
31
  Generates a response for the given prompt using the loaded model.
32
- If the model is not loaded, informs the user to load it first.
33
  """
34
- if model is None or tokenizer is None:
 
35
  return "Please load a model first by entering a model name and clicking the Load Model button."
36
-
37
- # Prepare the chat history (here, a simple system prompt is added)
 
 
38
  messages = [
39
  {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
40
  {"role": "user", "content": prompt}
@@ -44,6 +47,7 @@ def generate(prompt, history):
44
  tokenize=False,
45
  add_generation_prompt=True
46
  )
 
47
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
48
 
49
  generated_ids = model.generate(
@@ -70,13 +74,15 @@ with gr.Blocks() as demo:
70
  )
71
  load_button = gr.Button("Load Model")
72
  load_status = gr.Textbox(label="Status", interactive=False)
 
 
73
 
74
- # When the button is clicked, load_model() is called.
75
- load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status)
76
 
77
  gr.Markdown("## Chat Interface")
78
- # The ChatInterface calls generate() which uses the loaded model.
79
- chat_interface = gr.ChatInterface(fn=generate)
80
 
81
- # Launch the Gradio app (using share=True if you wish to share it publicly).
82
  demo.launch(share=True)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
+ # Global dictionary to store loaded models, keyed by model name.
7
+ loaded_models = {}
 
8
 
9
  @spaces.GPU
10
  def load_model(model_name: str):
11
  """
12
+ Loads the model and tokenizer and stores them in a global dictionary.
13
+ Returns a status message and the name of the loaded model.
14
  """
15
+ global loaded_models
16
  try:
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
 
20
  device_map="auto"
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ loaded_models[model_name] = (model, tokenizer)
24
+ return f"Model '{model_name}' loaded successfully.", model_name
25
  except Exception as e:
26
+ return f"Failed to load model '{model_name}': {str(e)}", ""
27
 
28
  @spaces.GPU
29
+ def generate(prompt, history, current_model_name):
30
  """
31
  Generates a response for the given prompt using the loaded model.
32
+ If the model (based on the current model name) isn’t loaded, it informs the user.
33
  """
34
+ global loaded_models
35
+ if current_model_name == "" or current_model_name not in loaded_models:
36
  return "Please load a model first by entering a model name and clicking the Load Model button."
37
+
38
+ model, tokenizer = loaded_models[current_model_name]
39
+
40
+ # Prepare the messages (with a system prompt)
41
  messages = [
42
  {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
43
  {"role": "user", "content": prompt}
 
47
  tokenize=False,
48
  add_generation_prompt=True
49
  )
50
+
51
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
52
 
53
  generated_ids = model.generate(
 
74
  )
75
  load_button = gr.Button("Load Model")
76
  load_status = gr.Textbox(label="Status", interactive=False)
77
+ # Hidden state to store the currently loaded model's name.
78
+ model_state = gr.State("")
79
 
80
+ # When the button is clicked, load_model() returns both a status message and the model name.
81
+ load_button.click(fn=load_model, inputs=model_name_input, outputs=[load_status, model_state])
82
 
83
  gr.Markdown("## Chat Interface")
84
+ # The chat interface now passes the hidden model_state into the generate function.
85
+ chat_interface = gr.ChatInterface(fn=generate, extra_inputs=[model_state])
86
 
87
+ # Launch the Gradio app (share=True to get a public link if desired).
88
  demo.launch(share=True)