sagar007 commited on
Commit
181cfde
·
verified ·
1 Parent(s): e60730b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -239
app.py CHANGED
@@ -3,130 +3,29 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
6
- import torch
7
- import torch.nn as nn
8
- from torch.nn import functional as F
9
- import tiktoken
10
- import gradio as gr
11
- import asyncio
12
- import gradio as gr
13
  import asyncio
14
 
15
- # Add the post-processing function here
16
- def post_process_text(text):
17
- # Ensure the text starts with a capital letter
18
- text = text.capitalize()
19
-
20
- # Remove any incomplete sentences at the end
21
- sentences = text.split('.')
22
- complete_sentences = sentences[:-1] if len(sentences) > 1 else sentences
23
-
24
- # Rejoin sentences and add a period if missing
25
- processed_text = '. '.join(complete_sentences)
26
- if not processed_text.endswith('.'):
27
- processed_text += '.'
28
-
29
- return processed_text
30
- # Define the model architecture
31
- class GPTConfig:
32
- def __init__(self):
33
- self.block_size = 1024
34
- self.vocab_size = 50304
35
- self.n_layer = 12
36
- self.n_head = 12
37
- self.n_embd = 768
38
-
39
- class CausalSelfAttention(nn.Module):
40
- def __init__(self, config):
41
- super().__init__()
42
- assert config.n_embd % config.n_head == 0
43
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
44
- self.c_proj = nn.Linear(config.n_embd, config.n_embd)
45
- self.n_head = config.n_head
46
- self.n_embd = config.n_embd
47
- self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
48
-
49
- def forward(self, x):
50
- B, T, C = x.size()
51
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
52
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
53
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
54
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
55
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
56
- y = y.transpose(1, 2).contiguous().view(B, T, C)
57
- return self.c_proj(y)
58
-
59
- class MLP(nn.Module):
60
- def __init__(self, config):
61
- super().__init__()
62
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
63
- self.gelu = nn.GELU()
64
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
65
-
66
- def forward(self, x):
67
- return self.c_proj(self.gelu(self.c_fc(x)))
68
-
69
- class Block(nn.Module):
70
- def __init__(self, config):
71
- super().__init__()
72
- self.ln_1 = nn.LayerNorm(config.n_embd)
73
- self.attn = CausalSelfAttention(config)
74
- self.ln_2 = nn.LayerNorm(config.n_embd)
75
- self.mlp = MLP(config)
76
 
77
- def forward(self, x):
78
- x = x + self.attn(self.ln_1(x))
79
- x = x + self.mlp(self.ln_2(x))
80
- return x
81
-
82
- class GPT(nn.Module):
83
- def __init__(self, config):
84
- super().__init__()
85
- self.config = config
86
- self.transformer = nn.ModuleDict(dict(
87
- wte = nn.Embedding(config.vocab_size, config.n_embd),
88
- wpe = nn.Embedding(config.block_size, config.n_embd),
89
- h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
90
- ln_f = nn.LayerNorm(config.n_embd),
91
- ))
92
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
93
- self.transformer.wte.weight = self.lm_head.weight
94
- self.apply(self._init_weights)
95
-
96
- def _init_weights(self, module):
97
- if isinstance(module, nn.Linear):
98
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
99
- if module.bias is not None:
100
- torch.nn.init.zeros_(module.bias)
101
- elif isinstance(module, nn.Embedding):
102
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
103
-
104
- def forward(self, idx, targets=None):
105
- device = idx.device
106
- b, t = idx.size()
107
- assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
108
- pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
109
-
110
- tok_emb = self.transformer.wte(idx)
111
- pos_emb = self.transformer.wpe(pos)
112
- x = tok_emb + pos_emb
113
- for block in self.transformer.h:
114
- x = block(x)
115
- x = self.transformer.ln_f(x)
116
- logits = self.lm_head(x)
117
-
118
- loss = None
119
- if targets is not None:
120
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
121
-
122
- return logits, loss
123
 
 
124
  @spaces.GPU
125
  def load_model(model_path):
126
  config = GPTConfig()
127
  model = GPT(config)
128
 
129
- checkpoint = torch.load(model_path, map_location=torch.device('cuda'))
 
130
 
131
  if 'model_state_dict' in checkpoint:
132
  model.load_state_dict(checkpoint['model_state_dict'])
@@ -134,7 +33,7 @@ def load_model(model_path):
134
  model.load_state_dict(checkpoint)
135
 
136
  model.eval()
137
- model.to('cuda')
138
  return model
139
 
140
  # Load the model
@@ -144,7 +43,8 @@ enc = tiktoken.get_encoding('gpt2')
144
  # Update the generate_text function
145
  @spaces.GPU(duration=60) # Adjust duration as needed
146
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
147
- input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).cuda()
 
148
  generated = []
149
 
150
  with torch.no_grad():
@@ -179,125 +79,4 @@ async def gradio_generate(prompt, max_length, temperature, top_k):
179
  output += token
180
  yield output
181
 
182
- # Custom CSS for the animation effect
183
- import gradio as gr
184
- import asyncio
185
-
186
- # Your existing imports and model code here...
187
-
188
- css = """
189
- <style>
190
- body {
191
- background-color: #0f1624;
192
- color: #e0e0e0;
193
- font-family: 'Courier New', monospace;
194
- background-image:
195
- radial-gradient(white, rgba(255,255,255,.2) 2px, transparent 40px),
196
- radial-gradient(white, rgba(255,255,255,.15) 1px, transparent 30px),
197
- radial-gradient(white, rgba(255,255,255,.1) 2px, transparent 40px),
198
- radial-gradient(rgba(255,255,255,.4), rgba(255,255,255,.1) 2px, transparent 30px);
199
- background-size: 550px 550px, 350px 350px, 250px 250px, 150px 150px;
200
- background-position: 0 0, 40px 60px, 130px 270px, 70px 100px;
201
- animation: backgroundScroll 60s linear infinite;
202
- }
203
- @keyframes backgroundScroll {
204
- 0% { background-position: 0 0, 40px 60px, 130px 270px, 70px 100px; }
205
- 100% { background-position: 550px 550px, 590px 610px, 680px 820px, 620px 650px; }
206
- }
207
- .container { max-width: 800px; margin: 0 auto; padding: 20px; }
208
- .header {
209
- text-align: center;
210
- margin-bottom: 30px;
211
- font-family: 'Copperplate', fantasy;
212
- color: #ffd700;
213
- text-shadow: 0 0 10px #ffd700, 0 0 20px #ffd700, 0 0 30px #ffd700;
214
- }
215
- .chat-box {
216
- background-color: rgba(42, 42, 42, 0.7);
217
- border-radius: 15px;
218
- padding: 20px;
219
- margin-bottom: 20px;
220
- box-shadow: 0 0 20px rgba(255, 215, 0, 0.3);
221
- }
222
- .user-input {
223
- background-color: rgba(58, 58, 58, 0.8);
224
- border: 2px solid #ffd700;
225
- color: #ffffff;
226
- padding: 10px;
227
- border-radius: 5px;
228
- width: 100%;
229
- transition: all 0.3s ease;
230
- }
231
- .user-input:focus {
232
- box-shadow: 0 0 15px #ffd700;
233
- }
234
- .generate-btn {
235
- background-color: #ffd700;
236
- color: #0f1624;
237
- border: none;
238
- padding: 10px 20px;
239
- border-radius: 5px;
240
- cursor: pointer;
241
- font-weight: bold;
242
- transition: all 0.3s ease;
243
- }
244
- .generate-btn:hover {
245
- background-color: #ffec8b;
246
- transform: scale(1.05);
247
- }
248
- .output-box {
249
- background-color: rgba(42, 42, 42, 0.7);
250
- border-radius: 15px;
251
- padding: 20px;
252
- margin-top: 20px;
253
- min-height: 100px;
254
- border: 1px solid #ffd700;
255
- white-space: pre-wrap;
256
- font-family: 'Georgia', serif;
257
- line-height: 1.6;
258
- box-shadow: inset 0 0 10px rgba(255, 215, 0, 0.3);
259
- }
260
- .gr-slider {
261
- --slider-color: #ffd700;
262
- }
263
- .gr-box {
264
- border-color: #ffd700;
265
- background-color: rgba(42, 42, 42, 0.7);
266
- }
267
- </style>
268
- """
269
-
270
- with gr.Blocks(css=css) as demo:
271
- gr.HTML("<div class='header'><h1>🌟 Enchanted Tales Generator 🌟</h1></div>")
272
-
273
- with gr.Row():
274
- with gr.Column(scale=3):
275
- prompt = gr.Textbox(
276
- placeholder="Begin your magical journey here (e.g., 'In a realm beyond the mists of time...')",
277
- label="Story Incantation",
278
- elem_classes="user-input"
279
- )
280
- with gr.Column(scale=1):
281
- generate_btn = gr.Button("Weave the Tale", elem_classes="generate-btn")
282
-
283
- with gr.Row():
284
- max_length = gr.Slider(minimum=50, maximum=500, value=432, step=1, label="Scroll Length")
285
- temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Magical Intensity")
286
- top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Arcane Diversity")
287
-
288
- output = gr.Markdown(elem_classes="output-box")
289
-
290
- generate_btn.click(
291
- gradio_generate,
292
- inputs=[prompt, max_length, temperature, top_k],
293
- outputs=output
294
- )
295
-
296
- gr.HTML("""
297
- <div style="text-align: center; margin-top: 20px; font-style: italic; color: #ffd700;">
298
- "In the realm of imagination, every word is a spell, every sentence a charm."
299
- </div>
300
- """)
301
-
302
- if __name__ == "__main__":
303
- demo.launch()
 
3
  from torch.nn import functional as F
4
  import tiktoken
5
  import gradio as gr
 
 
 
 
 
 
 
6
  import asyncio
7
 
8
+ # Try to import spaces, use a dummy decorator if not available
9
+ try:
10
+ import spaces
11
+ use_spaces_gpu = True
12
+ except ImportError:
13
+ use_spaces_gpu = False
14
+ # Dummy decorator in case spaces is not available
15
+ def dummy_gpu_decorator(func):
16
+ return func
17
+ spaces = type('', (), {'GPU': dummy_gpu_decorator})()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # ... (keep the model architecture classes as they are)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Update the load_model function
22
  @spaces.GPU
23
  def load_model(model_path):
24
  config = GPTConfig()
25
  model = GPT(config)
26
 
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ checkpoint = torch.load(model_path, map_location=device)
29
 
30
  if 'model_state_dict' in checkpoint:
31
  model.load_state_dict(checkpoint['model_state_dict'])
 
33
  model.load_state_dict(checkpoint)
34
 
35
  model.eval()
36
+ model.to(device)
37
  return model
38
 
39
  # Load the model
 
43
  # Update the generate_text function
44
  @spaces.GPU(duration=60) # Adjust duration as needed
45
  async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
46
+ device = next(model.parameters()).device
47
+ input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
48
  generated = []
49
 
50
  with torch.no_grad():
 
79
  output += token
80
  yield output
81
 
82
+ # The rest of your Gradio interface code remains the same