eaglelandsonce commited on
Commit
4bf193c
·
verified ·
1 Parent(s): cb06d03

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +6 -6
pages/21_GraphRag.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import GraphormerForGraphClassification, GraphormerTokenizer
3
  from datasets import Dataset
4
  from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
5
  import torch
@@ -14,8 +14,8 @@ def load_model():
14
  num_classes=2, # Binary classification (positive/negative sentiment)
15
  ignore_mismatched_sizes=True,
16
  )
17
- tokenizer = GraphormerTokenizer.from_pretrained("clefourrier/pcqm4mv2_graphormer_base")
18
- return model, tokenizer
19
 
20
  def text_to_graph(text):
21
  words = text.split()
@@ -36,7 +36,7 @@ def text_to_graph(text):
36
  "y": [1] # Placeholder label, will be ignored during inference
37
  }
38
 
39
- def analyze_text(text, model, tokenizer):
40
  graph = text_to_graph(text)
41
  dataset = Dataset.from_dict({"train": [graph]})
42
  dataset_processed = dataset.map(preprocess_item, batched=False)
@@ -56,13 +56,13 @@ def analyze_text(text, model, tokenizer):
56
 
57
  st.title("Graph-based Text Analysis")
58
 
59
- model, tokenizer = load_model()
60
 
61
  text_input = st.text_area("Enter text for analysis:", height=200)
62
 
63
  if st.button("Analyze Text"):
64
  if text_input:
65
- sentiment, confidence, graph = analyze_text(text_input, model, tokenizer)
66
  st.write(f"Sentiment: {sentiment}")
67
  st.write(f"Confidence: {confidence:.2f}")
68
 
 
1
  import streamlit as st
2
+ from transformers import GraphormerForGraphClassification, GraphormerFeatureExtractor
3
  from datasets import Dataset
4
  from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
5
  import torch
 
14
  num_classes=2, # Binary classification (positive/negative sentiment)
15
  ignore_mismatched_sizes=True,
16
  )
17
+ feature_extractor = GraphormerFeatureExtractor.from_pretrained("clefourrier/pcqm4mv2_graphormer_base")
18
+ return model, feature_extractor
19
 
20
  def text_to_graph(text):
21
  words = text.split()
 
36
  "y": [1] # Placeholder label, will be ignored during inference
37
  }
38
 
39
+ def analyze_text(text, model, feature_extractor):
40
  graph = text_to_graph(text)
41
  dataset = Dataset.from_dict({"train": [graph]})
42
  dataset_processed = dataset.map(preprocess_item, batched=False)
 
56
 
57
  st.title("Graph-based Text Analysis")
58
 
59
+ model, feature_extractor = load_model()
60
 
61
  text_input = st.text_area("Enter text for analysis:", height=200)
62
 
63
  if st.button("Analyze Text"):
64
  if text_input:
65
+ sentiment, confidence, graph = analyze_text(text_input, model, feature_extractor)
66
  st.write(f"Sentiment: {sentiment}")
67
  st.write(f"Confidence: {confidence:.2f}")
68