eaglelandsonce commited on
Commit
8436e7c
·
verified ·
1 Parent(s): 4a2750f

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +45 -12
pages/21_GraphRag.py CHANGED
@@ -4,6 +4,22 @@ from transformers import AutoTokenizer, AutoModel
4
  import torch
5
  import graphrag
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @st.cache_resource
8
  def load_model():
9
  bert_model_name = "bert-base-uncased"
@@ -11,31 +27,48 @@ def load_model():
11
  bert_model = AutoModel.from_pretrained(bert_model_name)
12
 
13
  # Initialize Graphrag model
14
- model = graphrag.GraphRAG(
15
- bert_model,
16
- num_labels=2, # Adjust based on your task
17
- num_hidden_layers=2,
18
- hidden_size=768,
19
- intermediate_size=3072,
20
- )
 
 
 
 
 
 
 
 
 
 
21
 
22
- # If you have a pre-trained Graphrag model, load it here
23
- # model.load_state_dict(torch.load('path_to_your_model.pth'))
24
 
25
  return tokenizer, model
26
 
27
  def process_text(text, tokenizer, model):
 
 
 
28
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
29
  with torch.no_grad():
30
  outputs = model(**inputs)
31
  # Process outputs based on your specific task
32
  # This is a placeholder; adjust according to your model's output
33
- logits = outputs.logits if hasattr(outputs, 'logits') else outputs
 
 
 
 
 
 
34
  probabilities = torch.softmax(logits, dim=1)
35
  return probabilities.tolist()[0]
36
 
37
- st.title("Graphrag Text Analysis")
38
-
39
  tokenizer, model = load_model()
40
 
41
  # File uploader
 
4
  import torch
5
  import graphrag
6
 
7
+ # Diagnostic Section
8
+ st.title("Graphrag Module Investigation")
9
+
10
+ st.write("Graphrag version:", graphrag.__version__)
11
+ st.write("Contents of graphrag module:")
12
+ st.write(dir(graphrag))
13
+
14
+ for item in dir(graphrag):
15
+ st.write(f"Type of {item}: {type(getattr(graphrag, item))}")
16
+ if callable(getattr(graphrag, item)):
17
+ st.write(f"Docstring of {item}:")
18
+ st.write(getattr(graphrag, item).__doc__)
19
+
20
+ # Main Application Section
21
+ st.title("Graphrag Text Analysis")
22
+
23
  @st.cache_resource
24
  def load_model():
25
  bert_model_name = "bert-base-uncased"
 
27
  bert_model = AutoModel.from_pretrained(bert_model_name)
28
 
29
  # Initialize Graphrag model
30
+ # Note: This part may need to be adjusted based on the actual structure of graphrag
31
+ model = None
32
+ for item in dir(graphrag):
33
+ if 'model' in item.lower() or 'rag' in item.lower():
34
+ model_class = getattr(graphrag, item)
35
+ if callable(model_class):
36
+ try:
37
+ model = model_class(
38
+ bert_model,
39
+ num_labels=2, # Adjust based on your task
40
+ num_hidden_layers=2,
41
+ hidden_size=768,
42
+ intermediate_size=3072,
43
+ )
44
+ break
45
+ except Exception as e:
46
+ st.write(f"Tried initializing {item}, but got error: {str(e)}")
47
 
48
+ if model is None:
49
+ st.error("Could not initialize any Graphrag model. Please check the module structure.")
50
 
51
  return tokenizer, model
52
 
53
  def process_text(text, tokenizer, model):
54
+ if model is None:
55
+ return "Model not initialized"
56
+
57
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
58
  with torch.no_grad():
59
  outputs = model(**inputs)
60
  # Process outputs based on your specific task
61
  # This is a placeholder; adjust according to your model's output
62
+ if hasattr(outputs, 'logits'):
63
+ logits = outputs.logits
64
+ elif isinstance(outputs, torch.Tensor):
65
+ logits = outputs
66
+ else:
67
+ return "Unexpected output format"
68
+
69
  probabilities = torch.softmax(logits, dim=1)
70
  return probabilities.tolist()[0]
71
 
 
 
72
  tokenizer, model = load_model()
73
 
74
  # File uploader