yoonusajwardapiit commited on
Commit
141eb85
1 Parent(s): d2fde25

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -88,7 +88,7 @@ class BigramLanguageModel(nn.Module):
88
  logits = logits[:, -1, :] # Get the logits for the last token
89
  probs = nn.functional.softmax(logits, dim=-1)
90
  idx_next = torch.multinomial(probs, num_samples=1)
91
- idx_next = torch.clamp(idx_next, max=60) # Cap the generated index to max 60
92
  idx = torch.cat((idx, idx_next), dim=1)
93
  return idx
94
 
@@ -108,7 +108,7 @@ chars = sorted(list(set("abcdefghijklmnopqrstuvwxyz0123456789 .,!?-:;'\"\n")))
108
  stoi = {ch: i for i, ch in enumerate(chars)}
109
  itos = {i: ch for i, ch in enumerate(chars)}
110
  encode = lambda s: [stoi.get(c, stoi.get(c.lower(), -1)) for c in s if c in stoi or c.lower() in stoi] # Handles both cases
111
- decode = lambda l: ''.join([itos[i] for i in l])
112
 
113
  # Function to generate text using the model
114
  def generate_text(prompt):
@@ -146,8 +146,7 @@ interface = gr.Interface(
146
  inputs=gr.Textbox(lines=2, placeholder="Enter a location or prompt..."),
147
  outputs="text",
148
  title="Triptuner Model",
149
- description="Generate itineraries for locations in Sri Lanka's Central Province.",
150
- theme="compact", # Add a theme for better UI appearance
151
  )
152
 
153
  # Launch the interface
 
88
  logits = logits[:, -1, :] # Get the logits for the last token
89
  probs = nn.functional.softmax(logits, dim=-1)
90
  idx_next = torch.multinomial(probs, num_samples=1)
91
+ idx_next = torch.clamp(idx_next, min=0, max=60) # Strictly enforce index range [0, 60]
92
  idx = torch.cat((idx, idx_next), dim=1)
93
  return idx
94
 
 
108
  stoi = {ch: i for i, ch in enumerate(chars)}
109
  itos = {i: ch for i, ch in enumerate(chars)}
110
  encode = lambda s: [stoi.get(c, stoi.get(c.lower(), -1)) for c in s if c in stoi or c.lower() in stoi] # Handles both cases
111
+ decode = lambda l: ''.join([itos[i] for i in l if i < len(itos)]) # Ensures index is within bounds
112
 
113
  # Function to generate text using the model
114
  def generate_text(prompt):
 
146
  inputs=gr.Textbox(lines=2, placeholder="Enter a location or prompt..."),
147
  outputs="text",
148
  title="Triptuner Model",
149
+ description="Generate itineraries for locations in Sri Lanka's Central Province."
 
150
  )
151
 
152
  # Launch the interface