import streamlit as st import pandas as pd from transformers import AutoTokenizer, AutoModel import torch from graphrag.models import GraphragForSequenceClassification from graphrag.configuration_graphrag import GraphragConfig @st.cache_resource def load_model(): bert_model_name = "bert-base-uncased" tokenizer = AutoTokenizer.from_pretrained(bert_model_name) config = GraphragConfig.from_pretrained(bert_model_name) config.num_labels = 2 # Adjust based on your task model = GraphragForSequenceClassification(config) # If you have a pre-trained Graphrag model, load it here # model.load_state_dict(torch.load('path_to_your_model.pth')) return tokenizer, model def process_text(text, tokenizer, model): inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) # Process outputs based on your specific task logits = outputs.logits probabilities = torch.softmax(logits, dim=1) return probabilities.tolist()[0] st.title("Graphrag Text Analysis") tokenizer, model = load_model() # File uploader uploaded_file = st.file_uploader("Choose a CSV file", type="csv") if uploaded_file is not None: data = pd.read_csv(uploaded_file) st.write(data.head()) if st.button("Process Data"): results = [] for text in data['text']: # Assuming your CSV has a 'text' column result = process_text(text, tokenizer, model) results.append(result) data['results'] = results st.write(data) # Text input for single prediction text_input = st.text_area("Enter text for analysis:") if st.button("Analyze Text"): if text_input: result = process_text(text_input, tokenizer, model) st.write(f"Analysis Result: {result}") else: st.write("Please enter some text to analyze.") # Add a link to sample data st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")