yoonusajwardapiit commited on
Commit
247aecf
1 Parent(s): 9d37c49

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -81,10 +81,11 @@ class BigramLanguageModel(nn.Module):
81
  return logits, None
82
 
83
  def generate(self, idx, max_new_tokens):
 
84
  for _ in range(max_new_tokens):
85
- idx_cond = idx[:, -32:] # Ensure context length does not exceed block size
86
  logits, _ = self(idx_cond)
87
- logits = logits[:, -1, :]
88
  probs = nn.functional.softmax(logits, dim=-1)
89
  idx_next = torch.multinomial(probs, num_samples=1)
90
  idx = torch.cat((idx, idx_next), dim=1)
@@ -105,7 +106,7 @@ model = load_model()
105
  chars = sorted(list(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-:;'\"\n")))
106
  stoi = {ch: i for i, ch in enumerate(chars)}
107
  itos = {i: ch for i, ch in enumerate(chars)}
108
- encode = lambda s: [stoi[c] for c in s if c in stoi] # Ensures only known characters are encoded
109
  decode = lambda l: ''.join([itos[i] for i in l])
110
 
111
  # Function to generate text using the model
@@ -114,6 +115,10 @@ def generate_text(prompt):
114
  print(f"Received prompt: {prompt}")
115
  encoded_prompt = encode(prompt)
116
 
 
 
 
 
117
  # Ensure the prompt length fits within the block size
118
  if len(encoded_prompt) > 32:
119
  encoded_prompt = encoded_prompt[:32] # Truncate to fit block size
@@ -142,4 +147,4 @@ interface = gr.Interface(
142
  )
143
 
144
  # Launch the interface
145
- interface.launch(share=True)
 
81
  return logits, None
82
 
83
  def generate(self, idx, max_new_tokens):
84
+ # Ensure we respect the block size of 32
85
  for _ in range(max_new_tokens):
86
+ idx_cond = idx[:, -32:] # Truncate to the latest 32 tokens
87
  logits, _ = self(idx_cond)
88
+ logits = logits[:, -1, :] # Get the logits for the last token
89
  probs = nn.functional.softmax(logits, dim=-1)
90
  idx_next = torch.multinomial(probs, num_samples=1)
91
  idx = torch.cat((idx, idx_next), dim=1)
 
106
  chars = sorted(list(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-:;'\"\n")))
107
  stoi = {ch: i for i, ch in enumerate(chars)}
108
  itos = {i: ch for i, ch in enumerate(chars)}
109
+ encode = lambda s: [stoi[c] for c in s if c in stoi] # Ensure only known characters are encoded
110
  decode = lambda l: ''.join([itos[i] for i in l])
111
 
112
  # Function to generate text using the model
 
115
  print(f"Received prompt: {prompt}")
116
  encoded_prompt = encode(prompt)
117
 
118
+ # Check for out-of-vocabulary indices
119
+ if any(idx >= 61 for idx in encoded_prompt):
120
+ return "Error: Input contains characters not in the model vocabulary."
121
+
122
  # Ensure the prompt length fits within the block size
123
  if len(encoded_prompt) > 32:
124
  encoded_prompt = encoded_prompt[:32] # Truncate to fit block size
 
147
  )
148
 
149
  # Launch the interface
150
+ interface.launch()