Sandy0909's picture
Update app.py
450ed84
raw
history blame
1.03 kB
import streamlit as st
from transformers import BertTokenizer, BertForSequenceClassification
import torch
BERT_PATH = "ahmedrachid/FinancialBERT"
MODEL_PATH = "Sandy0909/finance_sentiment"
MAX_LEN = 512
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained(BERT_PATH)
model = BertForSequenceClassification.from_pretrained(BERT_PATH, num_labels=3)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
def predict_sentiment(sentence):
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LEN)
with torch.no_grad():
logits = model(**inputs).logits
predictions = torch.argmax(logits, dim=-1)
return ['negative', 'neutral', 'positive'][predictions[0].item()]
# Streamlit app
st.title("Financial Sentiment Analysis")
sentence = st.text_area("Enter a financial sentence:", "")
if st.button("Predict"):
sentiment = predict_sentiment(sentence)
st.write(f"The predicted sentiment is: {sentiment}")