diff --git a/.gitignore b/.gitignore index 26a04150097f82bd7679e3b44bf41d2f48093f1b..4daec0c7768118aa00c7825720fb88053c8d1637 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +unsloth_compiled_cache/ +*.ipynb star-vector/ SVGDreamer/ *.parquet diff --git a/dl.py b/dl.py new file mode 100644 index 0000000000000000000000000000000000000000..710edf3234d4a22037ff48a50ebfe8f420c12240 --- /dev/null +++ b/dl.py @@ -0,0 +1,203 @@ +import logging +import re +import cairosvg +import torch +from transformers import AutoModelForCausalLM +from lxml import etree +import kagglehub +from gen_image import ImageGenerator +from starvector.data.util import process_and_rasterize_svg + +svg_constraints = kagglehub.package_import('metric/svg-constraints') + +class DLModel: + def __init__(self, model_id="starvector/starvector-8b-im2svg", device="cuda"): + """ + Initialize the SVG generation pipeline using StarVector. + + Args: + model_id (str): The model identifier for the StarVector model. + device (str): The device to run the model on, either "cuda" or "cpu". + """ + self.image_generator = ImageGenerator(model_id="stabilityai/stable-diffusion-2-1-base", device=device) + self.default_svg = """""" + self.constraints = svg_constraints.SVGConstraints() + self.timeout_seconds = 90 + + # Load StarVector model + self.device = device + self.starvector = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + trust_remote_code=True + ) + self.processor = self.starvector.model.processor + self.starvector.to(device) + self.starvector.eval() + + def predict(self, description): + """ + Generate an SVG from a text description. + + Args: + description (str): The text description to generate an image from. + + Returns: + str: The generated SVG content. + """ + try: + # Step 1: Generate image using diffusion model + images = self.image_generator.generate(description) + image = images[0] + + # Save the generated image + image_path = "diff_image.png" + image.save(image_path) + logging.info(f"Intermediate image saved to {image_path}") + + # Step 2: Convert image to SVG using StarVector + processed_image = self.processor(image, return_tensors="pt")['pixel_values'].to(self.device) + if not processed_image.shape[0] == 1: + processed_image = processed_image.squeeze(0) + + batch = {"image": processed_image} + with torch.no_grad(): + raw_svg = self.starvector.generate_im2svg(batch, max_length=4000)[0] + raw_svg, _ = process_and_rasterize_svg(raw_svg) + + if 'viewBox' not in raw_svg: + raw_svg = raw_svg.replace(' str: + """Enforces constraints on an SVG string, removing disallowed elements + and attributes. + + Parameters + ---------- + svg_string : str + The SVG string to process. + + Returns + ------- + str + The processed SVG string, or the default SVG if constraints + cannot be satisfied. + """ + logging.info('Sanitizing SVG...') + + try: + # Remove XML declaration if it exists + svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip() + + parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) + root = etree.fromstring(svg_string, parser=parser) + except etree.ParseError as e: + logging.error('SVG Parse Error: %s. Returning default SVG.', e) + logging.error('SVG string: %s', svg_string) + return self.default_svg + + elements_to_remove = [] + for element in root.iter(): + tag_name = etree.QName(element.tag).localname + + # Remove disallowed elements + if tag_name not in self.constraints.allowed_elements: + elements_to_remove.append(element) + continue # Skip attribute checks for removed elements + + # Remove disallowed attributes + attrs_to_remove = [] + for attr in element.attrib: + attr_name = etree.QName(attr).localname + if ( + attr_name + not in self.constraints.allowed_elements[tag_name] + and attr_name + not in self.constraints.allowed_elements['common'] + ): + attrs_to_remove.append(attr) + + for attr in attrs_to_remove: + logging.debug( + 'Attribute "%s" for element "%s" not allowed. Removing.', + attr, + tag_name, + ) + del element.attrib[attr] + + # Check and remove invalid href attributes + for attr, value in element.attrib.items(): + if etree.QName(attr).localname == 'href' and not value.startswith('#'): + logging.debug( + 'Removing invalid href attribute in element "%s".', tag_name + ) + del element.attrib[attr] + + # Validate path elements to help ensure SVG conversion + if tag_name == 'path': + d_attribute = element.get('d') + if not d_attribute: + logging.warning('Path element is missing "d" attribute. Removing path.') + elements_to_remove.append(element) + continue # Skip further checks for this removed element + # Use regex to validate 'd' attribute format + path_regex = re.compile( + r'^' # Start of string + r'(?:' # Non-capturing group for each command + numbers block + r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) + r'\s*' # Optional whitespace after command + r'(?:' # Non-capturing group for optional numbers + r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number + r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) + r')?' # Numbers are optional (e.g. for Z command) + r'\s*' # Optional whitespace after numbers/command block + r')+' # One or more command blocks + r'\s*' # Optional trailing whitespace + r'$' # End of string + ) + if not path_regex.match(d_attribute): + logging.warning( + 'Path element has malformed "d" attribute format. Removing path.' + ) + elements_to_remove.append(element) + continue + logging.debug('Path element "d" attribute validated (regex check).') + + # Remove elements marked for removal + for element in elements_to_remove: + if element.getparent() is not None: + element.getparent().remove(element) + logging.debug('Removed element: %s', element.tag) + + try: + cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False) + return cleaned_svg_string + except ValueError as e: + logging.error( + 'SVG could not be sanitized to meet constraints: %s', e + ) + return self.default_svg + +# Example usage +if __name__ == "__main__": + model = DLModel() + svg = model.predict("a purple forest at dusk") + # Convert SVG to PNG + try: + # Create a PNG in memory + png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) + + # Save the PNG to a file + with open("output.png", "wb") as f: + f.write(png_data) + print("SVG saved as output.png") + except Exception as e: + print(f"Error converting SVG to PNG: {e}") \ No newline at end of file diff --git a/gen_image.py b/gen_image.py index 7b229a87dce891d163c7a0dd6f2898056b9636ff..f82f93acecef8e0630fa201e1e6d69365ec04a2c 100644 --- a/gen_image.py +++ b/gen_image.py @@ -35,7 +35,7 @@ class ImageGenerator: num_images (int, optional): Number of images to generate. Returns: - PIL.Image.Image: The generated image. + list[PIL.Image.Image]: The generated images. """ prompt = f"{prompt}, {self.positive_prompt}" if negative_prompt is None: @@ -51,7 +51,7 @@ class ImageGenerator: for i, image in enumerate(images): image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}") - return image + return images # Example usage if __name__ == "__main__": diff --git a/ml.py b/ml.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a8c49cd84fdd59f9138fefb8e8d88c04f2c9a3 --- /dev/null +++ b/ml.py @@ -0,0 +1,246 @@ +import os +import tempfile +import logging +import re +import subprocess +import cairosvg +from lxml import etree +import kagglehub +from gen_image import ImageGenerator +import vtracer + +svg_constraints = kagglehub.package_import('metric/svg-constraints') + +class MLModel: + def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"): + """ + Initialize the SVG generation pipeline. + + Args: + model_id (str): The model identifier for the stable diffusion model. + device (str): The device to run the model on, either "cuda" or "cpu". + """ + self.image_generator = ImageGenerator(model_id=model_id, device=device) + self.default_svg = """""" + self.constraints = svg_constraints.SVGConstraints() + self.timeout_seconds = 90 + + def predict(self, description, simplify=True, color_precision=6, + gradient_step=10, filter_speckle=4, path_precision=8): + """ + Generate an SVG from a text description. + + Args: + description (str): The text description to generate an image from. + simplify (bool): Whether to simplify the SVG paths. + color_precision (int): Color quantization precision. + gradient_step (int): Gradient step for color quantization (not used by vtracer). + filter_speckle (int): Filter speckle size. + path_precision (int): Path fitting precision. + + Returns: + str: The generated SVG content. + """ + try: + # Step 1: Generate image using diffusion model + images = self.image_generator.generate(description) + image = images[0] + + # Step 2: Save image to a temporary file + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img: + temp_img_path = temp_img.name + image.save(temp_img_path) + + # Step 3: Convert image to SVG using vtracer + with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg: + temp_svg_path = temp_svg.name + + # Process the image with vtracer using parameters directly + vtracer.convert_image_to_svg_py( + temp_img_path, + temp_svg_path, + colormode='color', + hierarchical='stacked' if simplify else 'cutout', + mode='spline', + filter_speckle=filter_speckle, + color_precision=color_precision, + path_precision=path_precision, + corner_threshold=60, + length_threshold=4.0, + max_iterations=10, + splice_threshold=45 + ) + + # Step 4: Read the generated SVG + with open(temp_svg_path, 'r') as f: + svg_content = f.read() + + # Clean up temporary files + os.unlink(temp_img_path) + os.unlink(temp_svg_path) + + # Step 5: Enforce constraints + svg_content = self.enforce_constraints(svg_content) + + return svg_content + except Exception as e: + logging.error(f"Error generating SVG: {e}") + return self.default_svg + + def enforce_constraints(self, svg_string: str) -> str: + """Enforces constraints on an SVG string, removing disallowed elements + and attributes. + + Parameters + ---------- + svg_string : str + The SVG string to process. + + Returns + ------- + str + The processed SVG string, or the default SVG if constraints + cannot be satisfied. + """ + logging.info('Sanitizing SVG...') + + try: + # Remove XML declaration if it exists + svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip() + + parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) + root = etree.fromstring(svg_string, parser=parser) + except etree.ParseError as e: + logging.error('SVG Parse Error: %s. Returning default SVG.', e) + logging.error('SVG string: %s', svg_string) + return self.default_svg + + elements_to_remove = [] + for element in root.iter(): + tag_name = etree.QName(element.tag).localname + + # Remove disallowed elements + if tag_name not in self.constraints.allowed_elements: + elements_to_remove.append(element) + continue # Skip attribute checks for removed elements + + # Remove disallowed attributes + attrs_to_remove = [] + for attr in element.attrib: + attr_name = etree.QName(attr).localname + if ( + attr_name + not in self.constraints.allowed_elements[tag_name] + and attr_name + not in self.constraints.allowed_elements['common'] + ): + attrs_to_remove.append(attr) + + for attr in attrs_to_remove: + logging.debug( + 'Attribute "%s" for element "%s" not allowed. Removing.', + attr, + tag_name, + ) + del element.attrib[attr] + + # Check and remove invalid href attributes + for attr, value in element.attrib.items(): + if etree.QName(attr).localname == 'href' and not value.startswith('#'): + logging.debug( + 'Removing invalid href attribute in element "%s".', tag_name + ) + del element.attrib[attr] + + # Validate path elements to help ensure SVG conversion + if tag_name == 'path': + d_attribute = element.get('d') + if not d_attribute: + logging.warning('Path element is missing "d" attribute. Removing path.') + elements_to_remove.append(element) + continue # Skip further checks for this removed element + # Use regex to validate 'd' attribute format + path_regex = re.compile( + r'^' # Start of string + r'(?:' # Non-capturing group for each command + numbers block + r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) + r'\s*' # Optional whitespace after command + r'(?:' # Non-capturing group for optional numbers + r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number + r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) + r')?' # Numbers are optional (e.g. for Z command) + r'\s*' # Optional whitespace after numbers/command block + r')+' # One or more command blocks + r'\s*' # Optional trailing whitespace + r'$' # End of string + ) + if not path_regex.match(d_attribute): + logging.warning( + 'Path element has malformed "d" attribute format. Removing path.' + ) + elements_to_remove.append(element) + continue + logging.debug('Path element "d" attribute validated (regex check).') + + # Remove elements marked for removal + for element in elements_to_remove: + if element.getparent() is not None: + element.getparent().remove(element) + logging.debug('Removed element: %s', element.tag) + + try: + cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False) + return cleaned_svg_string + except ValueError as e: + logging.error( + 'SVG could not be sanitized to meet constraints: %s', e + ) + return self.default_svg + + def optimize_svg(self, svg_content): + """ + Optimize the SVG content using SVGO. + + Args: + svg_content (str): The SVG content to optimize. + + Returns: + str: The optimized SVG content. + """ + try: + with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg: + temp_svg_path = temp_svg.name + temp_svg.write(svg_content.encode('utf-8')) + + with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_out: + temp_out_path = temp_out.name + + subprocess.run(["svgo", temp_svg_path, "-o", temp_out_path], check=True) + + with open(temp_out_path, 'r') as f: + optimized_svg = f.read() + + os.unlink(temp_svg_path) + os.unlink(temp_out_path) + + return optimized_svg + except (FileNotFoundError, subprocess.CalledProcessError): + print("Warning: SVGO not found or failed. Returning unoptimized SVG.") + return svg_content + + +# Example usage +if __name__ == "__main__": + model = MLModel() + svg = model.predict("a purple forest at dusk") + # Convert SVG to PNG + try: + # Create a PNG in memory + png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) + + # Save the PNG to a file + with open("output.png", "wb") as f: + f.write(png_data) + print("SVG saved as output.png") + except Exception as e: + print(f"Error converting SVG to PNG: {e}") \ No newline at end of file diff --git a/naive.py b/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..c796f8817708d34cfc3ddcaa0f50ee3b82d13690 --- /dev/null +++ b/naive.py @@ -0,0 +1,246 @@ +import concurrent +import io +import logging +import re + +import cairosvg +import kagglehub +import torch +from lxml import etree +from unsloth import FastLanguageModel +from unsloth.chat_templates import get_chat_template + +svg_constraints = kagglehub.package_import('metric/svg-constraints') + +class NaiveModel: + def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"): + self.device = device + self.max_seq_length = max_seq_length + self.load_in_4bit = True + + # Load the Unsloth Phi-4 model + self.model, self.tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=self.max_seq_length, + load_in_4bit=self.load_in_4bit + ) + + # Set up chat template + self.tokenizer = get_chat_template( + self.tokenizer, + chat_template="phi-4", + ) + + # Prepare model for inference + FastLanguageModel.for_inference(self.model) + + self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints. + +* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs` +* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity` + + +Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis. + +"A red circle with a blue square inside" +```svg + + + + +``` + +"{}" +""" + self.default_svg = """""" + self.constraints = svg_constraints.SVGConstraints() + self.timeout_seconds = 90 + + def predict(self, description: str, max_new_tokens=512) -> str: + def generate_svg(): + try: + # Format the prompt + prompt = self.prompt_template.format(description) + + # Create messages in the format expected by the chat template + messages = [ + {"role": "user", "content": prompt}, + ] + + # Tokenize the messages + inputs = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device) + + # Generate the output + outputs = self.model.generate( + input_ids=inputs, + max_new_tokens=max_new_tokens, + use_cache=True, + temperature=1.0, + min_p=0.1, + do_sample=True, + ) + + # Decode the output + output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Extract only the generated text (skip the prompt) + generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else "" + + logging.debug('Output decoded from model: %s', output_decoded) + + matches = re.findall(r"", output_decoded, re.DOTALL | re.IGNORECASE) + if matches: + svg = matches[-1] + else: + return self.default_svg + + logging.debug('Unprocessed SVG: %s', svg) + svg = self.enforce_constraints(svg) + logging.debug('Processed SVG: %s', svg) + + # Ensure the generated code can be converted by cairosvg + cairosvg.svg2png(bytestring=svg.encode('utf-8')) + return svg + except Exception as e: + logging.error('Exception during SVG generation: %s', e) + return self.default_svg + + # Execute SVG generation in a new thread to enforce time constraints + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(generate_svg) + try: + return future.result(timeout=self.timeout_seconds) + except concurrent.futures.TimeoutError: + logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds) + return self.default_svg + except Exception as e: + logging.error(f"An unexpected error occurred: {e}") + return self.default_svg + + def enforce_constraints(self, svg_string: str) -> str: + """Enforces constraints on an SVG string, removing disallowed elements + and attributes. + + Parameters + ---------- + svg_string : str + The SVG string to process. + + Returns + ------- + str + The processed SVG string, or the default SVG if constraints + cannot be satisfied. + """ + logging.info('Sanitizing SVG...') + + try: + parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) + root = etree.fromstring(svg_string, parser=parser) + except etree.ParseError as e: + logging.error('SVG Parse Error: %s. Returning default SVG.', e) + logging.error('SVG string: %s', svg_string) + return self.default_svg + + elements_to_remove = [] + for element in root.iter(): + tag_name = etree.QName(element.tag).localname + + # Remove disallowed elements + if tag_name not in self.constraints.allowed_elements: + elements_to_remove.append(element) + continue # Skip attribute checks for removed elements + + # Remove disallowed attributes + attrs_to_remove = [] + for attr in element.attrib: + attr_name = etree.QName(attr).localname + if ( + attr_name + not in self.constraints.allowed_elements[tag_name] + and attr_name + not in self.constraints.allowed_elements['common'] + ): + attrs_to_remove.append(attr) + + for attr in attrs_to_remove: + logging.debug( + 'Attribute "%s" for element "%s" not allowed. Removing.', + attr, + tag_name, + ) + del element.attrib[attr] + + # Check and remove invalid href attributes + for attr, value in element.attrib.items(): + if etree.QName(attr).localname == 'href' and not value.startswith('#'): + logging.debug( + 'Removing invalid href attribute in element "%s".', tag_name + ) + del element.attrib[attr] + + # Validate path elements to help ensure SVG conversion + if tag_name == 'path': + d_attribute = element.get('d') + if not d_attribute: + logging.warning('Path element is missing "d" attribute. Removing path.') + elements_to_remove.append(element) + continue # Skip further checks for this removed element + # Use regex to validate 'd' attribute format + path_regex = re.compile( + r'^' # Start of string + r'(?:' # Non-capturing group for each command + numbers block + r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) + r'\s*' # Optional whitespace after command + r'(?:' # Non-capturing group for optional numbers + r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number + r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) + r')?' # Numbers are optional (e.g. for Z command) + r'\s*' # Optional whitespace after numbers/command block + r')+' # One or more command blocks + r'\s*' # Optional trailing whitespace + r'$' # End of string + ) + if not path_regex.match(d_attribute): + logging.warning( + 'Path element has malformed "d" attribute format. Removing path.' + ) + elements_to_remove.append(element) + continue + logging.debug('Path element "d" attribute validated (regex check).') + + # Remove elements marked for removal + for element in elements_to_remove: + if element.getparent() is not None: + element.getparent().remove(element) + logging.debug('Removed element: %s', element.tag) + + try: + cleaned_svg_string = etree.tostring(root, encoding='unicode') + return cleaned_svg_string + except ValueError as e: + logging.error( + 'SVG could not be sanitized to meet constraints: %s', e + ) + return self.default_svg + + +if __name__ == "__main__": + model = NaiveModel() + svg = model.predict("a purple forest at dusk") + # Convert SVG to PNG + try: + # Create a PNG in memory + png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) + + # Save the PNG to a file + with open("output.png", "wb") as f: + f.write(png_data) + print("SVG saved as output.png") + except Exception as e: + print(f"Error converting SVG to PNG: {e}") diff --git a/requirements.txt b/requirements.txt index e78cc67e8bb612da2b172090faaec78b5a629bd0..e26d6df9c3e216a64f43b3d1f999200783c96faf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,12 @@ dotenv diffusers safetensors xformers +unsloth +tf-keras +vtracer +deepspeed +torch==2.5.1 +torchvision==0.20.1 # pip install 'tensorflow[and-cuda]' # pip install git+https://github.com/openai/CLIP.git diff --git a/starter.ipynb b/starter.ipynb deleted file mode 100644 index f227308e22d2681614faf4cece853e1d0ee3a7c4..0000000000000000000000000000000000000000 --- a/starter.ipynb +++ /dev/null @@ -1,333 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/user/miniconda3/envs/dwl/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "shape: (5, 2)
iddescription
strstr
"02d892""a purple forest at dusk"
"0dcd2e""gray wool coat with a faux fur…
"1e9ac1""a lighthouse overlooking the o…
"2b25db""burgundy corduroy pants with p…
"4e6a54""orange corduroy overalls"
" - ], - "text/plain": [ - "shape: (5, 2)\n", - "┌────────┬─────────────────────────────────┐\n", - "│ id ┆ description │\n", - "│ --- ┆ --- │\n", - "│ str ┆ str │\n", - "╞════════╪═════════════════════════════════╡\n", - "│ 02d892 ┆ a purple forest at dusk │\n", - "│ 0dcd2e ┆ gray wool coat with a faux fur… │\n", - "│ 1e9ac1 ┆ a lighthouse overlooking the o… │\n", - "│ 2b25db ┆ burgundy corduroy pants with p… │\n", - "│ 4e6a54 ┆ orange corduroy overalls │\n", - "└────────┴─────────────────────────────────┘" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# We can load and explore the competition's train set to get a feel for the data.\n", - "# We're not going to export this cell as it's not needed for our exported inferenceable model.\n", - "\n", - "import kagglehub\n", - "import polars as pl\n", - "\n", - "train_path = kagglehub.competition_download('drawing-with-llms', 'train.csv')\n", - "train = pl.read_csv(train_path)\n", - "\n", - "train.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "class Model:\n", - " def __init__(self):\n", - " '''Optional constructor, performs any setup logic, model instantiation, etc.'''\n", - " pass\n", - " \n", - " def predict(self, prompt: str) -> str:\n", - " '''Generates SVG which produces an image described by the prompt.\n", - "\n", - " Args:\n", - " prompt (str): A prompt describing an image\n", - " Returns:\n", - " String of valid SVG code.\n", - " '''\n", - " # Renders a simple circle regardless of input\n", - " return ''" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "image/svg+xml": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import SVG\n", - "\n", - "model = Model()\n", - "svg = model.predict('a goose winning a gold medal')\n", - "\n", - "print(svg)\n", - "display(SVG(svg))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['RN50',\n", - " 'RN101',\n", - " 'RN50x4',\n", - " 'RN50x16',\n", - " 'RN50x64',\n", - " 'ViT-B/32',\n", - " 'ViT-B/16',\n", - " 'ViT-L/14',\n", - " 'ViT-L/14@336px']" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import clip\n", - "clip.available_models()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-04-20 13:55:34.589770: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "E0000 00:00:1745171734.600777 13214 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "E0000 00:00:1745171734.603957 13214 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "W0000 00:00:1745171734.615566 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1745171734.615584 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1745171734.615585 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1745171734.615586 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "2025-04-20 13:55:34.618659: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n", - "Loading checkpoint shards: 100%|██████████| 4/4 [00:18<00:00, 4.68s/it]\n" - ] - } - ], - "source": [ - "import pandas as pd\n", - "import importlib\n", - "metric = importlib.import_module('metric')\n", - "importlib.reload(metric)\n", - "\n", - "vqa_evaluator = metric.VQAEvaluator()\n", - "aesthetic_evaluator = metric.AestheticEvaluator()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "VQA Score: 0.9996758976500401\n", - "Aesthetic Score: 0.5749330520629883\n", - "Final Fidelity Score: 0.8709845773271212\n" - ] - } - ], - "source": [ - "# score gpt4o generated images\n", - "import ast\n", - "import numpy as np\n", - "from PIL import Image\n", - "\n", - "# Load the first sample from descriptions.csv\n", - "descriptions_df = pd.read_csv('data/descriptions.csv')\n", - "first_description = descriptions_df.iloc[1]\n", - "\n", - "eval_df = pd.read_csv('data/eval.csv')\n", - "first_eval = eval_df.iloc[1]\n", - "\n", - "# Load the image\n", - "image_path = 'data/gray_coat.png' # Assuming the image is saved with this name\n", - "image = Image.open(image_path)\n", - "\n", - "# Prepare the inputs for scoring - need to parse the string representations\n", - "questions = ast.literal_eval(first_eval['question'])\n", - "choices = ast.literal_eval(first_eval['choices'])\n", - "answers = ast.literal_eval(first_eval['answer'])\n", - "\n", - "# Calculate VQA score - don't wrap in additional lists\n", - "vqa_score = vqa_evaluator.score(questions, choices, answers, image)\n", - "\n", - "# Calculate aesthetic score\n", - "aesthetic_score = aesthetic_evaluator.score(image)\n", - "\n", - "# Apply image processing as done in the metric.score function\n", - "image_processor = metric.ImageProcessor(image=image, seed=0).apply()\n", - "processed_image = image_processor.image.copy()\n", - "\n", - "# Calculate final fidelity score\n", - "instance_score = metric.harmonic_mean(vqa_score, aesthetic_score, beta=0.5)\n", - "\n", - "print(f\"VQA Score: {vqa_score}\")\n", - "print(f\"Aesthetic Score: {aesthetic_score}\")\n", - "print(f\"Final Fidelity Score: {instance_score}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No duplicate IDs found in data/descriptions.csv\n", - "Sorted rows by ID\n", - "Fixed and sorted CSV saved to data/descriptions.csv\n", - "No duplicate IDs found in data/eval.csv\n", - "Sorted data/eval.csv by ID\n" - ] - } - ], - "source": [ - "# Fix duplicate IDs in descriptions.csv and order rows by id\n", - "def fix_duplicate_ids(csv_path):\n", - " \"\"\"\n", - " Fix duplicate IDs in a CSV file by assigning new unique IDs to duplicates.\n", - " Then order rows by ID.\n", - " \"\"\"\n", - " # Read the CSV file\n", - " df = pd.read_csv(csv_path)\n", - " \n", - " # Check for duplicate IDs\n", - " duplicate_mask = df['id'].duplicated(keep='first')\n", - " duplicate_count = duplicate_mask.sum()\n", - " \n", - " if duplicate_count > 0:\n", - " print(f\"Found {duplicate_count} duplicate IDs in {csv_path}\")\n", - " \n", - " # Get the maximum ID value\n", - " max_id = df['id'].max()\n", - " \n", - " # Assign new IDs to duplicates\n", - " new_ids = list(range(max_id + 1, max_id + 1 + duplicate_count))\n", - " df.loc[duplicate_mask, 'id'] = new_ids\n", - " \n", - " print(f\"Assigned new IDs to duplicates\")\n", - " else:\n", - " print(f\"No duplicate IDs found in {csv_path}\")\n", - " \n", - " # Sort the dataframe by ID\n", - " df = df.sort_values(by='id')\n", - " print(f\"Sorted rows by ID\")\n", - " \n", - " # Save the fixed and sorted CSV\n", - " df.to_csv(csv_path, index=False)\n", - " print(f\"Fixed and sorted CSV saved to {csv_path}\")\n", - " \n", - " # Return the fixed dataframe\n", - " return df\n", - "\n", - "# Fix descriptions.csv\n", - "fixed_descriptions_df = fix_duplicate_ids('data/descriptions.csv')\n", - "\n", - "# Fix eval.csv if needed\n", - "# First check if eval.csv has the same issue\n", - "eval_df = pd.read_csv('data/eval.csv')\n", - "duplicate_eval_ids = eval_df['id'].duplicated(keep='first').sum()\n", - "\n", - "if duplicate_eval_ids > 0:\n", - " fixed_eval_df = fix_duplicate_ids('data/eval.csv')\n", - "else:\n", - " print(\"No duplicate IDs found in data/eval.csv\")\n", - " # Still sort by ID even if no duplicates\n", - " eval_df = eval_df.sort_values(by='id')\n", - " eval_df.to_csv('data/eval.csv', index=False)\n", - " print(\"Sorted data/eval.csv by ID\")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "dwl", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/starvector/__init__.py b/starvector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/starvector/adapter.py b/starvector/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5a0fe2de0a98472f67576a0fa32c47c68dedff --- /dev/null +++ b/starvector/adapter.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch.nn.init as init +import torch + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + +class Adapter(nn.Module): + def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): + super().__init__() + self.query_length = query_length + self.dropout_prob = dropout_prob + self.adapter_norm = adapter_norm + + self.dropout = nn.Dropout(p=self.dropout_prob) + + self.c_fc = nn.Linear(input_size, input_size*2) + self.act = Swish() + self.c_proj = nn.Linear(input_size*2, output_size) + + if adapter_norm == "layer_norm": + self.norm = nn.LayerNorm([self.query_length, output_size]) + elif adapter_norm == "batch_norm": + self.norm = nn.BatchNorm1d(self.query_length) + + self.init_type = init_type.lower() + self._initialize_weights() + + def forward(self, hidden_states): + hidden_states = self.dropout(hidden_states) + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.norm(hidden_states) + return hidden_states + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + if self.init_type == "glorot": + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif self.init_type == "normal": + init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + else: + raise ValueError("Invalid initialization type specified.") diff --git a/starvector/clip_model.py b/starvector/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb2349dc3a5521b0a59d896025f6a7251374897 --- /dev/null +++ b/starvector/clip_model.py @@ -0,0 +1,191 @@ +# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py + +from collections import OrderedDict +from itertools import repeat +import collections.abc +import math +import torch +import torch.nn.functional as F +from torch import nn +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): + """Convert applicable model parameters to the specified precision""" + + def _convert_weights_to_precision(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(precision) + if l.bias is not None: + l.bias.data = l.bias.data.to(precision) + + elif isinstance(l, (nn.MultiheadAttention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(precision) + else: + for _, p in l.named_parameters(): + p.data = p.data.to(precision) + + model.apply(_convert_weights_to_precision) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + layernorm_dtype = self.weight.dtype + ret = super().forward(x.type(layernorm_dtype)) + return ret.type(orig_type) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): + super().__init__() + self.input_resolution = input_resolution + self.num_features = width + self.num_heads = heads + self.num_patches = (input_resolution // patch_size) ** 2 + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + return x diff --git a/starvector/data/augmentation.py b/starvector/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..e138e2fdbe257a1741a0c397c4c0339323951afb --- /dev/null +++ b/starvector/data/augmentation.py @@ -0,0 +1,250 @@ + +import numpy as np +from svgpathtools import ( + Path, Arc, CubicBezier, QuadraticBezier, + svgstr2paths) +import os +from noise import pnoise1 +import re +import matplotlib.colors as mcolors +from bs4 import BeautifulSoup +from starvector.data.util import rasterize_svg + +class SVGTransforms: + def __init__(self, transformations): + self.transformations = transformations + self.noise_std = self.transformations.get('noise_std', False) + self.noise_type = self.transformations.get('noise_type', False) + self.rotate = self.transformations.get('rotate', False) + self.shift_re = self.transformations.get('shift_re', False) + self.shift_im = self.transformations.get('shift_im', False) + self.scale = self.transformations.get('scale', False) + self.color_noise = self.transformations.get('color_noise', False) + self.p = self.transformations.get('p', 0.5) + self.color_change = self.transformations.get('color_change', False) + self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000']) + + def sample_transformations(self): + if self.rotate: + a, b = self.rotate['from'], self.rotate['to'] + rotation_angle = np.random.uniform(a, b) + self.rotation_angle = rotation_angle + + if self.shift_re or self.shift_im: + self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to']) + self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to']) + + if self.scale: + self.scale = np.random.uniform(self.scale['from'], self.scale['to']) + + if self.color_noise: + self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to']) + + + def paths2str(self, groupped_paths, svg_opening_tag=''): + + keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry'] + all_groups_srt = '' + for group, elements in groupped_paths.items(): + group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', []) + group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items()) + path_strings = [] + path_str = '' + for path, attributes in paths_and_attributes: + path_attr_str = '' + d_str = path.d() + + for key, value in attributes.items(): + if key not in keys_to_exclude: + path_attr_str += f' {key}="{value}"' + + path_strings.append(f'') + path_str = "\n".join(path_strings) + if 'no_group'in group: + group_str = path_str + else: + group_str = f'\n{path_str}\n\n' + all_groups_srt += group_str + svg = f'{svg_opening_tag}\n{all_groups_srt}' + return svg + + def add_noise(self, seg): + noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to']) + if self.noise_type == 'gaussian': + noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \ + 1j * np.random.normal(loc=0.0, scale=noise_scale) + elif self.noise_type == 'perlin': + noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale + + if isinstance(seg, CubicBezier): + seg.control1 = seg.control1 + noise_sample + seg.control2 = seg.control2 + noise_sample + elif isinstance(seg, QuadraticBezier): + seg.control = seg.control + noise_sample + elif isinstance(seg, Arc): + seg.radius = seg.radius + noise_sample + + + return seg + + def do_rotate(self, path, viewbox_width, viewbox_height): + if self.rotate: + new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2)) + return new_path + else: + return path + + def do_shift(self, path): + if self.shift_re or self.shift_im: + return path.translated(complex(self.shift_real, self.shift_imag)) + else: + return path + + def do_scale(self, path): + if self.scale: + return path.scaled(self.scale) + else: + return path + + def add_color_noise(self, source_color): + # Convert color to RGB + if source_color.startswith("#"): + base_color = mcolors.hex2color(source_color) + else: + base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF')) + + # Add noise to each RGB component + noise = np.random.normal(0, self.color_noise_std, 3) + noisy_color = np.clip(np.array(base_color) + noise, 0, 1) + + # Convert the RGB color back to hex + hex_color = mcolors.rgb2hex(noisy_color) + + return hex_color + + def do_color_change(self, attr): + if 'fill' in attr: + if self.color_noise or self.color_change: + fill_value = attr['fill'] + if fill_value == 'none': + new_fill_value = 'none' + else: + if self.color_noise: + new_fill_value = self.add_color_noise(fill_value) + elif self.color_change: + new_fill_value = np.random.choice(self.colors) + attr['fill'] = new_fill_value + return attr + + def clean_attributes(self, attr): + attr_out = {} + if 'fill' in attr: + attr_out = attr + elif 'style' in attr: + fill_values = re.findall('fill:[^;]+', attr['style']) + if fill_values: + fill_value = fill_values[0].replace('fill:', '').strip() + attr_out['fill'] = fill_value + else: + attr_out = attr + else: + attr_out = attr + + return attr_out + + def get_viewbox_size(self, svg): + # Try to extract viewBox attribute + match = re.search(r'viewBox="([^"]+)"', svg) + if match: + viewbox = match.group(1) + else: + # If viewBox is not found, try to extract width and height attributes + match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg) + if match: + width, height = match.groups() + viewbox = f"0 0 {width} {height}" + else: + viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found + + viewbox = [float(x) for x in viewbox.split()] + viewbox_width, viewbox_height = viewbox[2], viewbox[3] + return viewbox_width, viewbox_height + + def augment(self, svg): + if os.path.isfile(svg): + # open svg file + with open(svg, 'r') as f: + svg = f.read() + + # Sample transformations for this sample + self.sample_transformations() + + + # Parse the SVG content + soup = BeautifulSoup(svg, 'xml') + + # Get opening tag + svg_opening_tag = re.findall(']+>', svg)[0] + + viewbox_width, viewbox_height = self.get_viewbox_size(svg) + + # Get all svg parents + groups = soup.findAll() + + # Create the groups of paths based on their original tag + grouped_paths = {} + for i, g in enumerate(groups): + if g.name == 'g': + group_id = group_id = g.get('id') if g.get('id') else f'none_{i}' + group_attrs = g.attrs + + elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs': + continue + + else: + group_id = f'no_group_{i}' + group_attrs = {} + + group_svg_string = f'{svg_opening_tag}{str(g)}' + try: + paths, attributes = svgstr2paths(group_svg_string) + except: + return svg, rasterize_svg(svg) + if not paths: + continue + + paths_and_attributes = [] + + # Rotation, shift, scale, noise addition + new_paths = [] + new_attributes = [] + for path, attribute in zip(paths, attributes): + attr = self.clean_attributes(attribute) + + new_path = self.do_rotate(path, viewbox_width, viewbox_height) + new_path = self.do_shift(new_path) + new_path = self.do_scale(new_path) + + if self.noise_std: + # Add noise to path to deform svg + noisy_path = [] + for seg in new_path: + noisy_seg = self.add_noise(seg) + noisy_path.append(noisy_seg) + new_paths.append(Path(*noisy_path)) + else: + new_paths.append(new_path) + + # Color change + attr = self.do_color_change(attr) + paths_and_attributes.append((new_path, attr)) + + grouped_paths[group_id] = { + 'paths': paths_and_attributes, + 'attrs': group_attrs + } + + svg = self.paths2str(grouped_paths, svg_opening_tag) + image = rasterize_svg(svg) + + return svg, image diff --git a/starvector/data/base.py b/starvector/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..33fee34512badbba8aef048a20651c8418e5a0c1 --- /dev/null +++ b/starvector/data/base.py @@ -0,0 +1,71 @@ +from torch.utils.data import Dataset +from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg +from starvector.util import instantiate_from_config +import numpy as np +from datasets import load_dataset + +class SVGDatasetBase(Dataset): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + self.split = split + self.im_size = im_size + + transforms = kwargs.get('transforms', False) + if transforms: + self.transforms = instantiate_from_config(transforms) + self.p = self.transforms.p + else: + self.transforms = None + self.p = 0.0 + + normalization = kwargs.get('normalize', False) + if normalization: + mean = tuple(normalization.get('mean', None)) + std = tuple(normalization.get('std', None)) + else: + mean = None + std = None + + self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std) + self.data = load_dataset(dataset_name, split=split) + + print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split") + + def __len__(self): + return len(self.data_json) + + def get_svg_and_image(self, svg_str, sample_id): + do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p]) + svg, image = None, None + + # Try to augment the image if conditions are met + if self.transforms is not None and do_augment: + try: + svg, image = self.transforms.augment(svg_str) + except Exception as e: + print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG") + + # If augmentation failed or wasn't attempted, try to rasterize the SVG + if svg is None or image is None: + try: + svg, image = svg_str, rasterize_svg(svg_str, self.im_size) + except Exception as e: + print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image") + svg = use_placeholder() + image = rasterize_svg(svg, self.im_size) + + # If the image is completely white, use a placeholder image + if np.array(image).mean() == 255.0: + print(f"Image is full white, using placeholder image for {sample_id}") + svg = use_placeholder() + image = rasterize_svg(svg) + + # Process the image + if 'siglip' in self.image_processor: + image = self.processor(image).pixel_values[0] + else: + image = self.processor(image) + + return svg, image + + def __getitem__(self, idx): + raise NotImplementedError("This method should be implemented by subclasses") diff --git a/starvector/data/dataset.py b/starvector/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f58c5c1d937856f58958321cee26d09439ef68e7 --- /dev/null +++ b/starvector/data/dataset.py @@ -0,0 +1,42 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.augmentation import SVGTransforms +from starvector.data.util import ImageTrainProcessor +from transformers import AutoProcessor + +class SVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): + super().__init__(dataset_name, split, im_size, num_samples, **kwargs) + + self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) + select_dataset_name = kwargs.get('select_dataset_name', False) + + if select_dataset_name: + self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } \ No newline at end of file diff --git a/starvector/data/emojisvg.py b/starvector/data/emojisvg.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf867d1ed188765836f8ebac1d7df714f041834 --- /dev/null +++ b/starvector/data/emojisvg.py @@ -0,0 +1,27 @@ +import os +from starvector.data.base import SVGDatasetBase + + +class EmojiSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } \ No newline at end of file diff --git a/starvector/data/figrsvg.py b/starvector/data/figrsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..32280a4d0383a06863ceb036e0287974f5d8d36a --- /dev/null +++ b/starvector/data/figrsvg.py @@ -0,0 +1,27 @@ +import os +from starvector.data.base import SVGDatasetBase +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +class FigrSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Id'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/fontsvg.py b/starvector/data/fontsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..1c90da9906f9af03b856fb751bdf879e1bad4401 --- /dev/null +++ b/starvector/data/fontsvg.py @@ -0,0 +1,28 @@ +import os +from starvector.data.base import SVGDatasetBase +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +class FontSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/iconsvg.py b/starvector/data/iconsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..45881997dd4f308e75035b5b7db0d4ad8a37ab03 --- /dev/null +++ b/starvector/data/iconsvg.py @@ -0,0 +1,38 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.util import ImageTrainProcessor +from transformers import AutoProcessor + +class SVGIconsDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/stacksvg.py b/starvector/data/stacksvg.py new file mode 100644 index 0000000000000000000000000000000000000000..23d3ea76b847a437d84855b0aa9f1721dfd82c42 --- /dev/null +++ b/starvector/data/stacksvg.py @@ -0,0 +1,59 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.augmentation import SVGTransforms +import random +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +text2svg_captions = [ + "Draw an SVG of ", + "Draw an SVG image of ", + "Draw an SVG picture of ", + "Generate an SVG of ", + "Create an SVG of ", + "Design an SVG of ", + "Make an SVG of ", +] + +class SVGStackDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, num_samples, **kwargs) + self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) + + # Text2SVG specific + self.random_caption = kwargs.get('random_caption', True) + select_dataset_name = kwargs.get('select_dataset_name', False) + if select_dataset_name: + self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if self.image_processor and 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + + # Randomly choose between 'caption_blip' and 'caption_llava' + caption_column = random.choice(['caption_blip2', 'caption_llava']) + caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption, + } diff --git a/starvector/data/util.py b/starvector/data/util.py new file mode 100644 index 0000000000000000000000000000000000000000..48635a6e01e7ab9e22e2c98fd328bcf478e0d7f8 --- /dev/null +++ b/starvector/data/util.py @@ -0,0 +1,389 @@ +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode, pad +import numpy as np +import matplotlib.pyplot as plt +from bs4 import BeautifulSoup +import re +from svgpathtools import svgstr2paths +import numpy as np +from PIL import Image +import cairosvg +from io import BytesIO +import numpy as np +import textwrap +import os +import base64 +import io + + + +CIRCLE_SVG = "" +VOID_SVF = "" + +def load_transforms(): + transforms = { + 'train': None, + 'eval': None + } + return transforms + +class ImageBaseProcessor(): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean=mean, std=std) + +class ImageTrainProcessor(ImageBaseProcessor): + def __init__(self, mean=None, std=None, size=224, **kwargs): + super().__init__(mean, std) + + self.size = size + + self.transform = transforms.Compose([ + transforms.Lambda(lambda img: self._rgba_to_rgb_white(img) if img.mode == "RGBA" else img), + transforms.Lambda(lambda img: self._pad_to_square(img)), + transforms.Resize(self.size, interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + self.normalize + ]) + + def __call__(self, item): + return self.transform(item) + + def _pad_to_square(self, img): + # Calculate padding to make the image square + width, height = img.size + max_dim = max(width, height) + padding = [(max_dim - width) // 2, (max_dim - height) // 2] + padding += [max_dim - width - padding[0], max_dim - height - padding[1]] + return pad(img, padding, fill=255) # Assuming white padding + + def _rgba_to_rgb_white(self, img): + background = Image.new("RGB", img.size, (255, 255, 255)) + background.paste(img, mask=img.split()[3]) + return background + + +def encode_image_base64(pil_image): + if pil_image.mode == 'RGBA': + pil_image = pil_image.convert('RGB') # Convert RGBA to RGB + buffered = io.BytesIO() + pil_image.save(buffered, format="JPEG") + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + return base64_image + +# -------------- Generation utils -------------- +def is_valid_svg(svg_text): + try: + svgstr2paths(svg_text) + return True + except Exception as e: + print(f"Invalid SVG: {str(e)}") + return False + +def clean_svg(svg_text, output_width=None, output_height=None): + soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml + svg_bs4 = soup.prettify() # Prettify to get a string + + # Store the original signal handler + import signal + original_handler = signal.getsignal(signal.SIGALRM) + + try: + # Set a timeout to prevent hanging + def timeout_handler(signum, frame): + raise TimeoutError("SVG processing timed out") + + # Set timeout + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(5) + + # Try direct conversion without BeautifulSoup + svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode() + + except TimeoutError: + print("SVG conversion timed out, using fallback method") + svg_cairo = """""" + finally: + # Always cancel the alarm and restore original handler, regardless of success or failure + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + + svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("]*\/>" + all_tags = re.findall(all_tags_pattern, svg_content) + self_closing_matches = re.findall(self_closing_pattern, svg_content) + self_closing_tags = [] + + for match in self_closing_matches: + tag = re.search(all_tags_pattern, match) + if tag: + self_closing_tags.append(tag.group(1)) + unclosed_tags = [] + + for tag in all_tags: + if all_tags.count(tag) > self_closing_tags.count(tag) + svg_content.count(''): + unclosed_tags.append(tag) + unclosed_tags = list(dict.fromkeys(unclosed_tags)) + + return unclosed_tags + + +# -------------- Plotting utils -------------- +def plot_images_side_by_side_with_metrics(image1, image2, l2_dist, CD, post_processed, out_path): + array1 = np.array(image1).astype(np.float32) + array2 = np.array(image2).astype(np.float32) + diff = np.abs(array1 - array2).astype(np.uint8) + + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + axes[0].imshow(image1) + axes[0].set_title('generated_svg') + axes[0].axis('off') + axes[1].imshow(image2) + axes[1].set_title('gt') + axes[1].axis('off') + axes[2].imshow(diff) + axes[2].set_title('Difference') + axes[2].axis('off') + plt.suptitle(f"MSE: {l2_dist:.4f}, CD: {CD:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_side_by_side(image1, image2, out_path): + array1 = np.array(image1).astype(np.float32) + array2 = np.array(image2).astype(np.float32) + diff = np.abs(array1 - array2).astype(np.uint8) + + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + axes[0].imshow(image1) + axes[0].set_title('generated_svg') + axes[0].axis('off') + axes[1].imshow(image2) + axes[1].set_title('gt') + axes[1].axis('off') + axes[2].imshow(diff) + axes[2].set_title('Difference') + axes[2].axis('off') + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_side_by_side_temperatures(samples_temp, metrics, sample_dir, outpath_filename): + # Create a plot with the original image and different temperature results + num_temps = len(samples_temp) + fig, axes = plt.subplots(2, num_temps + 1, figsize=(15, 4), gridspec_kw={'height_ratios': [10, 2]}) + + # Plot the original image + gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') + gt_image = Image.open(gt_image_path) + axes[0, 0].imshow(gt_image) + axes[0, 0].set_title('Original') + axes[0, 0].axis('off') + axes[1, 0].text(0.5, 0.5, 'Original', horizontalalignment='center', verticalalignment='center', fontsize=16) + axes[1, 0].axis('off') + + # Plot the generated images for different temperatures and metrics + for idx, (temp, sample) in enumerate(samples_temp.items()): + gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') + gen_image = Image.open(gen_image_path) + axes[0, idx + 1].imshow(gen_image) + axes[0, idx + 1].set_title(f'Temp {temp}') + axes[0, idx + 1].axis('off') + axes[1, idx + 1].text(0.5, 0.5, f'MSE: {metrics[temp]["mse"]:.2f}\nCD: {metrics[temp]["cd"]:.2f}', + horizontalalignment='center', verticalalignment='center', fontsize=12) + axes[1, idx + 1].axis('off') + + # Save the comparison plot + comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') + plt.tight_layout() + plt.savefig(comparison_path) + plt.close() + +def plot_images_and_prompt(prompt, svg_raster, gt_svg_raster, out_path): + # First col shows caption, second col shows generated svg, third col shows gt svg + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + axes[1].imshow(svg_raster) + axes[1].set_title('generated_svg') + axes[1].axis('off') + axes[2].imshow(gt_svg_raster) + axes[2].set_title('gt') + axes[2].axis('off') + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_and_prompt_with_metrics(prompt, svg_raster, gt_svg_raster, clip_score, post_processed, out_path): + # First col shows caption, second col shows generated svg, third col shows gt svg + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + axes[1].imshow(svg_raster) + axes[1].set_title('generated_svg') + axes[1].axis('off') + axes[2].imshow(gt_svg_raster) + axes[2].set_title('gt') + axes[2].axis('off') + plt.suptitle(f"CLIP Score: {clip_score:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_and_prompt_temperatures(prompt, samples_temp, metrics, sample_dir, outpath_filename): + # Calculate the number of temperature variations + num_temps = len(samples_temp) + + # Create a plot with text, the original image, and different temperature results + fig, axes = plt.subplots(1, num_temps + 2, figsize=(5 + 3 * (num_temps + 1), 6)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + + # Plot the GT (ground truth) image in the second cell + gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') + gt_image = Image.open(gt_image_path) + axes[1].imshow(gt_image) + axes[1].set_title('GT Image') + axes[1].axis('off') + + # Plot the generated images for different temperatures and display metrics + for idx, (temp, sample) in enumerate(samples_temp.items()): + gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') + gen_image = Image.open(gen_image_path) + axes[idx + 2].imshow(gen_image) + axes[idx + 2].set_title(f'Temp {temp}') + axes[idx + 2].axis('off') + clip_score = metrics[temp]["clip_score"] + axes[idx + 2].text(0.5, -0.1, f'CLIP: {clip_score:.4f}', horizontalalignment='center', verticalalignment='center', fontsize=12, transform=axes[idx + 2].transAxes) + + # Save the comparison plot + comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') + plt.tight_layout() + plt.savefig(comparison_path) + plt.close() + + return comparison_path + + +def plot_image_tensor(image): + import numpy as np + from PIL import Image + tensor = image[0].cpu().float() + tensor = tensor.permute(1, 2, 0) + array = (tensor.numpy() * 255).astype(np.uint8) + im = Image.fromarray(array) + im.save("tmp/output_image.jpg") + + +def plot_grid_samples(images, num_cols=5, out_path = 'grid.png'): + # Calculate the number of rows required for the grid + num_images = len(images) + num_rows = (num_images + num_cols - 1) // num_cols + + # Create a new figure + fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8)) + + # Loop through the image files and plot them + for i, image in enumerate(images): + row = i // num_cols + col = i % num_cols + + # Open and display the image using Pillow + if type(image) == str: + img = Image.open(image) + else: + img = image + axes[row, col].imshow(img) + # axes[row, col].set_title(os.path.basename(image_file)) + axes[row, col].axis('off') + + # Remove empty subplots + for i in range(num_images, num_rows * num_cols): + row = i // num_cols + col = i % num_cols + fig.delaxes(axes[row, col]) + + # Adjust spacing between subplots + plt.tight_layout() + + # save image + plt.savefig(out_path, dpi=300) + image = Image.open(out_path) + plt.close(fig) + + return image \ No newline at end of file diff --git a/starvector/image_encoder.py b/starvector/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3f85a4cad987d91cf47285c9dd3099e6c7d6b24d --- /dev/null +++ b/starvector/image_encoder.py @@ -0,0 +1,119 @@ +import os +import torch +import torch.nn as nn +import os +from omegaconf import OmegaConf +from starvector.model.image_encoder.clip_model import convert_weights_to_precision +from starvector.data.util import ImageTrainProcessor + +class ImageEncoder(nn.Module): + def __init__(self, config, **kwargs): + super(ImageEncoder, self).__init__() + + image_size = config.image_size + torch_dtype = kwargs.get('model_precision', config.torch_dtype) + self.image_encoder_type = config.image_encoder_type + if self.image_encoder_type == 'clip': + self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) + convert_weights_to_precision(self, torch_dtype) + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'vqgan': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'convnext': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif 'siglip' in self.image_encoder_type: + if self.image_encoder_type == 'siglip_512': + model_name = "google/siglip-base-patch16-512" + elif self.image_encoder_type == 'siglip_384': + model_name = "google/siglip-large-patch16-384" + elif self.image_encoder_type == 'siglip_256': + model_name = "google/siglip-base-patch16-256" + + from transformers import AutoProcessor, AutoModel + + self.visual_encoder = AutoModel.from_pretrained( + model_name, torch_dtype = torch_dtype + ).vision_model + + self.processor = AutoProcessor.from_pretrained( + model_name, torch_dtype = torch_dtype + ) + + def build_clip_encoder(self, image_size): + from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm + visual_encoder = VisionTransformer( + input_resolution=image_size, + patch_size=14, + width=1024, + layers=23, + heads=16, + use_grad_checkpointing=False) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def build_vqgan_encoder(self): + from taming.modules.diffusionmodules.model import Encoder + VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md + vqgan_chkp_path = VQGAN_CHECKPOINT + files_in_directory = os.listdir(vqgan_chkp_path + '/configs') + vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] + vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) + visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) + + # Load checkpoint weights + checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] + + # Create a new state_dict with modified keys + new_state_dict = {} + for key, value in checkpoint.items(): + if key.startswith('encoder.'): + new_key = key[len('encoder.'):] + new_state_dict[new_key] = value + + # Load weights + visual_encoder.load_state_dict(new_state_dict) + return visual_encoder + + def build_convnext_encoder(self): + import open_clip + model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') + return model.visual + + def forward(self, image): + if self.image_encoder_type == 'clip': + embeds = self.visual_encoder(image) + out = self.ln_vision(embeds) + elif self.image_encoder_type == 'open-clip': + out = self.visual_encoder(image)[1] + out = self.ln_vision(out) + elif self.image_encoder_type == 'vqgan': + out = self.visual_encoder(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif self.image_encoder_type == 'convnext': + out = self.visual_encoder.trunk.forward_features(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif 'siglip' in self.image_encoder_type: + out = self.visual_encoder(image)["last_hidden_state"] + return out + + def process_images(self, images): + if self.image_encoder_type == 'clip': + res = [] + for image in images: + res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W + return res + else: + return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) + \ No newline at end of file diff --git a/starvector/metrics/base_metric.py b/starvector/metrics/base_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d07d8472bc70220d8bc2188d41dcb20d8e92a1 --- /dev/null +++ b/starvector/metrics/base_metric.py @@ -0,0 +1,51 @@ +from starvector.metrics.util import AverageMeter +from tqdm import tqdm +import math + +class BaseMetric: + def __init__(self): + self.meter = AverageMeter() + + def reset(self): + self.meter.reset() + + def calculate_score(self, batch, update=True): + """ + Batch: {"gt_im": [PIL Image], "gen_im": [Image]} + """ + values = [] + batch_size = len(next(iter(batch.values()))) + for index in tqdm(range(batch_size)): + kwargs = {} + for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]: + if key in batch: + kwargs[key] = batch[key][index] + try: + measure = self.metric(**kwargs) + except Exception as e: + print("Error calculating metric: {}".format(e)) + continue + if math.isnan(measure): + continue + values.append(measure) + + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + score = sum(values) / len(values) + if update: + self.meter.update(score, len(values)) + return self.meter.avg, values + else: + return score, values + + def metric(self, **kwargs): + """ + This method should be overridden by subclasses to provide the specific metric computation. + """ + raise NotImplementedError("The metric method must be implemented by subclasses.") + + def get_average_score(self): + return self.meter.avg + diff --git a/starvector/metrics/compute_LPIPS.py b/starvector/metrics/compute_LPIPS.py new file mode 100644 index 0000000000000000000000000000000000000000..b30c42cfdf21febf2b30a83dd51d690423294321 --- /dev/null +++ b/starvector/metrics/compute_LPIPS.py @@ -0,0 +1,56 @@ +from torchvision.transforms import ToTensor, Normalize +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +import lpips +from tqdm import tqdm + + +class LPIPSDistanceCalculator(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.model = lpips.LPIPS(net='vgg').to(device) + self.metric = self.LPIPS + self.to_tensor = ToTensor() + self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.device = device + + def LPIPS(self, tensor_image1, tensor_image2): + tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device) + return self.model(tensor_image1, tensor_image2) + + def to_tensor_transform(self, pil_img): + return self.normalize(self.to_tensor(pil_img)) + + def collate_fn(self, batch): + gt_imgs, gen_imgs = zip(*batch) + tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs]) + tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs]) + return tensor_gt_imgs, tensor_gen_imgs + + def calculate_score(self, batch, batch_size=8, update=True): + gt_images = batch['gt_im'] + gen_images = batch['gen_im'] + + # Create DataLoader with custom collate function + data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False) + + values = [] + for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader): + # Compute LPIPS + lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch) + values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist()) + + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + avg_score = sum(values) / len(values) + if update: + self.meter.update(avg_score, len(values)) + return self.meter.avg, values + else: + return avg_score, values + \ No newline at end of file diff --git a/starvector/metrics/compute_SSIM.py b/starvector/metrics/compute_SSIM.py new file mode 100644 index 0000000000000000000000000000000000000000..e0dfb75435d78261197e12276cb53ca540ea7687 --- /dev/null +++ b/starvector/metrics/compute_SSIM.py @@ -0,0 +1,35 @@ +from starvector.metrics.base_metric import BaseMetric +from skimage.metrics import structural_similarity as ssim +import numpy as np + +class SSIMDistanceCalculator(BaseMetric): + def __init__(self, config=None): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.metric = self.compute_SSIM + + def compute_SSIM(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy + channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1 + sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter + + # Convert images to numpy arrays if they aren't already + img1_np = np.array(image1) + img2_np = np.array(image2) + + # Check if images are grayscale or RGB + if len(img1_np.shape) == 3 and img1_np.shape[2] == 3: + # Compute SSIM for RGB images + score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True) + else: + # Convert to grayscale if not already + if len(img1_np.shape) == 3: + img1_np = np.mean(img1_np, axis=2) + img2_np = np.mean(img2_np, axis=2) + + score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True) + + return score \ No newline at end of file diff --git a/starvector/metrics/compute_clip_score.py b/starvector/metrics/compute_clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..186fb2f85bef817a53f5e00a6a82a310451fe15b --- /dev/null +++ b/starvector/metrics/compute_clip_score.py @@ -0,0 +1,55 @@ +from torchvision.transforms import ToTensor +import torch.nn.functional as F +from starvector.metrics.base_metric import BaseMetric +import torch +from torchmetrics.multimodal.clip_score import CLIPScore +from torch.utils.data import DataLoader +from tqdm import tqdm +import torchvision.transforms as transforms +from torchmetrics.functional.multimodal.clip_score import _clip_score_update + +class CLIPScoreCalculator(BaseMetric): + def __init__(self): + super().__init__() + self.class_name = self.__class__.__name__ + self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32") + self.clip_score.to('cuda') + + def CLIP_Score(self, images, captions): + all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor) + return all_scores + + def collate_fn(self, batch): + gen_imgs, captions = zip(*batch) + tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs] + return tensor_gen_imgs, captions + + def calculate_score(self, batch, batch_size=512, update=True): + gen_images = batch['gen_im'] + captions = batch['caption'] + + # Create DataLoader with custom collate function + data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + + all_scores = [] + for batch_eval in tqdm(data_loader): + images, captions = batch_eval + images = [img.to('cuda', non_blocking=True) * 255 for img in images] + list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist() + all_scores.extend(list_scores) + + if not all_scores: + print("No valid scores found for metric calculation.") + return float("nan"), [] + + avg_score = sum(all_scores) / len(all_scores) + if update: + self.meter.update(avg_score, len(all_scores)) + return self.meter.avg, all_scores + else: + return avg_score, all_scores + +if __name__ == '__main__': + import multiprocessing + multiprocessing.set_start_method('spawn') + # Rest of your code... \ No newline at end of file diff --git a/starvector/metrics/compute_dino_score.py b/starvector/metrics/compute_dino_score.py new file mode 100644 index 0000000000000000000000000000000000000000..99a8364c4d5fd6a69caed545e95128ef34f9c56a --- /dev/null +++ b/starvector/metrics/compute_dino_score.py @@ -0,0 +1,55 @@ +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +from tqdm import tqdm +from transformers import AutoModel, AutoImageProcessor +from PIL import Image +import torch.nn as nn + +class DINOScoreCalculator(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.model, self.processor = self.get_DINOv2_model("base") + self.model = self.model.to(device) + self.device = device + + self.metric = self.calculate_DINOv2_similarity_score + + def get_DINOv2_model(self, model_size): + if model_size == "small": + model_size = "facebook/dinov2-small" + elif model_size == "base": + model_size = "facebook/dinov2-base" + elif model_size == "large": + model_size = "facebook/dinov2-large" + else: + raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}") + return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size) + + def process_input(self, image, processor): + if isinstance(image, str): + image = Image.open(image) + if isinstance(image, Image.Image): + with torch.no_grad(): + inputs = processor(images=image, return_tensors="pt").to(self.device) + outputs = self.model(**inputs) + features = outputs.last_hidden_state.mean(dim=1) + elif isinstance(image, torch.Tensor): + features = image.unsqueeze(0) if image.dim() == 1 else image + else: + raise ValueError("Input must be a file path, PIL Image, or tensor of features") + return features + + def calculate_DINOv2_similarity_score(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + features1 = self.process_input(image1, self.processor) + features2 = self.process_input(image2, self.processor) + + cos = nn.CosineSimilarity(dim=1) + sim = cos(features1, features2).item() + sim = (sim + 1) / 2 + + return sim diff --git a/starvector/metrics/compute_fid.py b/starvector/metrics/compute_fid.py new file mode 100644 index 0000000000000000000000000000000000000000..413fca4a4c14e66a30b4aafee21220d4e02a41c0 --- /dev/null +++ b/starvector/metrics/compute_fid.py @@ -0,0 +1,145 @@ +# Refer https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html +# from torchmetrics.image.fid import FrechetInceptionDistance +from PIL import Image +from starvector.metrics.base_metric import BaseMetric +import torch +from torchvision import transforms +import clip +from torch.nn.functional import adaptive_avg_pool2d +from starvector.metrics.inception import InceptionV3 +import numpy as np +from tqdm import tqdm +from scipy import linalg +import torchvision.transforms as TF + +class FIDCalculator(BaseMetric): + def __init__(self, model_name = 'InceptionV3',): + self.class_name = self.__class__.__name__ + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model_name = model_name + if self.model_name == 'ViT-B/32': + self.dims = 512 + model, preprocess = clip.load('ViT-B/32') + + elif self.model_name == 'InceptionV3': + self.dims = 2048 + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] + model = InceptionV3([block_idx]).to(self.device) + preprocess = TF.Compose([TF.ToTensor()]) + + self.model = model.cuda() + self.preprocess = preprocess + + def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + def get_activations(self, images): + dataset = ImageDataset(images, self.preprocess) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=4) + pred_arr = np.empty((len(images), self.dims)) + start_idx = 0 + for batch in tqdm(dataloader): + batch = batch.to(self.device) + + with torch.no_grad(): + if self.model_name == 'ViT-B/32': + pred = self.model.encode_image(batch).cpu().numpy() + elif self.model_name == 'InceptionV3': + pred = self.model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + start_idx = start_idx + pred.shape[0] + + return pred_arr + + def calculate_activation_statistics(self, images): + act = self.get_activations(images) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + def pil_images_to_tensor(self, images_list): + """Convert a list of PIL Images to a torch.Tensor.""" + tensors_list = [self.preprocess(img) for img in images_list] + return torch.stack(tensors_list).cuda() # BxCxHxW format + + def calculate_score(self, batch): + m1, s1 = self.calculate_activation_statistics(batch['gt_im']) + m2, s2 = self.calculate_activation_statistics(batch['gen_im']) + fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + def reset(self): + pass + +class ImageDataset(torch.utils.data.Dataset): + def __init__(self, images, processor=None): + self.images = images + self.processor = processor + + def __len__(self): + return len(self.images) + + def __getitem__(self, i): + img = self.images[i] + img = self.processor(img) + return img \ No newline at end of file diff --git a/starvector/metrics/compute_l2.py b/starvector/metrics/compute_l2.py new file mode 100644 index 0000000000000000000000000000000000000000..ecca16a88f6d39d4b1666da4762202ce1e69d0bb --- /dev/null +++ b/starvector/metrics/compute_l2.py @@ -0,0 +1,37 @@ +from torchvision.transforms import ToTensor +import torch.nn.functional as F +from starvector.metrics.base_metric import BaseMetric +import torch + +class L2DistanceCalculator(BaseMetric): + def __init__(self, config=None, masked_l2=False): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.metric = self.l2_distance + self.masked_l2 = masked_l2 + + def l2_distance(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + image1_tensor = ToTensor()(image1) + image2_tensor = ToTensor()(image2) + + if self.masked_l2: + # Create binary masks: 0 for white pixels, 1 for non-white pixels + mask1 = (image1_tensor != 1).any(dim=0).float() + mask2 = (image2_tensor != 1).any(dim=0).float() + + # Create a combined mask for overlapping non-white pixels + combined_mask = mask1 * mask2 + + # Apply the combined mask to both images + image1_tensor = image1_tensor * combined_mask.unsqueeze(0) + image2_tensor = image2_tensor * combined_mask.unsqueeze(0) + + # Compute mean squared error + mse = F.mse_loss(image1_tensor, image2_tensor) + return mse.item() + + + diff --git a/starvector/metrics/count_token_length.py b/starvector/metrics/count_token_length.py new file mode 100644 index 0000000000000000000000000000000000000000..8210771ec6fc7a148c069770c126610756a466e8 --- /dev/null +++ b/starvector/metrics/count_token_length.py @@ -0,0 +1,54 @@ +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +from tqdm import tqdm +from starvector.metrics.util import AverageMeter + +from transformers import AutoTokenizer + +class CountTokenLength(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b") + self.metric = self.calculate_token_length + self.meter_gt_tokens = AverageMeter() + self.meter_gen_tokens = AverageMeter() + self.meter_diff = AverageMeter() + + def calculate_token_length(self, **kwargs): + svg = kwargs.get('gt_svg') + tokens = self.tokenizer.encode(svg) + gen_svg = kwargs.get('gen_svg') + gen_tokens = self.tokenizer.encode(gen_svg) + diff = len(gen_tokens) - len(tokens) + return len(tokens), len(gen_tokens), diff + + def calculate_score(self, batch, update=None): + gt_svgs = batch['gt_svg'] + gen_svgs = batch['gen_svg'] + values = [] + for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"): + gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg) + self.meter_gt_tokens.update(gt_tokens, 1) + self.meter_gen_tokens.update(gen_tokens, 1) + self.meter_diff.update(diff, 1) + values.append({ + 'gt_tokens': gt_tokens, + 'gen_tokens': gen_tokens, + 'diff': diff + }) + avg_score = { + 'gt_tokens': self.meter_gt_tokens.avg, + 'gen_tokens': self.meter_gen_tokens.avg, + 'diff': self.meter_diff.avg + } + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + return avg_score, values + + def reset(self): + self.meter_gt_tokens.reset() + self.meter_gen_tokens.reset() + self.meter_diff.reset() diff --git a/starvector/metrics/inception.py b/starvector/metrics/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..cc56870522a1ee298abbfff9edd7243a0bf8e7dd --- /dev/null +++ b/starvector/metrics/inception.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = _inception_v3(weights='DEFAULT') + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def _inception_v3(*args, **kwargs): + """Wraps `torchvision.models.inception_v3`""" + try: + version = tuple(map(int, torchvision.__version__.split('.')[:2])) + except ValueError: + # Just a caution against weird version strings + version = (0,) + + # Skips default weight inititialization if supported by torchvision + # version. See https://github.com/mseitzer/pytorch-fid/issues/28. + if version >= (0, 6): + kwargs['init_weights'] = False + + # Backwards compatibility: `weights` argument was handled by `pretrained` + # argument prior to version 0.13. + if version < (0, 13) and 'weights' in kwargs: + if kwargs['weights'] == 'DEFAULT': + kwargs['pretrained'] = True + elif kwargs['weights'] is None: + kwargs['pretrained'] = False + else: + raise ValueError( + 'weights=={} not supported in torchvision {}'.format( + kwargs['weights'], torchvision.__version__ + ) + ) + del kwargs['weights'] + + return torchvision.models.inception_v3(*args, **kwargs) + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = _inception_v3(num_classes=1008, + aux_logits=False, + weights=None) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(torchvision.models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(torchvision.models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(torchvision.models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(torchvision.models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) \ No newline at end of file diff --git a/starvector/metrics/metrics.py b/starvector/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..86d7cdcd6dd5798e8bb35b8fe472f687d96f3b4d --- /dev/null +++ b/starvector/metrics/metrics.py @@ -0,0 +1,127 @@ +from starvector.metrics.compute_l2 import L2DistanceCalculator +from starvector.metrics.compute_LPIPS import LPIPSDistanceCalculator +from starvector.metrics.compute_SSIM import SSIMDistanceCalculator +from starvector.metrics.compute_fid import FIDCalculator +from starvector.metrics.compute_clip_score import CLIPScoreCalculator +from starvector.data.util import rasterize_svg +from starvector.metrics.util import AverageMeter +from starvector.metrics.compute_dino_score import DINOScoreCalculator +from starvector.metrics.count_token_length import CountTokenLength +import os +from tqdm import tqdm + +class SVGMetrics: + def __init__(self, config=None): + self.class_name = self.__class__.__name__ + + default_config = { + 'L2': True, + 'Masked-L2': False, + 'LPIPS': False, + 'SSIM': False, + 'FID': False, + 'FID_clip': False, + 'CLIPScore': False, + 'CountTokenLength': False, + 'ratio_post_processed': True, + 'ratio_non_compiling': True, + 'DinoScore': True, + } + self.config = config or default_config + + self.metrics = { + 'L2': L2DistanceCalculator, + 'Masked-L2': lambda: L2DistanceCalculator(masked_l2=True), + 'LPIPS': LPIPSDistanceCalculator, + 'SSIM': SSIMDistanceCalculator, + 'FID': lambda: FIDCalculator(model_name='InceptionV3'), + 'FID_clip': lambda: FIDCalculator(model_name='ViT-B/32'), + 'CLIPScore': CLIPScoreCalculator, + 'CountTokenLength': CountTokenLength, + 'ratio_post_processed': AverageMeter, + 'ratio_non_compiling': AverageMeter, + 'DinoScore': DINOScoreCalculator, + } + + self.active_metrics = {k: v() for k, v in self.metrics.items() if self.config.get(k)} + + def reset(self): + for metric in self.active_metrics.values(): + metric.reset() + + def batch_contains_raster(self, batch): + return "gt_im" in batch and "gen_im" in batch + + def batch_contains_svg(self, batch): + return "gt_svg" in batch and "gen_svg" in batch + + def calculate_metrics(self, batch, update=True): + if not self.batch_contains_raster(batch): + batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] + batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] + + avg_results_dict = {} + all_results_dict = {} + + def get_sample_id(json_item): + return json_item.get('outpath_filename') or json_item.get('sample_id') + + # initialize all_results_dict + for i, json_item in enumerate(batch['json']): + sample_id = get_sample_id(json_item) + if sample_id is None: + raise ValueError(f"Could not find 'outpath_filename' or 'sample_id' in batch['json'][{i}]") + all_results_dict[sample_id] = {} + + for metric_name, metric in self.active_metrics.items(): + print(f"Calculating {metric_name}...") + + # Handle metrics that return both average and per-sample results + if metric_name in ['L2', 'Masked-L2', 'SSIM', 'CLIPScore', 'LPIPS', 'CountTokenLength', 'DinoScore']: + avg_result, list_result = metric.calculate_score(batch, update=update) + avg_results_dict[metric_name] = avg_result + + # Store individual results + for i, result in enumerate(list_result): + sample_id = get_sample_id(batch['json'][i]) + all_results_dict[sample_id][metric_name] = result + + # Handle FID metrics that only return average + elif metric_name in ['FID', 'FID_clip']: + avg_results_dict[metric_name] = metric.calculate_score(batch) + + # Handle other metrics (ratio metrics) + else: + self._handle_ratio_metric(metric_name, metric, batch, avg_results_dict, all_results_dict) + + metric.reset() + print("Average results: \n", avg_results_dict) + return avg_results_dict, all_results_dict + + def calculate_fid(self, batch): + if not self.batch_contains_raster(batch): + batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] + batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] + + return self.active_metrics['FID'].calculate_score(batch).item() + + def get_average_metrics(self): + metrics = {} + for metric_name, metric in self.active_metrics.items(): + if hasattr(metric, 'avg'): + metrics[metric_name] = metric.avg + elif hasattr(metric, 'get_average_score'): + metrics[metric_name] = metric.get_average_score() + return metrics + + def _handle_ratio_metric(self, metric_name, metric, batch, avg_results_dict, all_results_dict): + """Helper method to handle ratio-based metrics.""" + metric_key = metric_name.replace('avg_', '').replace('ratio_', '') + + for item in batch['json']: + sample_id = get_sample_id(item) + value = item[metric_key] + all_results_dict[sample_id][metric_name] = value + metric.update(value, 1) + + avg_results_dict[metric_name] = metric.avg \ No newline at end of file diff --git a/starvector/metrics/util.py b/starvector/metrics/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1faac0ed299c21092234a19ef570584981573cc7 --- /dev/null +++ b/starvector/metrics/util.py @@ -0,0 +1,20 @@ + +# -------------- Metrics -------------- +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + \ No newline at end of file diff --git a/starvector/model/adapters/adapter.py b/starvector/model/adapters/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5a0fe2de0a98472f67576a0fa32c47c68dedff --- /dev/null +++ b/starvector/model/adapters/adapter.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch.nn.init as init +import torch + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + +class Adapter(nn.Module): + def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): + super().__init__() + self.query_length = query_length + self.dropout_prob = dropout_prob + self.adapter_norm = adapter_norm + + self.dropout = nn.Dropout(p=self.dropout_prob) + + self.c_fc = nn.Linear(input_size, input_size*2) + self.act = Swish() + self.c_proj = nn.Linear(input_size*2, output_size) + + if adapter_norm == "layer_norm": + self.norm = nn.LayerNorm([self.query_length, output_size]) + elif adapter_norm == "batch_norm": + self.norm = nn.BatchNorm1d(self.query_length) + + self.init_type = init_type.lower() + self._initialize_weights() + + def forward(self, hidden_states): + hidden_states = self.dropout(hidden_states) + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.norm(hidden_states) + return hidden_states + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + if self.init_type == "glorot": + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif self.init_type == "normal": + init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + else: + raise ValueError("Invalid initialization type specified.") diff --git a/starvector/model/builder.py b/starvector/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..785a2dfdbd131853760813d1ccf1a01da5d94fa7 --- /dev/null +++ b/starvector/model/builder.py @@ -0,0 +1,49 @@ + +from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig +from starvector.data.base import ImageTrainProcessor +from starvector.util import dtype_mapping +from transformers import AutoConfig + +def load_pretrained_model(model_path, device="cuda", **kwargs): + model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device) + tokenizer = model.model.svg_transformer.tokenizer + image_processor = ImageTrainProcessor() + context_len = model.model.query_length + model.model.max_length + return tokenizer, model, image_processor, context_len + +def model_builder(config): + model_name = config.model.get("model_name", False) + + args = { + "task": config.model.task, + "train_image_encoder": config.training.train_image_encoder, + "ignore_mismatched_sizes": True, + "starcoder_model_name": config.model.starcoder_model_name, + "train_LLM": config.training.train_LLM, + "torch_dtype": dtype_mapping[config.training.model_precision], + "transformer_layer_cls": config.model.get("transformer_layer_cls", False), + "use_cache": config.model.use_cache, + } + if model_name: + model = StarVectorForCausalLM.from_pretrained(model_name, **args) + else: + starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name) + + starvector_config = StarVectorConfig( + max_length_train=config.model.max_length, + image_encoder_type=config.model.image_encoder_type, + use_flash_attn=config.model.use_flash_attn, + adapter_norm=config.model.adapter_norm, + starcoder_model_name=config.model.starcoder_model_name, + torch_dtype=dtype_mapping[config.training.model_precision], + num_attention_heads=starcoder_model_config.num_attention_heads, + num_hidden_layers=starcoder_model_config.num_hidden_layers, + vocab_size=starcoder_model_config.vocab_size, + hidden_size=starcoder_model_config.hidden_size, + num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None), + ) + model = StarVectorForCausalLM(starvector_config, **args) + + return model + + \ No newline at end of file diff --git a/starvector/model/gpt_bigcode/__init__.py b/starvector/model/gpt_bigcode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1678bafc28fa8227101eb712cd41968269493c2b --- /dev/null +++ b/starvector/model/gpt_bigcode/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gpt_bigcode"] = [ + "GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTBigCodeForSequenceClassification", + "GPTBigCodeForTokenClassification", + "GPTBigCodeForCausalLM", + "GPTBigCodeModel", + "GPTBigCodePreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gpt_bigcode import ( + GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTBigCodeForCausalLM, + GPTBigCodeForSequenceClassification, + GPTBigCodeForTokenClassification, + GPTBigCodeModel, + GPTBigCodePreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py b/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py new file mode 100644 index 0000000000000000000000000000000000000000..ececb6332f9bad2b9af559f1a586f688c2996c78 --- /dev/null +++ b/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2023 The BigCode team and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPTBigCode configuration""" +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + + +logger = logging.get_logger(__name__) + + + + +class GPTBigCodeConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a + GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPTBigCode + [gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPTBigCodeModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", + "gelu_pytorch_tanh"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to call the fused softmax in float32. + scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to scale the attention softmax in float32. + attention_type (`bool`, *optional*, defaults to `True`): + Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`). + Example: + + ```python + >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel + + >>> # Initializing a GPTBigCode configuration + >>> configuration = GPTBigCodeConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPTBigCodeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt_bigcode" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_pytorch_tanh", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, + multi_query=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 + self.multi_query = multi_query + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py b/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py new file mode 100644 index 0000000000000000000000000000000000000000..b8334b2cbe65bb20288849eb0cf9747f48a48f0d --- /dev/null +++ b/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py @@ -0,0 +1,1502 @@ +# coding=utf-8 +# Copyright 2023 The Bigcode team and HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTBigCode model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from starvector.model.gpt_bigcode.configuration_gpt_bigcode import GPTBigCodeConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder" +_CONFIG_FOR_DOC = "GPTBigCodeConfig" + + + +# Fused kernels +# Use separate functions for each case because conditionals prevent kernel fusion. +# TODO: Could have better fused kernels depending on scaling, dropout and head mask. +# Is it doable without writing 32 functions? +@torch.jit.script +def upcast_masked_softmax( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype +): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1) + return x + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class GPTBigCodeAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + + self.mask_value = None + self.multi_query = config.multi_query + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.kv_heads = 1 if self.multi_query else self.num_heads + self.kv_dim = self.kv_heads * self.head_dim + self.split_size = self.embed_dim + self.is_causal = True + + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + self.layer_idx = layer_idx + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = ( + config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 + ) + self.attn_pdrop = config.attn_pdrop + + if self.is_cross_attention: + if self.multi_query: + raise NotImplementedError("Multi-Query Attention not supported for cross_attention") + + self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) + self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) + else: + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + + self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + def _get_mask_value(self, device, dtype): + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + dtype = query.dtype + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype + upcast = dtype != softmax_dtype + + unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + scale_factor = unscale**-1 + if self.scale_attn_weights: + scale_factor /= self.head_dim**0.5 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key_length = key.size(-1) + if self.multi_query: + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + else: + # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + query_length = query_shape[2] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + + attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) + if query.device.type == "cpu": + # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. + # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, + # but the fix has not been released as of pytorch version 2.0.0. + attn_weights = torch.zeros_like(attn_weights) + beta = 1 + else: + beta = 0 + attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + + if upcast: + # Use a fused kernel to prevent a large overhead from casting and scaling. + # Sub-optimal when the key length is not a multiple of 8. + if attention_mask is None: + attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) + else: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) + else: + if attention_mask is not None: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + + # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. + attn_weights = torch.where(attention_mask, attn_weights, mask_value) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + if self.multi_query: + head_mask = head_mask.transpose(1, 2) + attn_weights = attn_weights * head_mask + + if self.multi_query: + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + else: + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPTBigCodeFlashAttention2(GPTBigCodeAttention): + """ + GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + if self.multi_query: + batch_size, query_length, _ = query.shape + query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.unsqueeze(2) + value = value.unsqueeze(2) + else: + query_length = query.shape[2] + batch_size, _, tgt, _ = key.shape + query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + + attn_dropout = self.attn_pdrop if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_attn.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) + else: + attn_weights_reshaped = None + + outputs += (attn_weights_reshaped,) + + return outputs # a, present, (attentions) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPTBigCodeSdpaAttention(GPTBigCodeAttention): + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + if head_mask is not None: + # The super dispatch is done in the forward. + raise ValueError( + "PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." + ) + + scale = None + if not self.scale_attn_weights: + scale = 1 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key.shape[-2] + + if self.multi_query: + query_length = query_shape[1] + + # SDPA requires the dimension [..., sequence_length, head_dim]. + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + + # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. + key = key.unsqueeze(1) + value = value.unsqueeze(1) + + # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend + # and flash attention backend (No available kernel. Aborting execution.) from the shapes + # query = [batch_size, num_heads, query_length, head_dim] + # key = [batch_size, 1, past_length, head_dim] + # value = [batch_size, 1, past_length, head_dim] + # + # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check. + if is_torch_greater_or_equal_than_2_2: + key = key.expand(-1, self.num_heads, -1, -1) + value = value.expand(-1, self.num_heads, -1, -1) + else: + query_length = query_shape[-1] + + # See the comment above. + if query.device.type == "cuda" and attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=self.attn_pdrop if self.training else 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + scale=scale, + ) + + if self.multi_query: + # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) + sdpa_result = sdpa_result.transpose(1, 2) + + # Reshape is kind of expensive here, as it does a memory copy, + # but I did not manage to make away without it (logits do not match when using view) + # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) + sdpa_result = sdpa_result.reshape(query_shape) + + return sdpa_result, None + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + if not output_attentions and head_mask is None: + # Difference with the original implementation: there is no need to transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + else: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None." + ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs + + +class GPTBigCodeMLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = nn.Linear(embed_dim, intermediate_size) + self.c_proj = nn.Linear(intermediate_size, embed_dim) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPTBIGCODE_ATTENTION_CLASSES = { + "eager": GPTBigCodeAttention, + "flash_attention_2": GPTBigCodeFlashAttention2, + "sdpa": GPTBigCodeSdpaAttention, +} + + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + if config.multi_query: + raise NotImplementedError("Cross-attention not implemented for MQA") + + self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation]( + config, is_cross_attention=True, layer_idx=layer_idx + ) + + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPTBigCodeMLP(self.inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.Tensor]], + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPTBigCodePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTBigCodeConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["GPTBigCodeBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + module.c_proj.weight.data.normal_( + mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + ) + module.c_proj._is_hf_initialized = True + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +GPT_BIGCODE_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPTBigCodeConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT_BIGCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.", + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeModel(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.multi_query = config.multi_query + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False + ) + + self.gradient_checkpointing = False + + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0].size(-2) + + if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_length > 0: + position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] + elif position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Self-attention mask. + query_length = input_shape[-1] + key_length = past_length + query_length + self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None + encoder_attention_mask = ( + encoder_attention_mask.bool() + if (encoder_attention_mask is not None and 0 in encoder_attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + if attention_mask is not None: + self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + dtype=torch.bool, device=self_attention_mask.device + ) + + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + + if self._use_sdpa and head_mask is None and not output_attentions: + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + min_dtype = torch.finfo(dtype).min + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), + ) + + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if self.multi_query: + # gpt_bigcode using MQA has the bad taste to use a causal mask with shape + # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. + self_attention_mask = self_attention_mask.transpose(1, 2) + + if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda": + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + self_attention_mask = AttentionMaskConverter._unmask_unattended( + self_attention_mask, min_dtype=min_dtype + ) + + attention_mask = self_attention_mask + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if ( + self.config.add_cross_attention + and encoder_hidden_states is not None + and encoder_attention_mask is not None + ): + if encoder_attention_mask.dim() == 2: + encoder_attention_mask.unsqueeze(1) + assert encoder_attention_mask.dim() == 3 + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = [] if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache: + presents.append(outputs[1]) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPTBigCodeModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + if self.config.multi_query: + past_length = past_key_values[0].shape[1] + else: + past_length = past_key_values[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) + + +@add_start_docstrings( + """ + The GPTBigCode Model transformer with a sequence classification head on top (linear layer). + + [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPTBigCodeModel(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT_BIGCODE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPTBigCodeModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/starvector/model/image_encoder/clip_model.py b/starvector/model/image_encoder/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb2349dc3a5521b0a59d896025f6a7251374897 --- /dev/null +++ b/starvector/model/image_encoder/clip_model.py @@ -0,0 +1,191 @@ +# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py + +from collections import OrderedDict +from itertools import repeat +import collections.abc +import math +import torch +import torch.nn.functional as F +from torch import nn +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): + """Convert applicable model parameters to the specified precision""" + + def _convert_weights_to_precision(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(precision) + if l.bias is not None: + l.bias.data = l.bias.data.to(precision) + + elif isinstance(l, (nn.MultiheadAttention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(precision) + else: + for _, p in l.named_parameters(): + p.data = p.data.to(precision) + + model.apply(_convert_weights_to_precision) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + layernorm_dtype = self.weight.dtype + ret = super().forward(x.type(layernorm_dtype)) + return ret.type(orig_type) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): + super().__init__() + self.input_resolution = input_resolution + self.num_features = width + self.num_heads = heads + self.num_patches = (input_resolution // patch_size) ** 2 + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + return x diff --git a/starvector/model/image_encoder/image_encoder.py b/starvector/model/image_encoder/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a96dfc001fcf194611906437174f9469b440b1 --- /dev/null +++ b/starvector/model/image_encoder/image_encoder.py @@ -0,0 +1,120 @@ +import os +import torch +import torch.nn as nn +import os +from omegaconf import OmegaConf +from starvector.model.image_encoder.clip_model import convert_weights_to_precision +from starvector.data.util import ImageTrainProcessor + +class ImageEncoder(nn.Module): + def __init__(self, config, **kwargs): + super(ImageEncoder, self).__init__() + + image_size = config.image_size + torch_dtype = kwargs.get('model_precision', config.torch_dtype) + # torch_dtype = torch.float32 + self.image_encoder_type = config.image_encoder_type + if self.image_encoder_type == 'clip': + self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) + convert_weights_to_precision(self, torch_dtype) + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'vqgan': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'convnext': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif 'siglip' in self.image_encoder_type: + if self.image_encoder_type == 'siglip_512': + model_name = "google/siglip-base-patch16-512" + elif self.image_encoder_type == 'siglip_384': + model_name = "google/siglip-large-patch16-384" + elif self.image_encoder_type == 'siglip_256': + model_name = "google/siglip-base-patch16-256" + + from transformers import AutoProcessor, AutoModel + + self.visual_encoder = AutoModel.from_pretrained( + model_name, torch_dtype = torch_dtype + ).vision_model + + self.processor = AutoProcessor.from_pretrained( + model_name, torch_dtype = torch_dtype + ) + + def build_clip_encoder(self, image_size): + from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm + visual_encoder = VisionTransformer( + input_resolution=image_size, + patch_size=14, + width=1024, + layers=23, + heads=16, + use_grad_checkpointing=False) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def build_vqgan_encoder(self): + from taming.modules.diffusionmodules.model import Encoder + VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md + vqgan_chkp_path = VQGAN_CHECKPOINT + files_in_directory = os.listdir(vqgan_chkp_path + '/configs') + vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] + vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) + visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) + + # Load checkpoint weights + checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] + + # Create a new state_dict with modified keys + new_state_dict = {} + for key, value in checkpoint.items(): + if key.startswith('encoder.'): + new_key = key[len('encoder.'):] + new_state_dict[new_key] = value + + # Load weights + visual_encoder.load_state_dict(new_state_dict) + return visual_encoder + + def build_convnext_encoder(self): + import open_clip + model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') + return model.visual + + def forward(self, image): + if self.image_encoder_type == 'clip': + embeds = self.visual_encoder(image) + out = self.ln_vision(embeds) + elif self.image_encoder_type == 'open-clip': + out = self.visual_encoder(image)[1] + out = self.ln_vision(out) + elif self.image_encoder_type == 'vqgan': + out = self.visual_encoder(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif self.image_encoder_type == 'convnext': + out = self.visual_encoder.trunk.forward_features(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif 'siglip' in self.image_encoder_type: + out = self.visual_encoder(image)["last_hidden_state"] + return out + + def process_images(self, images): + if self.image_encoder_type == 'clip': + res = [] + for image in images: + res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W + return res + else: + return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) + \ No newline at end of file diff --git a/starvector/model/llm/starcoder.py b/starvector/model/llm/starcoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a09b3c6d5c1bfd04dc1ea02c6fd7f25602d16e03 --- /dev/null +++ b/starvector/model/llm/starcoder.py @@ -0,0 +1,51 @@ +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + ) + +class StarCoderModel(nn.Module): + def __init__(self, config, **kwargs): + super(StarCoderModel, self).__init__() + + self.init_tokenizer(config.starcoder_model_name) + + self.max_length = config.max_length + model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True) + kwargs = {} + kwargs['trust_remote_code'] = True + kwargs['torch_dtype'] = config.torch_dtype + + # Configure special tokens for generation + model_config.eos_token_id = self.tokenizer.eos_token_id + model_config.pad_token_id = self.tokenizer.pad_token_id + model_config.bos_token_id = self.tokenizer.bos_token_id + try: + model_config.flash_attention = config.use_flash_attn + model_config._attn_implementation = "flash_attention_2" + except ImportError: + config.use_flash_attn = False + + # model = GPTBigCodeForCausalLM(config=model_config) + model = AutoModelForCausalLM.from_pretrained(config.starcoder_model_name, config=model_config, **kwargs) + model.resize_token_embeddings(len(self.tokenizer)) + self.transformer = model + + # Prompt the model after image + self.prompt = '