import gradio as gr from huggingface_hub import InferenceClient import tensorflow as tf from huggingface_hub import login, create_repo, upload_file import os from transformers import AutoTokenizer, TFAutoModelForCausalLM policy = tf.keras.mixed_precision.Policy('mixed_bfloat16') tf.keras.mixed_precision.set_global_policy(policy) strategy = tf.distribute.MultiWorkerMirroredStrategy() login(os.environ.get("hf_token")) name = "WICKED4950/GPT2-InstEsther0.25eV3.1" tokenizer = AutoTokenizer.from_pretrained(name) tokenizer.pad_token = tokenizer.eos_token with strategy.scope(): model = TFAutoModelForCausalLM.from_pretrained(name) def raw_pred(input, model, tokenizer, max_length=1024, temperature=0.2): input_ids = tokenizer.encode(input, return_tensors='tf') # Initialize variables generated_ids = input_ids stop_token_id = tokenizer.encode("<|SOH|>", add_special_tokens=False)[0] # ID for <|SOH|> all_generated_tokens = [] # To store generated token IDs tokens_yielded = [] # To store tokens as they are yielded with strategy.scope(): for _ in range(max_length // 1): # Generate in chunks of 3 tokens # Generate three tokens at a time outputs = model.generate( generated_ids, max_length=generated_ids.shape[1] + 1, # Increment max length by 3 temperature=temperature, pad_token_id=tokenizer.eos_token_id, eos_token_id=stop_token_id, # Stop generation at <|SOH|> do_sample=True, num_return_sequences=1 ) # Get the newly generated tokens (last 3 tokens) new_tokens = outputs[0, -1:] generated_ids = outputs # Update the generated_ids with the new tokens # Store the generated tokens as numbers (IDs) all_generated_tokens.extend(new_tokens.numpy().tolist()) # Decode and yield the tokens as they are generated (as numbers) tokens_text = tokenizer.decode(new_tokens, skip_special_tokens=False) tokens_yielded.append(tokens_text) yield tokens_text # Stop if the generated tokens include <|SOH|> if stop_token_id in new_tokens.numpy(): final_text = tokenizer.decode(all_generated_tokens, skip_special_tokens=False) yield ("<|Clean|>" + final_text) break def respond(message, history): # Prepare input for the model give_mod = "" history = history[-3:] for chunk in history: give_mod = give_mod + "<|SOH|>" + chunk[0] + "<|SOB|>" + chunk[1] give_mod = give_mod + "<|SOH|>" + message + "<|SOB|>" print(give_mod) response = "" for token in raw_pred(give_mod, model, tokenizer): if "<|Clean|>" in token: response = token else: response += token yield response.replace("<|SOH|>","").replace("<|Clean|>","") print(response) # Gradio Chat Interface Setup demo = gr.ChatInterface( fn=respond, # Response handler function title="Chat with Esther", # Add a title description="A friendly chatbot ready to help and chat with you! 😊", # Brief description theme="compact", # Options: "compact", "default", "dark" ) # Launch the interface demo.launch() if __name__ == "__main__": demo.launch()