File size: 3,561 Bytes
e275bad
 
cb04c7f
e275bad
 
932c085
e275bad
 
8e036eb
932c085
 
 
 
8e036eb
05a7b40
 
 
 
 
 
 
 
 
8e036eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9b548b
8e036eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348b668
a9b548b
348b668
 
e275bad
 
 
 
 
7fe3527
481b521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9b548b
481b521
 
 
 
 
 
 
 
 
 
 
 
 
 
348b668
8aeba17
5cba70c
 
 
 
 
8aeba17
481b521
 
e275bad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from huggingface_hub import InferenceClient
import os

"""
Copied from inference in colab notebook
"""

from transformers import pipeline

# Load model and tokenizer globally to avoid reloading for every request
model_path = "Mat17892/t5small_enfr_opus"

# translator = pipeline("translation_xx_to_yy", model=model_path)

# def respond(
#     message: str,
#     history: list[tuple[str, str]],
#     system_message: str,
#     max_tokens: int,
#     temperature: float,
#     top_p: float,
# ):
#     message = "translate English to French:" + message

#     response = translator(message)[0]
#     yield response['translation_text']

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
import threading

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

def respond(
    message: str,
    system_message: str,
    max_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 1.0,
):
    # Preprocess the input message
    input_text = system_message + " " + message
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Set up the streamer
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    # Generate in a separate thread to avoid blocking
    generation_thread = threading.Thread(
        target=model.generate,
        kwargs={
            "input_ids": input_ids,
            "max_new_tokens": max_tokens,
            "do_sample": True,
            "temperature": temperature,
            "top_p": top_p,
            "streamer": streamer,
        },
    )
    generation_thread.start()

    # Stream the output progressively
    generated_text = ""
    for token in streamer:  # Append each token to the accumulated text
        generated_text += token
        yield generated_text


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""

# Define the interface
with gr.Blocks() as demo:
    gr.Markdown("# Google Translate-like Interface")

    with gr.Row():
        with gr.Column():
            source_textbox = gr.Textbox(
                placeholder="Enter text in English...",
                label="Source Text (English)",
                lines=5,
            )
        with gr.Column():
            translated_textbox = gr.Textbox(
                placeholder="Translation will appear here...",
                label="Translated Text (French)",
                lines=5,
                interactive=False,
            )

    translate_button = gr.Button("Translate")

    with gr.Accordion("Advanced Settings", open=False):
        system_message_input = gr.Textbox(
            value="translate English to French:",
            label="System message",
        )
        max_tokens_slider = gr.Slider(
            minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"
        )
        temperature_slider = gr.Slider(
            minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"
        )
        top_p_slider = gr.Slider(
            minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
        )

    # Define functionality
    translate_button.click(
        respond,
        inputs=[
            source_textbox,
            system_message_input,
            max_tokens_slider,
            temperature_slider,
            top_p_slider,
        ],
        outputs=translated_textbox,
    )

if __name__ == "__main__":
    demo.launch()