rabiyulfahim HamidRezaAttar commited on
Commit
564cc15
·
0 Parent(s):

Duplicate from HamidRezaAttar/gpt2-home

Browse files

Co-authored-by: HamidReza Fatollah Zadeh Attar <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Gpt2 Home
3
+ emoji: 🏢
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ license: apache-2.0
10
+ duplicated_from: HamidRezaAttar/gpt2-home
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
__pycache__/examples.cpython-39.pyc ADDED
Binary file (446 Bytes). View file
 
__pycache__/meta.cpython-39.pyc ADDED
Binary file (292 Bytes). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.22 kB). View file
 
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline, set_seed
3
+ from transformers import AutoTokenizer
4
+ from normalizer import Normalizer
5
+ import random
6
+
7
+ import meta
8
+ import examples
9
+ from utils import (
10
+ remote_css,
11
+ local_css
12
+ )
13
+
14
+
15
+ class TextGeneration:
16
+ def __init__(self):
17
+ self.debug = False
18
+ self.dummy_output = None
19
+ self.tokenizer = None
20
+ self.generator = None
21
+ self.task = "text-generation"
22
+ self.model_name_or_path = "HamidRezaAttar/gpt2-product-description-generator"
23
+ set_seed(42)
24
+
25
+ def load(self):
26
+ if not self.debug:
27
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
28
+ self.generator = pipeline(self.task, model=self.model_name_or_path, tokenizer=self.model_name_or_path)
29
+
30
+ def generate(self, prompt, generation_kwargs):
31
+ if not self.debug:
32
+ generation_kwargs["num_return_sequences"] = 1
33
+
34
+ max_length = len(self.tokenizer(prompt)["input_ids"]) + generation_kwargs["max_length"]
35
+ generation_kwargs["max_length"] = max_length
36
+
37
+ generation_kwargs["return_full_text"] = False
38
+
39
+ return self.generator(
40
+ prompt,
41
+ **generation_kwargs,
42
+ )[0]["generated_text"]
43
+
44
+ return self.dummy_output
45
+
46
+
47
+ @st.cache(allow_output_mutation=True)
48
+ def load_text_generator():
49
+ generator = TextGeneration()
50
+ generator.load()
51
+ return generator
52
+
53
+
54
+ def main():
55
+ st.set_page_config(
56
+ page_title="GPT2 - Home",
57
+ page_icon="🏡",
58
+ layout="wide",
59
+ initial_sidebar_state="expanded"
60
+ )
61
+ remote_css("https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22")
62
+ local_css("assets/ltr.css")
63
+ generator = load_text_generator()
64
+
65
+ st.sidebar.markdown(meta.SIDEBAR_INFO)
66
+
67
+ max_length = st.sidebar.slider(
68
+ label='Max Length',
69
+ help="The maximum length of the sequence to be generated.",
70
+ min_value=1,
71
+ max_value=128,
72
+ value=50,
73
+ step=1
74
+ )
75
+ top_k = st.sidebar.slider(
76
+ label='Top-k',
77
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering",
78
+ min_value=40,
79
+ max_value=80,
80
+ value=50,
81
+ step=1
82
+ )
83
+ top_p = st.sidebar.slider(
84
+ label='Top-p',
85
+ help="Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for "
86
+ "generation.",
87
+ min_value=0.0,
88
+ max_value=1.0,
89
+ value=0.95,
90
+ step=0.01
91
+ )
92
+ temperature = st.sidebar.slider(
93
+ label='Temperature',
94
+ help="The value used to module the next token probabilities",
95
+ min_value=0.1,
96
+ max_value=10.0,
97
+ value=1.0,
98
+ step=0.05
99
+ )
100
+ do_sample = st.sidebar.selectbox(
101
+ label='Sampling ?',
102
+ options=(True, False),
103
+ help="Whether or not to use sampling; use greedy decoding otherwise.",
104
+ )
105
+ generation_kwargs = {
106
+ "max_length": max_length,
107
+ "top_k": top_k,
108
+ "top_p": top_p,
109
+ "temperature": temperature,
110
+ "do_sample": do_sample,
111
+ }
112
+
113
+ st.markdown(meta.HEADER_INFO)
114
+ prompts = list(examples.EXAMPLES.keys()) + ["Custom"]
115
+ prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
116
+
117
+ if prompt == "Custom":
118
+ prompt_box = meta.PROMPT_BOX
119
+ else:
120
+ prompt_box = random.choice(examples.EXAMPLES[prompt])
121
+
122
+ text = st.text_area("Enter text", prompt_box)
123
+ generation_kwargs_ph = st.empty()
124
+ cleaner = Normalizer()
125
+ if st.button("Generate !"):
126
+ with st.spinner(text="Generating ..."):
127
+ generation_kwargs_ph.markdown(", ".join([f"`{k}`: {v}" for k, v in generation_kwargs.items()]))
128
+ if text:
129
+ generated_text = generator.generate(text, generation_kwargs)
130
+ generated_text = cleaner.clean_txt(generated_text)
131
+ st.markdown(
132
+ f'<p class="ltr ltr-box">'
133
+ f'<span class="result-text">{text} <span>'
134
+ f'<span class="result-text generated-text">{generated_text}</span>'
135
+ f'</p>',
136
+ unsafe_allow_html=True
137
+ )
138
+
139
+ if __name__ == '__main__':
140
+ main()
assets/ltr.css ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .ltr,
2
+ textarea {
3
+ font-family: Roboto !important;
4
+ text-align: left;
5
+ direction: ltr !important;
6
+ }
7
+ .ltr-box {
8
+ border-bottom: 1px solid #ddd;
9
+ padding-bottom: 20px;
10
+ }
11
+ .rtl {
12
+ text-align: left;
13
+ direction: ltr !important;
14
+ }
15
+
16
+ span.result-text {
17
+ padding: 3px 3px;
18
+ line-height: 32px;
19
+ }
20
+ span.generated-text {
21
+ background-color: rgb(118 200 147 / 13%);
22
+ }
examples.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EXAMPLES = {
2
+ "Table": [
3
+ "Handcrafted of solid acacia in weathered gray, our round Jozy drop-leaf dining table is a space-saving."
4
+ ],
5
+ "Bed": [
6
+ "Maximize your bedroom space without sacrificing style with the storage bed."
7
+ ],
8
+ "Sofa": [
9
+ "Our plush and luxurious Emmett modular sofa brings custom comfort to your living space."
10
+ ]
11
+ }
meta.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ HEADER_INFO = """
2
+ # GPT2 - Home
3
+ English GPT-2 home product description generator demo.
4
+ """.strip()
5
+ SIDEBAR_INFO = """
6
+ # Configuration
7
+ """.strip()
8
+ PROMPT_BOX = "Enter your text..."
normalizer.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Normalizer:
2
+
3
+ def __init__(self):
4
+ pass
5
+
6
+ def remove_repetitions(self, text):
7
+ first_ocurrences = []
8
+ for sentence in text.split("."):
9
+ if sentence not in first_ocurrences:
10
+ first_ocurrences.append(sentence)
11
+ return '.'.join(first_ocurrences)
12
+
13
+ def trim_last_sentence(self, text):
14
+ return text[:text.rfind(".")+1]
15
+
16
+ def clean_txt(self, text):
17
+ return self.trim_last_sentence(self.remove_repetitions(text))
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ Pillow==9.0.1
3
+ streamlit==1.5.1
4
+ transformers==4.16.2
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ from PIL import Image
4
+
5
+
6
+ def load_image(image_path, image_resize=None):
7
+ image = Image.open(image_path)
8
+ if isinstance(image_resize, tuple):
9
+ image.resize(image_resize)
10
+ return image
11
+
12
+
13
+ def load_text(text_path):
14
+ text = ''
15
+ with open(text_path) as f:
16
+ text = f.read()
17
+
18
+ return text
19
+
20
+
21
+ def load_json(json_path):
22
+ jdata = ''
23
+ with open(json_path) as f:
24
+ jdata = json.load(f)
25
+
26
+ return jdata
27
+
28
+
29
+ def local_css(css_path):
30
+ with open(css_path) as f:
31
+ st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
32
+
33
+
34
+ def remote_css(css_url):
35
+ st.markdown(f'<link href="{css_url}" rel="stylesheet">', unsafe_allow_html=True)
36
+