File size: 2,777 Bytes
69cc5ab
50d93bb
69cc5ab
50d93bb
 
69cc5ab
 
50d93bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
from transformers import pipeline

text_generation_model = "cahya/indochat-tiny"
text_generation = pipeline("text-generation", text_generation_model)


def get_answer(user_input, decoding_methods, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
    if decoding_methods == "Beam Search":
        do_sample = False
    elif decoding_methods == "Sampling":
        do_sample = True
    else:
        do_sample = False
    print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
    prompt = f"User: {user_input}\nAssistant: "
    generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
                                     do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature,
                                     repetition_penalty=repetition_penalty)#, penalty_alpha=penalty_alpha)
    answer = generated_text[0]["generated_text"]
    answer_without_prompt = answer[len(prompt)+1:]
    return answer_without_prompt


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown(
            "## IndoChat")
    with gr.Row():
        with gr.Column():
            user_input = gr.inputs.Textbox(placeholder="",
                                           label="Ask me something in Indonesian or English",
                                           default="Bagaimana cara mendidik anak supaya tidak berbohong?")
            decoding_methods = gr.Dropdown(["Beam Search", "Sampling", "Contrastive Search"])
            top_k = gr.inputs.Slider(label="Top K: The number of highest probability vocabulary tokens to keep",
                                     default=40, maximum=50, minimum=1, step=1)
            top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
            temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
            repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
            penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
            with gr.Row():
                button_generate_story = gr.Button("Submit")
        with gr.Column():
            generated_answer = gr.Textbox()
    with gr.Row():
        gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")

    button_generate_story.click(get_answer, inputs=[user_input, decoding_methods, top_k, top_p, temperature,
                                                    repetition_penalty, penalty_alpha], outputs=[generated_answer])

demo.launch(enable_queue=False)