Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import GraphormerForGraphClassification, GraphormerTokenizer | |
from datasets import Dataset | |
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator | |
import torch | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from collections import Counter | |
def load_model(): | |
model = GraphormerForGraphClassification.from_pretrained( | |
"clefourrier/pcqm4mv2_graphormer_base", | |
num_classes=2, # Binary classification (positive/negative sentiment) | |
ignore_mismatched_sizes=True, | |
) | |
tokenizer = GraphormerTokenizer.from_pretrained("clefourrier/pcqm4mv2_graphormer_base") | |
return model, tokenizer | |
def text_to_graph(text): | |
words = text.split() | |
G = nx.Graph() | |
for i, word in enumerate(words): | |
G.add_node(i, word=word) | |
if i > 0: | |
G.add_edge(i-1, i) | |
edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()], | |
[e[1] for e in G.edges()] + [e[0] for e in G.edges()]] | |
return { | |
"edge_index": edge_index, | |
"num_nodes": len(G.nodes()), | |
"node_feat": [[ord(word[0])] for word in words], # Use ASCII value of first letter as feature | |
"edge_attr": [[1] for _ in range(len(G.edges()) * 2)], # All edges have the same attribute | |
"y": [1] # Placeholder label, will be ignored during inference | |
} | |
def analyze_text(text, model, tokenizer): | |
graph = text_to_graph(text) | |
dataset = Dataset.from_dict({"train": [graph]}) | |
dataset_processed = dataset.map(preprocess_item, batched=False) | |
inputs = GraphormerDataCollator()(dataset_processed["train"]) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=1) | |
sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative" | |
confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item() | |
return sentiment, confidence, graph | |
st.title("Graph-based Text Analysis") | |
model, tokenizer = load_model() | |
text_input = st.text_area("Enter text for analysis:", height=200) | |
if st.button("Analyze Text"): | |
if text_input: | |
sentiment, confidence, graph = analyze_text(text_input, model, tokenizer) | |
st.write(f"Sentiment: {sentiment}") | |
st.write(f"Confidence: {confidence:.2f}") | |
# Additional analysis | |
word_count = len(text_input.split()) | |
st.write(f"Word count: {word_count}") | |
# Most common words | |
words = [word.lower() for word in text_input.split() if word.isalnum()] | |
word_freq = Counter(words).most_common(5) | |
st.write("Top 5 most common words:") | |
for word, freq in word_freq: | |
st.write(f"- {word}: {freq}") | |
# Visualize graph | |
G = nx.Graph() | |
G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1])) | |
plt.figure(figsize=(10, 6)) | |
nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray') | |
plt.title("Text as Graph") | |
st.pyplot(plt) | |
else: | |
st.write("Please enter some text to analyze.") |