eaglelandsonce commited on
Commit
eb6725a
·
verified ·
1 Parent(s): 41f73cb

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +11 -15
pages/21_GraphRag.py CHANGED
@@ -1,26 +1,23 @@
1
- # put code here
2
-
3
  import streamlit as st
4
  import pandas as pd
5
  from transformers import AutoTokenizer, AutoModel
6
- from graphrag import GraphragModel, GraphragConfig
7
  import torch
 
 
8
 
9
  @st.cache_resource
10
  def load_model():
11
  bert_model_name = "bert-base-uncased"
12
  tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
- bert_model = AutoModel.from_pretrained(bert_model_name)
14
-
15
- config = GraphragConfig(
16
- bert_model=bert_model,
17
- num_labels=2, # Adjust based on your task
18
- num_hidden_layers=2,
19
- hidden_size=768,
20
- intermediate_size=3072,
21
- )
22
-
23
- model = GraphragModel(config)
24
  return tokenizer, model
25
 
26
  def process_text(text, tokenizer, model):
@@ -28,7 +25,6 @@ def process_text(text, tokenizer, model):
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
  # Process outputs based on your specific task
31
- # This is a placeholder; adjust according to your model's output
32
  logits = outputs.logits
33
  probabilities = torch.softmax(logits, dim=1)
34
  return probabilities.tolist()[0]
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  from transformers import AutoTokenizer, AutoModel
 
4
  import torch
5
+ from graphrag.models import GraphragForSequenceClassification
6
+ from graphrag.configuration_graphrag import GraphragConfig
7
 
8
  @st.cache_resource
9
  def load_model():
10
  bert_model_name = "bert-base-uncased"
11
  tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
12
+
13
+ config = GraphragConfig.from_pretrained(bert_model_name)
14
+ config.num_labels = 2 # Adjust based on your task
15
+
16
+ model = GraphragForSequenceClassification(config)
17
+
18
+ # If you have a pre-trained Graphrag model, load it here
19
+ # model.load_state_dict(torch.load('path_to_your_model.pth'))
20
+
 
 
21
  return tokenizer, model
22
 
23
  def process_text(text, tokenizer, model):
 
25
  with torch.no_grad():
26
  outputs = model(**inputs)
27
  # Process outputs based on your specific task
 
28
  logits = outputs.logits
29
  probabilities = torch.softmax(logits, dim=1)
30
  return probabilities.tolist()[0]