import streamlit as st from transformers import AutoTokenizer, AutoModel import torch import networkx as nx import matplotlib.pyplot as plt from collections import Counter import graphrag # Import the graphrag library @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") bert_model = AutoModel.from_pretrained("bert-base-uncased") # Initialize GraphRAG model # Note: You may need to adjust these parameters based on GraphRAG's actual interface graph_rag_model = graphrag.GraphRAG( bert_model, num_labels=2, # For binary sentiment classification num_hidden_layers=2, hidden_size=768, intermediate_size=3072, ) return tokenizer, graph_rag_model def text_to_graph(text): words = text.split() G = nx.Graph() for i, word in enumerate(words): G.add_node(i, word=word) if i > 0: G.add_edge(i-1, i) edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()], [e[1] for e in G.edges()] + [e[0] for e in G.edges()]] return { "edge_index": edge_index, "num_nodes": len(G.nodes()), "node_feat": [[ord(word[0])] for word in words], # Use ASCII value of first letter as feature "edge_attr": [[1] for _ in range(len(G.edges()) * 2)], # All edges have the same attribute } def analyze_text(text, tokenizer, model): # Tokenize the text inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) # Create graph representation graph = text_to_graph(text) # Combine tokenized input with graph representation # Note: You may need to adjust this based on GraphRAG's actual input requirements combined_input = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "edge_index": torch.tensor(graph["edge_index"], dtype=torch.long), "node_feat": torch.tensor(graph["node_feat"], dtype=torch.float), "edge_attr": torch.tensor(graph["edge_attr"], dtype=torch.float), "num_nodes": graph["num_nodes"] } # Perform inference with torch.no_grad(): outputs = model(**combined_input) # Process outputs # Note: Adjust this based on GraphRAG's actual output format logits = outputs.logits if hasattr(outputs, 'logits') else outputs probabilities = torch.softmax(logits, dim=1) sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative" confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item() return sentiment, confidence, graph st.title("GraphRAG-based Text Analysis") tokenizer, model = load_model() text_input = st.text_area("Enter text for analysis:", height=200) if st.button("Analyze Text"): if text_input: sentiment, confidence, graph = analyze_text(text_input, tokenizer, model) st.write(f"Sentiment: {sentiment}") st.write(f"Confidence: {confidence:.2f}") # Additional analysis word_count = len(text_input.split()) st.write(f"Word count: {word_count}") # Most common words words = [word.lower() for word in text_input.split() if word.isalnum()] word_freq = Counter(words).most_common(5) st.write("Top 5 most common words:") for word, freq in word_freq: st.write(f"- {word}: {freq}") # Visualize graph G = nx.Graph() G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1])) plt.figure(figsize=(10, 6)) nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray') plt.title("Text as Graph") st.pyplot(plt) else: st.write("Please enter some text to analyze.")