pdb2vector / app.py
simonduerr's picture
Update app.py
30f2e61 verified
raw
history blame
7.66 kB
import os
import json
import gradio as gr
from gradio_moleculeview import moleculeview
import cellscape
def html_output(input_file):
with open(input_file, "r") as f:
svg = f.read().replace("<svg", "<svg id='svgElement'")
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
</head>
<body>
"""
+ svg
+ """
<button id="copySvgBtn" style="padding:4px;border:1px solid gray;margin:3px">Copy SVG to Clipboard</button>
<button id="copyPngBtn" style="padding:4px;border:1px solid gray;margin:3px">Copy PNG to Clipboard</button>
<!-- Buttons for Download -->
<button id="downloadSvgBtn" style="padding:4px;border:1px solid gray;margin:3px">Download SVG</button>
<button id="downloadPngBtn" style="padding:4px;border:1px solid gray;margin:3px">Download PNG</button>
<script type="text/javascript">
function copySvgToClipboard() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const blob = new Blob([svgData], { type: 'image/svg+xml' });
const clipboardItem = [new ClipboardItem({ 'image/svg+xml': blob })];
navigator.clipboard.write(clipboardItem).then(() => {
alert("SVG copied to clipboard!");
}).catch(err => {
console.error("Could not copy SVG to clipboard: ", err);
});
}
// Function to convert SVG to PNG and copy to clipboard
function copyPngToClipboard() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = function () {
canvas.width = svgElement.clientWidth;
canvas.height = svgElement.clientHeight;
ctx.drawImage(img, 0, 0);
canvas.toBlob(blob => {
const clipboardItem = [new ClipboardItem({ 'image/png': blob })];
navigator.clipboard.write(clipboardItem).then(() => {
alert("PNG copied to clipboard!");
}).catch(err => {
console.error("Could not copy PNG to clipboard: ", err);
});
}, 'image/png');
};
img.src = 'data:image/svg+xml;base64,' + btoa(svgData);
}
// Function to download SVG
function downloadSvg() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const blob = new Blob([svgData], { type: 'image/svg+xml' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'image.svg';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}
// Function to download PNG
function downloadPng() {
const svgElement = document.getElementById('svgElement');
const svgData = new XMLSerializer().serializeToString(svgElement);
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = function () {
canvas.width = svgElement.clientWidth;
canvas.height = svgElement.clientHeight;
ctx.drawImage(img, 0, 0);
canvas.toBlob(blob => {
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'image.png';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
}, 'image/png');
};
img.src = 'data:image/svg+xml;base64,' + btoa(svgData);
}
// Button event listeners
document.getElementById('copySvgBtn').addEventListener('click', copySvgToClipboard);
document.getElementById('copyPngBtn').addEventListener('click', copyPngToClipboard);
document.getElementById('downloadSvgBtn').addEventListener('click', downloadSvg);
document.getElementById('downloadPngBtn').addEventListener('click', downloadPng);
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def predict(input_mol, style, contour_level, view_str, chains):
# write view to file
with open("view_matrix", "w") as f:
f.write(json.loads(view_str))
chain_str = ""
chain_dict = json.loads(chains)
# sort keys in dict and add colors to chain_str
for chain in sorted(chain_dict.keys()):
chain_str += f" '{chain_dict[chain]}'"
if style == "Goodsell3D":
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline residue --color_by chain --depth_shading --depth_lines --colors {chain_str} --depth flat --back_outline --view view_matrix --save outline_all.svg")
elif style == "Contour":
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline chain --color_by chain --depth_contour_interval {contour_level} --colors {chain_str} --depth contours --back_outline --view view_matrix --save outline_all.svg")
else:
os.system(f"cellscape cartoon --pdb {input_mol.name} --outline chain --colors {chain_str} --depth flat --back_outline --view view_matrix --save outline_all.svg")
#read content of file
print(os.stat("outline_all.svg").st_size / (1024 * 1024))
os.system("inkscape outline_all.svg --actions='select-all;path-simplify;export-plain-svg' --export-filename pdb_opt.svg")
print(os.stat("outline_all.svg").st_size / (1024 * 1024))
return html_output("outline_all.svg"), "pdb_opt.svg"
def show_contour_level(style):
if style=="Contour":
return gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=True)
else:
return gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=False)
with gr.Blocks() as demo:
gr.Markdown("# PDB2Vector")
style = gr.Radio(value="Flat", choices=["Flat", "Contour", "Goodsell3D"], label="Style")
contour_level = gr.Slider(minimum=1,maximum=50,step=1, value=10, label="Contour level", visible=False)
style.change(show_contour_level, style, contour_level)
inp = moleculeview(label="Molecule3D")
view_str = gr.Textbox("viewMatrixResult", label="View Matrix", visible=False)
chains = gr.Textbox("chainsResult", label="Chains", visible=False)
hidden_style = gr.Textbox(visible=False)
timestamp = gr.Textbox(visible=False)
btn = gr.Button("Vectorize")
html = gr.HTML("")
out_file = gr.File(label="Download SVG")
btn.click(None, style, [view_str, chains, hidden_style, timestamp], js="(style) => [document.getElementById('viewMatrixResult').value, document.getElementById('chains').value, style, Date.now()]") #
timestamp.change(predict, [inp, style, contour_level, view_str, chains], [html, out_file])
# on change of chains trigger, rendering
if __name__ == "__main__":
demo.launch()