Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import tiktoken | |
import gradio as gr | |
import asyncio | |
# Try to import spaces, use a dummy decorator if not available | |
try: | |
import spaces | |
use_spaces_gpu = True | |
except ImportError: | |
use_spaces_gpu = False | |
# Dummy decorator in case spaces is not available | |
def dummy_gpu_decorator(func): | |
return func | |
spaces = type('', (), {'GPU': dummy_gpu_decorator})() | |
# ... (keep the model architecture classes as they are) | |
# Update the load_model function | |
def load_model(model_path): | |
config = GPTConfig() | |
model = GPT(config) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
checkpoint = torch.load(model_path, map_location=device) | |
if 'model_state_dict' in checkpoint: | |
model.load_state_dict(checkpoint['model_state_dict']) | |
else: | |
model.load_state_dict(checkpoint) | |
model.eval() | |
model.to(device) | |
return model | |
# Load the model | |
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file | |
enc = tiktoken.get_encoding('gpt2') | |
# Update the generate_text function | |
# Adjust duration as needed | |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40): | |
device = next(model.parameters()).device | |
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device) | |
generated = [] | |
with torch.no_grad(): | |
for _ in range(max_length): | |
outputs, _ = model(input_ids) | |
next_token_logits = outputs[:, -1, :] | |
next_token_logits = next_token_logits / temperature | |
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) | |
next_token_probs = F.softmax(top_k_logits, dim=-1) | |
next_token_index = torch.multinomial(next_token_probs, num_samples=1) | |
next_token = top_k_indices.gather(-1, next_token_index) | |
input_ids = torch.cat([input_ids, next_token], dim=-1) | |
generated.append(next_token.item()) | |
next_token_str = enc.decode([next_token.item()]) | |
yield next_token_str | |
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100: | |
break | |
await asyncio.sleep(0.02) # Slightly faster typing effect | |
if len(generated) == max_length: | |
yield "... (output truncated due to length)" | |
# Update the gradio_generate function | |
# Adjust duration as needed | |
async def gradio_generate(prompt, max_length, temperature, top_k): | |
output = "" | |
async for token in generate_text(prompt, max_length, temperature, top_k): | |
output += token | |
yield output | |
# The rest of your Gradio interface code remains the same |