File size: 896 Bytes
c376e61
 
 
5861710
 
 
 
 
 
 
 
 
 
 
 
 
c376e61
5861710
 
 
 
c376e61
 
 
49dcd1c
985d236
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import pipeline
import streamlit as st

from streamlit.report_thread import get_report_ctx
def query_cache(q_emb=None):
    ctx = get_report_ctx()
    session_id = ctx.session_id
    session = st.server.server.Server.get_current()._get_session_info(session_id).session
    if not hasattr(session, "_query_state"):
        setattr(session, "_query_state", q_emb)
    if q_emb:
        session._query_state = q_emb
    return session._query_state
# usage
q_emb = query_cache() # will get from cache if exists
#q_emb = query_cache(new_emb) # will set cache to value

if 'user_text' not in q_emb:
    q_emb.user_text = 'foo'

st.text_input("Write something", value=q_emb.user_text)

if st.button("Write with transformer"):
    gpt2 = pipeline('text-generation')
    res = gpt2("My name is Mario and")[0]["generated_text"]
#    st.session_state.user_text = res
    st.user_text = res