FlameF0X commited on
Commit
e4f4e0b
·
verified ·
1 Parent(s): 0e929c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -65
app.py CHANGED
@@ -32,102 +32,111 @@ def load_model_and_tokenizer(model_name):
32
 
33
  return model_cache[model_name], tokenizer_cache[model_name]
34
 
35
- def respond(
36
- message,
37
- history: list[tuple[str, str]],
38
- model_choice,
39
- system_message,
40
- max_tokens,
41
- temperature,
42
- top_p,
43
- ):
44
  # Load selected model and tokenizer
45
- model_name = model_choice
46
- model, tokenizer = load_model_and_tokenizer(model_name)
47
 
48
  # Build conversation messages
49
  messages = [{"role": "system", "content": system_message}]
50
- for user_msg, assistant_msg in history:
51
- if user_msg:
52
- messages.append({"role": "user", "content": user_msg})
53
- if assistant_msg:
54
- messages.append({"role": "assistant", "content": assistant_msg})
55
 
56
- messages.append({"role": "user", "content": message})
 
57
 
58
- # Format prompt using chat template
59
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
61
 
62
- # Set up streaming
63
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
64
 
65
- # Configure generation parameters
66
- do_sample = temperature > 0 or top_p < 1.0
67
- generation_kwargs = dict(
68
- **inputs,
69
- streamer=streamer,
70
- max_new_tokens=max_tokens,
71
- temperature=temperature,
72
- top_p=top_p,
73
- do_sample=do_sample,
74
- pad_token_id=tokenizer.pad_token_id
75
- )
76
 
77
- # Start generation in separate thread
78
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
79
- thread.start()
80
 
81
- # Stream response
82
- response = ""
83
- for token in streamer:
84
- response += token
85
- yield response
86
-
87
- # Define available models
88
- available_models = [
89
- "GoofyLM/BrainrotLM-Assistant-362M",
90
- "GoofyLM/BrainrotLM2-Assistant-362M"
91
- ]
92
 
93
- # Create Gradio interface
94
  with gr.Blocks() as demo:
95
  gr.Markdown("# BrainrotLM Chat Interface")
96
 
97
  with gr.Row():
98
- with gr.Column(scale=4):
99
  chatbot = gr.Chatbot(height=600)
100
- msg = gr.Textbox(label="Message", placeholder="Type your message here...", lines=3)
101
- clear = gr.Button("Clear")
102
-
 
 
 
 
 
 
 
 
 
103
  with gr.Column(scale=1):
104
  model_dropdown = gr.Dropdown(
105
  choices=available_models,
106
  value=available_models[0],
107
  label="Select Model"
108
  )
 
109
  system_message = gr.Textbox(
110
- value="Your name is BrainrotLM, an AI assistant trained by GoofyLM.",
111
  label="System message",
112
  lines=4
113
  )
 
114
  max_tokens = gr.Slider(1, 512, value=72, label="Max new tokens")
115
  temperature = gr.Slider(0.1, 2.0, value=0.65, label="Temperature")
116
  top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p (nucleus sampling)")
117
 
118
- chat_interface = gr.ChatInterface(
119
- fn=respond,
120
- chatbot=chatbot,
121
- textbox=msg,
122
- clear_btn=clear,
123
- additional_inputs=[
124
- model_dropdown,
125
- system_message,
126
- max_tokens,
127
- temperature,
128
- top_p,
129
- ],
130
  )
 
 
 
 
 
 
 
131
 
132
  if __name__ == "__main__":
133
  demo.launch()
 
32
 
33
  return model_cache[model_name], tokenizer_cache[model_name]
34
 
35
+ # Define available models
36
+ available_models = [
37
+ "GoofyLM/BrainrotLM-Assistant-362M",
38
+ "GoofyLM/BrainrotLM2-Assistant-362M"
39
+ ]
40
+
41
+ def respond(message, chat_history, model_choice, system_message, max_tokens, temperature, top_p):
 
 
42
  # Load selected model and tokenizer
43
+ model, tokenizer = load_model_and_tokenizer(model_choice)
 
44
 
45
  # Build conversation messages
46
  messages = [{"role": "system", "content": system_message}]
47
+ for user_msg, assistant_msg in chat_history:
48
+ messages.append({"role": "user", "content": user_msg})
49
+ if assistant_msg: # This might be None during streaming
50
+ messages.append({"role": "assistant", "content": assistant_msg})
 
51
 
52
+ # Add the current message
53
+ messages.append({"role": "user", "content": message})
54
 
55
+ # Format prompt using chat template
56
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
57
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
58
 
59
+ # Set up streaming
60
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
61
 
62
+ # Configure generation parameters
63
+ generation_kwargs = dict(
64
+ **inputs,
65
+ streamer=streamer,
66
+ max_new_tokens=max_tokens,
67
+ temperature=temperature,
68
+ top_p=top_p,
69
+ do_sample=(temperature > 0 or top_p < 1.0),
70
+ pad_token_id=tokenizer.pad_token_id
71
+ )
 
72
 
73
+ # Start generation in a separate thread
74
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
75
+ thread.start()
76
 
77
+ # Stream the response
78
+ partial_message = ""
79
+ for new_token in streamer:
80
+ partial_message += new_token
81
+ yield chat_history + [(message, partial_message)]
82
+
83
+ return chat_history + [(message, partial_message)]
 
 
 
 
84
 
85
+ # Create the Gradio interface
86
  with gr.Blocks() as demo:
87
  gr.Markdown("# BrainrotLM Chat Interface")
88
 
89
  with gr.Row():
90
+ with gr.Column(scale=3):
91
  chatbot = gr.Chatbot(height=600)
92
+
93
+ with gr.Row():
94
+ msg = gr.Textbox(
95
+ label="Message",
96
+ placeholder="Type your message here...",
97
+ lines=3,
98
+ show_label=False
99
+ )
100
+ submit = gr.Button("Send", variant="primary")
101
+
102
+ clear = gr.Button("Clear Conversation")
103
+
104
  with gr.Column(scale=1):
105
  model_dropdown = gr.Dropdown(
106
  choices=available_models,
107
  value=available_models[0],
108
  label="Select Model"
109
  )
110
+
111
  system_message = gr.Textbox(
112
+ value="Your name is BrainrotLM, an AI assistant trained by GoofyLM.",
113
  label="System message",
114
  lines=4
115
  )
116
+
117
  max_tokens = gr.Slider(1, 512, value=72, label="Max new tokens")
118
  temperature = gr.Slider(0.1, 2.0, value=0.65, label="Temperature")
119
  top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p (nucleus sampling)")
120
 
121
+ # Set up event handlers
122
+ submit_event = msg.submit(
123
+ respond,
124
+ inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
125
+ outputs=chatbot
126
+ )
127
+
128
+ submit_click = submit.click(
129
+ respond,
130
+ inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
131
+ outputs=chatbot
 
132
  )
133
+
134
+ # Clear message box after sending
135
+ submit_event.then(lambda: "", None, msg)
136
+ submit_click.then(lambda: "", None, msg)
137
+
138
+ # Clear conversation button
139
+ clear.click(lambda: None, None, chatbot)
140
 
141
  if __name__ == "__main__":
142
  demo.launch()