Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
import torch | |
# Model initialization | |
repo_name = "BeardedMonster/SabiYarn-125M-translate" # Model repository | |
tokenizer_name = "BeardedMonster/SabiYarn-125M" # Tokenizer repository | |
# Load the model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
# Define generation configuration with a maximum length | |
generation_config = GenerationConfig( | |
max_length=150, # Set a maximum length for output | |
max_new_tokens=50, # Ensure sufficient tokens for your translations | |
num_beams=5, # Moderate number of beams for a balance between speed and quality | |
do_sample=False, # Disable sampling to make output deterministic | |
temperature=1.0, # Neutral temperature since sampling is off | |
top_k=None, # Set to None for deterministic generation | |
top_p=None, # Set to None for deterministic generation | |
repetition_penalty=4.0, # Neutral repetition penalty for translation | |
length_penalty=3.0, # No penalty for sequence length; modify if your translations tend to be too short/long | |
early_stopping=True # Stop early when all beams finish to speed up generation | |
) | |
def generate_text(prompt, language): | |
# Add translation tag to prompt | |
tagged_prompt = f"<translate> <{language.lower()}> {prompt} <translate>" | |
# Tokenize | |
inputs = tokenizer(tagged_prompt, return_tensors="pt", padding=True, truncation=True).to(device) | |
print(f"Tagged Prompt: {tagged_prompt}") | |
print(f"Inputs: {inputs}") | |
print(f"Input IDs shape: {inputs['input_ids'].shape}") | |
print(f"Attention Mask shape: {inputs['attention_mask'].shape}") | |
# Generate | |
try: | |
outputs = model.generate( | |
**inputs, | |
max_length=generation_config.max_length, | |
num_beams=generation_config.num_beams, | |
do_sample=generation_config.do_sample, | |
temperature=generation_config.temperature, | |
top_k=generation_config.top_k, | |
top_p=generation_config.top_p, | |
repetition_penalty=generation_config.repetition_penalty, | |
length_penalty=generation_config.length_penalty, | |
early_stopping=generation_config.early_stopping | |
) | |
# Decode and return | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
return "An error occurred during text generation." | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Enter your prompt"), | |
gr.Dropdown(choices=["yor", "ibo", "hau", "efi", "pcm", "urh"], label="Select Language") | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="Nigerian Language Generator", | |
description="Generate text in Yoruba, Igbo, Hausa, Efik, Pidgin, or Urhobo using the Sabi Yarn model." | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True) # Set share=True to create a public link |