TensorFlowClass / pages /21_GraphRag.py
eaglelandsonce's picture
Update pages/21_GraphRag.py
cb06d03 verified
raw
history blame
3.35 kB
import streamlit as st
from transformers import GraphormerForGraphClassification, GraphormerTokenizer
from datasets import Dataset
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
import torch
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter
@st.cache_resource
def load_model():
model = GraphormerForGraphClassification.from_pretrained(
"clefourrier/pcqm4mv2_graphormer_base",
num_classes=2, # Binary classification (positive/negative sentiment)
ignore_mismatched_sizes=True,
)
tokenizer = GraphormerTokenizer.from_pretrained("clefourrier/pcqm4mv2_graphormer_base")
return model, tokenizer
def text_to_graph(text):
words = text.split()
G = nx.Graph()
for i, word in enumerate(words):
G.add_node(i, word=word)
if i > 0:
G.add_edge(i-1, i)
edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()],
[e[1] for e in G.edges()] + [e[0] for e in G.edges()]]
return {
"edge_index": edge_index,
"num_nodes": len(G.nodes()),
"node_feat": [[ord(word[0])] for word in words], # Use ASCII value of first letter as feature
"edge_attr": [[1] for _ in range(len(G.edges()) * 2)], # All edges have the same attribute
"y": [1] # Placeholder label, will be ignored during inference
}
def analyze_text(text, model, tokenizer):
graph = text_to_graph(text)
dataset = Dataset.from_dict({"train": [graph]})
dataset_processed = dataset.map(preprocess_item, batched=False)
inputs = GraphormerDataCollator()(dataset_processed["train"])
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
return sentiment, confidence, graph
st.title("Graph-based Text Analysis")
model, tokenizer = load_model()
text_input = st.text_area("Enter text for analysis:", height=200)
if st.button("Analyze Text"):
if text_input:
sentiment, confidence, graph = analyze_text(text_input, model, tokenizer)
st.write(f"Sentiment: {sentiment}")
st.write(f"Confidence: {confidence:.2f}")
# Additional analysis
word_count = len(text_input.split())
st.write(f"Word count: {word_count}")
# Most common words
words = [word.lower() for word in text_input.split() if word.isalnum()]
word_freq = Counter(words).most_common(5)
st.write("Top 5 most common words:")
for word, freq in word_freq:
st.write(f"- {word}: {freq}")
# Visualize graph
G = nx.Graph()
G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1]))
plt.figure(figsize=(10, 6))
nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray')
plt.title("Text as Graph")
st.pyplot(plt)
else:
st.write("Please enter some text to analyze.")