Spaces:
Sleeping
Sleeping
Update pages/21_GraphRag.py
Browse files- pages/21_GraphRag.py +6 -6
pages/21_GraphRag.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import GraphormerForGraphClassification,
|
3 |
from datasets import Dataset
|
4 |
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
|
5 |
import torch
|
@@ -14,8 +14,8 @@ def load_model():
|
|
14 |
num_classes=2, # Binary classification (positive/negative sentiment)
|
15 |
ignore_mismatched_sizes=True,
|
16 |
)
|
17 |
-
|
18 |
-
return model,
|
19 |
|
20 |
def text_to_graph(text):
|
21 |
words = text.split()
|
@@ -36,7 +36,7 @@ def text_to_graph(text):
|
|
36 |
"y": [1] # Placeholder label, will be ignored during inference
|
37 |
}
|
38 |
|
39 |
-
def analyze_text(text, model,
|
40 |
graph = text_to_graph(text)
|
41 |
dataset = Dataset.from_dict({"train": [graph]})
|
42 |
dataset_processed = dataset.map(preprocess_item, batched=False)
|
@@ -56,13 +56,13 @@ def analyze_text(text, model, tokenizer):
|
|
56 |
|
57 |
st.title("Graph-based Text Analysis")
|
58 |
|
59 |
-
model,
|
60 |
|
61 |
text_input = st.text_area("Enter text for analysis:", height=200)
|
62 |
|
63 |
if st.button("Analyze Text"):
|
64 |
if text_input:
|
65 |
-
sentiment, confidence, graph = analyze_text(text_input, model,
|
66 |
st.write(f"Sentiment: {sentiment}")
|
67 |
st.write(f"Confidence: {confidence:.2f}")
|
68 |
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import GraphormerForGraphClassification, GraphormerFeatureExtractor
|
3 |
from datasets import Dataset
|
4 |
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
|
5 |
import torch
|
|
|
14 |
num_classes=2, # Binary classification (positive/negative sentiment)
|
15 |
ignore_mismatched_sizes=True,
|
16 |
)
|
17 |
+
feature_extractor = GraphormerFeatureExtractor.from_pretrained("clefourrier/pcqm4mv2_graphormer_base")
|
18 |
+
return model, feature_extractor
|
19 |
|
20 |
def text_to_graph(text):
|
21 |
words = text.split()
|
|
|
36 |
"y": [1] # Placeholder label, will be ignored during inference
|
37 |
}
|
38 |
|
39 |
+
def analyze_text(text, model, feature_extractor):
|
40 |
graph = text_to_graph(text)
|
41 |
dataset = Dataset.from_dict({"train": [graph]})
|
42 |
dataset_processed = dataset.map(preprocess_item, batched=False)
|
|
|
56 |
|
57 |
st.title("Graph-based Text Analysis")
|
58 |
|
59 |
+
model, feature_extractor = load_model()
|
60 |
|
61 |
text_input = st.text_area("Enter text for analysis:", height=200)
|
62 |
|
63 |
if st.button("Analyze Text"):
|
64 |
if text_input:
|
65 |
+
sentiment, confidence, graph = analyze_text(text_input, model, feature_extractor)
|
66 |
st.write(f"Sentiment: {sentiment}")
|
67 |
st.write(f"Confidence: {confidence:.2f}")
|
68 |
|