eaglelandsonce commited on
Commit
510db06
·
verified ·
1 Parent(s): 4bf193c

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +38 -20
pages/21_GraphRag.py CHANGED
@@ -1,21 +1,27 @@
1
  import streamlit as st
2
- from transformers import GraphormerForGraphClassification, GraphormerFeatureExtractor
3
- from datasets import Dataset
4
- from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
5
  import torch
6
  import networkx as nx
7
  import matplotlib.pyplot as plt
8
  from collections import Counter
 
9
 
10
  @st.cache_resource
11
  def load_model():
12
- model = GraphormerForGraphClassification.from_pretrained(
13
- "clefourrier/pcqm4mv2_graphormer_base",
14
- num_classes=2, # Binary classification (positive/negative sentiment)
15
- ignore_mismatched_sizes=True,
 
 
 
 
 
 
 
16
  )
17
- feature_extractor = GraphormerFeatureExtractor.from_pretrained("clefourrier/pcqm4mv2_graphormer_base")
18
- return model, feature_extractor
19
 
20
  def text_to_graph(text):
21
  words = text.split()
@@ -33,36 +39,48 @@ def text_to_graph(text):
33
  "num_nodes": len(G.nodes()),
34
  "node_feat": [[ord(word[0])] for word in words], # Use ASCII value of first letter as feature
35
  "edge_attr": [[1] for _ in range(len(G.edges()) * 2)], # All edges have the same attribute
36
- "y": [1] # Placeholder label, will be ignored during inference
37
  }
38
 
39
- def analyze_text(text, model, feature_extractor):
 
 
 
 
40
  graph = text_to_graph(text)
41
- dataset = Dataset.from_dict({"train": [graph]})
42
- dataset_processed = dataset.map(preprocess_item, batched=False)
43
 
44
- inputs = GraphormerDataCollator()(dataset_processed["train"])
45
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
46
 
 
47
  with torch.no_grad():
48
- outputs = model(**inputs)
49
 
50
- logits = outputs.logits
 
 
51
  probabilities = torch.softmax(logits, dim=1)
52
  sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
53
  confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
54
 
55
  return sentiment, confidence, graph
56
 
57
- st.title("Graph-based Text Analysis")
58
 
59
- model, feature_extractor = load_model()
60
 
61
  text_input = st.text_area("Enter text for analysis:", height=200)
62
 
63
  if st.button("Analyze Text"):
64
  if text_input:
65
- sentiment, confidence, graph = analyze_text(text_input, model, feature_extractor)
66
  st.write(f"Sentiment: {sentiment}")
67
  st.write(f"Confidence: {confidence:.2f}")
68
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModel
 
 
3
  import torch
4
  import networkx as nx
5
  import matplotlib.pyplot as plt
6
  from collections import Counter
7
+ import graphrag # Import the graphrag library
8
 
9
  @st.cache_resource
10
  def load_model():
11
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
12
+ bert_model = AutoModel.from_pretrained("bert-base-uncased")
13
+
14
+ # Initialize GraphRAG model
15
+ # Note: You may need to adjust these parameters based on GraphRAG's actual interface
16
+ graph_rag_model = graphrag.GraphRAG(
17
+ bert_model,
18
+ num_labels=2, # For binary sentiment classification
19
+ num_hidden_layers=2,
20
+ hidden_size=768,
21
+ intermediate_size=3072,
22
  )
23
+
24
+ return tokenizer, graph_rag_model
25
 
26
  def text_to_graph(text):
27
  words = text.split()
 
39
  "num_nodes": len(G.nodes()),
40
  "node_feat": [[ord(word[0])] for word in words], # Use ASCII value of first letter as feature
41
  "edge_attr": [[1] for _ in range(len(G.edges()) * 2)], # All edges have the same attribute
 
42
  }
43
 
44
+ def analyze_text(text, tokenizer, model):
45
+ # Tokenize the text
46
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
47
+
48
+ # Create graph representation
49
  graph = text_to_graph(text)
 
 
50
 
51
+ # Combine tokenized input with graph representation
52
+ # Note: You may need to adjust this based on GraphRAG's actual input requirements
53
+ combined_input = {
54
+ "input_ids": inputs["input_ids"],
55
+ "attention_mask": inputs["attention_mask"],
56
+ "edge_index": torch.tensor(graph["edge_index"], dtype=torch.long),
57
+ "node_feat": torch.tensor(graph["node_feat"], dtype=torch.float),
58
+ "edge_attr": torch.tensor(graph["edge_attr"], dtype=torch.float),
59
+ "num_nodes": graph["num_nodes"]
60
+ }
61
 
62
+ # Perform inference
63
  with torch.no_grad():
64
+ outputs = model(**combined_input)
65
 
66
+ # Process outputs
67
+ # Note: Adjust this based on GraphRAG's actual output format
68
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs
69
  probabilities = torch.softmax(logits, dim=1)
70
  sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
71
  confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
72
 
73
  return sentiment, confidence, graph
74
 
75
+ st.title("GraphRAG-based Text Analysis")
76
 
77
+ tokenizer, model = load_model()
78
 
79
  text_input = st.text_area("Enter text for analysis:", height=200)
80
 
81
  if st.button("Analyze Text"):
82
  if text_input:
83
+ sentiment, confidence, graph = analyze_text(text_input, tokenizer, model)
84
  st.write(f"Sentiment: {sentiment}")
85
  st.write(f"Confidence: {confidence:.2f}")
86