MaxLSB commited on
Commit
9be0b0d
·
verified ·
1 Parent(s): 63d4a2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -23,15 +23,16 @@ if not hf_token:
23
  tokenizer = None
24
  model = None
25
 
26
- def load_model(model_name):
27
  """Loads the specified model and tokenizer."""
28
  global tokenizer, model
29
  if model_name not in MODEL_PATHS:
30
  raise ValueError(f"Unknown model: {model_name}")
31
 
32
  print(f"Loading {model_name}...")
33
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATHS[model_name], token=hf_token)
34
- model = AutoModelForCausalLM.from_pretrained(MODEL_PATHS[model_name], token=hf_token)
 
35
  model.eval()
36
  print(f"{model_name} loaded.")
37
 
@@ -42,23 +43,29 @@ load_model(initial_model)
42
 
43
  def respond(
44
  prompt: str,
45
- chat_history,
46
  model_choice: str,
47
  max_tokens: int,
48
  temperature: float,
49
  top_p: float,
50
  ):
51
  global tokenizer, model
 
52
  # Reload model if it's not the currently loaded one
53
- if model.config._name_or_path != MODEL_PATHS[model_choice]:
 
 
54
  load_model(model_choice)
55
 
 
56
  inputs = tokenizer(prompt, return_tensors="pt")
57
  streamer = TextIteratorStreamer(
58
  tokenizer,
59
  skip_prompt=False,
60
  skip_special_tokens=True,
61
  )
 
 
62
  generate_kwargs = dict(
63
  **inputs,
64
  streamer=streamer,
@@ -68,15 +75,22 @@ def respond(
68
  top_p=top_p,
69
  eos_token_id=tokenizer.eos_token_id,
70
  )
 
 
71
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
72
  thread.start()
 
 
73
  accumulated = ""
74
  for new_text in streamer:
75
  accumulated += new_text
76
  yield accumulated
77
 
78
 
79
- with gr.Blocks(css=css, fill_width=True) as demo:
 
 
 
80
  with gr.Row():
81
  with gr.Column(scale=1):
82
  model_dropdown = gr.Dropdown(
@@ -86,13 +100,13 @@ with gr.Blocks(css=css, fill_width=True) as demo:
86
  interactive=True
87
  )
88
  max_tokens_slider = gr.Slider(
89
- 1, 512, value=512, step=1, label="Max new tokens"
90
  )
91
  temperature_slider = gr.Slider(
92
- 0.1, 2.0, value=0.7, step=0.1, label="Temperature"
93
  )
94
  top_p_slider = gr.Slider(
95
- 0.1, 1.0, value=0.9, step=0.05, label="Top‑p"
96
  )
97
 
98
  with gr.Column(scale=3):
@@ -116,4 +130,4 @@ with gr.Blocks(css=css, fill_width=True) as demo:
116
 
117
  if __name__ == "__main__":
118
  demo.queue()
119
- demo.launch()
 
23
  tokenizer = None
24
  model = None
25
 
26
+ def load_model(model_name: str):
27
  """Loads the specified model and tokenizer."""
28
  global tokenizer, model
29
  if model_name not in MODEL_PATHS:
30
  raise ValueError(f"Unknown model: {model_name}")
31
 
32
  print(f"Loading {model_name}...")
33
+ repo = MODEL_PATHS[model_name]
34
+ tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
35
+ model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token)
36
  model.eval()
37
  print(f"{model_name} loaded.")
38
 
 
43
 
44
  def respond(
45
  prompt: str,
46
+ chat_history: list,
47
  model_choice: str,
48
  max_tokens: int,
49
  temperature: float,
50
  top_p: float,
51
  ):
52
  global tokenizer, model
53
+
54
  # Reload model if it's not the currently loaded one
55
+ current_path = getattr(model.config, "_name_or_path", None)
56
+ desired_path = MODEL_PATHS[model_choice]
57
+ if current_path != desired_path:
58
  load_model(model_choice)
59
 
60
+ # Tokenize
61
  inputs = tokenizer(prompt, return_tensors="pt")
62
  streamer = TextIteratorStreamer(
63
  tokenizer,
64
  skip_prompt=False,
65
  skip_special_tokens=True,
66
  )
67
+
68
+ # Prepare generation kwargs
69
  generate_kwargs = dict(
70
  **inputs,
71
  streamer=streamer,
 
75
  top_p=top_p,
76
  eos_token_id=tokenizer.eos_token_id,
77
  )
78
+
79
+ # Launch generation in a background thread
80
  thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
81
  thread.start()
82
+
83
+ # Stream back to the UI
84
  accumulated = ""
85
  for new_text in streamer:
86
  accumulated += new_text
87
  yield accumulated
88
 
89
 
90
+ # If you have custom CSS, define it here; otherwise set to None or remove the css= line below
91
+ custom_css = None
92
+
93
+ with gr.Blocks(css=custom_css, fill_width=True) as demo:
94
  with gr.Row():
95
  with gr.Column(scale=1):
96
  model_dropdown = gr.Dropdown(
 
100
  interactive=True
101
  )
102
  max_tokens_slider = gr.Slider(
103
+ minimum=1, maximum=512, value=512, step=1, label="Max new tokens"
104
  )
105
  temperature_slider = gr.Slider(
106
+ minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"
107
  )
108
  top_p_slider = gr.Slider(
109
+ minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top‑p"
110
  )
111
 
112
  with gr.Column(scale=3):
 
130
 
131
  if __name__ == "__main__":
132
  demo.queue()
133
+ demo.launch()