text / app.py
Nikhil0987's picture
Update app.py
99bddf0 verified
raw
history blame
1.09 kB
import streamlit as st
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# Load the model and tokenizer
@st.cache_resource # Cache model for efficiency
def load_model():
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
return tokenizer, model
tokenizer, model = load_model()
# Input/Output areas
st.title("Simple Chatbot")
user_input = st.text_input("Enter your message:")
# Preprocess and generate response when the user hits Enter
if user_input:
if user_input.lower() == "quit":
st.stop() # End the Streamlit app
encoded_input = tokenizer(user_input, return_tensors='pt')
outputs = model(**encoded_input)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
# # Example sentiment mapping (you might have your own)
# sentiment_map = {0: "negative", 1: "neutral", 2: "positive"}
# sentiment = sentiment_map[predicted_class_id]
st.write("outputs")