TensorFlowClass / pages /21_GraphRag.py
eaglelandsonce's picture
Update pages/21_GraphRag.py
4a2750f verified
raw
history blame
2.17 kB
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import graphrag
@st.cache_resource
def load_model():
bert_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
# Initialize Graphrag model
model = graphrag.GraphRAG(
bert_model,
num_labels=2, # Adjust based on your task
num_hidden_layers=2,
hidden_size=768,
intermediate_size=3072,
)
# If you have a pre-trained Graphrag model, load it here
# model.load_state_dict(torch.load('path_to_your_model.pth'))
return tokenizer, model
def process_text(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
# Process outputs based on your specific task
# This is a placeholder; adjust according to your model's output
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
probabilities = torch.softmax(logits, dim=1)
return probabilities.tolist()[0]
st.title("Graphrag Text Analysis")
tokenizer, model = load_model()
# File uploader
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
st.write(data.head())
if st.button("Process Data"):
results = []
for text in data['text']: # Assuming your CSV has a 'text' column
result = process_text(text, tokenizer, model)
results.append(result)
data['results'] = results
st.write(data)
# Text input for single prediction
text_input = st.text_area("Enter text for analysis:")
if st.button("Analyze Text"):
if text_input:
result = process_text(text_input, tokenizer, model)
st.write(f"Analysis Result: {result}")
else:
st.write("Please enter some text to analyze.")
# Add a link to sample data
st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")