Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Model and tokenizer setup | |
def setup_model_and_tokenizer(): | |
logger.info("Loading model and tokenizer...") | |
model_name = "umairrrkhan/english-text-generation" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Ensure pad_token is set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
if model.config.pad_token_id is None: | |
model.config.pad_token_id = tokenizer.pad_token_id | |
logger.info("Model and tokenizer loaded successfully.") | |
return model, tokenizer | |
model, tokenizer = setup_model_and_tokenizer() | |
# Define text generation function | |
def generate_text(prompt): | |
logger.info(f"Received prompt: {prompt}") | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
try: | |
logger.info("Generating text...") | |
outputs = model.generate( | |
inputs['input_ids'], | |
max_length=50, | |
attention_mask=inputs['attention_mask'], | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logger.info(f"Generated response: {response}") | |
return response | |
except Exception as e: | |
logger.error(f"Error during text generation: {e}") | |
return "An error occurred during text generation." | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs="text", | |
outputs="text", | |
title="AI Text Generation Chatbot", | |
description="Lowkey curious? Type a prompt and see what Its generate!", | |
examples=["Tell me a story about a robot.", "Write a poem about the moon."] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
logger.info("Launching Gradio interface...") | |
iface.launch(debug=True, server_name="0.0.0.0", server_port=7860) | |