missvector commited on
Commit
71f8a1b
·
verified ·
1 Parent(s): 68c5c35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import ast
4
- from graphviz import Digraph
 
5
 
6
  client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")
7
 
@@ -81,15 +82,29 @@ def sampling(num_samples, num_associations):
81
  top_p=0.1
82
  ).choices[0].get('message')['content']))
83
 
84
- dot = Digraph(comment=f'SynthNet, {num_samples} samples, {num_associations} associations', graph_attr={'rankdir': 'LR'})
85
 
86
  for entry in triplets:
87
  source, label, target = entry['properties']
88
- dot.node(source, source)
89
- dot.node(target, target)
90
- dot.edge(source, target, label=label)
91
 
92
- dot.render('synthnet', format='png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  return 'synthnet.png'
94
 
95
  demo = gr.Interface(
@@ -104,4 +119,4 @@ demo = gr.Interface(
104
  )
105
 
106
  if __name__ == "__main__":
107
- demo.launch(share=True)
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import ast
4
+ import networkx as nx
5
+ import matplotlib.pyplot as plt
6
 
7
  client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")
8
 
 
82
  top_p=0.1
83
  ).choices[0].get('message')['content']))
84
 
85
+ G = nx.DiGraph()
86
 
87
  for entry in triplets:
88
  source, label, target = entry['properties']
89
+ G.add_node(source, label=source)
90
+ G.add_node(target, label=target)
91
+ G.add_edge(source, target, label=label)
92
 
93
+ pos = nx.spring_layout(G)
94
+ nx.draw_networkx_nodes(G, pos, node_size=500, node_color='lightblue')
95
+
96
+ edge_labels = nx.get_edge_attributes(G, 'label') # Get edge labels
97
+ nx.draw_networkx_edges(G, pos, arrowstyle='->', arrowsize=25)
98
+ nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
99
+
100
+ node_labels = nx.get_node_attributes(G, 'label') # Get node labels
101
+ nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_family="sans-serif")
102
+
103
+ plt.axis('off')
104
+ plt.tight_layout()
105
+
106
+ plt.savefig('synthnet.png')
107
+ plt.close()
108
  return 'synthnet.png'
109
 
110
  demo = gr.Interface(
 
119
  )
120
 
121
  if __name__ == "__main__":
122
+ demo.launch(share=True)