Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Untitled6.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1F6f_vJbssO7C2FM6FILWljFYacDmbVBY | |
""" | |
# Import necessary libraries | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Load model and tokenizer | |
model_name = "distilgpt2" # A lightweight, CPU-friendly model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Define the function to generate a response | |
def generate_response(prompt): | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors="pt") | |
# Generate a response | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=50, | |
do_sample=True, # Enable sampling | |
temperature=0.7, # Controls randomness | |
top_p=0.9, # Nucleus sampling | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode the output and set clean_up_tokenization_spaces to True to avoid warnings | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
return response | |
# Example usage | |
prompt = "I went to Safeway and I bought a" | |
response = generate_response(prompt) | |
print(response) | |
def persona_response(prompt, persona="I am a helpful assistant"): | |
full_prompt = f"{persona}. {prompt}" | |
return generate_response(full_prompt) | |
# Import Gradio | |
import gradio as gr | |
# Define Gradio interface function | |
def chat_interface(user_input, persona="I am a helpful assistant"): | |
return persona_response(user_input, persona) | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=chat_interface, | |
inputs=["text", "text"], # Allows input for both prompt and persona | |
outputs="text", | |
title="Simple Chatbot", | |
description="Type something to chat with the bot! Add a persona to change its style, like 'I am a shopping assistant.'" | |
) | |
# Launch the Gradio interface in Colab | |
interface.launch(share=True) # share=True creates a public link | |