eaglelandsonce commited on
Commit
4a2750f
·
verified ·
1 Parent(s): ee795ac

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +18 -11
pages/21_GraphRag.py CHANGED
@@ -1,19 +1,25 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
4
  import torch
 
5
 
6
  @st.cache_resource
7
  def load_model():
8
  bert_model_name = "bert-base-uncased"
9
  tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
10
-
11
- config = BertConfig.from_pretrained(bert_model_name)
12
- config.num_labels = 2 # Adjust based on your task
13
-
14
- model = BertForSequenceClassification.from_pretrained(bert_model_name, config=config)
15
-
16
- # If you have a pre-trained model, load it here
 
 
 
 
 
17
  # model.load_state_dict(torch.load('path_to_your_model.pth'))
18
 
19
  return tokenizer, model
@@ -23,11 +29,12 @@ def process_text(text, tokenizer, model):
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
  # Process outputs based on your specific task
26
- logits = outputs.logits
 
27
  probabilities = torch.softmax(logits, dim=1)
28
  return probabilities.tolist()[0]
29
 
30
- st.title("BERT Text Analysis")
31
 
32
  tokenizer, model = load_model()
33
 
@@ -57,4 +64,4 @@ if st.button("Analyze Text"):
57
  st.write("Please enter some text to analyze.")
58
 
59
  # Add a link to sample data
60
- st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")
 
1
  import streamlit as st
2
  import pandas as pd
3
+ from transformers import AutoTokenizer, AutoModel
4
  import torch
5
+ import graphrag
6
 
7
  @st.cache_resource
8
  def load_model():
9
  bert_model_name = "bert-base-uncased"
10
  tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
11
+ bert_model = AutoModel.from_pretrained(bert_model_name)
12
+
13
+ # Initialize Graphrag model
14
+ model = graphrag.GraphRAG(
15
+ bert_model,
16
+ num_labels=2, # Adjust based on your task
17
+ num_hidden_layers=2,
18
+ hidden_size=768,
19
+ intermediate_size=3072,
20
+ )
21
+
22
+ # If you have a pre-trained Graphrag model, load it here
23
  # model.load_state_dict(torch.load('path_to_your_model.pth'))
24
 
25
  return tokenizer, model
 
29
  with torch.no_grad():
30
  outputs = model(**inputs)
31
  # Process outputs based on your specific task
32
+ # This is a placeholder; adjust according to your model's output
33
+ logits = outputs.logits if hasattr(outputs, 'logits') else outputs
34
  probabilities = torch.softmax(logits, dim=1)
35
  return probabilities.tolist()[0]
36
 
37
+ st.title("Graphrag Text Analysis")
38
 
39
  tokenizer, model = load_model()
40
 
 
64
  st.write("Please enter some text to analyze.")
65
 
66
  # Add a link to sample data
67
+ st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")