eaglelandsonce commited on
Commit
41f73cb
·
verified ·
1 Parent(s): f00ee8c

Update pages/21_GraphRag.py

Browse files
Files changed (1) hide show
  1. pages/21_GraphRag.py +66 -1
pages/21_GraphRag.py CHANGED
@@ -1 +1,66 @@
1
- # put code here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # put code here
2
+
3
+ import streamlit as st
4
+ import pandas as pd
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from graphrag import GraphragModel, GraphragConfig
7
+ import torch
8
+
9
+ @st.cache_resource
10
+ def load_model():
11
+ bert_model_name = "bert-base-uncased"
12
+ tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
13
+ bert_model = AutoModel.from_pretrained(bert_model_name)
14
+
15
+ config = GraphragConfig(
16
+ bert_model=bert_model,
17
+ num_labels=2, # Adjust based on your task
18
+ num_hidden_layers=2,
19
+ hidden_size=768,
20
+ intermediate_size=3072,
21
+ )
22
+
23
+ model = GraphragModel(config)
24
+ return tokenizer, model
25
+
26
+ def process_text(text, tokenizer, model):
27
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ # Process outputs based on your specific task
31
+ # This is a placeholder; adjust according to your model's output
32
+ logits = outputs.logits
33
+ probabilities = torch.softmax(logits, dim=1)
34
+ return probabilities.tolist()[0]
35
+
36
+ st.title("Graphrag Text Analysis")
37
+
38
+ tokenizer, model = load_model()
39
+
40
+ # File uploader
41
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
42
+
43
+ if uploaded_file is not None:
44
+ data = pd.read_csv(uploaded_file)
45
+ st.write(data.head())
46
+
47
+ if st.button("Process Data"):
48
+ results = []
49
+ for text in data['text']: # Assuming your CSV has a 'text' column
50
+ result = process_text(text, tokenizer, model)
51
+ results.append(result)
52
+
53
+ data['results'] = results
54
+ st.write(data)
55
+
56
+ # Text input for single prediction
57
+ text_input = st.text_area("Enter text for analysis:")
58
+ if st.button("Analyze Text"):
59
+ if text_input:
60
+ result = process_text(text_input, tokenizer, model)
61
+ st.write(f"Analysis Result: {result}")
62
+ else:
63
+ st.write("Please enter some text to analyze.")
64
+
65
+ # Add a link to sample data
66
+ st.markdown("[Download Sample CSV](https://raw.githubusercontent.com/your_username/your_repo/main/sample_data.csv)")