Spaces:
Sleeping
Sleeping
yoonusajwardapiit
commited on
Upload app.py
Browse files
app.py
CHANGED
@@ -81,11 +81,12 @@ class BigramLanguageModel(nn.Module):
|
|
81 |
logits = self.lm_head(x)
|
82 |
return logits, None
|
83 |
|
84 |
-
def generate(self, idx, max_new_tokens):
|
85 |
for _ in range(max_new_tokens):
|
86 |
idx_cond = idx[:, -32:] # Truncate to the latest 32 tokens
|
87 |
logits, _ = self(idx_cond)
|
88 |
logits = logits[:, -1, :] # Get the logits for the last token
|
|
|
89 |
probs = nn.functional.softmax(logits, dim=-1)
|
90 |
idx_next = torch.multinomial(probs, num_samples=1)
|
91 |
idx_next = torch.clamp(idx_next, min=0, max=60) # Strictly enforce index range [0, 60]
|
@@ -129,13 +130,17 @@ def generate_text(prompt):
|
|
129 |
print(f"Encoded prompt: {context}")
|
130 |
|
131 |
with torch.no_grad():
|
132 |
-
generated = model.generate(context, max_new_tokens=20) #
|
133 |
print(f"Generated tensor: {generated}")
|
134 |
|
135 |
result = decode(generated[0].tolist())
|
136 |
print(f"Decoded result: {result}")
|
|
|
|
|
|
|
|
|
137 |
print(f"Processing time: {time.time() - start_time:.2f}s")
|
138 |
-
return
|
139 |
except Exception as e:
|
140 |
print(f"Error during generation: {e}")
|
141 |
return f"Error: {str(e)}"
|
|
|
81 |
logits = self.lm_head(x)
|
82 |
return logits, None
|
83 |
|
84 |
+
def generate(self, idx, max_new_tokens, temperature=0.7):
|
85 |
for _ in range(max_new_tokens):
|
86 |
idx_cond = idx[:, -32:] # Truncate to the latest 32 tokens
|
87 |
logits, _ = self(idx_cond)
|
88 |
logits = logits[:, -1, :] # Get the logits for the last token
|
89 |
+
logits = logits / temperature # Apply temperature control
|
90 |
probs = nn.functional.softmax(logits, dim=-1)
|
91 |
idx_next = torch.multinomial(probs, num_samples=1)
|
92 |
idx_next = torch.clamp(idx_next, min=0, max=60) # Strictly enforce index range [0, 60]
|
|
|
130 |
print(f"Encoded prompt: {context}")
|
131 |
|
132 |
with torch.no_grad():
|
133 |
+
generated = model.generate(context, max_new_tokens=20, temperature=0.7) # Adjust temperature
|
134 |
print(f"Generated tensor: {generated}")
|
135 |
|
136 |
result = decode(generated[0].tolist())
|
137 |
print(f"Decoded result: {result}")
|
138 |
+
|
139 |
+
# Post-process to clean up and make output more readable
|
140 |
+
cleaned_result = result.replace('\n', ' ').strip()
|
141 |
+
print(f"Cleaned result: {cleaned_result}")
|
142 |
print(f"Processing time: {time.time() - start_time:.2f}s")
|
143 |
+
return cleaned_result
|
144 |
except Exception as e:
|
145 |
print(f"Error during generation: {e}")
|
146 |
return f"Error: {str(e)}"
|