Spaces:
Sleeping
Sleeping
File size: 1,989 Bytes
41f73cb ee795ac 41f73cb eb6725a ee795ac eb6725a ee795ac eb6725a ee795ac eb6725a 41f73cb ee795ac 41f73cb ee795ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
import torch
@st.cache_resource
def load_model():
bert_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
config = BertConfig.from_pretrained(bert_model_name)
config.num_labels = 2 # Adjust based on your task
model = BertForSequenceClassification.from_pretrained(bert_model_name, config=config)
# If you have a pre-trained model, load it here
# model.load_state_dict(torch.load('path_to_your_model.pth'))
return tokenizer, model
def process_text(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
# Process outputs based on your specific task
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
return probabilities.tolist()[0]
st.title("BERT Text Analysis")
tokenizer, model = load_model()
# File uploader
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
st.write(data.head())
if st.button("Process Data"):
results = []
for text in data['text']: # Assuming your CSV has a 'text' column
result = process_text(text, tokenizer, model)
results.append(result)
data['results'] = results
st.write(data)
# Text input for single prediction
text_input = st.text_area("Enter text for analysis:")
if st.button("Analyze Text"):
if text_input:
result = process_text(text_input, tokenizer, model)
st.write(f"Analysis Result: {result}")
else:
st.write("Please enter some text to analyze.")
# Add a link to sample data
st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")
|