TakiTakiTa commited on
Commit
ec40b9a
·
verified ·
1 Parent(s): 83c9d49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -23
app.py CHANGED
@@ -5,14 +5,12 @@ import torch
5
 
6
  # Global dictionary to store loaded models, keyed by model name.
7
  loaded_models = {}
 
 
8
 
9
  @spaces.GPU
10
  def load_model(model_name: str):
11
- """
12
- Loads the model and tokenizer and stores them in a global dictionary.
13
- Returns a status message and the name of the loaded model.
14
- """
15
- global loaded_models
16
  try:
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
@@ -21,23 +19,20 @@ def load_model(model_name: str):
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
  loaded_models[model_name] = (model, tokenizer)
24
- return f"Model '{model_name}' loaded successfully.", model_name
 
25
  except Exception as e:
26
- return f"Failed to load model '{model_name}': {str(e)}", ""
27
 
28
  @spaces.GPU
29
- def generate(prompt, history, current_model_name):
30
- """
31
- Generates a response for the given prompt using the loaded model.
32
- If the model (based on the current model name) isn’t loaded, it informs the user.
33
- """
34
- global loaded_models
35
  if current_model_name == "" or current_model_name not in loaded_models:
36
  return "Please load a model first by entering a model name and clicking the Load Model button."
37
 
38
  model, tokenizer = loaded_models[current_model_name]
39
 
40
- # Prepare the messages (with a system prompt)
41
  messages = [
42
  {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
43
  {"role": "user", "content": prompt}
@@ -47,7 +42,6 @@ def generate(prompt, history, current_model_name):
47
  tokenize=False,
48
  add_generation_prompt=True
49
  )
50
-
51
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
52
 
53
  generated_ids = model.generate(
@@ -70,19 +64,16 @@ with gr.Blocks() as demo:
70
  model_name_input = gr.Textbox(
71
  label="Model Name",
72
  value="agentica-org/DeepScaleR-1.5B-Preview",
73
- placeholder="Enter model name"
74
  )
75
  load_button = gr.Button("Load Model")
76
  load_status = gr.Textbox(label="Status", interactive=False)
77
- # Hidden state to store the currently loaded model's name.
78
- model_state = gr.State("")
79
 
80
- # When the button is clicked, load_model() returns both a status message and the model name.
81
- load_button.click(fn=load_model, inputs=model_name_input, outputs=[load_status, model_state])
82
 
83
  gr.Markdown("## Chat Interface")
84
- # The chat interface now passes the hidden model_state into the generate function.
85
- chat_interface = gr.ChatInterface(fn=generate, extra_inputs=[model_state])
86
 
87
- # Launch the Gradio app (share=True to get a public link if desired).
88
  demo.launch(share=True)
 
5
 
6
  # Global dictionary to store loaded models, keyed by model name.
7
  loaded_models = {}
8
+ # Global variable to track the currently loaded model's name.
9
+ current_model_name = ""
10
 
11
  @spaces.GPU
12
  def load_model(model_name: str):
13
+ global loaded_models, current_model_name
 
 
 
 
14
  try:
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name,
 
19
  )
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  loaded_models[model_name] = (model, tokenizer)
22
+ current_model_name = model_name # update global state
23
+ return f"Model '{model_name}' loaded successfully."
24
  except Exception as e:
25
+ return f"Failed to load model '{model_name}': {str(e)}"
26
 
27
  @spaces.GPU
28
+ def generate(prompt, history):
29
+ global loaded_models, current_model_name
 
 
 
 
30
  if current_model_name == "" or current_model_name not in loaded_models:
31
  return "Please load a model first by entering a model name and clicking the Load Model button."
32
 
33
  model, tokenizer = loaded_models[current_model_name]
34
 
35
+ # Prepare the messages (with a system prompt and the user's prompt)
36
  messages = [
37
  {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
38
  {"role": "user", "content": prompt}
 
42
  tokenize=False,
43
  add_generation_prompt=True
44
  )
 
45
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
46
 
47
  generated_ids = model.generate(
 
64
  model_name_input = gr.Textbox(
65
  label="Model Name",
66
  value="agentica-org/DeepScaleR-1.5B-Preview",
67
+ placeholder="Enter model name (e.g., agentica-org/DeepScaleR-1.5B-Preview)"
68
  )
69
  load_button = gr.Button("Load Model")
70
  load_status = gr.Textbox(label="Status", interactive=False)
 
 
71
 
72
+ # When the Load Model button is clicked, load_model is called.
73
+ load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status)
74
 
75
  gr.Markdown("## Chat Interface")
76
+ # Create the chat interface without extra_inputs.
77
+ chat_interface = gr.ChatInterface(fn=generate)
78
 
 
79
  demo.launch(share=True)