File size: 3,381 Bytes
e3619f7
 
085d77c
 
f77535c
085d77c
 
 
 
e3619f7
085d77c
e3619f7
d357694
085d77c
 
 
 
e3619f7
d273c52
085d77c
e3619f7
085d77c
 
 
 
 
e3619f7
085d77c
 
 
 
 
 
 
 
 
 
 
 
e3619f7
085d77c
 
 
 
 
 
 
 
 
 
 
e3619f7
085d77c
 
 
 
 
 
 
 
 
8cf095c
 
085d77c
 
 
 
 
 
 
 
 
 
 
ca143a1
e3619f7
1de0b1f
 
 
 
e3619f7
 
1de0b1f
 
 
e3619f7
ca143a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
from huggingface_hub import InferenceClient
import tensorflow as tf
from huggingface_hub import login, create_repo, upload_file
import os
from transformers import AutoTokenizer, TFAutoModelForCausalLM
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
tf.keras.mixed_precision.set_global_policy(policy)
strategy = tf.distribute.MultiWorkerMirroredStrategy()

login(os.environ.get("hf_token"))

name = "WICKED4950/GPT2-InstEsther0.25eV3.1"
tokenizer = AutoTokenizer.from_pretrained(name)
tokenizer.pad_token = tokenizer.eos_token
with strategy.scope():
    model = TFAutoModelForCausalLM.from_pretrained(name)

def raw_pred(input, model, tokenizer, max_length=1024, temperature=0.2):
    input_ids = tokenizer.encode(input, return_tensors='tf')

    # Initialize variables
    generated_ids = input_ids
    stop_token_id = tokenizer.encode("<|SOH|>", add_special_tokens=False)[0]  # ID for <|SOH|>
    all_generated_tokens = []  # To store generated token IDs
    tokens_yielded = []  # To store tokens as they are yielded

    with strategy.scope():
        for _ in range(max_length // 1):  # Generate in chunks of 3 tokens
            # Generate three tokens at a time
            outputs = model.generate(
                generated_ids,
                max_length=generated_ids.shape[1] + 1,  # Increment max length by 3
                temperature=temperature,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=stop_token_id,  # Stop generation at <|SOH|>
                do_sample=True,
                num_return_sequences=1
            )

            # Get the newly generated tokens (last 3 tokens)
            new_tokens = outputs[0, -1:]
            generated_ids = outputs  # Update the generated_ids with the new tokens

            # Store the generated tokens as numbers (IDs)
            all_generated_tokens.extend(new_tokens.numpy().tolist())

            # Decode and yield the tokens as they are generated (as numbers)
            tokens_text = tokenizer.decode(new_tokens, skip_special_tokens=False)
            tokens_yielded.append(tokens_text)
            yield tokens_text

            # Stop if the generated tokens include <|SOH|>
            if stop_token_id in new_tokens.numpy():
                final_text = tokenizer.decode(all_generated_tokens, skip_special_tokens=False)
                yield ("<|Clean|>" + final_text)
                break

def respond(message, history):
    # Prepare input for the model
    give_mod = ""
    history = history[-3:]
    for chunk in history:
        give_mod = give_mod + "<|SOH|>" + chunk[0] + "<|SOB|>" + chunk[1]
    give_mod = give_mod + "<|SOH|>" + message + "<|SOB|>"
    print(give_mod)
    response = ""
    for token in raw_pred(give_mod, model, tokenizer):
        if "<|Clean|>" in token:
            response = token
        else:
            response += token
        yield response.replace("<|SOH|>","").replace("<|Clean|>","")
    print(response)
# Gradio Chat Interface Setup
demo = gr.ChatInterface(
    fn=respond,  # Response handler function
    title="Chat with Esther",  # Add a title
    description="A friendly chatbot ready to help and chat with you! 😊",  # Brief description
    theme="compact",  # Options: "compact", "default", "dark"
)

# Launch the interface
demo.launch()

if __name__ == "__main__":
    demo.launch()