Spaces:
Sleeping
Sleeping
import streamlit as st | |
import graphrag | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from sentence_transformers import SentenceTransformer | |
import torch | |
import nltk | |
from nltk.tokenize import sent_tokenize, word_tokenize | |
nltk.download('punkt', quiet=True) | |
def load_models(): | |
# Load SentenceTransformer model for sentence embeddings | |
sentence_model = SentenceTransformer('all-MiniLM-L6-v2') | |
return sentence_model | |
def text_to_graph(text, sentence_model): | |
# Tokenize text into sentences | |
sentences = sent_tokenize(text) | |
# Create graph | |
G = nx.Graph() | |
# Add nodes (sentences) to the graph | |
for i, sentence in enumerate(sentences): | |
embedding = sentence_model.encode(sentence) | |
G.add_node(i, text=sentence, embedding=embedding) | |
# Add edges between sentences based on cosine similarity | |
for i in range(len(sentences)): | |
for j in range(i+1, len(sentences)): | |
similarity = torch.cosine_similarity( | |
torch.tensor(G.nodes[i]['embedding']), | |
torch.tensor(G.nodes[j]['embedding']), | |
dim=0 | |
) | |
if similarity > 0.5: # Adjust this threshold as needed | |
G.add_edge(i, j, weight=similarity.item()) | |
return G, sentences | |
def analyze_text(text, sentence_model): | |
G, sentences = text_to_graph(text, sentence_model) | |
# Basic graph analysis | |
num_nodes = G.number_of_nodes() | |
num_edges = G.number_of_edges() | |
avg_degree = sum(dict(G.degree()).values()) / num_nodes | |
# Identify important sentences using PageRank | |
pagerank = nx.pagerank(G) | |
important_sentences = sorted(pagerank, key=pagerank.get, reverse=True)[:3] | |
return G, sentences, num_nodes, num_edges, avg_degree, important_sentences | |
st.title("GraphRAG-based Text Analysis") | |
sentence_model = load_models() | |
text_input = st.text_area("Enter text for analysis:", height=200) | |
if st.button("Analyze Text"): | |
if text_input: | |
G, sentences, num_nodes, num_edges, avg_degree, important_sentences = analyze_text(text_input, sentence_model) | |
st.write(f"Number of sentences: {num_nodes}") | |
st.write(f"Number of connections: {num_edges}") | |
st.write(f"Average connections per sentence: {avg_degree:.2f}") | |
st.subheader("Most important sentences:") | |
for i in important_sentences: | |
st.write(f"- {sentences[i]}") | |
# Visualize graph | |
plt.figure(figsize=(10, 6)) | |
pos = nx.spring_layout(G) | |
nx.draw(G, pos, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray') | |
plt.title("Text as Graph") | |
st.pyplot(plt) | |
else: | |
st.write("Please enter some text to analyze.") |