File size: 2,937 Bytes
b9e3404
 
 
 
 
 
 
 
 
 
 
fb8b485
 
b9e3404
586df44
 
 
b9e3404
 
 
 
 
586df44
 
b9e3404
 
 
 
586df44
 
1457ccb
b9e3404
1457ccb
b9e3404
 
 
 
 
 
 
 
 
 
 
 
586df44
b9e3404
586df44
 
b9e3404
586df44
f25c0be
b9e3404
 
 
 
 
 
586df44
 
 
 
 
b9e3404
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
import torch
import streamlit as st



MODELS={
        'uribe':{
                'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/uribe-twitter-assistant-30ep"),
                'model':AutoModelForCausalLM.from_pretrained("jhonparra18/uribe-twitter-assistant-30ep")},
        'petro':{
                'tokenizer':AutoTokenizer.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large"),
                'model':AutoModelForCausalLM.from_pretrained("jhonparra18/petro-twitter-assistant-30ep-large")}}

def callback_input_text(new_text):
    del st.session_state.input_user_txt
    st.session_state.input_user_txt=new_text

def text_completion(tokenizer,model,input_text:str,max_len:int=100):
    tokenizer.padding_side="left" ##start padding from left to right
    tokenizer.pad_token = tokenizer.eos_token 
    input_ids = tokenizer([input_text], return_tensors="pt",truncation=True,max_length=128)
    with torch.no_grad(): ##maybe useless as the generate method does not compute gradients, just in case
        outputs = model.generate(**input_ids, do_sample=True, max_length=max_len,top_k=100,top_p=0.95)
    out_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    return out_text




st.markdown("<h3 style='text-align: center; color: gray;'> &#128038 Tweet de Político Colombiano: Autocompletado/generación de texto a partir de GPT2</h3>", unsafe_allow_html=True)
st.text("")
st.markdown("<h5 style='text-align: center; color: gray;'>Causal Language Modeling, source code <a href='https://github.com/statscol/twitter-user-autocomplete-assistant'> here </a> </h5>", unsafe_allow_html=True)
st.text("")


col1,col2 = st.columns(2)

with col1:
    with st.form("input_values"):
        politician = st.selectbox(
            "Selecciona el político",
            ("Uribe", "Petro")
        )
        st.text("")
        max_length_text=st.slider('Num Max Tokens', 50, 200, 100,10,key="user_max_length")
        st.text("")
        input_user_text=st.empty()
        input_text_value=input_user_text.text_area('Input Text', 'Mi gobierno no es corrupto',key="input_user_txt",height=300)
        st.text("")
        complete_input=st.checkbox("Complete Input [Experimental]",value=False,help="Automáticamente rellenar el texto inicial con el resultado para una nueva iteración")
        go_button=st.form_submit_button('Generate')


with col2:
    
    if go_button: ##avoid re running script
        with st.spinner('Generating Text...'):
            output_text=text_completion(MODELS[politician.lower()]['tokenizer'],MODELS[politician.lower()]['model'],input_text_value,max_length_text)
            st.text_area("Tweet:",output_text,height=500,key="output_text")
            if complete_input:
                callback_input_text(output_text)
                input_user_text.text_area("Input Text", output_text,height=300)