File size: 800 Bytes
b58c818
1256a85
 
b58c818
ab5688d
 
cd3df30
ab5688d
b58c818
 
b4ace98
b58c818
ab5688d
 
cd3df30
ab5688d
 
cd3df30
b58c818
ab5688d
 
 
d2f7d16
ab5688d
e7f5cf5
2b90304
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import transformers
import streamlit as st

from transformers import AutoTokenizer, AutoModelWithLMHead
from transformers import pipeline

#tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")

@st.cache
def load_model(model_name):
    model = AutoModelWithLMHead.from_pretrained(model_name)
    return model
    
def load_text_gen_model():
    generator = pipeline("text-generation", model="gpt2-medium")
    return generator 
    
#model = load_model("gpt2-medium")

text_generator = load_text_gen_model()

action = st.sidebar.selectbox("Pick an Action", ["Generate an Article","Create an Image"])

if action == "Generate an Article":
    prompt = st.text_input("Enter a prompt")
    if prompt:
        res = text_generator( prompt, max_length=100, temperature=0.7)
        st.write(res)