shahp7575 commited on
Commit
bdf7620
1 Parent(s): e1c8931

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import requests
4
+ import torch
5
+ import streamlit as st
6
+ from streamlit_lottie import st_lottie
7
+ from transformers import AutoTokenizer, AutoModelWithLMHead
8
+
9
+ warnings.filterwarnings("ignore")
10
+
11
+ st.set_page_config(layout='centered', page_title='GPT2-Horoscopes')
12
+
13
+ def load_lottieurl(url: str):
14
+ # https://github.com/tylerjrichards/streamlit_goodreads_app/blob/master/books.py
15
+ r = requests.get(url)
16
+ if r.status_code != 200:
17
+ return None
18
+ return r.json()
19
+
20
+ lottie_book = load_lottieurl('https://assets2.lottiefiles.com/packages/lf20_WL3aE7.json')
21
+ st_lottie(lottie_book, speed=1, height=200, key="initial")
22
+
23
+ st.markdown('# GPT2-Horoscopes!')
24
+ st.markdown("""
25
+ Hello! This lovely app lets GPT-2 write awesome horoscopes for you. All you need to do
26
+ is select your sign and choose the horoscope category :)
27
+ """)
28
+ st.markdown("""
29
+ *If you are interested in the fine-tuned model, you can visit the [Model Hub](https://huggingface.co/shahp7575/gpt2-horoscopes) or
30
+ my [GitHub Repo](https://github.com/shahp7575/gpt2-horoscopes).*
31
+ """)
32
+
33
+
34
+ @st.cache(allow_output_mutation=True, max_entries=1)
35
+ def download_model():
36
+ tokenizer = AutoTokenizer.from_pretrained('shahp7575/gpt2-horoscopes')
37
+ model = AutoModelWithLMHead.from_pretrained('shahp7575/gpt2-horoscopes')
38
+ return model, tokenizer
39
+ model, tokenizer = download_model()
40
+
41
+ def make_prompt(category):
42
+ return f"<|category|> {category} <|horoscope|>"
43
+
44
+ def generate(prompt, model, tokenizer, temperature, num_outputs, top_k):
45
+
46
+ sample_outputs = model.generate(prompt,
47
+ #bos_token_id=random.randint(1,30000),
48
+ do_sample=True,
49
+ top_k=top_k,
50
+ max_length = 300,
51
+ top_p=0.95,
52
+ temperature=temperature,
53
+ num_return_sequences=num_outputs)
54
+
55
+ return sample_outputs
56
+
57
+ with st.beta_container():
58
+
59
+ horoscope = st.selectbox("Choose Your Sign: ", ('Aquarius', 'Pisces', 'Aries',
60
+ 'Taurus', 'Gemini', 'Cancer',
61
+ 'Leo', 'Virgo', 'Libra',
62
+ 'Scorpio', 'Sagittarius', 'Capricorn'), index=0)
63
+ choice = st.selectbox("Choose Category:", ('general', 'career', 'love', 'wellness', 'birthday'),
64
+ index=0, )
65
+
66
+ temp_slider = st.slider("Temperature (Higher Value = More randomness)", min_value=0.01, max_value=1.0, value=0.95)
67
+
68
+ if st.button('Generate Horoscopes!'):
69
+ prompt = make_prompt(choice)
70
+ prompt_encoded = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
71
+ with st.spinner('Generating...'):
72
+ sample_output = generate(prompt_encoded, model, tokenizer, temperature=temp_slider, num_outputs=1, top_k=40)
73
+ final_out = tokenizer.decode(sample_output[0], skip_special_tokens=True)
74
+ st.write(final_out[len(choice)+2:])
75
+ else: pass
76
+
77
+