Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# Load a better free model (OpenAssistant) | |
MODEL_NAME = "OpenAssistant/oasst-sft-1-pythia-12b" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
# System prompt for the AI | |
SYSTEM_PROMPT = """NORTHERN_AI is an AI assistant. If asked about who created it or who is the CEO, | |
it should respond that it was created by AR.BALTEE who is also the CEO.""" | |
# Function to generate AI responses | |
def get_ai_response(message): | |
try: | |
# Check if asking about creator/CEO | |
if any(keyword in message.lower() for keyword in ["who made you", "who created you", "creator", "ceo", "who owns"]): | |
return "I was created by AR.BALTEE, who is also the CEO of NORTHERN_AI." | |
# Prepare input for the model | |
input_text = f"{SYSTEM_PROMPT}\n\nUser: {message}\nAI:" | |
inputs = tokenizer(input_text, return_tensors="pt").to(device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=200, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
# Decode and clean the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("AI:")[-1].strip() | |
return response | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
return "Sorry, I encountered an error while generating a response. Please try again." | |
# Custom CSS for a beautiful UI | |
css = """ | |
.gradio-container { | |
max-width: 800px !important; | |
margin: 0 auto !important; | |
background: linear-gradient(135deg, #f0f4f8, #d9e2ec) !important; | |
padding: 20px !important; | |
border-radius: 15px !important; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important; | |
} | |
#header-container { | |
display: flex !important; | |
align-items: center !important; | |
margin-bottom: 1.5rem !important; | |
background-color: transparent !important; | |
padding: 0.5rem 1rem !important; | |
} | |
#logo { | |
background-color: #0066ff !important; | |
color: white !important; | |
border-radius: 50% !important; | |
width: 40px !important; | |
height: 40px !important; | |
display: flex !important; | |
align-items: center !important; | |
justify-content: center !important; | |
font-weight: bold !important; | |
margin-right: 10px !important; | |
font-size: 20px !important; | |
} | |
#title { | |
margin: 0 !important; | |
font-size: 24px !important; | |
font-weight: 600 !important; | |
color: #333 !important; | |
} | |
#chatbot { | |
background-color: white !important; | |
border-radius: 15px !important; | |
padding: 20px !important; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important; | |
height: 400px !important; | |
overflow-y: auto !important; | |
} | |
#footer { | |
font-size: 12px !important; | |
color: #666 !important; | |
text-align: center !important; | |
margin-top: 1.5rem !important; | |
padding: 0.5rem !important; | |
} | |
.textbox { | |
border-radius: 15px !important; | |
border: 1px solid #ddd !important; | |
padding: 10px !important; | |
font-size: 14px !important; | |
width: 100% !important; | |
} | |
.button { | |
background-color: #0066ff !important; | |
color: white !important; | |
border-radius: 15px !important; | |
padding: 10px 20px !important; | |
font-size: 14px !important; | |
border: none !important; | |
cursor: pointer !important; | |
transition: background-color 0.3s ease !important; | |
} | |
.button:hover { | |
background-color: #0052cc !important; | |
} | |
""" | |
# Create Gradio interface | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(): | |
# Custom header | |
with gr.Row(elem_id="header-container"): | |
gr.HTML('<div id="logo">N</div>') | |
gr.HTML('<h1 id="title">NORTHERN_AI</h1>') | |
# Chat interface | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Message NORTHERN_AI...", | |
show_label=False, | |
container=False, | |
elem_classes="textbox" | |
) | |
submit_btn = gr.Button("Send", elem_classes="button") | |
gr.HTML('<div id="footer">Powered by open-source technology</div>') | |
# State for tracking conversation | |
state = gr.State([]) | |
# Functions | |
def respond(message, chat_history): | |
if message == "": | |
return "", chat_history | |
# Add user message to history | |
chat_history.append((message, None)) | |
try: | |
# Generate response | |
bot_message = get_ai_response(message) | |
# Update last message with bot response | |
chat_history[-1] = (message, bot_message) | |
return "", chat_history | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
# Remove failed message attempt | |
chat_history.pop() | |
# Return error message | |
return "", chat_history | |
# Set up event handlers | |
msg.submit(respond, [msg, state], [msg, chatbot]) | |
submit_btn.click(respond, [msg, state], [msg, chatbot]) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |