sailormars18 commited on
Commit
3581479
·
1 Parent(s): 990f381

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -1
app.py CHANGED
@@ -1,6 +1,69 @@
1
  import subprocess
2
- subprocess.run(["pip", "install","gradio"])
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  # Define a Gradio interface for the generate_text function, allowing users to input a prompt and generate text based on it
6
  iface = gr.Interface(
 
1
  import subprocess
2
+ subprocess.run(["pip", "install","gradio","transformers"])
3
  import gradio as gr
4
+ import transformers
5
+
6
+ # Define a function for generating text based on a prompt using the fine-tuned GPT-2 model and the tokenizer
7
+ def generate_text(prompt, length=100, theme=None, **kwargs):
8
+
9
+ model_url = "https://huggingface.co/spaces/sailormars18/Yelp-reviews-usingGPT2/blob/main/pytorch_model.bin"
10
+
11
+ # Load the model from the Hugging Face space
12
+ model = transformers.GPT2LMHeadModel.from_pretrained(model_url).to(device)
13
+
14
+ # Load the tokenizer from the Hugging Face space
15
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_url)
16
+
17
+ # If a theme is specified, add it to the prompt as a prefix for a special token
18
+ if theme:
19
+ prompt = ' <{}> '.format(theme.strip()) + prompt.strip()
20
+
21
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
22
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
23
+ pad_token_id = tokenizer.eos_token_id
24
+
25
+ # Set the max length of the generated text based on the input parameter
26
+ max_length = length if length > 0 else 100
27
+
28
+ sample_outputs = model.generate(
29
+ input_ids,
30
+ attention_mask=attention_mask,
31
+ pad_token_id=pad_token_id,
32
+ do_sample=True,
33
+ max_length=max_length,
34
+ top_k=50,
35
+ top_p=0.95,
36
+ temperature=0.8,
37
+ num_return_sequences=1,
38
+ no_repeat_ngram_size=2,
39
+ repetition_penalty=1.5,
40
+ )
41
+ generated_text = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
42
+
43
+ # Post preprocessing of the generated text
44
+
45
+ # Remove any leading and trailing quotation marks
46
+ generated_text = generated_text.strip('"')
47
+
48
+ # Remove leading and trailing whitespace
49
+ generated_text = generated_text.strip()
50
+
51
+ # Find the special token in the generated text and remove it
52
+ match = re.search(r'<([^>]+)>', generated_text)
53
+ if match:
54
+ generated_text = generated_text[:match.start()] + generated_text[match.end():]
55
+
56
+ # Remove any leading numeric characters and quotation marks
57
+ generated_text = re.sub(r'^\d+', '', generated_text)
58
+ generated_text = re.sub(r'^"', '', generated_text)
59
+
60
+ # Remove any newline characters from the generated text
61
+ generated_text = generated_text.replace('\n', '')
62
+
63
+ # Remove any other unwanted special characters
64
+ generated_text = re.sub(r'[^\w\s]+', '', generated_text)
65
+
66
+ return generated_text.strip().capitalize()
67
 
68
  # Define a Gradio interface for the generate_text function, allowing users to input a prompt and generate text based on it
69
  iface = gr.Interface(