Remsky commited on
Commit
4289090
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files
Files changed (8) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +94 -0
  4. lib/__init__.py +0 -0
  5. lib/graph_extract.py +142 -0
  6. lib/samples.py +46 -0
  7. lib/visualize.py +111 -0
  8. requirements.txt +7 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Triplex Knowledge Graph Visualizer
3
+ emoji: 🕸️
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: true
9
+ models:
10
+ - SciPhi/Triplex
11
+ preload_from_hub:
12
+ - SciPhi/Triplex
13
+ ---
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import random
3
+
4
+ import gradio as gr
5
+ import spaces
6
+
7
+ from lib.graph_extract import triplextract, parse_triples
8
+ from lib.visualize import create_cytoscape_plot
9
+ from lib.samples import snippets
10
+
11
+ WORD_LIMIT = 300
12
+
13
+ @spaces.GPU
14
+ def process_text(text, entity_types, predicates):
15
+ if not text:
16
+ return None, "Please enter some text."
17
+
18
+ words = text.split()
19
+ if len(words) > WORD_LIMIT:
20
+ return None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
21
+
22
+ entity_types = [et.strip() for et in entity_types.split(",") if et.strip()]
23
+ predicates = [p.strip() for p in predicates.split(",") if p.strip()]
24
+
25
+ if not entity_types:
26
+ return None, "Please enter at least one entity type."
27
+ if not predicates:
28
+ return None, "Please enter at least one predicate."
29
+
30
+ try:
31
+ prediction = triplextract(text, entity_types, predicates)
32
+ if prediction.startswith("Error"):
33
+ return None, prediction
34
+
35
+ entities, relationships = parse_triples(prediction)
36
+
37
+ if not entities and not relationships:
38
+ return (
39
+ None,
40
+ "No entities or relationships found. Try different text or check your input.",
41
+ )
42
+
43
+ fig = create_cytoscape_plot(entities, relationships)
44
+ return (
45
+ fig,
46
+ f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
47
+ )
48
+ except Exception as e:
49
+ print(f"Error in process_text: {e}")
50
+ return None, f"An error occurred: {str(e)}"
51
+
52
+ def update_inputs(sample_name):
53
+ sample = snippets[sample_name]
54
+ return sample.text_input, sample.entity_types, sample.predicates
55
+
56
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
57
+ gr.Markdown("# Knowledge Graph Extractor")
58
+
59
+ default_sample_name = random.choice(list(snippets.keys()))
60
+ default_sample = snippets[default_sample_name]
61
+
62
+ with gr.Row():
63
+ with gr.Column(scale=1):
64
+ sample_dropdown = gr.Dropdown(
65
+ choices=list(snippets.keys()),
66
+ label="Select Sample",
67
+ value=default_sample_name
68
+ )
69
+ input_text = gr.Textbox(
70
+ label="Input Text",
71
+ lines=5,
72
+ value=default_sample.text_input
73
+ )
74
+ entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types)
75
+ predicates = gr.Textbox(label="Predicates", value=default_sample.predicates)
76
+ submit_btn = gr.Button("Extract Knowledge Graph")
77
+ with gr.Column(scale=2):
78
+ output_graph = gr.Plot(label="Knowledge Graph")
79
+ error_message = gr.Textbox(label="Textual Output")
80
+
81
+ sample_dropdown.change(
82
+ update_inputs,
83
+ inputs=[sample_dropdown],
84
+ outputs=[input_text, entity_types, predicates]
85
+ )
86
+
87
+ submit_btn.click(
88
+ process_text,
89
+ inputs=[input_text, entity_types, predicates],
90
+ outputs=[output_graph, error_message],
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()
lib/__init__.py ADDED
File without changes
lib/graph_extract.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+ import torch
5
+ import warnings
6
+ import spaces
7
+
8
+ flash_attn_installed = False
9
+ try:
10
+ import subprocess
11
+ print("Installing flash-attn...")
12
+ subprocess.run(
13
+ "pip install flash-attn --no-build-isolation",
14
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
15
+ shell=True,
16
+ )
17
+ flash_attn_installed = True
18
+ except Exception as e:
19
+ print(f"Error installing flash-attn: {e}")
20
+
21
+
22
+ # Suppress specific warnings
23
+ warnings.filterwarnings(
24
+ "ignore",
25
+ message="You have modified the pretrained model configuration to control generation.",
26
+ )
27
+ warnings.filterwarnings(
28
+ "ignore",
29
+ message="You are not running the flash-attention implementation, expect numerical differences.",
30
+ )
31
+
32
+ print("Initializing application...")
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ "sciphi/triplex",
36
+ trust_remote_code=True,
37
+ attn_implementation="flash_attention_2" if flash_attn_installed else None,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="auto",
40
+ low_cpu_mem_usage=True,#advised if any device map given
41
+ ).eval()
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(
44
+ "sciphi/triplex",
45
+ trust_remote_code=True,
46
+ attn_implementation="flash_attention_2",
47
+ torch_dtype=torch.bfloat16,
48
+ )
49
+
50
+
51
+ print("Model and tokenizer loaded successfully.")
52
+
53
+ # Set up generation config
54
+ generation_config = GenerationConfig.from_pretrained("sciphi/triplex")
55
+ generation_config.max_length = 2048
56
+ generation_config.pad_token_id = tokenizer.eos_token_id
57
+ @spaces.GPU
58
+ def triplextract(text, entity_types, predicates):
59
+ input_format = """Perform Named Entity Recognition (NER) and extract knowledge graph triplets from the text. NER identifies named entities of given entity types, and triple extraction identifies relationships between entities using specified predicates. Return the result as a JSON object with an "entities_and_triples" key containing an array of entities and triples.
60
+ **Entity Types:**
61
+ {entity_types}
62
+ **Predicates:**
63
+ {predicates}
64
+ **Text:**
65
+ {text}
66
+ """
67
+ message = input_format.format(
68
+ entity_types = json.dumps({"entity_types": entity_types}),
69
+ predicates = json.dumps({"predicates": predicates}),
70
+ text = text)
71
+
72
+ # message = input_format.format(
73
+ # entity_types=entity_types, predicates=predicates, text=text
74
+ # )
75
+
76
+ messages = [{"role": "user", "content": message}]
77
+
78
+ print("Tokenizing input...")
79
+ input_ids = tokenizer.apply_chat_template(
80
+ messages, add_generation_prompt=True, return_tensors="pt"
81
+ ).to(model.device)
82
+
83
+ attention_mask = input_ids.ne(tokenizer.pad_token_id)
84
+
85
+ print("Generating output...")
86
+ try:
87
+ with torch.no_grad():
88
+ output = model.generate(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ generation_config=generation_config,
92
+ do_sample=True,
93
+ )
94
+
95
+ decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
96
+ print("Decoding output completed.")
97
+
98
+ return decoded_output
99
+ except torch.cuda.OutOfMemoryError as e:
100
+ print(f"CUDA out of memory error: {e}")
101
+ return "Error: CUDA out of memory."
102
+ except Exception as e:
103
+ print(f"Error in generation: {e}")
104
+ return f"Error in generation: {str(e)}"
105
+
106
+ def parse_triples(prediction):
107
+ entities = {}
108
+ relationships = []
109
+
110
+ try:
111
+ data = json.loads(prediction)
112
+ items = data.get("entities_and_triples", [])
113
+ except json.JSONDecodeError:
114
+ json_match = re.search(r"```json\s*(.*?)\s*```", prediction, re.DOTALL)
115
+ if json_match:
116
+ try:
117
+ data = json.loads(json_match.group(1))
118
+ items = data.get("entities_and_triples", [])
119
+ except json.JSONDecodeError:
120
+ items = re.findall(r"\[(.*?)\]", prediction)
121
+ else:
122
+ items = re.findall(r"\[(.*?)\]", prediction)
123
+
124
+ for item in items:
125
+ if isinstance(item, str):
126
+ if ":" in item:
127
+ id, entity = item.split(",", 1)
128
+ id = id.strip("[]").strip()
129
+ entity_type, entity_value = entity.split(":", 1)
130
+ entities[id] = {
131
+ "type": entity_type.strip(),
132
+ "value": entity_value.strip(),
133
+ }
134
+ else:
135
+ parts = item.split()
136
+ if len(parts) >= 3:
137
+ source = parts[0].strip("[]")
138
+ relation = " ".join(parts[1:-1])
139
+ target = parts[-1].strip("[]")
140
+ relationships.append((source, relation.strip(), target))
141
+
142
+ return entities, relationships
lib/samples.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ Snippet = namedtuple('Snippet', ['text_input', 'entity_types', 'predicates'])
4
+
5
+ snippets = {
6
+ 'paris': Snippet(
7
+ text_input="""Paris is the capital of France. It has a population of 2.16 million people.
8
+ The Eiffel Tower, located in Paris, is a famous landmark with a height of 324 meters.
9
+ Paris is known for its romantic atmosphere.""",
10
+ entity_types="LOCATION, POPULATION, STYLE",
11
+ predicates="HAS, IS"
12
+ ),
13
+
14
+ 'dickens': Snippet(
15
+ text_input="""It was the best of times, it was the worst of times, it was the age of wisdom,
16
+ it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity,
17
+ it was the season of Light, it was the season of Darkness, it was the spring of hope,
18
+ it was the winter of despair, we had everything before us, we had nothing before us,
19
+ we were all going direct to Heaven, we were all going direct the other way – in short,
20
+ the period was so far like the present period, that some of its noisiest authorities
21
+ insisted on its being received, for good or for evil, in the superlative degree of comparison only.""",
22
+ entity_types="TIME, EMOTION, LOCATION, EVENT, OUTCOME, PLACE",
23
+ predicates="WAS, HAD, WERE"
24
+ ),
25
+
26
+ 'tech_company': Snippet(
27
+ text_input="""Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne in 1976.
28
+ Headquartered in Cupertino, California, Apple designs and produces consumer electronics,
29
+ software, and online services. The company's flagship products include the iPhone smartphone,
30
+ iPad tablet, and Mac personal computer. As of 2023, Apple has over 150,000 employees worldwide
31
+ and generates annual revenue exceeding $350 billion.""",
32
+ entity_types="COMPANY, PERSON, PRODUCT, LOCATION, DATE, NUMBER, EVENT, SUBJECT",
33
+ predicates="FOUNDED, HEADQUARTERED_IN, PRODUCES, HAS, EMPLOYEES, "
34
+ ),
35
+
36
+ 'climate_change': Snippet(
37
+ text_input="""Global warming is causing significant changes to Earth's climate. The average global
38
+ temperature has increased by approximately 1.1°C since the pre-industrial era. This warming is
39
+ primarily caused by human activities, particularly the emission of greenhouse gases like carbon dioxide.
40
+ The Paris Agreement, signed in 2015, aims to limit global temperature increase to well below 2°C above
41
+ pre-industrial levels. To achieve this goal, many countries are implementing policies to reduce carbon
42
+ emissions and transition to renewable energy sources.""",
43
+ entity_types="PHENOMENON, PLANET, TEMPERATURE, CAUSE, CHEMICAL, AGREEMENT, DATE, GOAL, POLICY",
44
+ predicates="CAUSES, INCREASED_BY, CAUSED_BY, SIGNED_IN, AIMS_TO, IMPLEMENTING"
45
+ )
46
+ }
lib/visualize.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import networkx as nx
3
+
4
+ import plotly.graph_objects as go
5
+ import networkx as nx
6
+
7
+ def create_cytoscape_plot(entities, relationships):
8
+ G = nx.DiGraph() # Use DiGraph for directed edges
9
+
10
+ for entity_id, entity_data in entities.items():
11
+ G.add_node(entity_id, **entity_data)
12
+
13
+ for source, relation, target in relationships:
14
+ G.add_edge(source, target, relation=relation)
15
+
16
+ pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters
17
+
18
+ edge_trace = go.Scatter(
19
+ x=[],
20
+ y=[],
21
+ line=dict(width=1, color="#888"),
22
+ hoverinfo="text",
23
+ mode="lines",
24
+ text=[],
25
+ )
26
+
27
+ node_trace = go.Scatter(
28
+ x=[],
29
+ y=[],
30
+ mode="markers+text",
31
+ hoverinfo="text",
32
+ marker=dict(
33
+ showscale=True,
34
+ colorscale="Viridis",
35
+ reversescale=True,
36
+ color=[],
37
+ size=15,
38
+ colorbar=dict(
39
+ thickness=15,
40
+ title="Node Connections",
41
+ xanchor="left",
42
+ titleside="right",
43
+ ),
44
+ line_width=2,
45
+ ),
46
+ text=[],
47
+ textposition="top center",
48
+ )
49
+
50
+ edge_labels = []
51
+
52
+ for edge in G.edges():
53
+ x0, y0 = pos[edge[0]]
54
+ x1, y1 = pos[edge[1]]
55
+ edge_trace["x"] += (x0, x1, None)
56
+ edge_trace["y"] += (y0, y1, None)
57
+
58
+ # Calculate midpoint for edge label
59
+ mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
60
+ edge_labels.append(
61
+ go.Scatter(
62
+ x=[mid_x],
63
+ y=[mid_y],
64
+ mode="text",
65
+ text=[G.edges[edge]["relation"]],
66
+ textposition="middle center",
67
+ hoverinfo="none",
68
+ showlegend=False,
69
+ textfont=dict(size=8),
70
+ )
71
+ )
72
+
73
+ for node in G.nodes():
74
+ x, y = pos[node]
75
+ node_trace["x"] += (x,)
76
+ node_trace["y"] += (y,)
77
+ node_info = f"{entities[node]['value']} ({entities[node]['type']})"
78
+ node_trace["text"] += (node_info,)
79
+ node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
80
+
81
+ fig = go.Figure(
82
+ data=[edge_trace, node_trace] + edge_labels,
83
+ layout=go.Layout(
84
+ title="Knowledge Graph",
85
+ titlefont_size=16,
86
+ showlegend=False,
87
+ hovermode="closest",
88
+ margin=dict(b=20, l=5, r=5, t=40),
89
+ annotations=[],
90
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
91
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
92
+ width=800,
93
+ height=600,
94
+ ),
95
+ )
96
+
97
+ # Enable dragging of nodes
98
+ fig.update_layout(
99
+ newshape=dict(line_color="#009900"),
100
+ # Enable zoom
101
+ xaxis=dict(
102
+ scaleanchor="y",
103
+ scaleratio=1,
104
+ ),
105
+ yaxis=dict(
106
+ scaleanchor="x",
107
+ scaleratio=1,
108
+ ),
109
+ )
110
+
111
+ return fig
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.39.0
2
+ plotly==5.23.0
3
+ matplotlib==3.7.2
4
+ torch==2.0.1
5
+ transformers==4.43.3
6
+ accelerate==0.33.0
7
+ networkx