--- license: bsd-3-clause-clear --- # WAFFLE: Multi-Modal Model for Automated Front-End Development We develope WAFFLE, a fine-tuning approach to train multi-modal LLM (MLLM) to generate HTML code from webpage screenshots or UI designs. WAFFLE uses a structure-aware attention mechanism to improve MLLMs' understanding of HTML's structure and a contrastive fine-tuning approach to align MLLMs' understanding of UI images and HTML code. Models fine-tuned with WAFFLE show up to 9.00 pp (percentage point) higher HTML match, 0.0982 higher CW-SSIM, 32.99 higher CLIP, and 27.12 pp higher LLEM on our new benchmark WebSight-Test and an existing benchmark Design2Code. ## Updates: * 10/24/2024: Our preprint avaiable at: [arXiv](https://arxiv.org/abs/2410.18362), [huggingface](https://huggingface.co/papers/2410.18362) * 10/24/2024: Our code (keep maintaining) avaiable at: [code](https://github.com/lt-asset/Waffle) * 10/24/2024: Our fine-tuned Waffle_VLM_WebSight (7B), using DoRA, is released at: [lt-asset/Waffle_VLM_WebSight](https://huggingface.co/lt-asset/Waffle_VLM_WebSight) ## Dependency - peft 0.11.1 - transformers 4.41.1 - pytorch 2.3.0 - selenium - Python 3.10.14 - deepspeed 0.14.1 - datasets 2.19.1 - beautifulsoup4 4.12.3 - accelerate 0.30.1 ## Quick Start * Input UI design Find a webpage screenshot, or UI design: ![test-495.png](examples/test-495.png) * Run Waffle_VLM_WebSight ```python import torch from PIL import Image from transformers import AutoProcessor, AutoModelForCausalLM from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from transformers.image_transforms import resize, to_channel_dimension_format from utils import TreeBuilder def convert_to_rgb(image): if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite def inference_vlm_websight(image_path, html_path): def custom_transform(x): x = convert_to_rgb(x) x = to_numpy_array(x) x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) x = processor.image_processor.rescale(x, scale=1 / 255) x = processor.image_processor.normalize( x, mean=processor.image_processor.image_mean, std=processor.image_processor.image_std ) x = to_channel_dimension_format(x, ChannelDimension.FIRST) x = torch.tensor(x) return x model_dir = "lt-asset/Waffle_VLM_WebSight" processor = AutoProcessor.from_pretrained(model_dir) model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda() assert model.config.web_attention_range == 2, "Waffle_VLM_WebSight is trained with hierarchical attention applied to 2 / 8 heads" # use 2/8 = 1/4 attention heads for hierarchical attention (as described in paper) model.eval() image_seq_len = model.config.perceiver_config.resampler_n_latents BOS_TOKEN = processor.tokenizer.bos_token BAD_WORDS_IDS = processor.tokenizer(["", ""], add_special_tokens=False).input_ids image = Image.open(image_path) inputs = processor.tokenizer( f"{BOS_TOKEN}{'' * image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = processor.image_processor([image], transform=custom_transform).to(dtype=torch.bfloat16) inputs_for_generation = {k: v.cuda() for k, v in inputs.items()} inputs_for_generation["web_attention_mask"] = None inputs_for_generation["html_tree"] = TreeBuilder(processor.tokenizer) inputs_for_generation["html_tree"].web_attention_mask = inputs_for_generation["web_attention_mask"] generated_ids = model.generate( **inputs_for_generation, bad_words_ids=BAD_WORDS_IDS, max_length=2048, num_return_sequences=1, do_sample=False ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] with open(html_path, 'w') as wp: wp.write(generated_text) if __name__ == '__main__': inference_vlm_websight('examples/test-495.png', 'examples/example-495.html') ``` * Waffle_VLM_WebSight generated HTML code [example-495.html](examples/example-495.html) * Rendered Waffle_VLM_WebSight output Render the HTML, or preview the HTML to check the correctness: ![example-495.html](examples/example-495.png) ## Citation ``` @misc{liang2024wafflemultimodalmodelautomated, title={WAFFLE: Multi-Modal Model for Automated Front-End Development}, author={Shanchao Liang and Nan Jiang and Shangshu Qian and Lin Tan}, year={2024}, eprint={2410.18362}, archivePrefix={arXiv}, primaryClass={cs.SE}, url={https://arxiv.org/abs/2410.18362}, } ```