Spaces:
Running
Running
File size: 7,664 Bytes
8adb95d 10b1be2 8adb95d 10b1be2 8adb95d 10b1be2 30f2e61 853bf92 30f2e61 3c00078 853bf92 3c00078 853bf92 30f2e61 853bf92 30f2e61 8adb95d 59e713c 660f406 59e713c 660f406 8adb95d 10b1be2 083f3dd 1365361 59e713c 8adb95d fef897d c5e4478 8adb95d 10b1be2 c5e4478 59e713c 8adb95d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import json
import gradio as gr
from gradio_moleculeview import moleculeview
import cellscape
def html_output(input_file):
with open(input_file, "r") as f:
svg = f.read().replace("<svg", "<svg id='svgElement'")
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
</head>
<body>
"""
+ svg
+ """
<button id="copySvgBtn" style="padding:4px;border:1px solid gray;margin:3px">Copy SVG to Clipboard</button>
<button id="copyPngBtn" style="padding:4px;border:1px solid gray;margin:3px">Copy PNG to Clipboard</button>
<!-- Buttons for Download -->
<button id="downloadSvgBtn" style="padding:4px;border:1px solid gray;margin:3px">Download SVG</button>
<button id="downloadPngBtn" style="padding:4px;border:1px solid gray;margin:3px">Download PNG</button>
<script type="text/javascript">
function copySvgToClipboard() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const blob = new Blob([svgData], { type: 'image/svg+xml' });
const clipboardItem = [new ClipboardItem({ 'image/svg+xml': blob })];
navigator.clipboard.write(clipboardItem).then(() => {
alert("SVG copied to clipboard!");
}).catch(err => {
console.error("Could not copy SVG to clipboard: ", err);
});
}
// Function to convert SVG to PNG and copy to clipboard
function copyPngToClipboard() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = function () {
canvas.width = svgElement.clientWidth;
canvas.height = svgElement.clientHeight;
ctx.drawImage(img, 0, 0);
canvas.toBlob(blob => {
const clipboardItem = [new ClipboardItem({ 'image/png': blob })];
navigator.clipboard.write(clipboardItem).then(() => {
alert("PNG copied to clipboard!");
}).catch(err => {
console.error("Could not copy PNG to clipboard: ", err);
});
}, 'image/png');
};
img.src = 'data:image/svg+xml;base64,' + btoa(svgData);
}
// Function to download SVG
function downloadSvg() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const blob = new Blob([svgData], { type: 'image/svg+xml' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'image.svg';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
// Function to download PNG
function downloadPng() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = function () {
canvas.width = svgElement.clientWidth;
canvas.height = svgElement.clientHeight;
ctx.drawImage(img, 0, 0);
canvas.toBlob(blob => {
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'image.png';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}, 'image/png');
};
img.src = 'data:image/svg+xml;base64,' + btoa(svgData);
}
// Button event listeners
document.getElementById('copySvgBtn').addEventListener('click', copySvgToClipboard);
document.getElementById('copyPngBtn').addEventListener('click', copyPngToClipboard);
document.getElementById('downloadSvgBtn').addEventListener('click', downloadSvg);
document.getElementById('downloadPngBtn').addEventListener('click', downloadPng);
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def predict(input_mol, style, contour_level, view_str, chains):
# write view to file
with open("view_matrix", "w") as f:
f.write(json.loads(view_str))
chain_str = ""
chain_dict = json.loads(chains)
# sort keys in dict and add colors to chain_str
for chain in sorted(chain_dict.keys()):
chain_str += f" '{chain_dict[chain]}'"
if style == "Goodsell3D":
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline residue --color_by chain --depth_shading --depth_lines --colors {chain_str} --depth flat --back_outline --view view_matrix --save outline_all.svg")
elif style == "Contour":
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline chain --color_by chain --depth_contour_interval {contour_level} --colors {chain_str} --depth contours --back_outline --view view_matrix --save outline_all.svg")
else:
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline chain --colors {chain_str} --depth flat --back_outline --view view_matrix --save outline_all.svg")
#read content of file
print(os.stat("outline_all.svg").st_size / (1024 * 1024))
os.system("inkscape outline_all.svg --actions='select-all;path-simplify;export-plain-svg' --export-filename pdb_opt.svg")
print(os.stat("outline_all.svg").st_size / (1024 * 1024))
return html_output("outline_all.svg"), "pdb_opt.svg"
def show_contour_level(style):
if style=="Contour":
return gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=True)
else:
return gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=False)
with gr.Blocks() as demo:
gr.Markdown("# PDB2Vector")
style = gr.Radio(value="Flat", choices=["Flat", "Contour", "Goodsell3D"], label="Style")
contour_level = gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=False)
style.change(show_contour_level, style, contour_level)
inp = moleculeview(label="Molecule3D")
view_str = gr.Textbox("viewMatrixResult", label="View Matrix", visible=False)
chains = gr.Textbox("chainsResult", label="Chains", visible=False)
hidden_style = gr.Textbox(visible=False)
timestamp = gr.Textbox(visible=False)
btn = gr.Button("Vectorize")
html = gr.HTML("")
out_file = gr.File(label="Download SVG")
btn.click(None, style, [view_str, chains, hidden_style, timestamp], js="(style) => [document.getElementById('viewMatrixResult').value, document.getElementById('chains').value, style, Date.now()]") #
timestamp.change(predict, [inp, style, contour_level, view_str, chains], [html, out_file])
# on change of chains trigger, rendering
if __name__ == "__main__":
demo.launch()
|