from typing import List, Dict, Any, Tuple, Optional import spacy import networkx as nx import matplotlib.pyplot as plt from io import BytesIO import base64 import re import json from langchain_core.messages import HumanMessage from langchain.chat_models import init_chat_model from dotenv import load_dotenv import os # Interactive visualization from pyvis.network import Network # Load environment variables _ = load_dotenv() class LLMKnowledgeGraph: def __init__(self, model: str = "gemini-2.0-flash", model_provider: str = "google_genai"): """Initialize the LLM for knowledge graph generation.""" self.llm = init_chat_model( model=model, model_provider=model_provider, temperature=0.1, # Lower temperature for more deterministic results max_tokens=2000 ) self.entity_prompt = """ Extract all named entities from the following text and categorize them into the following types: - PERSON: People, including fictional - ORG: Companies, agencies, institutions, etc. - GPE: Countries, cities, states - DATE: Absolute or relative dates or periods - MONEY: Monetary values - PERCENT: Percentage values - QUANTITY: Measurements, weights, distances - EVENT: Named hurricanes, battles, wars, sports events, etc. - WORK_OF_ART: Titles of books, songs, etc. - LAW: Legal document titles - LANGUAGE: Any named language Return the entities in JSON format with the following structure: [ {"text": "entity text", "label": "ENTITY_TYPE", "start": character_start, "end": character_end} ] Text: """ self.relation_prompt = """ Analyze the following text and extract relationships between entities in the form of subject-relation-object triples. For each relation, provide: - The subject (entity that is the source of the relation) - The relation type (e.g., 'works at', 'located in', 'part of') - The object (entity that is the target of the relation) Return the relations in JSON format with the following structure: [ {"subject": "subject text", "relation": "relation type", "object": "object text"} ] Text: """ def extract_entities_with_llm(self, text: str) -> List[Dict[str, Any]]: """Extract entities from text using LLM.""" try: response = self.llm.invoke([HumanMessage(content=self.entity_prompt + text)]) # Handle case where response might be a string or a message object if hasattr(response, 'content'): content = response.content else: content = str(response) # Clean the response to ensure it's valid JSON content = content.strip() if content.startswith('```json'): content = content[content.find('['):content.rfind(']')+1] elif content.startswith('['): content = content[:content.rfind(']')+1] entities = json.loads(content) return entities except Exception as e: print(f"Error extracting entities with LLM: {str(e)}") print(f"Response content: {getattr(response, 'content', str(response))}") return [] def extract_relations_with_llm(self, text: str) -> List[Dict[str, str]]: """Extract relations between entities using LLM.""" try: response = self.llm.invoke([HumanMessage(content=self.relation_prompt + text)]) # Handle case where response might be a string or a message object if hasattr(response, 'content'): content = response.content else: content = str(response) # Clean the response to ensure it's valid JSON content = content.strip() if content.startswith('```json'): content = content[content.find('['):content.rfind(']')+1] elif content.startswith('['): content = content[:content.rfind(']')+1] relations = json.loads(content) return relations except Exception as e: print(f"Error extracting relations with LLM: {str(e)}") print(f"Response content: {getattr(response, 'content', str(response))}") return [] def extract_relations(text: str, model_name: str = "gemini-2.0-flash", use_llm: bool = True) -> Dict[str, Any]: """ Extract entities and their relations from text to build a knowledge graph. Args: text: Input text to process model_name: Name of the model to use (spaCy model or LLM) use_llm: Whether to use LLM for relation extraction (default: True) Returns: Dictionary containing nodes and edges for the knowledge graph """ if use_llm: # Use LLM for both entity and relation extraction kg_extractor = LLMKnowledgeGraph(model=model_name) # Extract entities using LLM entities = kg_extractor.extract_entities_with_llm(text) # Extract relations using LLM relations = kg_extractor.extract_relations_with_llm(text) else: # Fallback to spaCy for entity and relation extraction try: nlp = spacy.load(model_name) except OSError: # If model is not found, download it import subprocess import sys subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name]) nlp = spacy.load(model_name) # Process the text doc = nlp(text) # Extract entities entities = [{"text": ent.text, "label": ent.label_, "start": ent.start_char, "end": ent.end_char} for ent in doc.ents] # Extract relations (subject-verb-object) relations = [] for sent in doc.sents: for token in sent: if token.dep_ in ("ROOT", "nsubj", "dobj"): subj = "" obj = "" relation = "" # Find subject if token.dep_ == "nsubj" and token.head.pos_ == "VERB": subj = token.text relation = token.head.lemma_ # Find object for child in token.head.children: if child.dep_ == "dobj": obj = child.text break if subj and obj and relation: relations.append({ "subject": subj, "relation": relation, "object": obj }) return { "entities": entities, "relations": relations } def build_nx_graph(entities: List[Dict], relations: List[Dict]) -> nx.DiGraph: """Build a NetworkX DiGraph from entities and relations. Ensure all nodes have a 'label'.""" G = nx.DiGraph() # Add entities as nodes for entity in entities: label = entity.get("label") or entity.get("type") or "ENTITY" text = entity.get("text") or entity.get("word") G.add_node(text, label=label, type="entity") # Add edges and ensure nodes exist with label for rel in relations: subj = rel.get("subject") obj = rel.get("object") rel_label = rel.get("relation", "related_to") if subj is not None and subj not in G: G.add_node(subj, label="ENTITY", type="entity") if obj is not None and obj not in G: G.add_node(obj, label="ENTITY", type="entity") G.add_edge(subj, obj, label=rel_label) return G def visualize_knowledge_graph(entities: List[Dict], relations: List[Dict]) -> str: """ Generate a static PNG visualization of the knowledge graph, returned as base64 string for HTML embedding. """ G = build_nx_graph(entities, relations) plt.figure(figsize=(12, 8)) pos = nx.spring_layout(G, k=0.5, iterations=50) # Color nodes by entity type entity_types = list(set([d.get('label', 'ENTITY') for n, d in G.nodes(data=True)])) color_map = {etype: plt.cm.tab20(i % 20) for i, etype in enumerate(entity_types)} node_colors = [color_map[d.get('label', 'ENTITY')] for n, d in G.nodes(data=True)] nx.draw_networkx_nodes(G, pos, node_size=2000, node_color=node_colors, alpha=0.8) nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True, arrowsize=20) nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold') edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)} nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8) buf = BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') plt.close() img_str = base64.b64encode(buf.getvalue()).decode('utf-8') return f"data:image/png;base64,{img_str}" def visualize_knowledge_graph_interactive(entities: List[Dict], relations: List[Dict]) -> str: """ Generate an interactive HTML visualization of the knowledge graph using pyvis. Returns HTML as a string for embedding in Gradio or web UI. """ G = build_nx_graph(entities, relations) net = Network(height="600px", width="100%", directed=True, notebook=False) # Color map for entity types entity_types = list(set([d.get('label', 'ENTITY') for n, d in G.nodes(data=True)])) color_palette = ["#e3f2fd", "#e8f5e9", "#fff8e1", "#f3e5f5", "#e8eaf6", "#e0f7fa", "#f1f8e9", "#fce4ec", "#e8f5e9", "#f5f5f5", "#fafafa", "#e1f5fe", "#fff3e0", "#d7ccc8", "#f9fbe7", "#fbe9e7", "#ede7f6", "#e0f2f1"] color_map = {etype: color_palette[i % len(color_palette)] for i, etype in enumerate(entity_types)} for n, d in G.nodes(data=True): label = d.get('label', 'ENTITY') net.add_node(n, label=n, title=f"{n}
Type: {label}", color=color_map[label]) for u, v, d in G.edges(data=True): net.add_edge(u, v, label=d['label'], title=d['label']) net.set_options('''var options = { "edges": { "arrows": {"to": {"enabled": true}}, "color": {"color": "#888"} }, "nodes": { "font": {"size": 18} }, "physics": { "enabled": true } };''') html_buf = BytesIO() net.write_html(html_buf) html_buf.seek(0) html = html_buf.read().decode('utf-8') # Remove , wrappers to allow embedding in Gradio body_start = html.find('') + len('') body_end = html.find('') body_content = html[body_start:body_end] return body_content def build_knowledge_graph(text: str, model_name: str = "gemini-2.0-flash", use_llm: bool = True) -> Dict[str, Any]: """ Main function to build a knowledge graph from text. Args: text: Input text to process model_name: Name of the model to use (spaCy model or LLM) use_llm: Whether to use LLM for relation extraction (default: True) Returns: Dictionary containing the knowledge graph data and visualization """ # Extract entities and relations result = extract_relations(text, model_name, use_llm) # Generate visualization if result.get("entities") and result.get("relations"): visualization = visualize_knowledge_graph(result["entities"], result["relations"]) result["visualization"] = visualization else: result["visualization"] = None return result