eaglelandsonce commited on
Commit
c753736
·
verified ·
1 Parent(s): 510db06

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +34 -44
pages/21_GraphRag.py CHANGED
@@ -4,21 +4,45 @@ import torch
4
  import networkx as nx
5
  import matplotlib.pyplot as plt
6
  from collections import Counter
7
- import graphrag # Import the graphrag library
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  @st.cache_resource
10
  def load_model():
11
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
12
  bert_model = AutoModel.from_pretrained("bert-base-uncased")
13
 
14
- # Initialize GraphRAG model
15
- # Note: You may need to adjust these parameters based on GraphRAG's actual interface
16
- graph_rag_model = graphrag.GraphRAG(
17
  bert_model,
18
  num_labels=2, # For binary sentiment classification
19
- num_hidden_layers=2,
20
- hidden_size=768,
21
- intermediate_size=3072,
22
  )
23
 
24
  return tokenizer, graph_rag_model
@@ -49,7 +73,7 @@ def analyze_text(text, tokenizer, model):
49
  graph = text_to_graph(text)
50
 
51
  # Combine tokenized input with graph representation
52
- # Note: You may need to adjust this based on GraphRAG's actual input requirements
53
  combined_input = {
54
  "input_ids": inputs["input_ids"],
55
  "attention_mask": inputs["attention_mask"],
@@ -64,7 +88,7 @@ def analyze_text(text, tokenizer, model):
64
  outputs = model(**combined_input)
65
 
66
  # Process outputs
67
- # Note: Adjust this based on GraphRAG's actual output format
68
  logits = outputs.logits if hasattr(outputs, 'logits') else outputs
69
  probabilities = torch.softmax(logits, dim=1)
70
  sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
@@ -72,38 +96,4 @@ def analyze_text(text, tokenizer, model):
72
 
73
  return sentiment, confidence, graph
74
 
75
- st.title("GraphRAG-based Text Analysis")
76
-
77
- tokenizer, model = load_model()
78
-
79
- text_input = st.text_area("Enter text for analysis:", height=200)
80
-
81
- if st.button("Analyze Text"):
82
- if text_input:
83
- sentiment, confidence, graph = analyze_text(text_input, tokenizer, model)
84
- st.write(f"Sentiment: {sentiment}")
85
- st.write(f"Confidence: {confidence:.2f}")
86
-
87
- # Additional analysis
88
- word_count = len(text_input.split())
89
- st.write(f"Word count: {word_count}")
90
-
91
- # Most common words
92
- words = [word.lower() for word in text_input.split() if word.isalnum()]
93
- word_freq = Counter(words).most_common(5)
94
-
95
- st.write("Top 5 most common words:")
96
- for word, freq in word_freq:
97
- st.write(f"- {word}: {freq}")
98
-
99
- # Visualize graph
100
- G = nx.Graph()
101
- G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1]))
102
-
103
- plt.figure(figsize=(10, 6))
104
- nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray')
105
- plt.title("Text as Graph")
106
- st.pyplot(plt)
107
-
108
- else:
109
- st.write("Please enter some text to analyze.")
 
4
  import networkx as nx
5
  import matplotlib.pyplot as plt
6
  from collections import Counter
7
+ import graphrag
8
+ import inspect
9
+
10
+ st.title("GraphRAG Module Exploration and Text Analysis")
11
+
12
+ # Diagnostic section
13
+ st.header("GraphRAG Module Contents")
14
+ graphrag_contents = dir(graphrag)
15
+ st.write("Available attributes and methods in graphrag module:")
16
+ for item in graphrag_contents:
17
+ st.write(f"- {item}")
18
+ attr = getattr(graphrag, item)
19
+ if inspect.isclass(attr) or inspect.isfunction(attr):
20
+ st.write(f" Signature: {inspect.signature(attr)}")
21
+ st.write(f" Docstring: {attr.__doc__}")
22
+
23
+ # Attempt to find a suitable model class
24
+ model_class = None
25
+ for item in graphrag_contents:
26
+ if 'model' in item.lower():
27
+ model_class = getattr(graphrag, item)
28
+ st.write(f"Found potential model class: {item}")
29
+ break
30
+
31
+ if model_class is None:
32
+ st.error("Could not find a suitable model class in graphrag module.")
33
+ st.stop()
34
 
35
  @st.cache_resource
36
  def load_model():
37
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
38
  bert_model = AutoModel.from_pretrained("bert-base-uncased")
39
 
40
+ # Initialize graphrag model
41
+ # Note: This is a placeholder. Adjust based on the actual model class found
42
+ graph_rag_model = model_class(
43
  bert_model,
44
  num_labels=2, # For binary sentiment classification
45
+ # Add or remove parameters based on the actual model's requirements
 
 
46
  )
47
 
48
  return tokenizer, graph_rag_model
 
73
  graph = text_to_graph(text)
74
 
75
  # Combine tokenized input with graph representation
76
+ # Note: This is a placeholder. Adjust based on the actual model's input requirements
77
  combined_input = {
78
  "input_ids": inputs["input_ids"],
79
  "attention_mask": inputs["attention_mask"],
 
88
  outputs = model(**combined_input)
89
 
90
  # Process outputs
91
+ # Note: Adjust this based on the actual model's output format
92
  logits = outputs.logits if hasattr(outputs, 'logits') else outputs
93
  probabilities = torch.softmax(logits, dim=1)
94
  sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
 
96
 
97
  return sentiment, confidence, graph
98
 
99
+ # Rest of the Streamlit app (text input, analysis button, etc.) remains the same...