text / app.py
Nikhil0987's picture
Update app.py
a640ab4 verified
import streamlit as st
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# Load the model and tokenizer
@st.cache_resource # Cache model for efficiency
def load_model():
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
return tokenizer, model
tokenizer, model = load_model()
# Input/Output areas
st.title("Simple Chatbot")
user_input = st.text_input("Enter your message:")
# Preprocess and generate response when the user hits Enter
if user_input:
if user_input.lower() == "quit":
st.stop()
# Encode the user input
input_ids = tokenizer.encode(user_input, return_tensors='pt')
# Generate a response (adjust parameters for control)
output_sequences = model.generate(
input_ids=input_ids,
max_length=50, # Example max response length
temperature=0.8, # Controls creativity
# ... other generation parameters ...
)
# Decode the generated text and display
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
st.write(generated_text)