File size: 3,067 Bytes
e3619f7
 
085d77c
 
 
 
 
 
e3619f7
085d77c
e3619f7
085d77c
 
 
 
 
e3619f7
085d77c
 
e3619f7
085d77c
 
 
 
 
e3619f7
085d77c
 
 
 
 
 
 
 
 
 
 
 
e3619f7
085d77c
 
 
 
 
 
 
 
 
 
 
e3619f7
085d77c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca143a1
e3619f7
085d77c
e3619f7
 
 
ca143a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from huggingface_hub import InferenceClient
import tensorflow as tf
from huggingface_hub import login, create_repo, upload_file
from transformers import AutoTokenizer, TFAutoModelForCausalLM
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
tf.keras.mixed_precision.set_global_policy(policy)
strategy = tf.distribute.MultiWorkerMirroredStrategy()

login(os.environ.get("hf_token"))

name = "WICKED4950/GPT2-InstEsther0.21eV3.1"
tokenizer = AutoTokenizer.from_pretrained(name)
tokenizer.pad_token = tokenizer.eos_token
with strategy.scope():
    model = TFAutoModelForCausalLM.from_pretrained(name)

def raw_pred(input, model, tokenizer, max_length=50, temperature=0.2):
    input_ids = tokenizer.encode(input, return_tensors='tf')

    # Initialize variables
    generated_ids = input_ids
    stop_token_id = tokenizer.encode("<|SOH|>", add_special_tokens=False)[0]  # ID for <|SOH|>
    all_generated_tokens = []  # To store generated token IDs
    tokens_yielded = []  # To store tokens as they are yielded

    with strategy.scope():
        for _ in range(max_length // 1):  # Generate in chunks of 3 tokens
            # Generate three tokens at a time
            outputs = model.generate(
                generated_ids,
                max_length=generated_ids.shape[1] + 1,  # Increment max length by 3
                temperature=temperature,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=stop_token_id,  # Stop generation at <|SOH|>
                do_sample=True,
                num_return_sequences=1
            )

            # Get the newly generated tokens (last 3 tokens)
            new_tokens = outputs[0, -1:]
            generated_ids = outputs  # Update the generated_ids with the new tokens

            # Store the generated tokens as numbers (IDs)
            all_generated_tokens.extend(new_tokens.numpy().tolist())

            # Decode and yield the tokens as they are generated (as numbers)
            tokens_text = tokenizer.decode(new_tokens, skip_special_tokens=False)
            tokens_yielded.append(tokens_text)
            yield tokens_text

            # Stop if the generated tokens include <|SOH|>
            if stop_token_id in new_tokens.numpy():
                final_text = tokenizer.decode(all_generated_tokens, skip_special_tokens=False)
                yield ("<|Clean|>" + final_text)
                break

def respond(message, history):
    # Prepare input for the model
    give_mod = ""
    for chunk in history:
        give_mod = give_mod + "<|SOH|>" + chunk[0] + "<|SOB|>" + chunk[1]
    give_mod = give_mod + "<|SOH|>" + message + "<|SOB|>"
    print(give_mod)
    response = ""
    for token in raw_pred(give_mod, model, tokenizer):
        if "<|Clean|>" in token:
            response = token
        else:
            response += token
        yield response.replace("<|SOH|>","").replace("<|Clean|>","")
    print(response)
# Gradio Chat Interface Setup
demo = gr.ChatInterface(
    respond
)

if __name__ == "__main__":
    demo.launch()