Adityadn commited on
Commit
5fac3fd
·
verified ·
1 Parent(s): 79bc2b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -22
app.py CHANGED
@@ -1,35 +1,27 @@
1
  import gradio as gr
2
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
 
4
- # Load pre-trained model and tokenizer
5
- model_name = "gpt2" # You can use other models like gpt-2-large or gpt-3 for better performance
6
- model = GPT2LMHeadModel.from_pretrained(model_name)
7
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
 
9
- # Function to generate keywords based on a prompt
10
- def generate_keywords(prompt):
11
- # Encode input prompt with a more direct instruction for only keywords
12
- prompt_with_instruction = prompt + " Only provide a list of keywords, no additional text."
13
- inputs = tokenizer.encode(prompt_with_instruction, return_tensors="pt")
14
 
15
- # Generate output from model
16
  outputs = model.generate(inputs, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95)
17
 
18
- # Decode generated tokens
19
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
-
21
- # Clean up the text to remove unnecessary parts
22
- # Remove anything after 'Only provide a list of keywords'
23
- clean_text = generated_text.split("Only provide a list of keywords")[0].strip()
24
-
25
- # Return the keywords only
26
- return clean_text
27
 
28
  # Gradio interface
29
  iface = gr.Interface(fn=generate_keywords,
30
- inputs=gr.Textbox(label="Enter Ad Prompt", placeholder="E.g., Generate ad keywords for wireless headphones"),
31
  outputs=gr.Textbox(label="Generated Keywords"),
32
  live=True)
33
 
34
- # Launch interface
35
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ # Load the pre-trained LLaMA model and tokenizer
5
+ tokenizer = AutoTokenizer.from_pretrained("facebook/llama-7b")
6
+ model = AutoModelForCausalLM.from_pretrained("facebook/llama-7b")
 
7
 
8
+ # Function to generate keywords from input text
9
+ def generate_keywords(text):
10
+ # Encode the input text
11
+ inputs = tokenizer.encode(text, return_tensors="pt")
 
12
 
13
+ # Generate the output from the model
14
  outputs = model.generate(inputs, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95)
15
 
16
+ # Decode and return the generated keywords
17
+ keywords = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
+ return keywords.strip()
 
 
 
 
 
 
19
 
20
  # Gradio interface
21
  iface = gr.Interface(fn=generate_keywords,
22
+ inputs=gr.Textbox(label="Enter Prompt", placeholder="E.g., Generate ad keywords for wireless headphones"),
23
  outputs=gr.Textbox(label="Generated Keywords"),
24
  live=True)
25
 
26
+ # Launch the Gradio interface
27
  iface.launch()