yoonusajwardapiit commited on
Commit
837ecd8
1 Parent(s): 247aecf

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -81,7 +81,6 @@ class BigramLanguageModel(nn.Module):
81
  return logits, None
82
 
83
  def generate(self, idx, max_new_tokens):
84
- # Ensure we respect the block size of 32
85
  for _ in range(max_new_tokens):
86
  idx_cond = idx[:, -32:] # Truncate to the latest 32 tokens
87
  logits, _ = self(idx_cond)
@@ -103,10 +102,11 @@ def load_model():
103
  model = load_model()
104
 
105
  # Define a comprehensive character set based on training data
106
- chars = sorted(list(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 .,!?-:;'\"\n")))
 
107
  stoi = {ch: i for i, ch in enumerate(chars)}
108
  itos = {i: ch for i, ch in enumerate(chars)}
109
- encode = lambda s: [stoi[c] for c in s if c in stoi] # Ensure only known characters are encoded
110
  decode = lambda l: ''.join([itos[i] for i in l])
111
 
112
  # Function to generate text using the model
@@ -116,7 +116,7 @@ def generate_text(prompt):
116
  encoded_prompt = encode(prompt)
117
 
118
  # Check for out-of-vocabulary indices
119
- if any(idx >= 61 for idx in encoded_prompt):
120
  return "Error: Input contains characters not in the model vocabulary."
121
 
122
  # Ensure the prompt length fits within the block size
 
81
  return logits, None
82
 
83
  def generate(self, idx, max_new_tokens):
 
84
  for _ in range(max_new_tokens):
85
  idx_cond = idx[:, -32:] # Truncate to the latest 32 tokens
86
  logits, _ = self(idx_cond)
 
102
  model = load_model()
103
 
104
  # Define a comprehensive character set based on training data
105
+ # Convert all input to lowercase if the model is trained on lowercase data
106
+ chars = sorted(list(set("abcdefghijklmnopqrstuvwxyz0123456789 .,!?-:;'\"\n")))
107
  stoi = {ch: i for i, ch in enumerate(chars)}
108
  itos = {i: ch for i, ch in enumerate(chars)}
109
+ encode = lambda s: [stoi.get(c, stoi.get(c.lower(), -1)) for c in s if c in stoi or c.lower() in stoi] # Handles both cases
110
  decode = lambda l: ''.join([itos[i] for i in l])
111
 
112
  # Function to generate text using the model
 
116
  encoded_prompt = encode(prompt)
117
 
118
  # Check for out-of-vocabulary indices
119
+ if any(idx == -1 for idx in encoded_prompt):
120
  return "Error: Input contains characters not in the model vocabulary."
121
 
122
  # Ensure the prompt length fits within the block size