File size: 1,372 Bytes
de32ef6
 
 
b61931d
de32ef6
 
 
125e91c
de32ef6
b61931d
de32ef6
 
 
 
 
413eb14
0937020
de32ef6
 
 
 
413eb14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained('t5-base')
@st.cache
def load_model():
    model = AutoModelForSeq2SeqLM.from_pretrained(f"Supiri/t5-base-conversation")
    return model

model = load_model()

num_beams = st.slider('Number of beams', min_value=1, max_value=10, value=6)
num_beam_groups = st.slider('Number of beam groups', min_value=1, max_value=10, value=2)
diversity_penalty = st.slider('Diversity penalty', min_value=0.1, max_value=5.0, value=2.5)

context = st.text_area('Personality', value="Hinata was soft-spoken and polite, always addressing people with proper honorifics. She is kind, always thinking of others more than for herself, caring for their feelings and well-being. She doesn't like being confrontational for any reason. This led to her being meek or timid to others, as her overwhelming kindness can render her unable to respond or act for fear of offending somebody.")
query = st.text_input('Question', value="What's your name?")

input_ids = tokenizer(f"personality: {context}", f"inquiry: {query}", return_tensors='pt').input_ids
outputs = model.generate(input_ids, num_beams=num_beams, diversity_penalty=diversity_penalty, num_beam_groups=num_beam_groups)

st.write(f"{context.split(' ')[0]}:\t", tokenizer.decode(outputs[0], skip_special_tokens=True))