TensorFlowClass / pages /21_GraphRag.py
eaglelandsonce's picture
Update pages/21_GraphRag.py
2ead64f verified
raw
history blame
2.79 kB
import streamlit as st
import graphrag
import networkx as nx
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
import torch
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
nltk.download('punkt', quiet=True)
@st.cache_resource
def load_models():
# Load SentenceTransformer model for sentence embeddings
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
return sentence_model
def text_to_graph(text, sentence_model):
# Tokenize text into sentences
sentences = sent_tokenize(text)
# Create graph
G = nx.Graph()
# Add nodes (sentences) to the graph
for i, sentence in enumerate(sentences):
embedding = sentence_model.encode(sentence)
G.add_node(i, text=sentence, embedding=embedding)
# Add edges between sentences based on cosine similarity
for i in range(len(sentences)):
for j in range(i+1, len(sentences)):
similarity = torch.cosine_similarity(
torch.tensor(G.nodes[i]['embedding']),
torch.tensor(G.nodes[j]['embedding']),
dim=0
)
if similarity > 0.5: # Adjust this threshold as needed
G.add_edge(i, j, weight=similarity.item())
return G, sentences
def analyze_text(text, sentence_model):
G, sentences = text_to_graph(text, sentence_model)
# Basic graph analysis
num_nodes = G.number_of_nodes()
num_edges = G.number_of_edges()
avg_degree = sum(dict(G.degree()).values()) / num_nodes
# Identify important sentences using PageRank
pagerank = nx.pagerank(G)
important_sentences = sorted(pagerank, key=pagerank.get, reverse=True)[:3]
return G, sentences, num_nodes, num_edges, avg_degree, important_sentences
st.title("GraphRAG-based Text Analysis")
sentence_model = load_models()
text_input = st.text_area("Enter text for analysis:", height=200)
if st.button("Analyze Text"):
if text_input:
G, sentences, num_nodes, num_edges, avg_degree, important_sentences = analyze_text(text_input, sentence_model)
st.write(f"Number of sentences: {num_nodes}")
st.write(f"Number of connections: {num_edges}")
st.write(f"Average connections per sentence: {avg_degree:.2f}")
st.subheader("Most important sentences:")
for i in important_sentences:
st.write(f"- {sentences[i]}")
# Visualize graph
plt.figure(figsize=(10, 6))
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray')
plt.title("Text as Graph")
st.pyplot(plt)
else:
st.write("Please enter some text to analyze.")