TakiTakiTa commited on
Commit
bd6741d
·
verified ·
1 Parent(s): 699d2be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -13
app.py CHANGED
@@ -3,17 +3,38 @@ import spaces
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
- model_name = "Qwen/Qwen2.5-7B-Instruct"
 
 
7
 
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_name,
10
- torch_dtype=torch.bfloat16,
11
- device_map="auto"
12
- )
13
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @spaces.GPU
16
  def generate(prompt, history):
 
 
 
 
 
 
 
 
17
  messages = [
18
  {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
19
  {"role": "user", "content": prompt}
@@ -29,16 +50,33 @@ def generate(prompt, history):
29
  **model_inputs,
30
  max_new_tokens=512
31
  )
 
32
  generated_ids = [
33
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
 
34
  ]
35
 
36
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
37
  return response
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
-
41
- chat_interface = gr.ChatInterface(
42
- fn=generate,
43
- )
44
- chat_interface.launch(share=True)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
 
6
+ # Global variables to store the loaded model and tokenizer.
7
+ model = None
8
+ tokenizer = None
9
 
10
+ @spaces.GPU
11
+ def load_model(model_name: str):
12
+ """
13
+ Loads the model and tokenizer given the model name.
14
+ Returns a status message.
15
+ """
16
+ global model, tokenizer
17
+ try:
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_name,
20
+ torch_dtype=torch.bfloat16,
21
+ device_map="auto"
22
+ )
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ return f"Model '{model_name}' loaded successfully."
25
+ except Exception as e:
26
+ return f"Failed to load model '{model_name}': {str(e)}"
27
 
28
  @spaces.GPU
29
  def generate(prompt, history):
30
+ """
31
+ Generates a response for the given prompt using the loaded model.
32
+ If the model is not loaded, informs the user to load it first.
33
+ """
34
+ if model is None or tokenizer is None:
35
+ return "Please load a model first by entering a model name and clicking the Load Model button."
36
+
37
+ # Prepare the chat history (here, a simple system prompt is added)
38
  messages = [
39
  {"role": "system", "content": "Je bent een vriendelijke, behulpzame assistent."},
40
  {"role": "user", "content": prompt}
 
50
  **model_inputs,
51
  max_new_tokens=512
52
  )
53
+ # Remove the input tokens from the generated tokens.
54
  generated_ids = [
55
+ output_ids[len(input_ids):]
56
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
57
  ]
58
 
59
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
60
  return response
61
 
62
+ # Build the Gradio UI using Blocks.
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("## Model Loader")
65
+ with gr.Row():
66
+ model_name_input = gr.Textbox(
67
+ label="Model Name",
68
+ value="simplescaling/s1-32B",
69
+ placeholder="Enter model name"
70
+ )
71
+ load_button = gr.Button("Load Model")
72
+ load_status = gr.Textbox(label="Status", interactive=False)
73
+
74
+ # When the button is clicked, load_model() is called.
75
+ load_button.click(fn=load_model, inputs=model_name_input, outputs=load_status)
76
+
77
+ gr.Markdown("## Chat Interface")
78
+ # The ChatInterface calls generate() which uses the loaded model.
79
+ chat_interface = gr.ChatInterface(fn=generate)
80
 
81
+ # Launch the Gradio app (using share=True if you wish to share it publicly).
82
+ demo.launch(share=True)