File size: 5,493 Bytes
7ce5a68
 
608d7cd
 
7ce5a68
608d7cd
7ce5a68
608d7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce5a68
608d7cd
7ce5a68
608d7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce5a68
608d7cd
7ce5a68
608d7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce5a68
608d7cd
7ce5a68
608d7cd
 
 
 
 
7ce5a68
608d7cd
 
7ce5a68
6493d23
608d7cd
 
 
 
 
 
 
1f58b80
7ce5a68
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from huggingface_hub import InferenceClient
import ast
from graphviz import Digraph

client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")

def sampling(num_samples, num_associations):
  outputs = ast.literal_eval(client.chat.completions.create(
                             messages=[
                                 {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
                                 "words: []\n"
                                 f"{num_samples} samples in a list"
                                 },
                                 {"role": "user",
                                  "content": f"synthesize {num_samples} random but widespread words for semantic modeling"},
                                 ],
                                 response_format={
                                     "type": "json",
                                     "value": {
                                         "properties": {
                                             "words": {"type": "array", "items": {"type": "string"}},
                                             }
                                         }
                                     },
                             stream=False,
                             max_tokens=1024,
                             temperature=0.7,
                             top_p=0.1
                             ).choices[0].get('message')['content'])

  fields = {}

  for word in outputs['words']:
    fields[word] = ast.literal_eval(client.chat.completions.create(
                  messages=[
                      {"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n'
                                                    'associations: []'
                      },
                      {"role": "user",
                      "content": f"synthesize {num_associations} associations for the word {word}"},
                  ],
                  response_format={
                                     "type": "json",
                                     "value": {
                                         "properties": {
                                             "associations": {"type": "array", "items": {"type": "string"}}
                                             }
                                         }
                                     },
                  stream=False,
                  max_tokens=2000,
                  temperature=0.7,
                  top_p=0.1
                  ).choices[0].get('message')['content'])

  triplets = []

  for cluster in fields:
    for association in fields[cluster]['associations']:
      triplets.append(ast.literal_eval(client.chat.completions.create(
                                       messages=[
                                                 {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
                                                                               "properties: [subject, predicate, object]\n"
                                                                               "use chain-of-thought for predictions"
                                                 },
                                                 {"role": "user",
                                                  "content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"},
                                                 ],
                                                 response_format={
                                                                  "type": "json",
                                                                  "value": {
                                                                      "properties": {
                                                                                      "properties": {"type": "array", "items": {"type": "string"}}
                                                                                    }
                                                                      }
                                                                  },
                                       stream=False,
                                       max_tokens=128,
                                       temperature=0.7,
                                       top_p=0.1
                                       ).choices[0].get('message')['content']))

  dot = Digraph(comment=f'SynthNet, {num_samples} samples, {num_associations} associations', graph_attr={'rankdir': 'LR'})

  for entry in triplets:
    source, label, target = entry['properties']
    dot.node(source, source)
    dot.node(target, target)
    dot.edge(source, target, label=label)

  dot.render('synthnet', format='png')
  return 'synthnet.png'

demo = gr.Interface(
    inputs=[
        gr.Slider(minimum=1, maximum=256, label="Number of Samples"),
        gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"),
        ],
    fn=sampling,
    outputs=gr.Image(type="filepath"),
    title="SynthNet",
    description="Select a number of samples and associations to each sample to generate a graph.",
)


if __name__ == "__main__":
    demo.launch()