anjikum commited on
Commit
dbb4d61
·
verified ·
1 Parent(s): de2601d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +63 -0
  2. gpt_model_quantized.pt +3 -0
  3. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import tiktoken
4
+ import numpy as np
5
+ from train import GPT, GPTConfig # Make sure to upload train.py to the Space
6
+
7
+ def load_quantized_model():
8
+ model = GPT(GPTConfig())
9
+ quantized_dict = torch.load("gpt_model_quantized.pt")
10
+
11
+ # Dequantize model
12
+ state_dict = {}
13
+ for key, value in quantized_dict.items():
14
+ if isinstance(value, dict):
15
+ state_dict[key] = torch.tensor(
16
+ value['data'].astype(np.float32) * value['scale']
17
+ )
18
+ else:
19
+ state_dict[key] = value
20
+
21
+ model.load_state_dict(state_dict)
22
+ model.eval()
23
+ return model
24
+
25
+ def generate_text(input_text):
26
+ try:
27
+ # Set device
28
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+
30
+ # Load model
31
+ model = load_quantized_model()
32
+ model = model.to(device)
33
+
34
+ # Tokenize input
35
+ tokenizer = tiktoken.get_encoding('gpt2')
36
+ input_tokens = torch.tensor([tokenizer.encode(input_text)]).to(device)
37
+
38
+ # Generate
39
+ with torch.no_grad():
40
+ output_tokens = model.generate(input_tokens, max_new_tokens=500)[0].tolist()
41
+
42
+ # Decode and return
43
+ generated_text = tokenizer.decode(output_tokens)
44
+ return generated_text
45
+ except Exception as e:
46
+ return f"Error generating text: {e}"
47
+
48
+ # Create Gradio interface
49
+ iface = gr.Interface(
50
+ fn=generate_text,
51
+ inputs=gr.Textbox(lines=5, label="Input Text"),
52
+ outputs=gr.Textbox(lines=10, label="Generated Text"),
53
+ title="GPT Text Generator",
54
+ description="Enter some text and the model will generate a continuation.",
55
+ examples=[
56
+ ["The quick brown fox"],
57
+ ["In a world where AI"],
58
+ ["Once upon a time"]
59
+ ]
60
+ )
61
+
62
+ # Launch the interface
63
+ iface.launch()
gpt_model_quantized.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84ddd50901ce047532e6587705beb5768740afc9879c4652094b842dd913c41c
3
+ size 255934032
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+