Gabriel Okiri
test
4122a88
raw
history blame
3.37 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
# Model initialization
repo_name = "BeardedMonster/SabiYarn-125M-translate" # Model repository
tokenizer_name = "BeardedMonster/SabiYarn-125M" # Tokenizer repository
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Define generation configuration with a maximum length
generation_config = GenerationConfig(
max_length=150, # Set a maximum length for output
max_new_tokens=50, # Ensure sufficient tokens for your translations
num_beams=5, # Moderate number of beams for a balance between speed and quality
do_sample=False, # Disable sampling to make output deterministic
temperature=1.0, # Neutral temperature since sampling is off
top_k=None, # Set to None for deterministic generation
top_p=None, # Set to None for deterministic generation
repetition_penalty=4.0, # Neutral repetition penalty for translation
length_penalty=3.0, # No penalty for sequence length; modify if your translations tend to be too short/long
early_stopping=True # Stop early when all beams finish to speed up generation
)
def generate_text(prompt, language):
# Add translation tag to prompt
tagged_prompt = f"<translate> <{language.lower()}> {prompt} <translate>"
# Tokenize
inputs = tokenizer(tagged_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
print(f"Tagged Prompt: {tagged_prompt}")
print(f"Inputs: {inputs}")
print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Attention Mask shape: {inputs['attention_mask'].shape}")
# Generate
try:
outputs = model.generate(
**inputs,
max_length=generation_config.max_length,
num_beams=generation_config.num_beams,
do_sample=generation_config.do_sample,
temperature=generation_config.temperature,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
repetition_penalty=generation_config.repetition_penalty,
length_penalty=generation_config.length_penalty,
early_stopping=generation_config.early_stopping
)
# Decode and return
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Error during generation: {e}")
return "An error occurred during text generation."
# Create Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Enter your prompt"),
gr.Dropdown(choices=["yor", "ibo", "hau", "efi", "pcm", "urh"], label="Select Language")
],
outputs=gr.Textbox(label="Generated Text"),
title="Nigerian Language Generator",
description="Generate text in Yoruba, Igbo, Hausa, Efik, Pidgin, or Urhobo using the Sabi Yarn model."
)
if __name__ == "__main__":
iface.launch(share=True) # Set share=True to create a public link