Mini-GPT / app.py
TharunSiva's picture
Update app.py
1f156ff
raw
history blame
578 Bytes
import gradio as gr
from model import *
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = GPTLanguageModel().to(DEVICE)
model.load_state_dict(torch.load("mini-gpt.pth",map_location=DEVICE), strict=False)
answer = decode(model.generate(context, max_new_tokens=1000)[0].tolist())
def display(number):
return answer[:number+1]
input_slider = gr.Slider(minimum=500, maximum=1000, default=500, label="Select the maxium number of tokens/words:")
output_text = gr.Textbox()
demo = gr.Interface(fn=display, inputs=input_slider, outputs=output_text)
demo.launch()