File size: 3,931 Bytes
331412c 039cd66 331412c 9c7dc56 331412c 039cd66 9c7dc56 039cd66 9c7dc56 039cd66 9c7dc56 331412c 9c7dc56 331412c 9c7dc56 039cd66 9c7dc56 039cd66 331412c bbd4f95 331412c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
import torch
EXAMPLE_MD = """
```python
import torch
t1 = torch.arange({n1}).view({dim1})
t2 = torch.arange({n2}).view({dim2})
(t1 @ t2).shape = {out_shape}
```
"""
def generate_example(dim1: list, dim2: list):
n1 = 1
n2 = 1
for i in dim1:
n1 *= i
for i in dim2:
n2 *= i
t1 = torch.arange(n1).view(dim1)
t2 = torch.arange(n2).view(dim2)
try:
out_shape = list((t1 @ t2).shape)
except RuntimeError:
out_shape = "error"
code = EXAMPLE_MD.format(
n1=str(n1), dim1=str(dim1), n2=str(n2), dim2=str(dim2), out_shape=str(out_shape)
)
return dim1, dim2, code
def sanitize_dimention(dim):
if dim is None:
gr.Error("one of the dimentions is empty, please fill it")
if "[" in dim:
dim = dim.replace("[", "")
if "]" in dim:
dim = dim.replace("]", "")
if "," in dim:
dim = dim.replace(",", " ").strip()
out = [int(i.strip()) for i in dim.split()]
else:
out = [int(dim.strip())]
if 0 in out:
gr.Error(
"Found the number 0 in one of the dimensions which is not allowed, consider using 1 instead"
)
return out
def create_row(dim):
out = "| "
for i in dim:
out = out + str(i) + " | "
return out + "\n"
def create_header(n_dim, checks=None):
checks = ["<!-- -->"] * n_dim if checks is None else checks
out = "| "
for i in checks:
out = out + i + " | "
out += "\n" + "|---" * n_dim + "|\n"
return out
def generate_table(dim1, dim2, checks=None):
n_dim = len(dim1)
table = create_header(n_dim, checks)
# tensor 1
table += create_row(dim1)
# tensor 2
table += create_row(dim2)
return table
def alignment_and_fill_with_ones(dim1, dim2):
n_dim = max(len(dim1), len(dim2))
if len(dim1) == len(dim2):
pass
elif len(dim1) < len(dim2):
placeholder = [1] * (n_dim - len(dim1))
placeholder.extend(dim1)
dim1 = placeholder
else:
placeholder = [1] * (n_dim - len(dim2))
placeholder.extend(dim2)
dim2 = placeholder
return dim1, dim2
def check_validity(dim1,dim2):
if len(dim1) < 2:
return ["WIP"] * len(dim1)
out = []
for i in range(len(dim1)-2):
if dim1[i] == dim2[i]:
out.append("V")
else :
out.append("X")
# final dims
if dim1[-1] == dim2[-2]:
out.extend(["V","V"])
else :
out.extend(["X","X"])
return out
def substitute_ones_with_concat(dim1,dim2):
for i in range(len(dim1)-2):
dim1[i] = dim2[i] if dim1[i] == 1 else dim1[i]
dim2[i] = dim1[i] if dim2[i] == 1 else dim2[i]
return dim1, dim2
def predict(dim1, dim2):
dim1 = sanitize_dimention(dim1)
dim2 = sanitize_dimention(dim2)
dim1, dim2, code = generate_example(dim1, dim2)
# TODO
# fix for dims if one or both have dimensions is 1
# Table 1
dim1, dim2 = alignment_and_fill_with_ones(dim1, dim2)
table1 = generate_table(dim1, dim2)
# Table 2
dim1, dim2 = substitute_ones_with_concat(dim1,dim2)
table2 = generate_table(dim1, dim2)
# Table 3
checks = check_validity(dim1,dim2)
table3 = generate_table(dim1,dim2,checks)
out = code
out += "\n# Step1 (alignment and pre_append with ones)\n" + table1
out += "\n# Step2 (susbtitute columns that have 1 with concat)\nexcept for last 2 dimensions\n" + table2
out += "\n# Step3 (check if matrix multiplication is valid)\n"
out += "* last dimension of dim1 should equal before last dimension of dim2\n"
out += "* all the other dimensions should be equal to one another\n\n" + table3
return out
demo = gr.Interface(
predict,
inputs=["text", "text"],
outputs=["markdown"],
examples=[["9,2,1,3,3", "5,3,7"], ["1,2,3", "5,2,7"]],
)
demo.launch(debug=True)
|