Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import graphrag | |
def load_model(): | |
bert_model_name = "bert-base-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
bert_model = AutoModel.from_pretrained(bert_model_name) | |
# Initialize Graphrag model | |
model = graphrag.GraphRAG( | |
bert_model, | |
num_labels=2, # Adjust based on your task | |
num_hidden_layers=2, | |
hidden_size=768, | |
intermediate_size=3072, | |
) | |
# If you have a pre-trained Graphrag model, load it here | |
# model.load_state_dict(torch.load('path_to_your_model.pth')) | |
return tokenizer, model | |
def process_text(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Process outputs based on your specific task | |
# This is a placeholder; adjust according to your model's output | |
logits = outputs.logits if hasattr(outputs, 'logits') else outputs | |
probabilities = torch.softmax(logits, dim=1) | |
return probabilities.tolist()[0] | |
st.title("Graphrag Text Analysis") | |
tokenizer, model = load_model() | |
# File uploader | |
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") | |
if uploaded_file is not None: | |
data = pd.read_csv(uploaded_file) | |
st.write(data.head()) | |
if st.button("Process Data"): | |
results = [] | |
for text in data['text']: # Assuming your CSV has a 'text' column | |
result = process_text(text, tokenizer, model) | |
results.append(result) | |
data['results'] = results | |
st.write(data) | |
# Text input for single prediction | |
text_input = st.text_area("Enter text for analysis:") | |
if st.button("Analyze Text"): | |
if text_input: | |
result = process_text(text_input, tokenizer, model) | |
st.write(f"Analysis Result: {result}") | |
else: | |
st.write("Please enter some text to analyze.") | |
# Add a link to sample data | |
st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)") |