AiCoderv2 commited on
Commit
8111aa7
·
verified ·
1 Parent(s): d941e57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -2,23 +2,31 @@ import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import gradio as gr
4
 
5
- # Load GPT-2 XL model
6
- model_name = "gpt2-xl"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- # Create generator pipeline
11
- generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
12
 
13
  def generate_data(prompt, amount):
14
- responses = []
15
- for _ in range(amount):
16
- output = generator(prompt, max_length=100, num_return_sequences=1)[0]['generated_text']
17
- responses.append(output.strip())
18
- return responses
 
 
 
 
 
 
 
 
19
 
20
  with gr.Blocks() as demo:
21
- gr.Markdown("### GPT-2 XL Data Generator\nDescribe the data you'd like the AI to generate.")
22
  prompt_input = gr.Textbox(label="Prompt / Data Type", placeholder="Describe the data you want")
23
  amount_input = gr.Slider(1, 10, value=3, step=1, label="Number of Data Items")
24
  output_box = gr.Textbox(label="Generated Data", lines=15)
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import gradio as gr
4
 
5
+ # Load smaller GPT-2 model
6
+ model_name = "gpt2" # smaller and faster
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ # Create generator pipeline for CPU
11
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1)
12
 
13
  def generate_data(prompt, amount):
14
+ # Generate multiple samples in batch
15
+ responses = generator(
16
+ prompt,
17
+ max_length=50, # keep short for speed
18
+ num_return_sequences=amount,
19
+ do_sample=False, # greedy for speed
20
+ temperature=0.7,
21
+ top_k=50,
22
+ top_p=0.95,
23
+ pad_token_id=tokenizer.eos_token_id,
24
+ num_beams=1 # greedy
25
+ )
26
+ return [resp['generated_text'].strip() for resp in responses]
27
 
28
  with gr.Blocks() as demo:
29
+ gr.Markdown("### Faster Data Generator with GPT-2\nDescribe what data you want to generate.")
30
  prompt_input = gr.Textbox(label="Prompt / Data Type", placeholder="Describe the data you want")
31
  amount_input = gr.Slider(1, 10, value=3, step=1, label="Number of Data Items")
32
  output_box = gr.Textbox(label="Generated Data", lines=15)