sailormars18 commited on
Commit
b014abf
·
1 Parent(s): 945e7fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -5,16 +5,25 @@ import gradio as gr
5
  import torch
6
  import transformers
7
 
 
 
 
 
 
8
  # Define a function for generating text based on a prompt using the fine-tuned GPT-2 model and the tokenizer
9
  def generate_text(prompt, length=100, theme=None, **kwargs):
10
 
11
  model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
 
 
12
 
13
  # Load the model from the Hugging Face space
14
- model = transformers.GPT2LMHeadModel.from_pretrained(model_url).to(device)
 
15
 
16
  # Load the tokenizer from the Hugging Face space
17
- tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_url)
 
18
 
19
  # If a theme is specified, add it to the prompt as a prefix for a special token
20
  if theme:
 
5
  import torch
6
  import transformers
7
 
8
+ import json
9
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
  # Define a function for generating text based on a prompt using the fine-tuned GPT-2 model and the tokenizer
14
  def generate_text(prompt, length=100, theme=None, **kwargs):
15
 
16
  model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
17
+ config_name = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/config.json"
18
+ generation_config_name = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/generation_config.json"
19
 
20
  # Load the model from the Hugging Face space
21
+ model = transformers.GPT2LMHeadModel.from_pretrained(model_url, config_name=config_name,
22
+ generation_config_name=generation_config_name).to(device)
23
 
24
  # Load the tokenizer from the Hugging Face space
25
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_url, config_name=config_name,
26
+ generation_config_name=generation_config_name)
27
 
28
  # If a theme is specified, add it to the prompt as a prefix for a special token
29
  if theme: