|
from typing import List, Dict, Any, Tuple, Optional |
|
import spacy |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
import base64 |
|
import re |
|
import json |
|
from langchain_core.messages import HumanMessage |
|
from langchain.chat_models import init_chat_model |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
from pyvis.network import Network |
|
|
|
|
|
_ = load_dotenv() |
|
|
|
class LLMKnowledgeGraph: |
|
def __init__(self, model: str = "gemini-2.0-flash", model_provider: str = "google_genai"): |
|
"""Initialize the LLM for knowledge graph generation.""" |
|
self.llm = init_chat_model( |
|
model=model, |
|
model_provider=model_provider, |
|
temperature=0.1, |
|
max_tokens=2000 |
|
) |
|
self.entity_prompt = """ |
|
Extract all named entities from the following text and categorize them into the following types: |
|
- PERSON: People, including fictional |
|
- ORG: Companies, agencies, institutions, etc. |
|
- GPE: Countries, cities, states |
|
- DATE: Absolute or relative dates or periods |
|
- MONEY: Monetary values |
|
- PERCENT: Percentage values |
|
- QUANTITY: Measurements, weights, distances |
|
- EVENT: Named hurricanes, battles, wars, sports events, etc. |
|
- WORK_OF_ART: Titles of books, songs, etc. |
|
- LAW: Legal document titles |
|
- LANGUAGE: Any named language |
|
|
|
Return the entities in JSON format with the following structure: |
|
[ |
|
{"text": "entity text", "label": "ENTITY_TYPE", "start": character_start, "end": character_end} |
|
] |
|
|
|
Text: """ |
|
|
|
self.relation_prompt = """ |
|
Analyze the following text and extract relationships between entities in the form of subject-relation-object triples. |
|
For each relation, provide: |
|
- The subject (entity that is the source of the relation) |
|
- The relation type (e.g., 'works at', 'located in', 'part of') |
|
- The object (entity that is the target of the relation) |
|
|
|
Return the relations in JSON format with the following structure: |
|
[ |
|
{"subject": "subject text", "relation": "relation type", "object": "object text"} |
|
] |
|
|
|
Text: """ |
|
|
|
def extract_entities_with_llm(self, text: str) -> List[Dict[str, Any]]: |
|
"""Extract entities from text using LLM.""" |
|
try: |
|
response = self.llm.invoke([HumanMessage(content=self.entity_prompt + text)]) |
|
|
|
if hasattr(response, 'content'): |
|
content = response.content |
|
else: |
|
content = str(response) |
|
|
|
|
|
content = content.strip() |
|
if content.startswith('```json'): |
|
content = content[content.find('['):content.rfind(']')+1] |
|
elif content.startswith('['): |
|
content = content[:content.rfind(']')+1] |
|
|
|
entities = json.loads(content) |
|
return entities |
|
except Exception as e: |
|
print(f"Error extracting entities with LLM: {str(e)}") |
|
print(f"Response content: {getattr(response, 'content', str(response))}") |
|
return [] |
|
|
|
def extract_relations_with_llm(self, text: str) -> List[Dict[str, str]]: |
|
"""Extract relations between entities using LLM.""" |
|
try: |
|
response = self.llm.invoke([HumanMessage(content=self.relation_prompt + text)]) |
|
|
|
if hasattr(response, 'content'): |
|
content = response.content |
|
else: |
|
content = str(response) |
|
|
|
|
|
content = content.strip() |
|
if content.startswith('```json'): |
|
content = content[content.find('['):content.rfind(']')+1] |
|
elif content.startswith('['): |
|
content = content[:content.rfind(']')+1] |
|
|
|
relations = json.loads(content) |
|
return relations |
|
except Exception as e: |
|
print(f"Error extracting relations with LLM: {str(e)}") |
|
print(f"Response content: {getattr(response, 'content', str(response))}") |
|
return [] |
|
|
|
def extract_relations(text: str, model_name: str = "gemini-2.0-flash", use_llm: bool = True) -> Dict[str, Any]: |
|
""" |
|
Extract entities and their relations from text to build a knowledge graph. |
|
|
|
Args: |
|
text: Input text to process |
|
model_name: Name of the model to use (spaCy model or LLM) |
|
use_llm: Whether to use LLM for relation extraction (default: True) |
|
|
|
Returns: |
|
Dictionary containing nodes and edges for the knowledge graph |
|
""" |
|
if use_llm: |
|
|
|
kg_extractor = LLMKnowledgeGraph(model=model_name) |
|
|
|
|
|
entities = kg_extractor.extract_entities_with_llm(text) |
|
|
|
|
|
relations = kg_extractor.extract_relations_with_llm(text) |
|
else: |
|
|
|
try: |
|
nlp = spacy.load(model_name) |
|
except OSError: |
|
|
|
import subprocess |
|
import sys |
|
subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name]) |
|
nlp = spacy.load(model_name) |
|
|
|
|
|
doc = nlp(text) |
|
|
|
|
|
entities = [{"text": ent.text, "label": ent.label_, "start": ent.start_char, "end": ent.end_char} |
|
for ent in doc.ents] |
|
|
|
|
|
relations = [] |
|
for sent in doc.sents: |
|
for token in sent: |
|
if token.dep_ in ("ROOT", "nsubj", "dobj"): |
|
subj = "" |
|
obj = "" |
|
relation = "" |
|
|
|
|
|
if token.dep_ == "nsubj" and token.head.pos_ == "VERB": |
|
subj = token.text |
|
relation = token.head.lemma_ |
|
|
|
for child in token.head.children: |
|
if child.dep_ == "dobj": |
|
obj = child.text |
|
break |
|
|
|
if subj and obj and relation: |
|
relations.append({ |
|
"subject": subj, |
|
"relation": relation, |
|
"object": obj |
|
}) |
|
|
|
return { |
|
"entities": entities, |
|
"relations": relations |
|
} |
|
|
|
def build_nx_graph(entities: List[Dict], relations: List[Dict]) -> nx.DiGraph: |
|
"""Build a NetworkX DiGraph from entities and relations. Ensure all nodes have a 'label'.""" |
|
G = nx.DiGraph() |
|
|
|
for entity in entities: |
|
label = entity.get("label") or entity.get("type") or "ENTITY" |
|
text = entity.get("text") or entity.get("word") |
|
G.add_node(text, label=label, type="entity") |
|
|
|
for rel in relations: |
|
subj = rel.get("subject") |
|
obj = rel.get("object") |
|
rel_label = rel.get("relation", "related_to") |
|
if subj is not None and subj not in G: |
|
G.add_node(subj, label="ENTITY", type="entity") |
|
if obj is not None and obj not in G: |
|
G.add_node(obj, label="ENTITY", type="entity") |
|
G.add_edge(subj, obj, label=rel_label) |
|
return G |
|
|
|
def visualize_knowledge_graph(entities: List[Dict], relations: List[Dict]) -> str: |
|
""" |
|
Generate a static PNG visualization of the knowledge graph, returned as base64 string for HTML embedding. |
|
""" |
|
G = build_nx_graph(entities, relations) |
|
plt.figure(figsize=(12, 8)) |
|
pos = nx.spring_layout(G, k=0.5, iterations=50) |
|
|
|
entity_types = list(set([d.get('label', 'ENTITY') for n, d in G.nodes(data=True)])) |
|
color_map = {etype: plt.cm.tab20(i % 20) for i, etype in enumerate(entity_types)} |
|
node_colors = [color_map[d.get('label', 'ENTITY')] for n, d in G.nodes(data=True)] |
|
nx.draw_networkx_nodes(G, pos, node_size=2000, node_color=node_colors, alpha=0.8) |
|
nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True, arrowsize=20) |
|
nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold') |
|
edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)} |
|
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8) |
|
buf = BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight') |
|
plt.close() |
|
img_str = base64.b64encode(buf.getvalue()).decode('utf-8') |
|
return f"data:image/png;base64,{img_str}" |
|
|
|
def visualize_knowledge_graph_interactive(entities: List[Dict], relations: List[Dict]) -> str: |
|
""" |
|
Generate an interactive HTML visualization of the knowledge graph using pyvis. |
|
Returns HTML as a string for embedding in Gradio or web UI. |
|
""" |
|
G = build_nx_graph(entities, relations) |
|
net = Network(height="600px", width="100%", directed=True, notebook=False) |
|
|
|
entity_types = list(set([d.get('label', 'ENTITY') for n, d in G.nodes(data=True)])) |
|
color_palette = ["#e3f2fd", "#e8f5e9", "#fff8e1", "#f3e5f5", "#e8eaf6", "#e0f7fa", "#f1f8e9", "#fce4ec", "#e8f5e9", "#f5f5f5", "#fafafa", "#e1f5fe", "#fff3e0", "#d7ccc8", "#f9fbe7", "#fbe9e7", "#ede7f6", "#e0f2f1"] |
|
color_map = {etype: color_palette[i % len(color_palette)] for i, etype in enumerate(entity_types)} |
|
for n, d in G.nodes(data=True): |
|
label = d.get('label', 'ENTITY') |
|
net.add_node(n, label=n, title=f"{n}<br>Type: {label}", color=color_map[label]) |
|
for u, v, d in G.edges(data=True): |
|
net.add_edge(u, v, label=d['label'], title=d['label']) |
|
net.set_options('''var options = { "edges": { "arrows": {"to": {"enabled": true}}, "color": {"color": "#888"} }, "nodes": { "font": {"size": 18} }, "physics": { "enabled": true } };''') |
|
html_buf = BytesIO() |
|
net.write_html(html_buf) |
|
html_buf.seek(0) |
|
html = html_buf.read().decode('utf-8') |
|
|
|
body_start = html.find('<body>') + len('<body>') |
|
body_end = html.find('</body>') |
|
body_content = html[body_start:body_end] |
|
return body_content |
|
|
|
def build_knowledge_graph(text: str, model_name: str = "gemini-2.0-flash", use_llm: bool = True) -> Dict[str, Any]: |
|
""" |
|
Main function to build a knowledge graph from text. |
|
|
|
Args: |
|
text: Input text to process |
|
model_name: Name of the model to use (spaCy model or LLM) |
|
use_llm: Whether to use LLM for relation extraction (default: True) |
|
|
|
Returns: |
|
Dictionary containing the knowledge graph data and visualization |
|
""" |
|
|
|
result = extract_relations(text, model_name, use_llm) |
|
|
|
|
|
if result.get("entities") and result.get("relations"): |
|
visualization = visualize_knowledge_graph(result["entities"], result["relations"]) |
|
result["visualization"] = visualization |
|
else: |
|
result["visualization"] = None |
|
|
|
return result |
|
|