Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,409 +1,147 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import math
|
7 |
import os
|
8 |
-
import
|
9 |
-
import
|
10 |
-
from
|
11 |
-
from huggingface_hub import
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
return self.weight * x
|
23 |
-
|
24 |
-
class LlamaAttention(nn.Module):
|
25 |
-
def __init__(self, config):
|
26 |
-
super().__init__()
|
27 |
-
self.hidden_size = config.hidden_size
|
28 |
-
self.num_heads = config.num_attention_heads
|
29 |
-
self.num_kv_heads = config.num_key_value_heads
|
30 |
-
self.head_dim = config.hidden_size // config.num_attention_heads
|
31 |
-
|
32 |
-
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
33 |
-
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
34 |
-
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
35 |
-
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
36 |
-
|
37 |
-
def forward(self, hidden_states, attention_mask=None):
|
38 |
-
batch_size, seq_length, _ = hidden_states.size()
|
39 |
-
|
40 |
-
q = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim)
|
41 |
-
k = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
42 |
-
v = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
43 |
-
|
44 |
-
if self.num_kv_heads < self.num_heads:
|
45 |
-
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
|
46 |
-
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
|
47 |
-
|
48 |
-
q = q.transpose(1, 2)
|
49 |
-
k = k.transpose(1, 2)
|
50 |
-
v = v.transpose(1, 2)
|
51 |
-
|
52 |
-
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
53 |
-
|
54 |
-
if attention_mask is not None:
|
55 |
-
attention_scores = attention_scores + attention_mask
|
56 |
-
|
57 |
-
attention_probs = F.softmax(attention_scores, dim=-1)
|
58 |
-
context = torch.matmul(attention_probs, v)
|
59 |
-
|
60 |
-
context = context.transpose(1, 2).contiguous()
|
61 |
-
context = context.view(batch_size, seq_length, -1)
|
62 |
-
|
63 |
-
return self.o_proj(context)
|
64 |
-
|
65 |
-
class LlamaMLP(nn.Module):
|
66 |
-
def __init__(self, config):
|
67 |
-
super().__init__()
|
68 |
-
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
69 |
-
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
70 |
-
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
71 |
-
self.act_fn = nn.SiLU()
|
72 |
-
|
73 |
-
def forward(self, x):
|
74 |
-
gate = self.act_fn(self.gate_proj(x))
|
75 |
-
up = self.up_proj(x)
|
76 |
-
return self.down_proj(gate * up)
|
77 |
-
|
78 |
-
class LlamaDecoderLayer(nn.Module):
|
79 |
-
def __init__(self, config):
|
80 |
-
super().__init__()
|
81 |
-
self.self_attn = LlamaAttention(config)
|
82 |
-
self.mlp = LlamaMLP(config)
|
83 |
-
self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
|
84 |
-
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
|
85 |
-
|
86 |
-
def forward(self, hidden_states, attention_mask=None):
|
87 |
-
residual = hidden_states
|
88 |
-
hidden_states = self.input_layernorm(hidden_states)
|
89 |
-
hidden_states = self.self_attn(hidden_states, attention_mask)
|
90 |
-
hidden_states = residual + hidden_states
|
91 |
-
|
92 |
-
residual = hidden_states
|
93 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
94 |
-
hidden_states = self.mlp(hidden_states)
|
95 |
-
hidden_states = residual + hidden_states
|
96 |
-
|
97 |
-
return hidden_states
|
98 |
-
|
99 |
-
class SmolLM2Config(PretrainedConfig):
|
100 |
-
model_type = "smollm2"
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
):
|
122 |
-
self.vocab_size = vocab_size
|
123 |
-
self.hidden_size = hidden_size
|
124 |
-
self.intermediate_size = intermediate_size
|
125 |
-
self.num_hidden_layers = num_hidden_layers
|
126 |
-
self.num_attention_heads = num_attention_heads
|
127 |
-
self.num_key_value_heads = num_key_value_heads
|
128 |
-
self.hidden_act = hidden_act
|
129 |
-
self.max_position_embeddings = max_position_embeddings
|
130 |
-
self.initializer_range = initializer_range
|
131 |
-
self.rms_norm_eps = rms_norm_eps
|
132 |
-
self.use_cache = use_cache
|
133 |
-
self.rope_theta = rope_theta
|
134 |
-
super().__init__(
|
135 |
-
pad_token_id=pad_token_id,
|
136 |
-
bos_token_id=bos_token_id,
|
137 |
-
eos_token_id=eos_token_id,
|
138 |
-
tie_word_embeddings=tie_word_embeddings,
|
139 |
-
**kwargs
|
140 |
)
|
141 |
-
|
142 |
-
class SmolLM2ForCausalLM(PreTrainedModel):
|
143 |
-
config_class = SmolLM2Config
|
144 |
-
_no_split_modules = ["LlamaDecoderLayer"]
|
145 |
-
|
146 |
-
def __init__(self, config):
|
147 |
-
super().__init__(config)
|
148 |
-
self.config = config
|
149 |
-
|
150 |
-
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
151 |
-
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
152 |
-
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
|
153 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
# Create causal attention mask if none provided
|
162 |
-
if attention_mask is None:
|
163 |
-
attention_mask = torch.triu(
|
164 |
-
torch.ones((input_ids.size(1), input_ids.size(1)), dtype=torch.bool, device=input_ids.device),
|
165 |
-
diagonal=1
|
166 |
-
)
|
167 |
-
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
|
168 |
-
attention_mask = attention_mask * -1e4
|
169 |
-
|
170 |
-
for layer in self.layers:
|
171 |
-
hidden_states = layer(hidden_states, attention_mask)
|
172 |
-
|
173 |
-
hidden_states = self.norm(hidden_states)
|
174 |
-
logits = self.lm_head(hidden_states)
|
175 |
-
|
176 |
-
loss = None
|
177 |
-
if labels is not None:
|
178 |
-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
|
179 |
-
|
180 |
-
return logits if loss is None else (loss, logits)
|
181 |
-
|
182 |
-
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
183 |
-
return {
|
184 |
-
"input_ids": input_ids,
|
185 |
-
"attention_mask": kwargs.get("attention_mask", None)
|
186 |
-
}
|
187 |
-
|
188 |
-
# Register the model architecture
|
189 |
-
from transformers import AutoConfig, AutoModelForCausalLM
|
190 |
-
AutoConfig.register("smollm2", SmolLM2Config)
|
191 |
-
AutoModelForCausalLM.register(SmolLM2Config, SmolLM2ForCausalLM)
|
192 |
-
|
193 |
-
# Load model and tokenizer
|
194 |
-
model_id = "jatingocodeo/SmolLM2"
|
195 |
-
|
196 |
-
def check_huggingface_access():
|
197 |
-
try:
|
198 |
-
from huggingface_hub import HfApi
|
199 |
-
api = HfApi()
|
200 |
|
201 |
-
#
|
202 |
try:
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
try:
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
#
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
login(hf_token)
|
241 |
-
|
242 |
-
print("\n1. Loading tokenizer...")
|
243 |
-
try:
|
244 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
245 |
-
model_id,
|
246 |
-
use_auth_token=hf_token,
|
247 |
-
trust_remote_code=True
|
248 |
-
)
|
249 |
-
print("✓ Tokenizer loaded successfully")
|
250 |
-
print(f"Tokenizer type: {type(tokenizer)}")
|
251 |
-
print(f"Vocabulary size: {len(tokenizer)}")
|
252 |
-
except Exception as e:
|
253 |
-
print(f"× Error loading tokenizer: {str(e)}")
|
254 |
-
raise
|
255 |
-
|
256 |
-
print("\n2. Adding special tokens...")
|
257 |
-
try:
|
258 |
-
special_tokens = {
|
259 |
-
'pad_token': '[PAD]',
|
260 |
-
'eos_token': '</s>',
|
261 |
-
'bos_token': '<s>'
|
262 |
-
}
|
263 |
-
num_added = tokenizer.add_special_tokens(special_tokens)
|
264 |
-
print(f"✓ Added {num_added} special tokens")
|
265 |
-
print(f"Special tokens: {tokenizer.special_tokens_map}")
|
266 |
-
except Exception as e:
|
267 |
-
print(f"× Error adding special tokens: {str(e)}")
|
268 |
-
raise
|
269 |
-
|
270 |
-
print("\n3. Loading model...")
|
271 |
-
try:
|
272 |
-
model = AutoModelForCausalLM.from_pretrained(
|
273 |
-
model_id,
|
274 |
-
use_auth_token=hf_token,
|
275 |
-
trust_remote_code=True,
|
276 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
277 |
-
low_cpu_mem_usage=True
|
278 |
-
)
|
279 |
-
print("✓ Model loaded successfully")
|
280 |
-
print(f"Model type: {type(model)}")
|
281 |
-
except Exception as e:
|
282 |
-
print(f"× Error loading model: {str(e)}")
|
283 |
-
raise
|
284 |
-
|
285 |
-
print("\n4. Moving model to device...")
|
286 |
-
try:
|
287 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
288 |
-
print(f"Selected device: {device}")
|
289 |
-
model = model.to(device)
|
290 |
-
print("✓ Model moved to device successfully")
|
291 |
-
except Exception as e:
|
292 |
-
print(f"× Error moving model to device: {str(e)}")
|
293 |
-
raise
|
294 |
-
|
295 |
-
print("\n=== Model loading completed successfully! ===")
|
296 |
-
return model, tokenizer
|
297 |
-
|
298 |
except Exception as e:
|
299 |
-
|
300 |
-
print(f"Error type: {type(e).__name__}")
|
301 |
-
print(f"Error message: {str(e)}")
|
302 |
-
print("\nFull traceback:")
|
303 |
-
import traceback
|
304 |
-
traceback.print_exc()
|
305 |
-
print("\nAdditional debug info:")
|
306 |
-
print(f"Python version: {sys.version}")
|
307 |
-
print(f"PyTorch version: {torch.__version__}")
|
308 |
-
print(f"Transformers version: {transformers.__version__}")
|
309 |
-
print(f"CUDA available: {torch.cuda.is_available()}")
|
310 |
-
if torch.cuda.is_available():
|
311 |
-
print(f"CUDA version: {torch.version.cuda}")
|
312 |
-
print("\nEnvironment variables:")
|
313 |
-
print(f"HF_TOKEN set: {'HF_TOKEN' in os.environ}")
|
314 |
-
raise
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
print("\n=== Starting text generation ===")
|
319 |
-
print(f"Input prompt: {prompt}")
|
320 |
-
print(f"Parameters: max_length={max_length}, temperature={temperature}, top_k={top_k}")
|
321 |
-
|
322 |
-
if not hasattr(generate_text, "model"):
|
323 |
-
print("\n1. First call - loading model...")
|
324 |
-
generate_text.model, generate_text.tokenizer = load_model()
|
325 |
-
|
326 |
-
if not prompt.strip():
|
327 |
-
print("× Empty prompt received")
|
328 |
-
return "Please enter a prompt."
|
329 |
-
|
330 |
-
print("\n2. Processing prompt...")
|
331 |
-
if not prompt.startswith(generate_text.tokenizer.bos_token):
|
332 |
-
prompt = generate_text.tokenizer.bos_token + prompt
|
333 |
-
print("Added BOS token to prompt")
|
334 |
-
|
335 |
-
print("\n3. Encoding prompt...")
|
336 |
-
try:
|
337 |
-
input_ids = generate_text.tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
338 |
-
print(f"Encoded shape: {input_ids.shape}")
|
339 |
-
input_ids = input_ids.to(generate_text.model.device)
|
340 |
-
print("✓ Encoding successful")
|
341 |
-
except Exception as e:
|
342 |
-
print(f"× Error encoding prompt: {str(e)}")
|
343 |
-
raise
|
344 |
-
|
345 |
-
print("\n4. Generating text...")
|
346 |
-
try:
|
347 |
-
with torch.no_grad():
|
348 |
-
output_ids = generate_text.model.generate(
|
349 |
-
input_ids,
|
350 |
-
max_length=min(max_length + len(input_ids[0]), 2048),
|
351 |
-
temperature=temperature,
|
352 |
-
top_k=top_k,
|
353 |
-
do_sample=True,
|
354 |
-
pad_token_id=generate_text.tokenizer.pad_token_id,
|
355 |
-
eos_token_id=generate_text.tokenizer.eos_token_id,
|
356 |
-
num_return_sequences=1
|
357 |
-
)
|
358 |
-
print(f"Generation shape: {output_ids.shape}")
|
359 |
-
except Exception as e:
|
360 |
-
print(f"× Error during generation: {str(e)}")
|
361 |
-
raise
|
362 |
-
|
363 |
-
print("\n5. Decoding output...")
|
364 |
-
try:
|
365 |
-
generated_text = generate_text.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
366 |
-
print("✓ Decoding successful")
|
367 |
-
print(f"Output length: {len(generated_text)}")
|
368 |
-
except Exception as e:
|
369 |
-
print(f"× Error decoding output: {str(e)}")
|
370 |
-
raise
|
371 |
-
|
372 |
-
print("\n=== Generation completed successfully! ===")
|
373 |
-
return generated_text.strip()
|
374 |
-
|
375 |
-
except Exception as e:
|
376 |
-
print("\n!!! ERROR IN TEXT GENERATION !!!")
|
377 |
-
print(f"Error type: {type(e).__name__}")
|
378 |
-
print(f"Error message: {str(e)}")
|
379 |
-
print("\nFull traceback:")
|
380 |
-
import traceback
|
381 |
-
traceback.print_exc()
|
382 |
-
return f"An error occurred: {str(e)}"
|
383 |
|
384 |
# Create Gradio interface
|
385 |
-
|
386 |
fn=generate_text,
|
387 |
inputs=[
|
388 |
-
gr.Textbox(
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
],
|
393 |
-
outputs=gr.Textbox(label="Generated Text", lines=
|
394 |
-
title="
|
395 |
-
description="""
|
396 |
-
|
397 |
-
|
398 |
-
|
|
|
|
|
|
|
|
|
399 |
examples=[
|
400 |
-
["
|
401 |
-
["
|
402 |
-
["
|
403 |
-
]
|
404 |
-
allow_flagging="never"
|
405 |
)
|
406 |
|
407 |
if __name__ == "__main__":
|
408 |
-
|
409 |
-
iface.launch(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
from train_optimized import GPT, GPTConfig
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
import json
|
7 |
+
|
8 |
+
# Cache for model and tokenizer
|
9 |
+
MODEL = None
|
10 |
+
CHARS = None
|
11 |
+
STOI = None
|
12 |
+
ITOS = None
|
13 |
+
|
14 |
+
def initialize():
|
15 |
+
global MODEL, CHARS, STOI, ITOS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
if MODEL is None:
|
18 |
+
print("Loading model and tokenizer...")
|
19 |
+
# Download model files from HF Hub
|
20 |
+
config_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="config.json")
|
21 |
+
model_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="pytorch_model.bin")
|
22 |
+
|
23 |
+
# Load config
|
24 |
+
with open(config_path, 'r') as f:
|
25 |
+
config_dict = json.load(f)
|
26 |
+
|
27 |
+
# Initialize model with config
|
28 |
+
config = GPTConfig(
|
29 |
+
vocab_size=config_dict['vocab_size'],
|
30 |
+
n_layer=config_dict['n_layer'],
|
31 |
+
n_head=config_dict['n_head'],
|
32 |
+
n_embd=config_dict['n_embd'],
|
33 |
+
block_size=config_dict['block_size'],
|
34 |
+
dropout=config_dict['dropout'],
|
35 |
+
bias=config_dict['bias']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
)
|
37 |
+
model = GPT(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
# Load model weights
|
40 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
41 |
+
model.load_state_dict(state_dict)
|
42 |
+
model.eval()
|
43 |
+
MODEL = model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
# Initialize tokenizer
|
46 |
try:
|
47 |
+
# First try to download input.txt from the repository
|
48 |
+
input_path = hf_hub_download(repo_id="jatingocodeo/shakespeare-decoder", filename="input.txt")
|
49 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
50 |
+
text = f.read()
|
51 |
+
except:
|
52 |
+
print("Warning: Could not load input.txt from repository. Using default character set.")
|
53 |
+
# Use a comprehensive set of characters that would be in Shakespeare's text
|
54 |
+
text = """
|
55 |
+
ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,!?-_:;'"()[]{}\n </>\t@#$%^&*+=~`|\\
|
56 |
+
"""
|
57 |
+
|
58 |
+
CHARS = sorted(list(set(text)))
|
59 |
+
STOI = {ch:i for i,ch in enumerate(CHARS)}
|
60 |
+
ITOS = {i:ch for i,ch in enumerate(CHARS)}
|
61 |
+
|
62 |
+
print(f"Model loaded successfully! Vocabulary size: {len(CHARS)} characters")
|
63 |
+
print("Available characters:", ''.join(CHARS))
|
64 |
+
|
65 |
+
def generate_text(
|
66 |
+
prompt,
|
67 |
+
max_new_tokens=100,
|
68 |
+
temperature=0.8,
|
69 |
+
top_k=50
|
70 |
+
):
|
71 |
+
# Initialize if not already done
|
72 |
+
if MODEL is None:
|
73 |
+
initialize()
|
74 |
+
|
75 |
+
# Encode the prompt
|
76 |
+
encode = lambda s: [STOI[c] for c in s]
|
77 |
+
decode = lambda l: ''.join([ITOS[i] for i in l])
|
78 |
+
|
79 |
try:
|
80 |
+
# Convert prompt to tensor
|
81 |
+
x = torch.tensor(encode(prompt), dtype=torch.long)[None,...]
|
82 |
+
|
83 |
+
# Generate
|
84 |
+
with torch.no_grad():
|
85 |
+
y = MODEL.generate(x, max_new_tokens, temperature, top_k)[0]
|
86 |
+
|
87 |
+
# Decode and return
|
88 |
+
generated_text = decode(y.tolist())
|
89 |
+
return generated_text
|
90 |
+
except KeyError:
|
91 |
+
return "Error: The prompt contains characters that are not in the training data. Please use only standard English characters."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
except Exception as e:
|
93 |
+
return f"Error generating text: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
# Initialize on startup
|
96 |
+
initialize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# Create Gradio interface
|
99 |
+
demo = gr.Interface(
|
100 |
fn=generate_text,
|
101 |
inputs=[
|
102 |
+
gr.Textbox(
|
103 |
+
label="Prompt",
|
104 |
+
placeholder="Enter your prompt here...",
|
105 |
+
lines=5
|
106 |
+
),
|
107 |
+
gr.Slider(
|
108 |
+
label="Max New Tokens",
|
109 |
+
minimum=10,
|
110 |
+
maximum=500,
|
111 |
+
value=100,
|
112 |
+
step=10
|
113 |
+
),
|
114 |
+
gr.Slider(
|
115 |
+
label="Temperature",
|
116 |
+
minimum=0.1,
|
117 |
+
maximum=2.0,
|
118 |
+
value=0.8,
|
119 |
+
step=0.1
|
120 |
+
),
|
121 |
+
gr.Slider(
|
122 |
+
label="Top-k",
|
123 |
+
minimum=1,
|
124 |
+
maximum=100,
|
125 |
+
value=50,
|
126 |
+
step=1
|
127 |
+
),
|
128 |
],
|
129 |
+
outputs=gr.Textbox(label="Generated Text", lines=10),
|
130 |
+
title="Shakespeare GPT",
|
131 |
+
description="""
|
132 |
+
This is a GPT model trained on Shakespeare's text. Enter a prompt and the model will continue it in Shakespeare's style.
|
133 |
+
|
134 |
+
Parameters:
|
135 |
+
- Temperature: Higher values make the output more random, lower values make it more deterministic
|
136 |
+
- Top-k: Number of highest probability tokens to consider at each step
|
137 |
+
- Max New Tokens: Maximum number of tokens to generate
|
138 |
+
""",
|
139 |
examples=[
|
140 |
+
["To be, or not to be,", 100, 0.8, 50],
|
141 |
+
["Friends, Romans, countrymen,", 150, 0.7, 40],
|
142 |
+
["Now is the winter of", 200, 0.9, 30],
|
143 |
+
]
|
|
|
144 |
)
|
145 |
|
146 |
if __name__ == "__main__":
|
147 |
+
demo.launch()
|
|