synthnet / app.py
missvector's picture
Update app.py
71f8a1b verified
import gradio as gr
from huggingface_hub import InferenceClient
import ast
import networkx as nx
import matplotlib.pyplot as plt
client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")
def sampling(num_samples, num_associations):
outputs = ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
"words: []\n"
f"{num_samples} samples in a list"
},
{"role": "user",
"content": f"synthesize {num_samples} random but widespread words for semantic modeling"},
],
response_format={
"type": "json",
"value": {
"properties": {
"words": {"type": "array", "items": {"type": "string"}},
}
}
},
stream=False,
max_tokens=1024,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content'])
fields = {}
for word in outputs['words']:
fields[word] = ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n'
'associations: []'
},
{"role": "user",
"content": f"synthesize {num_associations} associations for the word {word}"},
],
response_format={
"type": "json",
"value": {
"properties": {
"associations": {"type": "array", "items": {"type": "string"}}
}
}
},
stream=False,
max_tokens=2000,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content'])
triplets = []
for cluster in fields:
for association in fields[cluster]['associations']:
triplets.append(ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
"properties: [subject, predicate, object]\n"
"use chain-of-thought for predictions"
},
{"role": "user",
"content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"},
],
response_format={
"type": "json",
"value": {
"properties": {
"properties": {"type": "array", "items": {"type": "string"}}
}
}
},
stream=False,
max_tokens=128,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content']))
G = nx.DiGraph()
for entry in triplets:
source, label, target = entry['properties']
G.add_node(source, label=source)
G.add_node(target, label=target)
G.add_edge(source, target, label=label)
pos = nx.spring_layout(G)
nx.draw_networkx_nodes(G, pos, node_size=500, node_color='lightblue')
edge_labels = nx.get_edge_attributes(G, 'label') # Get edge labels
nx.draw_networkx_edges(G, pos, arrowstyle='->', arrowsize=25)
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
node_labels = nx.get_node_attributes(G, 'label') # Get node labels
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_family="sans-serif")
plt.axis('off')
plt.tight_layout()
plt.savefig('synthnet.png')
plt.close()
return 'synthnet.png'
demo = gr.Interface(
inputs=[
gr.Slider(minimum=1, maximum=256, label="Number of Samples"),
gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"),
],
fn=sampling,
outputs=gr.Image(type="filepath"),
title="SynthNet",
description="Select a number of samples and associations to each sample to generate a graph.",
)
if __name__ == "__main__":
demo.launch(share=True)