Spaces:
Paused
Paused
File size: 5,510 Bytes
325f8af 9129c29 325f8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
from transformers import pipeline
from config import ModelArgs
from inference import remove_prefix
from model import Llama
import torch
from inference_sft import topk_sampling
import os
import subprocess
import re
from tokenizer import Tokenizer
import torch.nn.functional as F
import shutil
# Define model paths
model_paths = {
"Pretrained": "weights/pretrained/models--YuvrajSingh9886--StoryLlama/snapshots/8c285708701e5b8eed82e70a26d1cc1207e97831/snapshot_4650.pt"
}
# ACCESS_TOKEN = os.getenv("GDRIVE_ACCESS_TOKEN")
# def download_models():
for i in model_paths:
subprocess.run(["python", "download_model_weight.py", "--model_type", i.lower()], check=True)
# download_models()
tk = Tokenizer()
tk = tk.ready_tokenizer()
def beam_search(model, prompt, device, max_length=50, beam_width=5, top_k=50, temperature=1.0):
input_ids = tk.encode(prompt, return_tensors='pt').to(device)
# Initialize beams with initial input repeated beam_width times
beams = input_ids.repeat(beam_width, 1)
beam_scores = torch.zeros(beam_width).to(device) # Initialize scores
for _ in range(max_length):
with torch.no_grad():
outputs = model(beams)
logits = outputs[:, -1, :] # Get last token logits
# Apply temperature scaling
scaled_logits = logits / temperature
# Calculate log probabilities
log_probs = F.log_softmax(scaled_logits, dim=-1)
# Get top k candidates for each beam
topk_log_probs, topk_indices = torch.topk(log_probs, top_k, dim=-1)
# Generate all possible candidates
expanded_beams = beams.repeat_interleave(top_k, dim=0)
new_tokens = topk_indices.view(-1, 1)
candidate_beams = torch.cat([expanded_beams, new_tokens], dim=1)
# Calculate new scores for all candidates
expanded_scores = beam_scores.repeat_interleave(top_k)
candidate_scores = expanded_scores + topk_log_probs.view(-1)
# Select top beam_width candidates
top_scores, top_indices = candidate_scores.topk(beam_width)
beams = candidate_beams[top_indices]
beam_scores = top_scores
# Select best beam
best_idx = beam_scores.argmax()
best_sequence = beams[best_idx]
return tk.decode(best_sequence, skip_special_tokens=True)
# Function to load the selected model
def load_model(model_type):
model_path = model_paths[model_type]
# Check if the model exists; if not, download it
# if not os.path.exists(model_path):
# shutil.rmtree(model_path)
# os.mkdir(model_path)
# print(f"{model_type} Model not found! Downloading...")
# subprocess.run(["python", "download_model_weight.py", f"--{model_type.lower()}"], check=True)
# else:
# print(f"{model_type} Model found, skipping download.")
# Load the model
model = Llama(
device=ModelArgs.device,
embeddings_dims=ModelArgs.embeddings_dims,
no_of_decoder_layers=ModelArgs.no_of_decoder_layers,
block_size=ModelArgs.block_size,
vocab_size=ModelArgs.vocab_size,
dropout=ModelArgs.dropout
)
model = model.to(ModelArgs.device)
dict_model = torch.load(model_path, weights_only=False)
dict_model['MODEL_STATE'] = remove_prefix(dict_model['MODEL_STATE'], '_orig_mod.')
model.load_state_dict(dict_model['MODEL_STATE'])
model.eval()
return model
# download_models()
current_model = load_model("Pretrained")
def answer_question(model_type, prompt, temperature, top_k, max_length):
global current_model
# Reload model if the selected model type is different
if model_type == "Base (Pretrained)":
model_type = "Pretrained"
if model_paths[model_type] != model_paths.get(current_model, None):
current_model = load_model(model_type)
# formatted_prompt = f"### Instruction: Answer the following query. \n\n ### Input: {prompt}.\n\n ### Response: "
with torch.no_grad():
# if decoding_method == "Beam Search":
# generated_text = beam_search(current_model, formatted_prompt, device=ModelArgs.device,
# max_length=max_length, beam_width=5, top_k=top_k, temperature=temperature)
# else:
generated_text = topk_sampling(current_model, prompt, max_length=max_length,
top_k=top_k, temperature=temperature, device=ModelArgs.device)
return generated_text
iface = gr.Interface(
fn=answer_question,
inputs=[
gr.Dropdown(choices=["Base (Pretrained)"], value="Pretrained", label="Select Model"),
# gr.Dropdown(choices=["Top-K Sampling", "Beam Search"], value="Top-K Sampling", label="Decoding Method"),
gr.Textbox(label="Prompt", lines=5),
gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Temperature"),
gr.Slider(minimum=50,maximum = ModelArgs.vocab_size, value=50, step=1, label="Top-k"),
gr.Slider(minimum=10, maximum=ModelArgs.block_size, value=256, step=1, label="Max Length")
],
outputs=gr.Textbox(label="Answer"),
title="StoryLlama",
description="Enter a prompt, select a model (Pretrained) and the model will generate a story!."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()
|