TensorFlowClass / pages /21_GraphRag.py
eaglelandsonce's picture
Update pages/21_GraphRag.py
41f73cb verified
raw
history blame
2.06 kB
# put code here
import streamlit as st
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from graphrag import GraphragModel, GraphragConfig
import torch
@st.cache_resource
def load_model():
bert_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
config = GraphragConfig(
bert_model=bert_model,
num_labels=2, # Adjust based on your task
num_hidden_layers=2,
hidden_size=768,
intermediate_size=3072,
)
model = GraphragModel(config)
return tokenizer, model
def process_text(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
# Process outputs based on your specific task
# This is a placeholder; adjust according to your model's output
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
return probabilities.tolist()[0]
st.title("Graphrag Text Analysis")
tokenizer, model = load_model()
# File uploader
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
st.write(data.head())
if st.button("Process Data"):
results = []
for text in data['text']: # Assuming your CSV has a 'text' column
result = process_text(text, tokenizer, model)
results.append(result)
data['results'] = results
st.write(data)
# Text input for single prediction
text_input = st.text_area("Enter text for analysis:")
if st.button("Analyze Text"):
if text_input:
result = process_text(text_input, tokenizer, model)
st.write(f"Analysis Result: {result}")
else:
st.write("Please enter some text to analyze.")
# Add a link to sample data
st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")