import gradio as gr from huggingface_hub import InferenceClient import ast from graphviz import Digraph client = InferenceClient("Qwen/Qwen2.5-72B-Instruct") def sampling(num_samples, num_associations): outputs = ast.literal_eval(client.chat.completions.create( messages=[ {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n" "words: []\n" f"{num_samples} samples in a list" }, {"role": "user", "content": f"synthesize {num_samples} random but widespread words for semantic modeling"}, ], response_format={ "type": "json", "value": { "properties": { "words": {"type": "array", "items": {"type": "string"}}, } } }, stream=False, max_tokens=1024, temperature=0.7, top_p=0.1 ).choices[0].get('message')['content']) fields = {} for word in outputs['words']: fields[word] = ast.literal_eval(client.chat.completions.create( messages=[ {"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n' 'associations: []' }, {"role": "user", "content": f"synthesize {num_associations} associations for the word {word}"}, ], response_format={ "type": "json", "value": { "properties": { "associations": {"type": "array", "items": {"type": "string"}} } } }, stream=False, max_tokens=2000, temperature=0.7, top_p=0.1 ).choices[0].get('message')['content']) triplets = [] for cluster in fields: for association in fields[cluster]['associations']: triplets.append(ast.literal_eval(client.chat.completions.create( messages=[ {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n" "properties: [subject, predicate, object]\n" "use chain-of-thought for predictions" }, {"role": "user", "content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"}, ], response_format={ "type": "json", "value": { "properties": { "properties": {"type": "array", "items": {"type": "string"}} } } }, stream=False, max_tokens=128, temperature=0.7, top_p=0.1 ).choices[0].get('message')['content'])) dot = Digraph(comment=f'SynthNet, {num_samples} samples, {num_associations} associations', graph_attr={'rankdir': 'LR'}) for entry in triplets: source, label, target = entry['properties'] dot.node(source, source) dot.node(target, target) dot.edge(source, target, label=label) dot.render('synthnet', format='png') return 'synthnet.png' demo = gr.Interface( inputs=[ gr.Slider(minimum=1, maximum=256, label="Number of Samples"), gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"), ], fn=sampling, outputs=gr.Image(type="filepath"), title="SynthNet", description="Select a number of samples and associations to each sample to generate a graph.", ) if __name__ == "__main__": demo.launch()