axa / graphRAG.py
Mayara Ayat
Upload folder using huggingface_hub
f7ab812 verified
from openai import OpenAI
import os
import networkx as nx
from dotenv import load_dotenv
from constants import DOCUMENTS
from tqdm import tqdm
from cdlib import algorithms
import matplotlib.pyplot as plt
load_dotenv(".env.example")
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def draw_graph(graph):
pos = nx.spring_layout(graph) # Position the nodes
plt.figure(figsize=(12, 8))
nx.draw(
graph,
pos,
with_labels=True,
node_color="skyblue",
edge_color="gray",
node_size=1500,
font_size=10,
font_weight="bold",
)
edge_labels = nx.get_edge_attributes(graph, "label")
nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8)
plt.title("Graph Visualization of Extracted Entities and Relationships")
plt.savefig("graph.png")
plt.show()
# Source texts -> chunks
def get_chunks(documents, chunk_size=1000, overlap_size=200):
chunks = []
for doc in documents:
for i in range(0, len(doc), chunk_size - overlap_size):
chunks.append(doc[i : i + chunk_size])
return chunks
# print(get_chunks(DOCUMENTS))
# Chunks -> Element instances
def extract_elements(chunks):
elements = []
for index, chunk in enumerate(chunks):
print(f"Processing chunk {index + 1}/{len(chunks)}")
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "Extract entities and relationships from the following text.",
},
{"role": "user", "content": chunk},
],
)
print(response.choices[0].message.content)
entities_and_relations = response.choices[0].message.content
elements.append(entities_and_relations)
return elements
# print(extract_elements(get_chunks(DOCUMENTS)))
# Element instances -> Element summaries
def summarize_elements(elements):
summaries = []
for index, element in enumerate(elements):
print(f"Summarizing element {index + 1}/{len(elements)}")
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": 'Summarize the following entities and relationships in a structured format. Use "->" to represent relationships, after the "Relationships:" word.',
},
{"role": "user", "content": element},
],
)
print("Element summary:", response.choices[0].message.content)
summary = response.choices[0].message.content
summaries.append(summary)
return summaries
# print(summarize_elements(extract_elements(get_chunks(DOCUMENTS))))
# Element summaries -> Graph communities
def build_graph(summaries):
G = nx.Graph()
for index, summary in enumerate(summaries):
print(f"Summary index {index + 1}/{len(summaries)}")
lines = summary.split("\n")
entities_section = False
relationships_section = False
entities = []
for line in tqdm(lines):
if line.startswith("### Entities:") or line.startswith("**Entities:**"):
entities_section = True
relationships_section = False
continue
elif line.startswith("### Relationships:") or line.startswith(
"**Relationships:**"
):
entities_section = False
relationships_section = True
continue
if entities_section and line.strip():
if line[0].isdigit() and line[1] == ".":
line = line.split(".", 1)[1].strip()
entity = line.strip()
entity = entity.replace("**", "")
entities.append(entity)
G.add_node(entity)
elif relationships_section and line.strip():
parts = line.split("->")
if len(parts) == 2:
source = parts[0].strip()
target = parts[-1].strip()
relation = " -> ".join(parts[1:-1]).strip()
G.add_edge(source, target, label=relation)
return G
# Graph communities -> Graph summaries
def detect_communities(graph):
communities = []
index = 0
for component in nx.connected_components(graph):
print(
f"Component index {index} of {len(list(nx.connected_components(graph)))}:"
)
subgraph = graph.subgraph(component)
if len(subgraph.nodes) > 1:
try:
sub_communities = algorithms.leiden(subgraph)
for community in sub_communities.communities:
communities.append(list(community))
except Exception as e:
print(f"Error processing community {index}: {e}")
else:
communities.append(list(subgraph.nodes))
index += 1
print("Communities from detect_communities:", communities)
return communities
# summarize communities
def summarize_communities(communities, graph):
community_summaries = []
for index, community in enumerate(communities):
print(f"Summarize Community index {index+1}/{len(communities)}:")
subgraph = graph.subgraph(community)
nodes = list(subgraph.nodes)
edges = list(subgraph.edges(data=True))
description = "Entities: " + ", ".join(nodes) + "\nRelationships: "
relationships = []
for edge in edges:
source, target, data = edge
relation = data.get("label", "")
relationships.append(f"{source} -> {data['label']} -> {target}")
description += ", ".join(relationships)
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "Summarize the following community of entities and relationships.",
},
{"role": "user", "content": description},
],
)
summary = response.choices[0].message.content.strip()
community_summaries.append(summary)
return community_summaries
# Community Summaries → Community Answers → Global Answer
def generate_answer(community_summaries, query):
intermediate_answers = []
for index, summary in enumerate(community_summaries):
print(f"Answering community {index+1}/{len(community_summaries)}:")
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "Answer the following query based on the provided summary.",
},
{"role": "user", "content": f"Query: {query} Summary: {summary}"},
],
)
print("Intermediate answer:", response.choices[0].message.content)
intermediate_answers.append(response.choices[0].message.content)
final_response = client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "system",
"content": "Combine these answers into a final, concise response.",
},
{
"role": "user",
"content": f"Intermediate answers: {intermediate_answers}",
},
],
)
final_answer = final_response.choices[0].message.content
return final_answer
def graphrag_pipeline(documents, query):
chunks = get_chunks(documents)
elements = extract_elements(chunks)
summaries = summarize_elements(elements)
graph = build_graph(summaries)
num_entities = graph.number_of_nodes()
print(f"Number of entities in the graph: {num_entities}")
draw_graph(graph)
communities = detect_communities(graph)
print(communities)
community_summaries = summarize_communities(communities, graph)
final_answer = generate_answer(community_summaries, query)
return final_answer
query = "What factors in these articles can impact medical inflation in the UK in the short term?"
# "What are the main themes in these documents?"
print("Query:", query)
answer = graphrag_pipeline(DOCUMENTS, query)
print("Answer:", answer)