File size: 3,208 Bytes
e3619f7
fe0a282
085d77c
 
fe0a282
 
 
085d77c
 
 
e3619f7
fe0a282
085d77c
fe0a282
 
bcd6a17
085d77c
 
fe0a282
085d77c
 
e3619f7
fe0a282
afe2bef
61bd718
e3619f7
085d77c
 
fe0a282
085d77c
 
e3619f7
085d77c
 
fe0a282
085d77c
 
fe0a282
085d77c
 
fe0a282
085d77c
 
 
e3619f7
fe0a282
085d77c
 
 
fe0a282
085d77c
 
 
 
e3619f7
fe0a282
085d77c
 
fe0a282
085d77c
 
fe0a282
085d77c
 
fe0a282
8cf095c
fe0a282
 
c538203
fe0a282
085d77c
 
 
 
c538203
085d77c
 
fe0a282
 
ca143a1
e3619f7
fe0a282
 
 
 
e3619f7
 
 
fe0a282
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from huggingface_hub import InferenceClient, login
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM
import os

# Set up mixed precision and distribution strategy
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
tf.keras.mixed_precision.set_global_policy(policy)
strategy = tf.distribute.MultiWorkerMirroredStrategy()

# Log into Hugging Face
login(os.environ.get("hf_token"))

# Load tokenizer and model
name = "WICKED4950/GPT2-InstEsther0.28eV3.1"
tokenizer = AutoTokenizer.from_pretrained(name)
tokenizer.pad_token = tokenizer.eos_token

with strategy.scope():
    model = TFAutoModelForCausalLM.from_pretrained(name)

# Raw Prediction Function
def raw_pred(input, model, tokenizer, max_length=1024, temperature=0.2):
    input_ids = tokenizer.encode(input, return_tensors='tf')

    # Initialize variables
    generated_ids = input_ids
    stop_token_id = tokenizer.encode("<|SOH|>", add_special_tokens=False)[0]
    all_generated_tokens = []  # To store generated token IDs
    tokens_yielded = []  # To store tokens as they are yielded

    with strategy.scope():
        for _ in range(max_length // 1):  # Generate in chunks of 3 tokens
            # Generate tokens
            outputs = model.generate(
                generated_ids,
                max_length=generated_ids.shape[1] + 1,
                temperature=temperature,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=stop_token_id,
                do_sample=True,
                num_return_sequences=1
            )

            # Get the newly generated tokens
            new_tokens = outputs[0, -1:]
            generated_ids = outputs  # Update the generated_ids with the new tokens

            # Store and yield the generated tokens
            all_generated_tokens.extend(new_tokens.numpy().tolist())
            tokens_text = tokenizer.decode(new_tokens, skip_special_tokens=False)
            tokens_yielded.append(tokens_text)
            yield tokens_text

            # Stop if stop token is encountered
            if stop_token_id in new_tokens.numpy():
                final_text = tokenizer.decode(all_generated_tokens, skip_special_tokens=False)
                yield "<|Clean|>" + final_text
                break

# Response Handler Function
def respond(message, history):
    give_mod = ""
    history = history[-3:]  # Limit history to last 3 exchanges
    for chunk in history:
        give_mod += f"<|SOH|>{chunk[0]}<|SOB|>{chunk[1]}"
    give_mod += f"<|SOH|>{message.capitalize()}<|SOB|>"
    print(give_mod)

    response = ""
    for token in raw_pred(give_mod, model, tokenizer):
        if "<|Clean|>" in token:
            response = token
            print(response)
        else:
            response += token
        yield response.replace("<|SOH|>", "").replace("<|Clean|>", "")

# Gradio Chat Interface Setup
demo = gr.ChatInterface(
    fn=respond,
    title="Chat with Esther",  # Title of the app
    description="A friendly chatbot ready to help and chat with you! 😊",  # Description of the app
    theme="compact",  # Choose the theme
)

if __name__ == "__main__":
    demo.launch(debug=True)