Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import tensorflow as tf | |
from huggingface_hub import login, create_repo, upload_file | |
import os | |
from transformers import AutoTokenizer, TFAutoModelForCausalLM | |
policy = tf.keras.mixed_precision.Policy('mixed_bfloat16') | |
tf.keras.mixed_precision.set_global_policy(policy) | |
strategy = tf.distribute.MultiWorkerMirroredStrategy() | |
login(os.environ.get("hf_token")) | |
name = "WICKED4950/GPT2-InstEsther0.25eV3.1" | |
tokenizer = AutoTokenizer.from_pretrained(name) | |
tokenizer.pad_token = tokenizer.eos_token | |
with strategy.scope(): | |
model = TFAutoModelForCausalLM.from_pretrained(name) | |
def raw_pred(input, model, tokenizer, max_length=1024, temperature=0.2): | |
input_ids = tokenizer.encode(input, return_tensors='tf') | |
# Initialize variables | |
generated_ids = input_ids | |
stop_token_id = tokenizer.encode("<|SOH|>", add_special_tokens=False)[0] # ID for <|SOH|> | |
all_generated_tokens = [] # To store generated token IDs | |
tokens_yielded = [] # To store tokens as they are yielded | |
with strategy.scope(): | |
for _ in range(max_length // 1): # Generate in chunks of 3 tokens | |
# Generate three tokens at a time | |
outputs = model.generate( | |
generated_ids, | |
max_length=generated_ids.shape[1] + 1, # Increment max length by 3 | |
temperature=temperature, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=stop_token_id, # Stop generation at <|SOH|> | |
do_sample=True, | |
num_return_sequences=1 | |
) | |
# Get the newly generated tokens (last 3 tokens) | |
new_tokens = outputs[0, -1:] | |
generated_ids = outputs # Update the generated_ids with the new tokens | |
# Store the generated tokens as numbers (IDs) | |
all_generated_tokens.extend(new_tokens.numpy().tolist()) | |
# Decode and yield the tokens as they are generated (as numbers) | |
tokens_text = tokenizer.decode(new_tokens, skip_special_tokens=False) | |
tokens_yielded.append(tokens_text) | |
yield tokens_text | |
# Stop if the generated tokens include <|SOH|> | |
if stop_token_id in new_tokens.numpy(): | |
final_text = tokenizer.decode(all_generated_tokens, skip_special_tokens=False) | |
yield ("<|Clean|>" + final_text) | |
break | |
def respond(message, history): | |
# Prepare input for the model | |
give_mod = "" | |
history = history[-3:] | |
for chunk in history: | |
give_mod = give_mod + "<|SOH|>" + chunk[0] + "<|SOB|>" + chunk[1] | |
give_mod = give_mod + "<|SOH|>" + message + "<|SOB|>" | |
print(give_mod) | |
response = "" | |
for token in raw_pred(give_mod, model, tokenizer): | |
if "<|Clean|>" in token: | |
response = token | |
else: | |
response += token | |
yield response.replace("<|SOH|>","").replace("<|Clean|>","") | |
print(response) | |
# Gradio Chat Interface Setup | |
demo = gr.ChatInterface( | |
fn=respond, # Response handler function | |
title="Chat with Esther", # Add a title | |
description="A friendly chatbot ready to help and chat with you! π", # Brief description | |
theme="compact", # Options: "compact", "default", "dark" | |
) | |
# Launch the interface | |
demo.launch() | |
if __name__ == "__main__": | |
demo.launch() |