synthnet / app.py
missvector's picture
Update space
1f58b80
raw
history blame
5.49 kB
import gradio as gr
from huggingface_hub import InferenceClient
import ast
from graphviz import Digraph
client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")
def sampling(num_samples, num_associations):
outputs = ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
"words: []\n"
f"{num_samples} samples in a list"
},
{"role": "user",
"content": f"synthesize {num_samples} random but widespread words for semantic modeling"},
],
response_format={
"type": "json",
"value": {
"properties": {
"words": {"type": "array", "items": {"type": "string"}},
}
}
},
stream=False,
max_tokens=1024,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content'])
fields = {}
for word in outputs['words']:
fields[word] = ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n'
'associations: []'
},
{"role": "user",
"content": f"synthesize {num_associations} associations for the word {word}"},
],
response_format={
"type": "json",
"value": {
"properties": {
"associations": {"type": "array", "items": {"type": "string"}}
}
}
},
stream=False,
max_tokens=2000,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content'])
triplets = []
for cluster in fields:
for association in fields[cluster]['associations']:
triplets.append(ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
"properties: [subject, predicate, object]\n"
"use chain-of-thought for predictions"
},
{"role": "user",
"content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"},
],
response_format={
"type": "json",
"value": {
"properties": {
"properties": {"type": "array", "items": {"type": "string"}}
}
}
},
stream=False,
max_tokens=128,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content']))
dot = Digraph(comment=f'SynthNet, {num_samples} samples, {num_associations} associations', graph_attr={'rankdir': 'LR'})
for entry in triplets:
source, label, target = entry['properties']
dot.node(source, source)
dot.node(target, target)
dot.edge(source, target, label=label)
dot.render('synthnet', format='png')
return 'synthnet.png'
demo = gr.Interface(
inputs=[
gr.Slider(minimum=1, maximum=256, label="Number of Samples"),
gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"),
],
fn=sampling,
outputs=gr.Image(type="filepath"),
title="SynthNet",
description="Select a number of samples and associations to each sample to generate a graph.",
)
if __name__ == "__main__":
demo.launch()