File size: 3,439 Bytes
40e9898
 
 
 
 
36338f2
40e9898
fb12737
40e9898
 
 
36338f2
40e9898
 
36338f2
 
e4461ed
36338f2
 
 
40e9898
 
 
36338f2
 
 
 
 
40e9898
36338f2
2b02259
40e9898
36338f2
e4461ed
 
 
 
 
 
 
36338f2
 
2b02259
40e9898
e4461ed
 
2b02259
40e9898
36338f2
e4461ed
 
 
36338f2
2b02259
e4461ed
36338f2
 
e4461ed
 
 
 
 
 
 
 
 
 
36338f2
 
2b02259
e4461ed
36338f2
 
e4461ed
 
 
 
 
 
 
 
 
 
36338f2
e4461ed
 
 
 
 
 
36338f2
 
 
e4461ed
36338f2
 
e4461ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
""" Script for streamlit demo
    @author: AbinayaM02
"""

# Install necessary libraries
from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
import streamlit as st
import json

# Read the config
with open("config.json") as f:
    config = json.loads(f.read())

# Set page layout
st.set_page_config(
        page_title="Tamil Language Models",
        page_icon="✍️",
        layout="wide",
        initial_sidebar_state="expanded"
    )

# Load the model
@st.cache(allow_output_mutation=True)
def load_model(model_name):
    with st.spinner('Waiting for the model to load.....'):
        model = AutoModelWithLMHead.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

# Side bar
img = st.sidebar.image("images/tamil_logo.jpg", width=300)

# Choose the model based on selection
st.sidebar.title("கதை சொல்லி!")
page = st.sidebar.selectbox(label="Select model", 
                            options=config["models"],
                            help="Select the model to generate the text")
data = st.sidebar.selectbox(label="Select data",
                            options=config[page],
                            help="Select the data on which the model is trained")

# Main page
st.title("Tamil Language Demos")
st.markdown(
    "Built as part of the Flax/Jax Community week, this demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) "
    "and [GPT2 trained on Oscar & IndicNLP dataset] (https://huggingface.co/abinayam/gpt-2-tamil) "
    "to show language generation!"
)

# Set default options for examples
prompts = config["examples"] + ["Custom"]

if page == 'Text Generation' and data == 'Oscar':
    st.header('Tamil text generation with GPT2')
    st.markdown('A simple demo using gpt-2-tamil model trained on Oscar dataset!')
    model, tokenizer = load_model(config[data])
    # Set default options
    prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
    if prompt == "Custom":
        prompt_box = ""
    else:
        prompt_box = prompt
    text = st.text_input(
        'Add your custom text in Tamil',
        "",
        max_chars=1000)
    max_len = st.slider('Length of the sentence to generate', 25, 300, 100)
    gen_bt = st.button('Generate')
elif page == 'Text Generation' and data == "Oscar + Indic Corpus":
    st.header('Tamil text generation with GPT2')
    st.markdown('A simple demo using gpt-2-tamil model trained on Oscar + IndicNLP dataset')
    model, tokenizer = load_model(config[data])
    # Set default options
    prompt = st.selectbox('Examples', prompts, index=len(prompts) - 1)
    if prompt == "Custom":
        prompt_box = ""
    else:
        prompt_box = prompt
    text = st.text_input(
        'Add your custom text in Tamil',
        "",
        max_chars=1000)
    max_len = st.slider('Length of the sentence', 5, 300, 100)
    gen_bt = st.button('Generate')
else:
    st.title('Tamil News classification with Finetuned GPT2')
    st.markdown('In progress')

# Generate text
if gen_bt:
        try:
            with st.spinner('Generating...'):
                generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
                seqs = generator(prompt_box, max_length=max_len)[0]['generated_text']
            st.write(seqs)
        except Exception as e:
            st.exception(f'Exception: {e}')