eaglelandsonce commited on
Commit
ee795ac
·
verified ·
1 Parent(s): 2eb3bfa

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +6 -8
pages/21_GraphRag.py CHANGED
@@ -1,21 +1,19 @@
1
  import streamlit as st
2
  import pandas as pd
3
- from transformers import AutoTokenizer, AutoModel
4
  import torch
5
- from graphrag.models import GraphragForSequenceClassification
6
- from graphrag.configuration_graphrag import GraphragConfig
7
 
8
  @st.cache_resource
9
  def load_model():
10
  bert_model_name = "bert-base-uncased"
11
  tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
12
 
13
- config = GraphragConfig.from_pretrained(bert_model_name)
14
  config.num_labels = 2 # Adjust based on your task
15
 
16
- model = GraphragForSequenceClassification(config)
17
 
18
- # If you have a pre-trained Graphrag model, load it here
19
  # model.load_state_dict(torch.load('path_to_your_model.pth'))
20
 
21
  return tokenizer, model
@@ -29,7 +27,7 @@ def process_text(text, tokenizer, model):
29
  probabilities = torch.softmax(logits, dim=1)
30
  return probabilities.tolist()[0]
31
 
32
- st.title("Graphrag Text Analysis")
33
 
34
  tokenizer, model = load_model()
35
 
@@ -59,4 +57,4 @@ if st.button("Analyze Text"):
59
  st.write("Please enter some text to analyze.")
60
 
61
  # Add a link to sample data
62
- st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")
 
1
  import streamlit as st
2
  import pandas as pd
3
+ from transformers import AutoTokenizer, BertForSequenceClassification, BertConfig
4
  import torch
 
 
5
 
6
  @st.cache_resource
7
  def load_model():
8
  bert_model_name = "bert-base-uncased"
9
  tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
10
 
11
+ config = BertConfig.from_pretrained(bert_model_name)
12
  config.num_labels = 2 # Adjust based on your task
13
 
14
+ model = BertForSequenceClassification.from_pretrained(bert_model_name, config=config)
15
 
16
+ # If you have a pre-trained model, load it here
17
  # model.load_state_dict(torch.load('path_to_your_model.pth'))
18
 
19
  return tokenizer, model
 
27
  probabilities = torch.softmax(logits, dim=1)
28
  return probabilities.tolist()[0]
29
 
30
+ st.title("BERT Text Analysis")
31
 
32
  tokenizer, model = load_model()
33
 
 
57
  st.write("Please enter some text to analyze.")
58
 
59
  # Add a link to sample data
60
+ st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")