sailormars18 commited on
Commit
c0357a8
·
1 Parent(s): 90fa324

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -16,12 +16,20 @@ def generate_text(prompt, length=100, theme=None, **kwargs):
16
  model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
17
  config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/config.json"
18
  generation_config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/generation_config.json"
 
 
 
 
 
 
 
 
19
 
20
  # Load the model from the Hugging Face space
21
- model = transformers.GPT2LMHeadModel.from_pretrained(model_url, config=config_url).to(device)
22
 
23
  # Load the tokenizer from the Hugging Face space
24
- tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_url, config=config_url)
25
 
26
  # If a theme is specified, add it to the prompt as a prefix for a special token
27
  if theme:
@@ -36,7 +44,6 @@ def generate_text(prompt, length=100, theme=None, **kwargs):
36
 
37
  sample_outputs = model.generate(
38
  input_ids,
39
- config=generation_config_url,
40
  attention_mask=attention_mask,
41
  pad_token_id=pad_token_id,
42
  do_sample=True,
 
16
  model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
17
  config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/config.json"
18
  generation_config_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/generation_config.json"
19
+
20
+ # Load the model configuration from the config.json file
21
+ with open(config_url, 'r') as f:
22
+ config = json.load(f)
23
+
24
+ # Load the generation configuration from the generation_config.json file
25
+ with open(generation_config_url, 'r') as f:
26
+ generation_config = json.load(f)
27
 
28
  # Load the model from the Hugging Face space
29
+ model = transformers.GPT2LMHeadModel.from_pretrained(model_url, config=config, **generation_config).to(device)
30
 
31
  # Load the tokenizer from the Hugging Face space
32
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
33
 
34
  # If a theme is specified, add it to the prompt as a prefix for a special token
35
  if theme:
 
44
 
45
  sample_outputs = model.generate(
46
  input_ids,
 
47
  attention_mask=attention_mask,
48
  pad_token_id=pad_token_id,
49
  do_sample=True,