|
from openai import OpenAI |
|
import os |
|
import networkx as nx |
|
from dotenv import load_dotenv |
|
from constants import DOCUMENTS |
|
from tqdm import tqdm |
|
from cdlib import algorithms |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
load_dotenv(".env.example") |
|
|
|
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
|
|
|
|
def draw_graph(graph): |
|
pos = nx.spring_layout(graph) |
|
plt.figure(figsize=(12, 8)) |
|
nx.draw( |
|
graph, |
|
pos, |
|
with_labels=True, |
|
node_color="skyblue", |
|
edge_color="gray", |
|
node_size=1500, |
|
font_size=10, |
|
font_weight="bold", |
|
) |
|
edge_labels = nx.get_edge_attributes(graph, "label") |
|
nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8) |
|
plt.title("Graph Visualization of Extracted Entities and Relationships") |
|
plt.savefig("graph.png") |
|
plt.show() |
|
|
|
|
|
|
|
def get_chunks(documents, chunk_size=1000, overlap_size=200): |
|
chunks = [] |
|
for doc in documents: |
|
for i in range(0, len(doc), chunk_size - overlap_size): |
|
chunks.append(doc[i : i + chunk_size]) |
|
return chunks |
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_elements(chunks): |
|
elements = [] |
|
for index, chunk in enumerate(chunks): |
|
print(f"Processing chunk {index + 1}/{len(chunks)}") |
|
response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "Extract entities and relationships from the following text.", |
|
}, |
|
{"role": "user", "content": chunk}, |
|
], |
|
) |
|
print(response.choices[0].message.content) |
|
entities_and_relations = response.choices[0].message.content |
|
elements.append(entities_and_relations) |
|
return elements |
|
|
|
|
|
|
|
|
|
|
|
|
|
def summarize_elements(elements): |
|
summaries = [] |
|
for index, element in enumerate(elements): |
|
print(f"Summarizing element {index + 1}/{len(elements)}") |
|
response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": 'Summarize the following entities and relationships in a structured format. Use "->" to represent relationships, after the "Relationships:" word.', |
|
}, |
|
{"role": "user", "content": element}, |
|
], |
|
) |
|
print("Element summary:", response.choices[0].message.content) |
|
summary = response.choices[0].message.content |
|
summaries.append(summary) |
|
return summaries |
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_graph(summaries): |
|
G = nx.Graph() |
|
for index, summary in enumerate(summaries): |
|
print(f"Summary index {index + 1}/{len(summaries)}") |
|
lines = summary.split("\n") |
|
entities_section = False |
|
relationships_section = False |
|
entities = [] |
|
for line in tqdm(lines): |
|
if line.startswith("### Entities:") or line.startswith("**Entities:**"): |
|
entities_section = True |
|
relationships_section = False |
|
continue |
|
elif line.startswith("### Relationships:") or line.startswith( |
|
"**Relationships:**" |
|
): |
|
entities_section = False |
|
relationships_section = True |
|
continue |
|
if entities_section and line.strip(): |
|
if line[0].isdigit() and line[1] == ".": |
|
line = line.split(".", 1)[1].strip() |
|
entity = line.strip() |
|
entity = entity.replace("**", "") |
|
entities.append(entity) |
|
G.add_node(entity) |
|
elif relationships_section and line.strip(): |
|
parts = line.split("->") |
|
if len(parts) == 2: |
|
source = parts[0].strip() |
|
target = parts[-1].strip() |
|
relation = " -> ".join(parts[1:-1]).strip() |
|
G.add_edge(source, target, label=relation) |
|
|
|
return G |
|
|
|
|
|
|
|
def detect_communities(graph): |
|
communities = [] |
|
index = 0 |
|
for component in nx.connected_components(graph): |
|
print( |
|
f"Component index {index} of {len(list(nx.connected_components(graph)))}:" |
|
) |
|
subgraph = graph.subgraph(component) |
|
if len(subgraph.nodes) > 1: |
|
try: |
|
sub_communities = algorithms.leiden(subgraph) |
|
for community in sub_communities.communities: |
|
communities.append(list(community)) |
|
except Exception as e: |
|
print(f"Error processing community {index}: {e}") |
|
else: |
|
communities.append(list(subgraph.nodes)) |
|
index += 1 |
|
print("Communities from detect_communities:", communities) |
|
return communities |
|
|
|
|
|
|
|
def summarize_communities(communities, graph): |
|
community_summaries = [] |
|
for index, community in enumerate(communities): |
|
print(f"Summarize Community index {index+1}/{len(communities)}:") |
|
subgraph = graph.subgraph(community) |
|
nodes = list(subgraph.nodes) |
|
edges = list(subgraph.edges(data=True)) |
|
description = "Entities: " + ", ".join(nodes) + "\nRelationships: " |
|
relationships = [] |
|
for edge in edges: |
|
source, target, data = edge |
|
relation = data.get("label", "") |
|
relationships.append(f"{source} -> {data['label']} -> {target}") |
|
description += ", ".join(relationships) |
|
|
|
response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "Summarize the following community of entities and relationships.", |
|
}, |
|
{"role": "user", "content": description}, |
|
], |
|
) |
|
summary = response.choices[0].message.content.strip() |
|
community_summaries.append(summary) |
|
return community_summaries |
|
|
|
|
|
|
|
def generate_answer(community_summaries, query): |
|
intermediate_answers = [] |
|
for index, summary in enumerate(community_summaries): |
|
print(f"Answering community {index+1}/{len(community_summaries)}:") |
|
response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "Answer the following query based on the provided summary.", |
|
}, |
|
{"role": "user", "content": f"Query: {query} Summary: {summary}"}, |
|
], |
|
) |
|
print("Intermediate answer:", response.choices[0].message.content) |
|
intermediate_answers.append(response.choices[0].message.content) |
|
|
|
final_response = client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "Combine these answers into a final, concise response.", |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"Intermediate answers: {intermediate_answers}", |
|
}, |
|
], |
|
) |
|
final_answer = final_response.choices[0].message.content |
|
return final_answer |
|
|
|
|
|
def graphrag_pipeline(documents, query): |
|
chunks = get_chunks(documents) |
|
elements = extract_elements(chunks) |
|
summaries = summarize_elements(elements) |
|
graph = build_graph(summaries) |
|
num_entities = graph.number_of_nodes() |
|
print(f"Number of entities in the graph: {num_entities}") |
|
draw_graph(graph) |
|
communities = detect_communities(graph) |
|
print(communities) |
|
community_summaries = summarize_communities(communities, graph) |
|
final_answer = generate_answer(community_summaries, query) |
|
return final_answer |
|
|
|
|
|
query = "What factors in these articles can impact medical inflation in the UK in the short term?" |
|
|
|
print("Query:", query) |
|
answer = graphrag_pipeline(DOCUMENTS, query) |
|
print("Answer:", answer) |
|
|