Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
# Load pre-trained model and tokenizer | |
def load_model(model_name): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
return tokenizer, model, device | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return None, None, None | |
# Function to generate chat responses | |
def chat_with_niti(message, history): | |
tokenizer, model, device = load_model("facebook/mbart-large-50") | |
if tokenizer is None or model is None: | |
return "Sorry, I'm having trouble loading the model. Please try again later." | |
try: | |
# Add a prompt for better responses | |
prompt = f"User: {message}\nChatNiti:" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
output = model.generate( | |
input_ids, | |
max_length=100, | |
temperature=0.7, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(output[0], skip_special_tokens=True) | |
return response.split("ChatNiti:")[-1].strip() # Extract ChatNiti's response | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
return "Sorry, I encountered an error while generating a response." | |
# Create Gradio chat interface | |
demo = gr.ChatInterface( | |
fn=chat_with_niti, | |
title="ChatNiti - Your AI Chatbot", | |
description="Ask ChatNiti anything in Hindi, Hinglish, or English!" | |
) | |
# Launch the interface | |
demo.launch() |