Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import ast | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
client = InferenceClient("Qwen/Qwen2.5-72B-Instruct") | |
def sampling(num_samples, num_associations): | |
outputs = ast.literal_eval(client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n" | |
"words: []\n" | |
f"{num_samples} samples in a list" | |
}, | |
{"role": "user", | |
"content": f"synthesize {num_samples} random but widespread words for semantic modeling"}, | |
], | |
response_format={ | |
"type": "json", | |
"value": { | |
"properties": { | |
"words": {"type": "array", "items": {"type": "string"}}, | |
} | |
} | |
}, | |
stream=False, | |
max_tokens=1024, | |
temperature=0.7, | |
top_p=0.1 | |
).choices[0].get('message')['content']) | |
fields = {} | |
for word in outputs['words']: | |
fields[word] = ast.literal_eval(client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n' | |
'associations: []' | |
}, | |
{"role": "user", | |
"content": f"synthesize {num_associations} associations for the word {word}"}, | |
], | |
response_format={ | |
"type": "json", | |
"value": { | |
"properties": { | |
"associations": {"type": "array", "items": {"type": "string"}} | |
} | |
} | |
}, | |
stream=False, | |
max_tokens=2000, | |
temperature=0.7, | |
top_p=0.1 | |
).choices[0].get('message')['content']) | |
triplets = [] | |
for cluster in fields: | |
for association in fields[cluster]['associations']: | |
triplets.append(ast.literal_eval(client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n" | |
"properties: [subject, predicate, object]\n" | |
"use chain-of-thought for predictions" | |
}, | |
{"role": "user", | |
"content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"}, | |
], | |
response_format={ | |
"type": "json", | |
"value": { | |
"properties": { | |
"properties": {"type": "array", "items": {"type": "string"}} | |
} | |
} | |
}, | |
stream=False, | |
max_tokens=128, | |
temperature=0.7, | |
top_p=0.1 | |
).choices[0].get('message')['content'])) | |
G = nx.DiGraph() | |
for entry in triplets: | |
source, label, target = entry['properties'] | |
G.add_node(source, label=source) | |
G.add_node(target, label=target) | |
G.add_edge(source, target, label=label) | |
pos = nx.spring_layout(G) | |
nx.draw_networkx_nodes(G, pos, node_size=500, node_color='lightblue') | |
edge_labels = nx.get_edge_attributes(G, 'label') # Get edge labels | |
nx.draw_networkx_edges(G, pos, arrowstyle='->', arrowsize=25) | |
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels) | |
node_labels = nx.get_node_attributes(G, 'label') # Get node labels | |
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_family="sans-serif") | |
plt.axis('off') | |
plt.tight_layout() | |
plt.savefig('synthnet.png') | |
plt.close() | |
return 'synthnet.png' | |
demo = gr.Interface( | |
inputs=[ | |
gr.Slider(minimum=1, maximum=256, label="Number of Samples"), | |
gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"), | |
], | |
fn=sampling, | |
outputs=gr.Image(type="filepath"), | |
title="SynthNet", | |
description="Select a number of samples and associations to each sample to generate a graph.", | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |