diff --git a/.gitignore b/.gitignore index 26a04150097f82bd7679e3b44bf41d2f48093f1b..4daec0c7768118aa00c7825720fb88053c8d1637 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +unsloth_compiled_cache/ +*.ipynb star-vector/ SVGDreamer/ *.parquet diff --git a/dl.py b/dl.py new file mode 100644 index 0000000000000000000000000000000000000000..710edf3234d4a22037ff48a50ebfe8f420c12240 --- /dev/null +++ b/dl.py @@ -0,0 +1,203 @@ +import logging +import re +import cairosvg +import torch +from transformers import AutoModelForCausalLM +from lxml import etree +import kagglehub +from gen_image import ImageGenerator +from starvector.data.util import process_and_rasterize_svg + +svg_constraints = kagglehub.package_import('metric/svg-constraints') + +class DLModel: + def __init__(self, model_id="starvector/starvector-8b-im2svg", device="cuda"): + """ + Initialize the SVG generation pipeline using StarVector. + + Args: + model_id (str): The model identifier for the StarVector model. + device (str): The device to run the model on, either "cuda" or "cpu". + """ + self.image_generator = ImageGenerator(model_id="stabilityai/stable-diffusion-2-1-base", device=device) + self.default_svg = """""" + self.constraints = svg_constraints.SVGConstraints() + self.timeout_seconds = 90 + + # Load StarVector model + self.device = device + self.starvector = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + trust_remote_code=True + ) + self.processor = self.starvector.model.processor + self.starvector.to(device) + self.starvector.eval() + + def predict(self, description): + """ + Generate an SVG from a text description. + + Args: + description (str): The text description to generate an image from. + + Returns: + str: The generated SVG content. + """ + try: + # Step 1: Generate image using diffusion model + images = self.image_generator.generate(description) + image = images[0] + + # Save the generated image + image_path = "diff_image.png" + image.save(image_path) + logging.info(f"Intermediate image saved to {image_path}") + + # Step 2: Convert image to SVG using StarVector + processed_image = self.processor(image, return_tensors="pt")['pixel_values'].to(self.device) + if not processed_image.shape[0] == 1: + processed_image = processed_image.squeeze(0) + + batch = {"image": processed_image} + with torch.no_grad(): + raw_svg = self.starvector.generate_im2svg(batch, max_length=4000)[0] + raw_svg, _ = process_and_rasterize_svg(raw_svg) + + if 'viewBox' not in raw_svg: + raw_svg = raw_svg.replace(' str: + """Enforces constraints on an SVG string, removing disallowed elements + and attributes. + + Parameters + ---------- + svg_string : str + The SVG string to process. + + Returns + ------- + str + The processed SVG string, or the default SVG if constraints + cannot be satisfied. + """ + logging.info('Sanitizing SVG...') + + try: + # Remove XML declaration if it exists + svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip() + + parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) + root = etree.fromstring(svg_string, parser=parser) + except etree.ParseError as e: + logging.error('SVG Parse Error: %s. Returning default SVG.', e) + logging.error('SVG string: %s', svg_string) + return self.default_svg + + elements_to_remove = [] + for element in root.iter(): + tag_name = etree.QName(element.tag).localname + + # Remove disallowed elements + if tag_name not in self.constraints.allowed_elements: + elements_to_remove.append(element) + continue # Skip attribute checks for removed elements + + # Remove disallowed attributes + attrs_to_remove = [] + for attr in element.attrib: + attr_name = etree.QName(attr).localname + if ( + attr_name + not in self.constraints.allowed_elements[tag_name] + and attr_name + not in self.constraints.allowed_elements['common'] + ): + attrs_to_remove.append(attr) + + for attr in attrs_to_remove: + logging.debug( + 'Attribute "%s" for element "%s" not allowed. Removing.', + attr, + tag_name, + ) + del element.attrib[attr] + + # Check and remove invalid href attributes + for attr, value in element.attrib.items(): + if etree.QName(attr).localname == 'href' and not value.startswith('#'): + logging.debug( + 'Removing invalid href attribute in element "%s".', tag_name + ) + del element.attrib[attr] + + # Validate path elements to help ensure SVG conversion + if tag_name == 'path': + d_attribute = element.get('d') + if not d_attribute: + logging.warning('Path element is missing "d" attribute. Removing path.') + elements_to_remove.append(element) + continue # Skip further checks for this removed element + # Use regex to validate 'd' attribute format + path_regex = re.compile( + r'^' # Start of string + r'(?:' # Non-capturing group for each command + numbers block + r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) + r'\s*' # Optional whitespace after command + r'(?:' # Non-capturing group for optional numbers + r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number + r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) + r')?' # Numbers are optional (e.g. for Z command) + r'\s*' # Optional whitespace after numbers/command block + r')+' # One or more command blocks + r'\s*' # Optional trailing whitespace + r'$' # End of string + ) + if not path_regex.match(d_attribute): + logging.warning( + 'Path element has malformed "d" attribute format. Removing path.' + ) + elements_to_remove.append(element) + continue + logging.debug('Path element "d" attribute validated (regex check).') + + # Remove elements marked for removal + for element in elements_to_remove: + if element.getparent() is not None: + element.getparent().remove(element) + logging.debug('Removed element: %s', element.tag) + + try: + cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False) + return cleaned_svg_string + except ValueError as e: + logging.error( + 'SVG could not be sanitized to meet constraints: %s', e + ) + return self.default_svg + +# Example usage +if __name__ == "__main__": + model = DLModel() + svg = model.predict("a purple forest at dusk") + # Convert SVG to PNG + try: + # Create a PNG in memory + png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) + + # Save the PNG to a file + with open("output.png", "wb") as f: + f.write(png_data) + print("SVG saved as output.png") + except Exception as e: + print(f"Error converting SVG to PNG: {e}") \ No newline at end of file diff --git a/gen_image.py b/gen_image.py index 7b229a87dce891d163c7a0dd6f2898056b9636ff..f82f93acecef8e0630fa201e1e6d69365ec04a2c 100644 --- a/gen_image.py +++ b/gen_image.py @@ -35,7 +35,7 @@ class ImageGenerator: num_images (int, optional): Number of images to generate. Returns: - PIL.Image.Image: The generated image. + list[PIL.Image.Image]: The generated images. """ prompt = f"{prompt}, {self.positive_prompt}" if negative_prompt is None: @@ -51,7 +51,7 @@ class ImageGenerator: for i, image in enumerate(images): image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}") - return image + return images # Example usage if __name__ == "__main__": diff --git a/ml.py b/ml.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a8c49cd84fdd59f9138fefb8e8d88c04f2c9a3 --- /dev/null +++ b/ml.py @@ -0,0 +1,246 @@ +import os +import tempfile +import logging +import re +import subprocess +import cairosvg +from lxml import etree +import kagglehub +from gen_image import ImageGenerator +import vtracer + +svg_constraints = kagglehub.package_import('metric/svg-constraints') + +class MLModel: + def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"): + """ + Initialize the SVG generation pipeline. + + Args: + model_id (str): The model identifier for the stable diffusion model. + device (str): The device to run the model on, either "cuda" or "cpu". + """ + self.image_generator = ImageGenerator(model_id=model_id, device=device) + self.default_svg = """""" + self.constraints = svg_constraints.SVGConstraints() + self.timeout_seconds = 90 + + def predict(self, description, simplify=True, color_precision=6, + gradient_step=10, filter_speckle=4, path_precision=8): + """ + Generate an SVG from a text description. + + Args: + description (str): The text description to generate an image from. + simplify (bool): Whether to simplify the SVG paths. + color_precision (int): Color quantization precision. + gradient_step (int): Gradient step for color quantization (not used by vtracer). + filter_speckle (int): Filter speckle size. + path_precision (int): Path fitting precision. + + Returns: + str: The generated SVG content. + """ + try: + # Step 1: Generate image using diffusion model + images = self.image_generator.generate(description) + image = images[0] + + # Step 2: Save image to a temporary file + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img: + temp_img_path = temp_img.name + image.save(temp_img_path) + + # Step 3: Convert image to SVG using vtracer + with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg: + temp_svg_path = temp_svg.name + + # Process the image with vtracer using parameters directly + vtracer.convert_image_to_svg_py( + temp_img_path, + temp_svg_path, + colormode='color', + hierarchical='stacked' if simplify else 'cutout', + mode='spline', + filter_speckle=filter_speckle, + color_precision=color_precision, + path_precision=path_precision, + corner_threshold=60, + length_threshold=4.0, + max_iterations=10, + splice_threshold=45 + ) + + # Step 4: Read the generated SVG + with open(temp_svg_path, 'r') as f: + svg_content = f.read() + + # Clean up temporary files + os.unlink(temp_img_path) + os.unlink(temp_svg_path) + + # Step 5: Enforce constraints + svg_content = self.enforce_constraints(svg_content) + + return svg_content + except Exception as e: + logging.error(f"Error generating SVG: {e}") + return self.default_svg + + def enforce_constraints(self, svg_string: str) -> str: + """Enforces constraints on an SVG string, removing disallowed elements + and attributes. + + Parameters + ---------- + svg_string : str + The SVG string to process. + + Returns + ------- + str + The processed SVG string, or the default SVG if constraints + cannot be satisfied. + """ + logging.info('Sanitizing SVG...') + + try: + # Remove XML declaration if it exists + svg_string = re.sub(r'<\?xml[^>]+\?>', '', svg_string).strip() + + parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) + root = etree.fromstring(svg_string, parser=parser) + except etree.ParseError as e: + logging.error('SVG Parse Error: %s. Returning default SVG.', e) + logging.error('SVG string: %s', svg_string) + return self.default_svg + + elements_to_remove = [] + for element in root.iter(): + tag_name = etree.QName(element.tag).localname + + # Remove disallowed elements + if tag_name not in self.constraints.allowed_elements: + elements_to_remove.append(element) + continue # Skip attribute checks for removed elements + + # Remove disallowed attributes + attrs_to_remove = [] + for attr in element.attrib: + attr_name = etree.QName(attr).localname + if ( + attr_name + not in self.constraints.allowed_elements[tag_name] + and attr_name + not in self.constraints.allowed_elements['common'] + ): + attrs_to_remove.append(attr) + + for attr in attrs_to_remove: + logging.debug( + 'Attribute "%s" for element "%s" not allowed. Removing.', + attr, + tag_name, + ) + del element.attrib[attr] + + # Check and remove invalid href attributes + for attr, value in element.attrib.items(): + if etree.QName(attr).localname == 'href' and not value.startswith('#'): + logging.debug( + 'Removing invalid href attribute in element "%s".', tag_name + ) + del element.attrib[attr] + + # Validate path elements to help ensure SVG conversion + if tag_name == 'path': + d_attribute = element.get('d') + if not d_attribute: + logging.warning('Path element is missing "d" attribute. Removing path.') + elements_to_remove.append(element) + continue # Skip further checks for this removed element + # Use regex to validate 'd' attribute format + path_regex = re.compile( + r'^' # Start of string + r'(?:' # Non-capturing group for each command + numbers block + r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) + r'\s*' # Optional whitespace after command + r'(?:' # Non-capturing group for optional numbers + r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number + r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) + r')?' # Numbers are optional (e.g. for Z command) + r'\s*' # Optional whitespace after numbers/command block + r')+' # One or more command blocks + r'\s*' # Optional trailing whitespace + r'$' # End of string + ) + if not path_regex.match(d_attribute): + logging.warning( + 'Path element has malformed "d" attribute format. Removing path.' + ) + elements_to_remove.append(element) + continue + logging.debug('Path element "d" attribute validated (regex check).') + + # Remove elements marked for removal + for element in elements_to_remove: + if element.getparent() is not None: + element.getparent().remove(element) + logging.debug('Removed element: %s', element.tag) + + try: + cleaned_svg_string = etree.tostring(root, encoding='unicode', xml_declaration=False) + return cleaned_svg_string + except ValueError as e: + logging.error( + 'SVG could not be sanitized to meet constraints: %s', e + ) + return self.default_svg + + def optimize_svg(self, svg_content): + """ + Optimize the SVG content using SVGO. + + Args: + svg_content (str): The SVG content to optimize. + + Returns: + str: The optimized SVG content. + """ + try: + with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_svg: + temp_svg_path = temp_svg.name + temp_svg.write(svg_content.encode('utf-8')) + + with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as temp_out: + temp_out_path = temp_out.name + + subprocess.run(["svgo", temp_svg_path, "-o", temp_out_path], check=True) + + with open(temp_out_path, 'r') as f: + optimized_svg = f.read() + + os.unlink(temp_svg_path) + os.unlink(temp_out_path) + + return optimized_svg + except (FileNotFoundError, subprocess.CalledProcessError): + print("Warning: SVGO not found or failed. Returning unoptimized SVG.") + return svg_content + + +# Example usage +if __name__ == "__main__": + model = MLModel() + svg = model.predict("a purple forest at dusk") + # Convert SVG to PNG + try: + # Create a PNG in memory + png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) + + # Save the PNG to a file + with open("output.png", "wb") as f: + f.write(png_data) + print("SVG saved as output.png") + except Exception as e: + print(f"Error converting SVG to PNG: {e}") \ No newline at end of file diff --git a/naive.py b/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..c796f8817708d34cfc3ddcaa0f50ee3b82d13690 --- /dev/null +++ b/naive.py @@ -0,0 +1,246 @@ +import concurrent +import io +import logging +import re + +import cairosvg +import kagglehub +import torch +from lxml import etree +from unsloth import FastLanguageModel +from unsloth.chat_templates import get_chat_template + +svg_constraints = kagglehub.package_import('metric/svg-constraints') + +class NaiveModel: + def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"): + self.device = device + self.max_seq_length = max_seq_length + self.load_in_4bit = True + + # Load the Unsloth Phi-4 model + self.model, self.tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=self.max_seq_length, + load_in_4bit=self.load_in_4bit + ) + + # Set up chat template + self.tokenizer = get_chat_template( + self.tokenizer, + chat_template="phi-4", + ) + + # Prepare model for inference + FastLanguageModel.for_inference(self.model) + + self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints. + +* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs` +* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity` + + +Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis. + +"A red circle with a blue square inside" +```svg + + + + +``` + +"{}" +""" + self.default_svg = """""" + self.constraints = svg_constraints.SVGConstraints() + self.timeout_seconds = 90 + + def predict(self, description: str, max_new_tokens=512) -> str: + def generate_svg(): + try: + # Format the prompt + prompt = self.prompt_template.format(description) + + # Create messages in the format expected by the chat template + messages = [ + {"role": "user", "content": prompt}, + ] + + # Tokenize the messages + inputs = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + ).to(self.device) + + # Generate the output + outputs = self.model.generate( + input_ids=inputs, + max_new_tokens=max_new_tokens, + use_cache=True, + temperature=1.0, + min_p=0.1, + do_sample=True, + ) + + # Decode the output + output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Extract only the generated text (skip the prompt) + generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else "" + + logging.debug('Output decoded from model: %s', output_decoded) + + matches = re.findall(r"", output_decoded, re.DOTALL | re.IGNORECASE) + if matches: + svg = matches[-1] + else: + return self.default_svg + + logging.debug('Unprocessed SVG: %s', svg) + svg = self.enforce_constraints(svg) + logging.debug('Processed SVG: %s', svg) + + # Ensure the generated code can be converted by cairosvg + cairosvg.svg2png(bytestring=svg.encode('utf-8')) + return svg + except Exception as e: + logging.error('Exception during SVG generation: %s', e) + return self.default_svg + + # Execute SVG generation in a new thread to enforce time constraints + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(generate_svg) + try: + return future.result(timeout=self.timeout_seconds) + except concurrent.futures.TimeoutError: + logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds) + return self.default_svg + except Exception as e: + logging.error(f"An unexpected error occurred: {e}") + return self.default_svg + + def enforce_constraints(self, svg_string: str) -> str: + """Enforces constraints on an SVG string, removing disallowed elements + and attributes. + + Parameters + ---------- + svg_string : str + The SVG string to process. + + Returns + ------- + str + The processed SVG string, or the default SVG if constraints + cannot be satisfied. + """ + logging.info('Sanitizing SVG...') + + try: + parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) + root = etree.fromstring(svg_string, parser=parser) + except etree.ParseError as e: + logging.error('SVG Parse Error: %s. Returning default SVG.', e) + logging.error('SVG string: %s', svg_string) + return self.default_svg + + elements_to_remove = [] + for element in root.iter(): + tag_name = etree.QName(element.tag).localname + + # Remove disallowed elements + if tag_name not in self.constraints.allowed_elements: + elements_to_remove.append(element) + continue # Skip attribute checks for removed elements + + # Remove disallowed attributes + attrs_to_remove = [] + for attr in element.attrib: + attr_name = etree.QName(attr).localname + if ( + attr_name + not in self.constraints.allowed_elements[tag_name] + and attr_name + not in self.constraints.allowed_elements['common'] + ): + attrs_to_remove.append(attr) + + for attr in attrs_to_remove: + logging.debug( + 'Attribute "%s" for element "%s" not allowed. Removing.', + attr, + tag_name, + ) + del element.attrib[attr] + + # Check and remove invalid href attributes + for attr, value in element.attrib.items(): + if etree.QName(attr).localname == 'href' and not value.startswith('#'): + logging.debug( + 'Removing invalid href attribute in element "%s".', tag_name + ) + del element.attrib[attr] + + # Validate path elements to help ensure SVG conversion + if tag_name == 'path': + d_attribute = element.get('d') + if not d_attribute: + logging.warning('Path element is missing "d" attribute. Removing path.') + elements_to_remove.append(element) + continue # Skip further checks for this removed element + # Use regex to validate 'd' attribute format + path_regex = re.compile( + r'^' # Start of string + r'(?:' # Non-capturing group for each command + numbers block + r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) + r'\s*' # Optional whitespace after command + r'(?:' # Non-capturing group for optional numbers + r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number + r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) + r')?' # Numbers are optional (e.g. for Z command) + r'\s*' # Optional whitespace after numbers/command block + r')+' # One or more command blocks + r'\s*' # Optional trailing whitespace + r'$' # End of string + ) + if not path_regex.match(d_attribute): + logging.warning( + 'Path element has malformed "d" attribute format. Removing path.' + ) + elements_to_remove.append(element) + continue + logging.debug('Path element "d" attribute validated (regex check).') + + # Remove elements marked for removal + for element in elements_to_remove: + if element.getparent() is not None: + element.getparent().remove(element) + logging.debug('Removed element: %s', element.tag) + + try: + cleaned_svg_string = etree.tostring(root, encoding='unicode') + return cleaned_svg_string + except ValueError as e: + logging.error( + 'SVG could not be sanitized to meet constraints: %s', e + ) + return self.default_svg + + +if __name__ == "__main__": + model = NaiveModel() + svg = model.predict("a purple forest at dusk") + # Convert SVG to PNG + try: + # Create a PNG in memory + png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) + + # Save the PNG to a file + with open("output.png", "wb") as f: + f.write(png_data) + print("SVG saved as output.png") + except Exception as e: + print(f"Error converting SVG to PNG: {e}") diff --git a/requirements.txt b/requirements.txt index e78cc67e8bb612da2b172090faaec78b5a629bd0..e26d6df9c3e216a64f43b3d1f999200783c96faf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,12 @@ dotenv diffusers safetensors xformers +unsloth +tf-keras +vtracer +deepspeed +torch==2.5.1 +torchvision==0.20.1 # pip install 'tensorflow[and-cuda]' # pip install git+https://github.com/openai/CLIP.git diff --git a/starter.ipynb b/starter.ipynb deleted file mode 100644 index f227308e22d2681614faf4cece853e1d0ee3a7c4..0000000000000000000000000000000000000000 --- a/starter.ipynb +++ /dev/null @@ -1,333 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/user/miniconda3/envs/dwl/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "shape: (5, 2)iddescriptionstrstr"02d892""a purple forest at dusk""0dcd2e""gray wool coat with a faux fur…"1e9ac1""a lighthouse overlooking the o…"2b25db""burgundy corduroy pants with p…"4e6a54""orange corduroy overalls"" - ], - "text/plain": [ - "shape: (5, 2)\n", - "┌────────┬─────────────────────────────────┐\n", - "│ id ┆ description │\n", - "│ --- ┆ --- │\n", - "│ str ┆ str │\n", - "╞════════╪═════════════════════════════════╡\n", - "│ 02d892 ┆ a purple forest at dusk │\n", - "│ 0dcd2e ┆ gray wool coat with a faux fur… │\n", - "│ 1e9ac1 ┆ a lighthouse overlooking the o… │\n", - "│ 2b25db ┆ burgundy corduroy pants with p… │\n", - "│ 4e6a54 ┆ orange corduroy overalls │\n", - "└────────┴─────────────────────────────────┘" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# We can load and explore the competition's train set to get a feel for the data.\n", - "# We're not going to export this cell as it's not needed for our exported inferenceable model.\n", - "\n", - "import kagglehub\n", - "import polars as pl\n", - "\n", - "train_path = kagglehub.competition_download('drawing-with-llms', 'train.csv')\n", - "train = pl.read_csv(train_path)\n", - "\n", - "train.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "class Model:\n", - " def __init__(self):\n", - " '''Optional constructor, performs any setup logic, model instantiation, etc.'''\n", - " pass\n", - " \n", - " def predict(self, prompt: str) -> str:\n", - " '''Generates SVG which produces an image described by the prompt.\n", - "\n", - " Args:\n", - " prompt (str): A prompt describing an image\n", - " Returns:\n", - " String of valid SVG code.\n", - " '''\n", - " # Renders a simple circle regardless of input\n", - " return ''" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "image/svg+xml": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import SVG\n", - "\n", - "model = Model()\n", - "svg = model.predict('a goose winning a gold medal')\n", - "\n", - "print(svg)\n", - "display(SVG(svg))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['RN50',\n", - " 'RN101',\n", - " 'RN50x4',\n", - " 'RN50x16',\n", - " 'RN50x64',\n", - " 'ViT-B/32',\n", - " 'ViT-B/16',\n", - " 'ViT-L/14',\n", - " 'ViT-L/14@336px']" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import clip\n", - "clip.available_models()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-04-20 13:55:34.589770: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "E0000 00:00:1745171734.600777 13214 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "E0000 00:00:1745171734.603957 13214 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "W0000 00:00:1745171734.615566 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1745171734.615584 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1745171734.615585 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1745171734.615586 13214 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "2025-04-20 13:55:34.618659: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n", - "Loading checkpoint shards: 100%|██████████| 4/4 [00:18<00:00, 4.68s/it]\n" - ] - } - ], - "source": [ - "import pandas as pd\n", - "import importlib\n", - "metric = importlib.import_module('metric')\n", - "importlib.reload(metric)\n", - "\n", - "vqa_evaluator = metric.VQAEvaluator()\n", - "aesthetic_evaluator = metric.AestheticEvaluator()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "VQA Score: 0.9996758976500401\n", - "Aesthetic Score: 0.5749330520629883\n", - "Final Fidelity Score: 0.8709845773271212\n" - ] - } - ], - "source": [ - "# score gpt4o generated images\n", - "import ast\n", - "import numpy as np\n", - "from PIL import Image\n", - "\n", - "# Load the first sample from descriptions.csv\n", - "descriptions_df = pd.read_csv('data/descriptions.csv')\n", - "first_description = descriptions_df.iloc[1]\n", - "\n", - "eval_df = pd.read_csv('data/eval.csv')\n", - "first_eval = eval_df.iloc[1]\n", - "\n", - "# Load the image\n", - "image_path = 'data/gray_coat.png' # Assuming the image is saved with this name\n", - "image = Image.open(image_path)\n", - "\n", - "# Prepare the inputs for scoring - need to parse the string representations\n", - "questions = ast.literal_eval(first_eval['question'])\n", - "choices = ast.literal_eval(first_eval['choices'])\n", - "answers = ast.literal_eval(first_eval['answer'])\n", - "\n", - "# Calculate VQA score - don't wrap in additional lists\n", - "vqa_score = vqa_evaluator.score(questions, choices, answers, image)\n", - "\n", - "# Calculate aesthetic score\n", - "aesthetic_score = aesthetic_evaluator.score(image)\n", - "\n", - "# Apply image processing as done in the metric.score function\n", - "image_processor = metric.ImageProcessor(image=image, seed=0).apply()\n", - "processed_image = image_processor.image.copy()\n", - "\n", - "# Calculate final fidelity score\n", - "instance_score = metric.harmonic_mean(vqa_score, aesthetic_score, beta=0.5)\n", - "\n", - "print(f\"VQA Score: {vqa_score}\")\n", - "print(f\"Aesthetic Score: {aesthetic_score}\")\n", - "print(f\"Final Fidelity Score: {instance_score}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No duplicate IDs found in data/descriptions.csv\n", - "Sorted rows by ID\n", - "Fixed and sorted CSV saved to data/descriptions.csv\n", - "No duplicate IDs found in data/eval.csv\n", - "Sorted data/eval.csv by ID\n" - ] - } - ], - "source": [ - "# Fix duplicate IDs in descriptions.csv and order rows by id\n", - "def fix_duplicate_ids(csv_path):\n", - " \"\"\"\n", - " Fix duplicate IDs in a CSV file by assigning new unique IDs to duplicates.\n", - " Then order rows by ID.\n", - " \"\"\"\n", - " # Read the CSV file\n", - " df = pd.read_csv(csv_path)\n", - " \n", - " # Check for duplicate IDs\n", - " duplicate_mask = df['id'].duplicated(keep='first')\n", - " duplicate_count = duplicate_mask.sum()\n", - " \n", - " if duplicate_count > 0:\n", - " print(f\"Found {duplicate_count} duplicate IDs in {csv_path}\")\n", - " \n", - " # Get the maximum ID value\n", - " max_id = df['id'].max()\n", - " \n", - " # Assign new IDs to duplicates\n", - " new_ids = list(range(max_id + 1, max_id + 1 + duplicate_count))\n", - " df.loc[duplicate_mask, 'id'] = new_ids\n", - " \n", - " print(f\"Assigned new IDs to duplicates\")\n", - " else:\n", - " print(f\"No duplicate IDs found in {csv_path}\")\n", - " \n", - " # Sort the dataframe by ID\n", - " df = df.sort_values(by='id')\n", - " print(f\"Sorted rows by ID\")\n", - " \n", - " # Save the fixed and sorted CSV\n", - " df.to_csv(csv_path, index=False)\n", - " print(f\"Fixed and sorted CSV saved to {csv_path}\")\n", - " \n", - " # Return the fixed dataframe\n", - " return df\n", - "\n", - "# Fix descriptions.csv\n", - "fixed_descriptions_df = fix_duplicate_ids('data/descriptions.csv')\n", - "\n", - "# Fix eval.csv if needed\n", - "# First check if eval.csv has the same issue\n", - "eval_df = pd.read_csv('data/eval.csv')\n", - "duplicate_eval_ids = eval_df['id'].duplicated(keep='first').sum()\n", - "\n", - "if duplicate_eval_ids > 0:\n", - " fixed_eval_df = fix_duplicate_ids('data/eval.csv')\n", - "else:\n", - " print(\"No duplicate IDs found in data/eval.csv\")\n", - " # Still sort by ID even if no duplicates\n", - " eval_df = eval_df.sort_values(by='id')\n", - " eval_df.to_csv('data/eval.csv', index=False)\n", - " print(\"Sorted data/eval.csv by ID\")\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "dwl", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.11" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/starvector/__init__.py b/starvector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/starvector/adapter.py b/starvector/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5a0fe2de0a98472f67576a0fa32c47c68dedff --- /dev/null +++ b/starvector/adapter.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch.nn.init as init +import torch + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + +class Adapter(nn.Module): + def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): + super().__init__() + self.query_length = query_length + self.dropout_prob = dropout_prob + self.adapter_norm = adapter_norm + + self.dropout = nn.Dropout(p=self.dropout_prob) + + self.c_fc = nn.Linear(input_size, input_size*2) + self.act = Swish() + self.c_proj = nn.Linear(input_size*2, output_size) + + if adapter_norm == "layer_norm": + self.norm = nn.LayerNorm([self.query_length, output_size]) + elif adapter_norm == "batch_norm": + self.norm = nn.BatchNorm1d(self.query_length) + + self.init_type = init_type.lower() + self._initialize_weights() + + def forward(self, hidden_states): + hidden_states = self.dropout(hidden_states) + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.norm(hidden_states) + return hidden_states + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + if self.init_type == "glorot": + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif self.init_type == "normal": + init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + else: + raise ValueError("Invalid initialization type specified.") diff --git a/starvector/clip_model.py b/starvector/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb2349dc3a5521b0a59d896025f6a7251374897 --- /dev/null +++ b/starvector/clip_model.py @@ -0,0 +1,191 @@ +# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py + +from collections import OrderedDict +from itertools import repeat +import collections.abc +import math +import torch +import torch.nn.functional as F +from torch import nn +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): + """Convert applicable model parameters to the specified precision""" + + def _convert_weights_to_precision(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(precision) + if l.bias is not None: + l.bias.data = l.bias.data.to(precision) + + elif isinstance(l, (nn.MultiheadAttention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(precision) + else: + for _, p in l.named_parameters(): + p.data = p.data.to(precision) + + model.apply(_convert_weights_to_precision) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + layernorm_dtype = self.weight.dtype + ret = super().forward(x.type(layernorm_dtype)) + return ret.type(orig_type) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): + super().__init__() + self.input_resolution = input_resolution + self.num_features = width + self.num_heads = heads + self.num_patches = (input_resolution // patch_size) ** 2 + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + return x diff --git a/starvector/data/augmentation.py b/starvector/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..e138e2fdbe257a1741a0c397c4c0339323951afb --- /dev/null +++ b/starvector/data/augmentation.py @@ -0,0 +1,250 @@ + +import numpy as np +from svgpathtools import ( + Path, Arc, CubicBezier, QuadraticBezier, + svgstr2paths) +import os +from noise import pnoise1 +import re +import matplotlib.colors as mcolors +from bs4 import BeautifulSoup +from starvector.data.util import rasterize_svg + +class SVGTransforms: + def __init__(self, transformations): + self.transformations = transformations + self.noise_std = self.transformations.get('noise_std', False) + self.noise_type = self.transformations.get('noise_type', False) + self.rotate = self.transformations.get('rotate', False) + self.shift_re = self.transformations.get('shift_re', False) + self.shift_im = self.transformations.get('shift_im', False) + self.scale = self.transformations.get('scale', False) + self.color_noise = self.transformations.get('color_noise', False) + self.p = self.transformations.get('p', 0.5) + self.color_change = self.transformations.get('color_change', False) + self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000']) + + def sample_transformations(self): + if self.rotate: + a, b = self.rotate['from'], self.rotate['to'] + rotation_angle = np.random.uniform(a, b) + self.rotation_angle = rotation_angle + + if self.shift_re or self.shift_im: + self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to']) + self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to']) + + if self.scale: + self.scale = np.random.uniform(self.scale['from'], self.scale['to']) + + if self.color_noise: + self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to']) + + + def paths2str(self, groupped_paths, svg_opening_tag=''): + + keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry'] + all_groups_srt = '' + for group, elements in groupped_paths.items(): + group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', []) + group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items()) + path_strings = [] + path_str = '' + for path, attributes in paths_and_attributes: + path_attr_str = '' + d_str = path.d() + + for key, value in attributes.items(): + if key not in keys_to_exclude: + path_attr_str += f' {key}="{value}"' + + path_strings.append(f'') + path_str = "\n".join(path_strings) + if 'no_group'in group: + group_str = path_str + else: + group_str = f'\n{path_str}\n\n' + all_groups_srt += group_str + svg = f'{svg_opening_tag}\n{all_groups_srt}' + return svg + + def add_noise(self, seg): + noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to']) + if self.noise_type == 'gaussian': + noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \ + 1j * np.random.normal(loc=0.0, scale=noise_scale) + elif self.noise_type == 'perlin': + noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale + + if isinstance(seg, CubicBezier): + seg.control1 = seg.control1 + noise_sample + seg.control2 = seg.control2 + noise_sample + elif isinstance(seg, QuadraticBezier): + seg.control = seg.control + noise_sample + elif isinstance(seg, Arc): + seg.radius = seg.radius + noise_sample + + + return seg + + def do_rotate(self, path, viewbox_width, viewbox_height): + if self.rotate: + new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2)) + return new_path + else: + return path + + def do_shift(self, path): + if self.shift_re or self.shift_im: + return path.translated(complex(self.shift_real, self.shift_imag)) + else: + return path + + def do_scale(self, path): + if self.scale: + return path.scaled(self.scale) + else: + return path + + def add_color_noise(self, source_color): + # Convert color to RGB + if source_color.startswith("#"): + base_color = mcolors.hex2color(source_color) + else: + base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF')) + + # Add noise to each RGB component + noise = np.random.normal(0, self.color_noise_std, 3) + noisy_color = np.clip(np.array(base_color) + noise, 0, 1) + + # Convert the RGB color back to hex + hex_color = mcolors.rgb2hex(noisy_color) + + return hex_color + + def do_color_change(self, attr): + if 'fill' in attr: + if self.color_noise or self.color_change: + fill_value = attr['fill'] + if fill_value == 'none': + new_fill_value = 'none' + else: + if self.color_noise: + new_fill_value = self.add_color_noise(fill_value) + elif self.color_change: + new_fill_value = np.random.choice(self.colors) + attr['fill'] = new_fill_value + return attr + + def clean_attributes(self, attr): + attr_out = {} + if 'fill' in attr: + attr_out = attr + elif 'style' in attr: + fill_values = re.findall('fill:[^;]+', attr['style']) + if fill_values: + fill_value = fill_values[0].replace('fill:', '').strip() + attr_out['fill'] = fill_value + else: + attr_out = attr + else: + attr_out = attr + + return attr_out + + def get_viewbox_size(self, svg): + # Try to extract viewBox attribute + match = re.search(r'viewBox="([^"]+)"', svg) + if match: + viewbox = match.group(1) + else: + # If viewBox is not found, try to extract width and height attributes + match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg) + if match: + width, height = match.groups() + viewbox = f"0 0 {width} {height}" + else: + viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found + + viewbox = [float(x) for x in viewbox.split()] + viewbox_width, viewbox_height = viewbox[2], viewbox[3] + return viewbox_width, viewbox_height + + def augment(self, svg): + if os.path.isfile(svg): + # open svg file + with open(svg, 'r') as f: + svg = f.read() + + # Sample transformations for this sample + self.sample_transformations() + + + # Parse the SVG content + soup = BeautifulSoup(svg, 'xml') + + # Get opening tag + svg_opening_tag = re.findall(']+>', svg)[0] + + viewbox_width, viewbox_height = self.get_viewbox_size(svg) + + # Get all svg parents + groups = soup.findAll() + + # Create the groups of paths based on their original tag + grouped_paths = {} + for i, g in enumerate(groups): + if g.name == 'g': + group_id = group_id = g.get('id') if g.get('id') else f'none_{i}' + group_attrs = g.attrs + + elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs': + continue + + else: + group_id = f'no_group_{i}' + group_attrs = {} + + group_svg_string = f'{svg_opening_tag}{str(g)}' + try: + paths, attributes = svgstr2paths(group_svg_string) + except: + return svg, rasterize_svg(svg) + if not paths: + continue + + paths_and_attributes = [] + + # Rotation, shift, scale, noise addition + new_paths = [] + new_attributes = [] + for path, attribute in zip(paths, attributes): + attr = self.clean_attributes(attribute) + + new_path = self.do_rotate(path, viewbox_width, viewbox_height) + new_path = self.do_shift(new_path) + new_path = self.do_scale(new_path) + + if self.noise_std: + # Add noise to path to deform svg + noisy_path = [] + for seg in new_path: + noisy_seg = self.add_noise(seg) + noisy_path.append(noisy_seg) + new_paths.append(Path(*noisy_path)) + else: + new_paths.append(new_path) + + # Color change + attr = self.do_color_change(attr) + paths_and_attributes.append((new_path, attr)) + + grouped_paths[group_id] = { + 'paths': paths_and_attributes, + 'attrs': group_attrs + } + + svg = self.paths2str(grouped_paths, svg_opening_tag) + image = rasterize_svg(svg) + + return svg, image diff --git a/starvector/data/base.py b/starvector/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..33fee34512badbba8aef048a20651c8418e5a0c1 --- /dev/null +++ b/starvector/data/base.py @@ -0,0 +1,71 @@ +from torch.utils.data import Dataset +from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg +from starvector.util import instantiate_from_config +import numpy as np +from datasets import load_dataset + +class SVGDatasetBase(Dataset): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + self.split = split + self.im_size = im_size + + transforms = kwargs.get('transforms', False) + if transforms: + self.transforms = instantiate_from_config(transforms) + self.p = self.transforms.p + else: + self.transforms = None + self.p = 0.0 + + normalization = kwargs.get('normalize', False) + if normalization: + mean = tuple(normalization.get('mean', None)) + std = tuple(normalization.get('std', None)) + else: + mean = None + std = None + + self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std) + self.data = load_dataset(dataset_name, split=split) + + print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split") + + def __len__(self): + return len(self.data_json) + + def get_svg_and_image(self, svg_str, sample_id): + do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p]) + svg, image = None, None + + # Try to augment the image if conditions are met + if self.transforms is not None and do_augment: + try: + svg, image = self.transforms.augment(svg_str) + except Exception as e: + print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG") + + # If augmentation failed or wasn't attempted, try to rasterize the SVG + if svg is None or image is None: + try: + svg, image = svg_str, rasterize_svg(svg_str, self.im_size) + except Exception as e: + print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image") + svg = use_placeholder() + image = rasterize_svg(svg, self.im_size) + + # If the image is completely white, use a placeholder image + if np.array(image).mean() == 255.0: + print(f"Image is full white, using placeholder image for {sample_id}") + svg = use_placeholder() + image = rasterize_svg(svg) + + # Process the image + if 'siglip' in self.image_processor: + image = self.processor(image).pixel_values[0] + else: + image = self.processor(image) + + return svg, image + + def __getitem__(self, idx): + raise NotImplementedError("This method should be implemented by subclasses") diff --git a/starvector/data/dataset.py b/starvector/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f58c5c1d937856f58958321cee26d09439ef68e7 --- /dev/null +++ b/starvector/data/dataset.py @@ -0,0 +1,42 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.augmentation import SVGTransforms +from starvector.data.util import ImageTrainProcessor +from transformers import AutoProcessor + +class SVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): + super().__init__(dataset_name, split, im_size, num_samples, **kwargs) + + self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) + select_dataset_name = kwargs.get('select_dataset_name', False) + + if select_dataset_name: + self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } \ No newline at end of file diff --git a/starvector/data/emojisvg.py b/starvector/data/emojisvg.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf867d1ed188765836f8ebac1d7df714f041834 --- /dev/null +++ b/starvector/data/emojisvg.py @@ -0,0 +1,27 @@ +import os +from starvector.data.base import SVGDatasetBase + + +class EmojiSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=None, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } \ No newline at end of file diff --git a/starvector/data/figrsvg.py b/starvector/data/figrsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..32280a4d0383a06863ceb036e0287974f5d8d36a --- /dev/null +++ b/starvector/data/figrsvg.py @@ -0,0 +1,27 @@ +import os +from starvector.data.base import SVGDatasetBase +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +class FigrSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Id'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/fontsvg.py b/starvector/data/fontsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..1c90da9906f9af03b856fb751bdf879e1bad4401 --- /dev/null +++ b/starvector/data/fontsvg.py @@ -0,0 +1,28 @@ +import os +from starvector.data.base import SVGDatasetBase +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +class FontSVGDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/iconsvg.py b/starvector/data/iconsvg.py new file mode 100644 index 0000000000000000000000000000000000000000..45881997dd4f308e75035b5b7db0d4ad8a37ab03 --- /dev/null +++ b/starvector/data/iconsvg.py @@ -0,0 +1,38 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.util import ImageTrainProcessor +from transformers import AutoProcessor + +class SVGIconsDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, **kwargs) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + caption = self.data[idx].get('Caption', "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption + } diff --git a/starvector/data/stacksvg.py b/starvector/data/stacksvg.py new file mode 100644 index 0000000000000000000000000000000000000000..23d3ea76b847a437d84855b0aa9f1721dfd82c42 --- /dev/null +++ b/starvector/data/stacksvg.py @@ -0,0 +1,59 @@ +import os +from starvector.data.base import SVGDatasetBase +from starvector.data.augmentation import SVGTransforms +import random +from transformers import AutoProcessor +from starvector.data.util import ImageTrainProcessor + +text2svg_captions = [ + "Draw an SVG of ", + "Draw an SVG image of ", + "Draw an SVG picture of ", + "Generate an SVG of ", + "Create an SVG of ", + "Design an SVG of ", + "Make an SVG of ", +] + +class SVGStackDataset(SVGDatasetBase): + def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): + super().__init__(dataset_name, split, im_size, num_samples, **kwargs) + self.color_changer = SVGTransforms({'color_change' : True, 'colors' : ['#ff0000', '#0000ff', '#00ff00', '#ffff00', '#000000']}) + + # Text2SVG specific + self.random_caption = kwargs.get('random_caption', True) + select_dataset_name = kwargs.get('select_dataset_name', False) + if select_dataset_name: + self.data = self.data.filter(lambda example: example["model_name"]==select_dataset_name) + + self.num_samples = num_samples + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + self.image_processor = kwargs.get('image_processor', None) + if self.image_processor and 'siglip' in self.image_processor: + model_name = {'siglip_512': 'google/siglip-base-patch16-512', + 'siglip_384': 'google/siglip-large-patch16-384', + 'siglip_256': 'google/siglip-base-patch16-256'}[self.image_processor] + self.processor = AutoProcessor.from_pretrained(model_name).image_processor + else: + self.processor = ImageTrainProcessor(size=self.im_size) + + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + svg, image = self.get_svg_and_image(svg_str, sample_id) + + # Randomly choose between 'caption_blip' and 'caption_llava' + caption_column = random.choice(['caption_blip2', 'caption_llava']) + caption = random.choice(text2svg_captions) + self.data[idx].get(caption_column, "") + return { + 'svg': svg, + 'image': image, + 'id': sample_id, + 'caption': caption, + } diff --git a/starvector/data/util.py b/starvector/data/util.py new file mode 100644 index 0000000000000000000000000000000000000000..48635a6e01e7ab9e22e2c98fd328bcf478e0d7f8 --- /dev/null +++ b/starvector/data/util.py @@ -0,0 +1,389 @@ +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode, pad +import numpy as np +import matplotlib.pyplot as plt +from bs4 import BeautifulSoup +import re +from svgpathtools import svgstr2paths +import numpy as np +from PIL import Image +import cairosvg +from io import BytesIO +import numpy as np +import textwrap +import os +import base64 +import io + + + +CIRCLE_SVG = "" +VOID_SVF = "" + +def load_transforms(): + transforms = { + 'train': None, + 'eval': None + } + return transforms + +class ImageBaseProcessor(): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean=mean, std=std) + +class ImageTrainProcessor(ImageBaseProcessor): + def __init__(self, mean=None, std=None, size=224, **kwargs): + super().__init__(mean, std) + + self.size = size + + self.transform = transforms.Compose([ + transforms.Lambda(lambda img: self._rgba_to_rgb_white(img) if img.mode == "RGBA" else img), + transforms.Lambda(lambda img: self._pad_to_square(img)), + transforms.Resize(self.size, interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + self.normalize + ]) + + def __call__(self, item): + return self.transform(item) + + def _pad_to_square(self, img): + # Calculate padding to make the image square + width, height = img.size + max_dim = max(width, height) + padding = [(max_dim - width) // 2, (max_dim - height) // 2] + padding += [max_dim - width - padding[0], max_dim - height - padding[1]] + return pad(img, padding, fill=255) # Assuming white padding + + def _rgba_to_rgb_white(self, img): + background = Image.new("RGB", img.size, (255, 255, 255)) + background.paste(img, mask=img.split()[3]) + return background + + +def encode_image_base64(pil_image): + if pil_image.mode == 'RGBA': + pil_image = pil_image.convert('RGB') # Convert RGBA to RGB + buffered = io.BytesIO() + pil_image.save(buffered, format="JPEG") + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + return base64_image + +# -------------- Generation utils -------------- +def is_valid_svg(svg_text): + try: + svgstr2paths(svg_text) + return True + except Exception as e: + print(f"Invalid SVG: {str(e)}") + return False + +def clean_svg(svg_text, output_width=None, output_height=None): + soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml + svg_bs4 = soup.prettify() # Prettify to get a string + + # Store the original signal handler + import signal + original_handler = signal.getsignal(signal.SIGALRM) + + try: + # Set a timeout to prevent hanging + def timeout_handler(signum, frame): + raise TimeoutError("SVG processing timed out") + + # Set timeout + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(5) + + # Try direct conversion without BeautifulSoup + svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode() + + except TimeoutError: + print("SVG conversion timed out, using fallback method") + svg_cairo = """""" + finally: + # Always cancel the alarm and restore original handler, regardless of success or failure + signal.alarm(0) + signal.signal(signal.SIGALRM, original_handler) + + svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith("]*\/>" + all_tags = re.findall(all_tags_pattern, svg_content) + self_closing_matches = re.findall(self_closing_pattern, svg_content) + self_closing_tags = [] + + for match in self_closing_matches: + tag = re.search(all_tags_pattern, match) + if tag: + self_closing_tags.append(tag.group(1)) + unclosed_tags = [] + + for tag in all_tags: + if all_tags.count(tag) > self_closing_tags.count(tag) + svg_content.count('' + tag + '>'): + unclosed_tags.append(tag) + unclosed_tags = list(dict.fromkeys(unclosed_tags)) + + return unclosed_tags + + +# -------------- Plotting utils -------------- +def plot_images_side_by_side_with_metrics(image1, image2, l2_dist, CD, post_processed, out_path): + array1 = np.array(image1).astype(np.float32) + array2 = np.array(image2).astype(np.float32) + diff = np.abs(array1 - array2).astype(np.uint8) + + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + axes[0].imshow(image1) + axes[0].set_title('generated_svg') + axes[0].axis('off') + axes[1].imshow(image2) + axes[1].set_title('gt') + axes[1].axis('off') + axes[2].imshow(diff) + axes[2].set_title('Difference') + axes[2].axis('off') + plt.suptitle(f"MSE: {l2_dist:.4f}, CD: {CD:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_side_by_side(image1, image2, out_path): + array1 = np.array(image1).astype(np.float32) + array2 = np.array(image2).astype(np.float32) + diff = np.abs(array1 - array2).astype(np.uint8) + + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + axes[0].imshow(image1) + axes[0].set_title('generated_svg') + axes[0].axis('off') + axes[1].imshow(image2) + axes[1].set_title('gt') + axes[1].axis('off') + axes[2].imshow(diff) + axes[2].set_title('Difference') + axes[2].axis('off') + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_side_by_side_temperatures(samples_temp, metrics, sample_dir, outpath_filename): + # Create a plot with the original image and different temperature results + num_temps = len(samples_temp) + fig, axes = plt.subplots(2, num_temps + 1, figsize=(15, 4), gridspec_kw={'height_ratios': [10, 2]}) + + # Plot the original image + gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') + gt_image = Image.open(gt_image_path) + axes[0, 0].imshow(gt_image) + axes[0, 0].set_title('Original') + axes[0, 0].axis('off') + axes[1, 0].text(0.5, 0.5, 'Original', horizontalalignment='center', verticalalignment='center', fontsize=16) + axes[1, 0].axis('off') + + # Plot the generated images for different temperatures and metrics + for idx, (temp, sample) in enumerate(samples_temp.items()): + gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') + gen_image = Image.open(gen_image_path) + axes[0, idx + 1].imshow(gen_image) + axes[0, idx + 1].set_title(f'Temp {temp}') + axes[0, idx + 1].axis('off') + axes[1, idx + 1].text(0.5, 0.5, f'MSE: {metrics[temp]["mse"]:.2f}\nCD: {metrics[temp]["cd"]:.2f}', + horizontalalignment='center', verticalalignment='center', fontsize=12) + axes[1, idx + 1].axis('off') + + # Save the comparison plot + comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') + plt.tight_layout() + plt.savefig(comparison_path) + plt.close() + +def plot_images_and_prompt(prompt, svg_raster, gt_svg_raster, out_path): + # First col shows caption, second col shows generated svg, third col shows gt svg + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + axes[1].imshow(svg_raster) + axes[1].set_title('generated_svg') + axes[1].axis('off') + axes[2].imshow(gt_svg_raster) + axes[2].set_title('gt') + axes[2].axis('off') + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_and_prompt_with_metrics(prompt, svg_raster, gt_svg_raster, clip_score, post_processed, out_path): + # First col shows caption, second col shows generated svg, third col shows gt svg + fig, axes = plt.subplots(1, 3, figsize=(10, 5)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + axes[1].imshow(svg_raster) + axes[1].set_title('generated_svg') + axes[1].axis('off') + axes[2].imshow(gt_svg_raster) + axes[2].set_title('gt') + axes[2].axis('off') + plt.suptitle(f"CLIP Score: {clip_score:.4f}, post-processed: {str(post_processed)}", fontsize=16, y=1.05) + plt.savefig(out_path, bbox_inches='tight', pad_inches=0.1) + image = Image.open(out_path) + plt.close(fig) + return image + +def plot_images_and_prompt_temperatures(prompt, samples_temp, metrics, sample_dir, outpath_filename): + # Calculate the number of temperature variations + num_temps = len(samples_temp) + + # Create a plot with text, the original image, and different temperature results + fig, axes = plt.subplots(1, num_temps + 2, figsize=(5 + 3 * (num_temps + 1), 6)) + + # Split the prompt into multiple lines if it exceeds a certain length + prompt_lines = textwrap.wrap(prompt, width=30) + prompt_text = '\n'.join(prompt_lines) + + # Display the prompt in the first cell + axes[0].text(0, 0.5, prompt_text, fontsize=12, ha='left', wrap=True) + axes[0].axis('off') + + # Plot the GT (ground truth) image in the second cell + gt_image_path = os.path.join(sample_dir, f'temp_{list(samples_temp.keys())[0]}', f'{outpath_filename}_or.png') + gt_image = Image.open(gt_image_path) + axes[1].imshow(gt_image) + axes[1].set_title('GT Image') + axes[1].axis('off') + + # Plot the generated images for different temperatures and display metrics + for idx, (temp, sample) in enumerate(samples_temp.items()): + gen_image_path = os.path.join(sample_dir, f'temp_{temp}', f'{outpath_filename}.png') + gen_image = Image.open(gen_image_path) + axes[idx + 2].imshow(gen_image) + axes[idx + 2].set_title(f'Temp {temp}') + axes[idx + 2].axis('off') + clip_score = metrics[temp]["clip_score"] + axes[idx + 2].text(0.5, -0.1, f'CLIP: {clip_score:.4f}', horizontalalignment='center', verticalalignment='center', fontsize=12, transform=axes[idx + 2].transAxes) + + # Save the comparison plot + comparison_path = os.path.join(sample_dir, f'{outpath_filename}_comparison.png') + plt.tight_layout() + plt.savefig(comparison_path) + plt.close() + + return comparison_path + + +def plot_image_tensor(image): + import numpy as np + from PIL import Image + tensor = image[0].cpu().float() + tensor = tensor.permute(1, 2, 0) + array = (tensor.numpy() * 255).astype(np.uint8) + im = Image.fromarray(array) + im.save("tmp/output_image.jpg") + + +def plot_grid_samples(images, num_cols=5, out_path = 'grid.png'): + # Calculate the number of rows required for the grid + num_images = len(images) + num_rows = (num_images + num_cols - 1) // num_cols + + # Create a new figure + fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8)) + + # Loop through the image files and plot them + for i, image in enumerate(images): + row = i // num_cols + col = i % num_cols + + # Open and display the image using Pillow + if type(image) == str: + img = Image.open(image) + else: + img = image + axes[row, col].imshow(img) + # axes[row, col].set_title(os.path.basename(image_file)) + axes[row, col].axis('off') + + # Remove empty subplots + for i in range(num_images, num_rows * num_cols): + row = i // num_cols + col = i % num_cols + fig.delaxes(axes[row, col]) + + # Adjust spacing between subplots + plt.tight_layout() + + # save image + plt.savefig(out_path, dpi=300) + image = Image.open(out_path) + plt.close(fig) + + return image \ No newline at end of file diff --git a/starvector/image_encoder.py b/starvector/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3f85a4cad987d91cf47285c9dd3099e6c7d6b24d --- /dev/null +++ b/starvector/image_encoder.py @@ -0,0 +1,119 @@ +import os +import torch +import torch.nn as nn +import os +from omegaconf import OmegaConf +from starvector.model.image_encoder.clip_model import convert_weights_to_precision +from starvector.data.util import ImageTrainProcessor + +class ImageEncoder(nn.Module): + def __init__(self, config, **kwargs): + super(ImageEncoder, self).__init__() + + image_size = config.image_size + torch_dtype = kwargs.get('model_precision', config.torch_dtype) + self.image_encoder_type = config.image_encoder_type + if self.image_encoder_type == 'clip': + self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) + convert_weights_to_precision(self, torch_dtype) + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'vqgan': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'convnext': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif 'siglip' in self.image_encoder_type: + if self.image_encoder_type == 'siglip_512': + model_name = "google/siglip-base-patch16-512" + elif self.image_encoder_type == 'siglip_384': + model_name = "google/siglip-large-patch16-384" + elif self.image_encoder_type == 'siglip_256': + model_name = "google/siglip-base-patch16-256" + + from transformers import AutoProcessor, AutoModel + + self.visual_encoder = AutoModel.from_pretrained( + model_name, torch_dtype = torch_dtype + ).vision_model + + self.processor = AutoProcessor.from_pretrained( + model_name, torch_dtype = torch_dtype + ) + + def build_clip_encoder(self, image_size): + from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm + visual_encoder = VisionTransformer( + input_resolution=image_size, + patch_size=14, + width=1024, + layers=23, + heads=16, + use_grad_checkpointing=False) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def build_vqgan_encoder(self): + from taming.modules.diffusionmodules.model import Encoder + VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md + vqgan_chkp_path = VQGAN_CHECKPOINT + files_in_directory = os.listdir(vqgan_chkp_path + '/configs') + vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] + vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) + visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) + + # Load checkpoint weights + checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] + + # Create a new state_dict with modified keys + new_state_dict = {} + for key, value in checkpoint.items(): + if key.startswith('encoder.'): + new_key = key[len('encoder.'):] + new_state_dict[new_key] = value + + # Load weights + visual_encoder.load_state_dict(new_state_dict) + return visual_encoder + + def build_convnext_encoder(self): + import open_clip + model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') + return model.visual + + def forward(self, image): + if self.image_encoder_type == 'clip': + embeds = self.visual_encoder(image) + out = self.ln_vision(embeds) + elif self.image_encoder_type == 'open-clip': + out = self.visual_encoder(image)[1] + out = self.ln_vision(out) + elif self.image_encoder_type == 'vqgan': + out = self.visual_encoder(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif self.image_encoder_type == 'convnext': + out = self.visual_encoder.trunk.forward_features(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif 'siglip' in self.image_encoder_type: + out = self.visual_encoder(image)["last_hidden_state"] + return out + + def process_images(self, images): + if self.image_encoder_type == 'clip': + res = [] + for image in images: + res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W + return res + else: + return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) + \ No newline at end of file diff --git a/starvector/metrics/base_metric.py b/starvector/metrics/base_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d07d8472bc70220d8bc2188d41dcb20d8e92a1 --- /dev/null +++ b/starvector/metrics/base_metric.py @@ -0,0 +1,51 @@ +from starvector.metrics.util import AverageMeter +from tqdm import tqdm +import math + +class BaseMetric: + def __init__(self): + self.meter = AverageMeter() + + def reset(self): + self.meter.reset() + + def calculate_score(self, batch, update=True): + """ + Batch: {"gt_im": [PIL Image], "gen_im": [Image]} + """ + values = [] + batch_size = len(next(iter(batch.values()))) + for index in tqdm(range(batch_size)): + kwargs = {} + for key in ["gt_im", "gen_im", "gt_svg", "gen_svg", "caption"]: + if key in batch: + kwargs[key] = batch[key][index] + try: + measure = self.metric(**kwargs) + except Exception as e: + print("Error calculating metric: {}".format(e)) + continue + if math.isnan(measure): + continue + values.append(measure) + + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + score = sum(values) / len(values) + if update: + self.meter.update(score, len(values)) + return self.meter.avg, values + else: + return score, values + + def metric(self, **kwargs): + """ + This method should be overridden by subclasses to provide the specific metric computation. + """ + raise NotImplementedError("The metric method must be implemented by subclasses.") + + def get_average_score(self): + return self.meter.avg + diff --git a/starvector/metrics/compute_LPIPS.py b/starvector/metrics/compute_LPIPS.py new file mode 100644 index 0000000000000000000000000000000000000000..b30c42cfdf21febf2b30a83dd51d690423294321 --- /dev/null +++ b/starvector/metrics/compute_LPIPS.py @@ -0,0 +1,56 @@ +from torchvision.transforms import ToTensor, Normalize +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +import lpips +from tqdm import tqdm + + +class LPIPSDistanceCalculator(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.model = lpips.LPIPS(net='vgg').to(device) + self.metric = self.LPIPS + self.to_tensor = ToTensor() + self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.device = device + + def LPIPS(self, tensor_image1, tensor_image2): + tensor_image1, tensor_image2 = tensor_image1.to(self.device), tensor_image2.to(self.device) + return self.model(tensor_image1, tensor_image2) + + def to_tensor_transform(self, pil_img): + return self.normalize(self.to_tensor(pil_img)) + + def collate_fn(self, batch): + gt_imgs, gen_imgs = zip(*batch) + tensor_gt_imgs = torch.stack([self.to_tensor_transform(img) for img in gt_imgs]) + tensor_gen_imgs = torch.stack([self.to_tensor_transform(img) for img in gen_imgs]) + return tensor_gt_imgs, tensor_gen_imgs + + def calculate_score(self, batch, batch_size=8, update=True): + gt_images = batch['gt_im'] + gen_images = batch['gen_im'] + + # Create DataLoader with custom collate function + data_loader = DataLoader(list(zip(gt_images, gen_images)), batch_size=batch_size, collate_fn=self.collate_fn, shuffle=False) + + values = [] + for tensor_gt_batch, tensor_gen_batch in tqdm(data_loader): + # Compute LPIPS + lpips_values = self.LPIPS(tensor_gt_batch, tensor_gen_batch) + values.extend([lpips_values.squeeze().cpu().detach().tolist()] if lpips_values.numel() == 1 else lpips_values.squeeze().cpu().detach().tolist()) + + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + avg_score = sum(values) / len(values) + if update: + self.meter.update(avg_score, len(values)) + return self.meter.avg, values + else: + return avg_score, values + \ No newline at end of file diff --git a/starvector/metrics/compute_SSIM.py b/starvector/metrics/compute_SSIM.py new file mode 100644 index 0000000000000000000000000000000000000000..e0dfb75435d78261197e12276cb53ca540ea7687 --- /dev/null +++ b/starvector/metrics/compute_SSIM.py @@ -0,0 +1,35 @@ +from starvector.metrics.base_metric import BaseMetric +from skimage.metrics import structural_similarity as ssim +import numpy as np + +class SSIMDistanceCalculator(BaseMetric): + def __init__(self, config=None): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.metric = self.compute_SSIM + + def compute_SSIM(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + win_size = kwargs.get('win_size', 11) # Increase win_size for more accuracy + channel_axis = kwargs.get('channel_axis', -1) # Default channel_axis to -1 + sigma = kwargs.get('sigma', 1.5) # Add sigma parameter for Gaussian filter + + # Convert images to numpy arrays if they aren't already + img1_np = np.array(image1) + img2_np = np.array(image2) + + # Check if images are grayscale or RGB + if len(img1_np.shape) == 3 and img1_np.shape[2] == 3: + # Compute SSIM for RGB images + score, _ = ssim(img1_np, img2_np, win_size=win_size, channel_axis=channel_axis, sigma=sigma, full=True) + else: + # Convert to grayscale if not already + if len(img1_np.shape) == 3: + img1_np = np.mean(img1_np, axis=2) + img2_np = np.mean(img2_np, axis=2) + + score, _ = ssim(img1_np, img2_np, win_size=win_size, sigma=sigma, full=True) + + return score \ No newline at end of file diff --git a/starvector/metrics/compute_clip_score.py b/starvector/metrics/compute_clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..186fb2f85bef817a53f5e00a6a82a310451fe15b --- /dev/null +++ b/starvector/metrics/compute_clip_score.py @@ -0,0 +1,55 @@ +from torchvision.transforms import ToTensor +import torch.nn.functional as F +from starvector.metrics.base_metric import BaseMetric +import torch +from torchmetrics.multimodal.clip_score import CLIPScore +from torch.utils.data import DataLoader +from tqdm import tqdm +import torchvision.transforms as transforms +from torchmetrics.functional.multimodal.clip_score import _clip_score_update + +class CLIPScoreCalculator(BaseMetric): + def __init__(self): + super().__init__() + self.class_name = self.__class__.__name__ + self.clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch32") + self.clip_score.to('cuda') + + def CLIP_Score(self, images, captions): + all_scores = _clip_score_update(images, captions, self.clip_score.model, self.clip_score.processor) + return all_scores + + def collate_fn(self, batch): + gen_imgs, captions = zip(*batch) + tensor_gen_imgs = [transforms.ToTensor()(img) for img in gen_imgs] + return tensor_gen_imgs, captions + + def calculate_score(self, batch, batch_size=512, update=True): + gen_images = batch['gen_im'] + captions = batch['caption'] + + # Create DataLoader with custom collate function + data_loader = DataLoader(list(zip(gen_images, captions)), collate_fn=self.collate_fn, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + + all_scores = [] + for batch_eval in tqdm(data_loader): + images, captions = batch_eval + images = [img.to('cuda', non_blocking=True) * 255 for img in images] + list_scores = self.CLIP_Score(images, captions)[0].detach().cpu().tolist() + all_scores.extend(list_scores) + + if not all_scores: + print("No valid scores found for metric calculation.") + return float("nan"), [] + + avg_score = sum(all_scores) / len(all_scores) + if update: + self.meter.update(avg_score, len(all_scores)) + return self.meter.avg, all_scores + else: + return avg_score, all_scores + +if __name__ == '__main__': + import multiprocessing + multiprocessing.set_start_method('spawn') + # Rest of your code... \ No newline at end of file diff --git a/starvector/metrics/compute_dino_score.py b/starvector/metrics/compute_dino_score.py new file mode 100644 index 0000000000000000000000000000000000000000..99a8364c4d5fd6a69caed545e95128ef34f9c56a --- /dev/null +++ b/starvector/metrics/compute_dino_score.py @@ -0,0 +1,55 @@ +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +from tqdm import tqdm +from transformers import AutoModel, AutoImageProcessor +from PIL import Image +import torch.nn as nn + +class DINOScoreCalculator(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.model, self.processor = self.get_DINOv2_model("base") + self.model = self.model.to(device) + self.device = device + + self.metric = self.calculate_DINOv2_similarity_score + + def get_DINOv2_model(self, model_size): + if model_size == "small": + model_size = "facebook/dinov2-small" + elif model_size == "base": + model_size = "facebook/dinov2-base" + elif model_size == "large": + model_size = "facebook/dinov2-large" + else: + raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}") + return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size) + + def process_input(self, image, processor): + if isinstance(image, str): + image = Image.open(image) + if isinstance(image, Image.Image): + with torch.no_grad(): + inputs = processor(images=image, return_tensors="pt").to(self.device) + outputs = self.model(**inputs) + features = outputs.last_hidden_state.mean(dim=1) + elif isinstance(image, torch.Tensor): + features = image.unsqueeze(0) if image.dim() == 1 else image + else: + raise ValueError("Input must be a file path, PIL Image, or tensor of features") + return features + + def calculate_DINOv2_similarity_score(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + features1 = self.process_input(image1, self.processor) + features2 = self.process_input(image2, self.processor) + + cos = nn.CosineSimilarity(dim=1) + sim = cos(features1, features2).item() + sim = (sim + 1) / 2 + + return sim diff --git a/starvector/metrics/compute_fid.py b/starvector/metrics/compute_fid.py new file mode 100644 index 0000000000000000000000000000000000000000..413fca4a4c14e66a30b4aafee21220d4e02a41c0 --- /dev/null +++ b/starvector/metrics/compute_fid.py @@ -0,0 +1,145 @@ +# Refer https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html +# from torchmetrics.image.fid import FrechetInceptionDistance +from PIL import Image +from starvector.metrics.base_metric import BaseMetric +import torch +from torchvision import transforms +import clip +from torch.nn.functional import adaptive_avg_pool2d +from starvector.metrics.inception import InceptionV3 +import numpy as np +from tqdm import tqdm +from scipy import linalg +import torchvision.transforms as TF + +class FIDCalculator(BaseMetric): + def __init__(self, model_name = 'InceptionV3',): + self.class_name = self.__class__.__name__ + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model_name = model_name + if self.model_name == 'ViT-B/32': + self.dims = 512 + model, preprocess = clip.load('ViT-B/32') + + elif self.model_name == 'InceptionV3': + self.dims = 2048 + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] + model = InceptionV3([block_idx]).to(self.device) + preprocess = TF.Compose([TF.ToTensor()]) + + self.model = model.cuda() + self.preprocess = preprocess + + def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + def get_activations(self, images): + dataset = ImageDataset(images, self.preprocess) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=False, num_workers=4) + pred_arr = np.empty((len(images), self.dims)) + start_idx = 0 + for batch in tqdm(dataloader): + batch = batch.to(self.device) + + with torch.no_grad(): + if self.model_name == 'ViT-B/32': + pred = self.model.encode_image(batch).cpu().numpy() + elif self.model_name == 'InceptionV3': + pred = self.model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + start_idx = start_idx + pred.shape[0] + + return pred_arr + + def calculate_activation_statistics(self, images): + act = self.get_activations(images) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + def pil_images_to_tensor(self, images_list): + """Convert a list of PIL Images to a torch.Tensor.""" + tensors_list = [self.preprocess(img) for img in images_list] + return torch.stack(tensors_list).cuda() # BxCxHxW format + + def calculate_score(self, batch): + m1, s1 = self.calculate_activation_statistics(batch['gt_im']) + m2, s2 = self.calculate_activation_statistics(batch['gen_im']) + fid_value = self.calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + def reset(self): + pass + +class ImageDataset(torch.utils.data.Dataset): + def __init__(self, images, processor=None): + self.images = images + self.processor = processor + + def __len__(self): + return len(self.images) + + def __getitem__(self, i): + img = self.images[i] + img = self.processor(img) + return img \ No newline at end of file diff --git a/starvector/metrics/compute_l2.py b/starvector/metrics/compute_l2.py new file mode 100644 index 0000000000000000000000000000000000000000..ecca16a88f6d39d4b1666da4762202ce1e69d0bb --- /dev/null +++ b/starvector/metrics/compute_l2.py @@ -0,0 +1,37 @@ +from torchvision.transforms import ToTensor +import torch.nn.functional as F +from starvector.metrics.base_metric import BaseMetric +import torch + +class L2DistanceCalculator(BaseMetric): + def __init__(self, config=None, masked_l2=False): + super().__init__() + self.class_name = self.__class__.__name__ + self.config = config + self.metric = self.l2_distance + self.masked_l2 = masked_l2 + + def l2_distance(self, **kwargs): + image1 = kwargs.get('gt_im') + image2 = kwargs.get('gen_im') + image1_tensor = ToTensor()(image1) + image2_tensor = ToTensor()(image2) + + if self.masked_l2: + # Create binary masks: 0 for white pixels, 1 for non-white pixels + mask1 = (image1_tensor != 1).any(dim=0).float() + mask2 = (image2_tensor != 1).any(dim=0).float() + + # Create a combined mask for overlapping non-white pixels + combined_mask = mask1 * mask2 + + # Apply the combined mask to both images + image1_tensor = image1_tensor * combined_mask.unsqueeze(0) + image2_tensor = image2_tensor * combined_mask.unsqueeze(0) + + # Compute mean squared error + mse = F.mse_loss(image1_tensor, image2_tensor) + return mse.item() + + + diff --git a/starvector/metrics/count_token_length.py b/starvector/metrics/count_token_length.py new file mode 100644 index 0000000000000000000000000000000000000000..8210771ec6fc7a148c069770c126610756a466e8 --- /dev/null +++ b/starvector/metrics/count_token_length.py @@ -0,0 +1,54 @@ +import torch +from torch.utils.data import DataLoader +from starvector.metrics.base_metric import BaseMetric +from tqdm import tqdm +from starvector.metrics.util import AverageMeter + +from transformers import AutoTokenizer + +class CountTokenLength(BaseMetric): + def __init__(self, config=None, device='cuda'): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b") + self.metric = self.calculate_token_length + self.meter_gt_tokens = AverageMeter() + self.meter_gen_tokens = AverageMeter() + self.meter_diff = AverageMeter() + + def calculate_token_length(self, **kwargs): + svg = kwargs.get('gt_svg') + tokens = self.tokenizer.encode(svg) + gen_svg = kwargs.get('gen_svg') + gen_tokens = self.tokenizer.encode(gen_svg) + diff = len(gen_tokens) - len(tokens) + return len(tokens), len(gen_tokens), diff + + def calculate_score(self, batch, update=None): + gt_svgs = batch['gt_svg'] + gen_svgs = batch['gen_svg'] + values = [] + for gt_svg, gen_svg in tqdm(zip(gt_svgs, gen_svgs), total=len(gt_svgs), desc="Processing SVGs"): + gt_tokens, gen_tokens, diff = self.calculate_token_length(gt_svg=gt_svg, gen_svg=gen_svg) + self.meter_gt_tokens.update(gt_tokens, 1) + self.meter_gen_tokens.update(gen_tokens, 1) + self.meter_diff.update(diff, 1) + values.append({ + 'gt_tokens': gt_tokens, + 'gen_tokens': gen_tokens, + 'diff': diff + }) + avg_score = { + 'gt_tokens': self.meter_gt_tokens.avg, + 'gen_tokens': self.meter_gen_tokens.avg, + 'diff': self.meter_diff.avg + } + if not values: + print("No valid values found for metric calculation.") + return float("nan") + + return avg_score, values + + def reset(self): + self.meter_gt_tokens.reset() + self.meter_gen_tokens.reset() + self.meter_diff.reset() diff --git a/starvector/metrics/inception.py b/starvector/metrics/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..cc56870522a1ee298abbfff9edd7243a0bf8e7dd --- /dev/null +++ b/starvector/metrics/inception.py @@ -0,0 +1,341 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=(DEFAULT_BLOCK_INDEX,), + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = _inception_v3(weights='DEFAULT') + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def _inception_v3(*args, **kwargs): + """Wraps `torchvision.models.inception_v3`""" + try: + version = tuple(map(int, torchvision.__version__.split('.')[:2])) + except ValueError: + # Just a caution against weird version strings + version = (0,) + + # Skips default weight inititialization if supported by torchvision + # version. See https://github.com/mseitzer/pytorch-fid/issues/28. + if version >= (0, 6): + kwargs['init_weights'] = False + + # Backwards compatibility: `weights` argument was handled by `pretrained` + # argument prior to version 0.13. + if version < (0, 13) and 'weights' in kwargs: + if kwargs['weights'] == 'DEFAULT': + kwargs['pretrained'] = True + elif kwargs['weights'] is None: + kwargs['pretrained'] = False + else: + raise ValueError( + 'weights=={} not supported in torchvision {}'.format( + kwargs['weights'], torchvision.__version__ + ) + ) + del kwargs['weights'] + + return torchvision.models.inception_v3(*args, **kwargs) + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + inception = _inception_v3(num_classes=1008, + aux_logits=False, + weights=None) + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + inception.load_state_dict(state_dict) + return inception + + +class FIDInceptionA(torchvision.models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(torchvision.models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(torchvision.models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(torchvision.models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) \ No newline at end of file diff --git a/starvector/metrics/metrics.py b/starvector/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..86d7cdcd6dd5798e8bb35b8fe472f687d96f3b4d --- /dev/null +++ b/starvector/metrics/metrics.py @@ -0,0 +1,127 @@ +from starvector.metrics.compute_l2 import L2DistanceCalculator +from starvector.metrics.compute_LPIPS import LPIPSDistanceCalculator +from starvector.metrics.compute_SSIM import SSIMDistanceCalculator +from starvector.metrics.compute_fid import FIDCalculator +from starvector.metrics.compute_clip_score import CLIPScoreCalculator +from starvector.data.util import rasterize_svg +from starvector.metrics.util import AverageMeter +from starvector.metrics.compute_dino_score import DINOScoreCalculator +from starvector.metrics.count_token_length import CountTokenLength +import os +from tqdm import tqdm + +class SVGMetrics: + def __init__(self, config=None): + self.class_name = self.__class__.__name__ + + default_config = { + 'L2': True, + 'Masked-L2': False, + 'LPIPS': False, + 'SSIM': False, + 'FID': False, + 'FID_clip': False, + 'CLIPScore': False, + 'CountTokenLength': False, + 'ratio_post_processed': True, + 'ratio_non_compiling': True, + 'DinoScore': True, + } + self.config = config or default_config + + self.metrics = { + 'L2': L2DistanceCalculator, + 'Masked-L2': lambda: L2DistanceCalculator(masked_l2=True), + 'LPIPS': LPIPSDistanceCalculator, + 'SSIM': SSIMDistanceCalculator, + 'FID': lambda: FIDCalculator(model_name='InceptionV3'), + 'FID_clip': lambda: FIDCalculator(model_name='ViT-B/32'), + 'CLIPScore': CLIPScoreCalculator, + 'CountTokenLength': CountTokenLength, + 'ratio_post_processed': AverageMeter, + 'ratio_non_compiling': AverageMeter, + 'DinoScore': DINOScoreCalculator, + } + + self.active_metrics = {k: v() for k, v in self.metrics.items() if self.config.get(k)} + + def reset(self): + for metric in self.active_metrics.values(): + metric.reset() + + def batch_contains_raster(self, batch): + return "gt_im" in batch and "gen_im" in batch + + def batch_contains_svg(self, batch): + return "gt_svg" in batch and "gen_svg" in batch + + def calculate_metrics(self, batch, update=True): + if not self.batch_contains_raster(batch): + batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] + batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] + + avg_results_dict = {} + all_results_dict = {} + + def get_sample_id(json_item): + return json_item.get('outpath_filename') or json_item.get('sample_id') + + # initialize all_results_dict + for i, json_item in enumerate(batch['json']): + sample_id = get_sample_id(json_item) + if sample_id is None: + raise ValueError(f"Could not find 'outpath_filename' or 'sample_id' in batch['json'][{i}]") + all_results_dict[sample_id] = {} + + for metric_name, metric in self.active_metrics.items(): + print(f"Calculating {metric_name}...") + + # Handle metrics that return both average and per-sample results + if metric_name in ['L2', 'Masked-L2', 'SSIM', 'CLIPScore', 'LPIPS', 'CountTokenLength', 'DinoScore']: + avg_result, list_result = metric.calculate_score(batch, update=update) + avg_results_dict[metric_name] = avg_result + + # Store individual results + for i, result in enumerate(list_result): + sample_id = get_sample_id(batch['json'][i]) + all_results_dict[sample_id][metric_name] = result + + # Handle FID metrics that only return average + elif metric_name in ['FID', 'FID_clip']: + avg_results_dict[metric_name] = metric.calculate_score(batch) + + # Handle other metrics (ratio metrics) + else: + self._handle_ratio_metric(metric_name, metric, batch, avg_results_dict, all_results_dict) + + metric.reset() + print("Average results: \n", avg_results_dict) + return avg_results_dict, all_results_dict + + def calculate_fid(self, batch): + if not self.batch_contains_raster(batch): + batch["gt_im"] = [rasterize_svg(svg) for svg in batch["gt_svg"]] + batch["gen_im"] = [rasterize_svg(svg) for svg in batch["gen_svg"]] + + return self.active_metrics['FID'].calculate_score(batch).item() + + def get_average_metrics(self): + metrics = {} + for metric_name, metric in self.active_metrics.items(): + if hasattr(metric, 'avg'): + metrics[metric_name] = metric.avg + elif hasattr(metric, 'get_average_score'): + metrics[metric_name] = metric.get_average_score() + return metrics + + def _handle_ratio_metric(self, metric_name, metric, batch, avg_results_dict, all_results_dict): + """Helper method to handle ratio-based metrics.""" + metric_key = metric_name.replace('avg_', '').replace('ratio_', '') + + for item in batch['json']: + sample_id = get_sample_id(item) + value = item[metric_key] + all_results_dict[sample_id][metric_name] = value + metric.update(value, 1) + + avg_results_dict[metric_name] = metric.avg \ No newline at end of file diff --git a/starvector/metrics/util.py b/starvector/metrics/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1faac0ed299c21092234a19ef570584981573cc7 --- /dev/null +++ b/starvector/metrics/util.py @@ -0,0 +1,20 @@ + +# -------------- Metrics -------------- +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + \ No newline at end of file diff --git a/starvector/model/adapters/adapter.py b/starvector/model/adapters/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5a0fe2de0a98472f67576a0fa32c47c68dedff --- /dev/null +++ b/starvector/model/adapters/adapter.py @@ -0,0 +1,53 @@ +import torch.nn as nn +import torch.nn.init as init +import torch + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + return x * torch.sigmoid(x) + +class Adapter(nn.Module): + def __init__(self, input_size, output_size, adapter_norm="layer_norm", init_type="glorot", query_length=32, dropout_prob=0.1): + super().__init__() + self.query_length = query_length + self.dropout_prob = dropout_prob + self.adapter_norm = adapter_norm + + self.dropout = nn.Dropout(p=self.dropout_prob) + + self.c_fc = nn.Linear(input_size, input_size*2) + self.act = Swish() + self.c_proj = nn.Linear(input_size*2, output_size) + + if adapter_norm == "layer_norm": + self.norm = nn.LayerNorm([self.query_length, output_size]) + elif adapter_norm == "batch_norm": + self.norm = nn.BatchNorm1d(self.query_length) + + self.init_type = init_type.lower() + self._initialize_weights() + + def forward(self, hidden_states): + hidden_states = self.dropout(hidden_states) + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.norm(hidden_states) + return hidden_states + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + if self.init_type == "glorot": + init.xavier_uniform_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif self.init_type == "normal": + init.normal_(m.weight, mean=0, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + else: + raise ValueError("Invalid initialization type specified.") diff --git a/starvector/model/builder.py b/starvector/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..785a2dfdbd131853760813d1ccf1a01da5d94fa7 --- /dev/null +++ b/starvector/model/builder.py @@ -0,0 +1,49 @@ + +from starvector.model.starvector_arch import StarVectorForCausalLM, StarVectorConfig +from starvector.data.base import ImageTrainProcessor +from starvector.util import dtype_mapping +from transformers import AutoConfig + +def load_pretrained_model(model_path, device="cuda", **kwargs): + model = StarVectorForCausalLM.from_pretrained(model_path, **kwargs).to(device) + tokenizer = model.model.svg_transformer.tokenizer + image_processor = ImageTrainProcessor() + context_len = model.model.query_length + model.model.max_length + return tokenizer, model, image_processor, context_len + +def model_builder(config): + model_name = config.model.get("model_name", False) + + args = { + "task": config.model.task, + "train_image_encoder": config.training.train_image_encoder, + "ignore_mismatched_sizes": True, + "starcoder_model_name": config.model.starcoder_model_name, + "train_LLM": config.training.train_LLM, + "torch_dtype": dtype_mapping[config.training.model_precision], + "transformer_layer_cls": config.model.get("transformer_layer_cls", False), + "use_cache": config.model.use_cache, + } + if model_name: + model = StarVectorForCausalLM.from_pretrained(model_name, **args) + else: + starcoder_model_config = AutoConfig.from_pretrained(config.model.starcoder_model_name) + + starvector_config = StarVectorConfig( + max_length_train=config.model.max_length, + image_encoder_type=config.model.image_encoder_type, + use_flash_attn=config.model.use_flash_attn, + adapter_norm=config.model.adapter_norm, + starcoder_model_name=config.model.starcoder_model_name, + torch_dtype=dtype_mapping[config.training.model_precision], + num_attention_heads=starcoder_model_config.num_attention_heads, + num_hidden_layers=starcoder_model_config.num_hidden_layers, + vocab_size=starcoder_model_config.vocab_size, + hidden_size=starcoder_model_config.hidden_size, + num_kv_heads=getattr(starcoder_model_config, "num_key_value_heads", None), + ) + model = StarVectorForCausalLM(starvector_config, **args) + + return model + + \ No newline at end of file diff --git a/starvector/model/gpt_bigcode/__init__.py b/starvector/model/gpt_bigcode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1678bafc28fa8227101eb712cd41968269493c2b --- /dev/null +++ b/starvector/model/gpt_bigcode/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_gpt_bigcode": ["GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTBigCodeConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gpt_bigcode"] = [ + "GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTBigCodeForSequenceClassification", + "GPTBigCodeForTokenClassification", + "GPTBigCodeForCausalLM", + "GPTBigCodeModel", + "GPTBigCodePreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_gpt_bigcode import GPT_BIGCODE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTBigCodeConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gpt_bigcode import ( + GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTBigCodeForCausalLM, + GPTBigCodeForSequenceClassification, + GPTBigCodeForTokenClassification, + GPTBigCodeModel, + GPTBigCodePreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py b/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py new file mode 100644 index 0000000000000000000000000000000000000000..ececb6332f9bad2b9af559f1a586f688c2996c78 --- /dev/null +++ b/starvector/model/gpt_bigcode/configuration_gpt_bigcode.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2023 The BigCode team and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPTBigCode configuration""" +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + + +logger = logging.get_logger(__name__) + + + + +class GPTBigCodeConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPTBigCodeModel`]. It is used to instantiate a + GPTBigCode model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPTBigCode + [gpt_bigcode](https://huggingface.co/gpt_bigcode) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPTBigCodeModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_pytorch_tanh"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new", + "gelu_pytorch_tanh"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to call the fused softmax in float32. + scale_attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to scale the attention softmax in float32. + attention_type (`bool`, *optional*, defaults to `True`): + Whether to use Multi-Query Attion (`True`) or Multi-Head Attention (`False`). + Example: + + ```python + >>> from transformers import GPTBigCodeConfig, GPTBigCodeModel + + >>> # Initializing a GPTBigCode configuration + >>> configuration = GPTBigCodeConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPTBigCodeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt_bigcode" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_pytorch_tanh", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + attention_softmax_in_fp32=True, + scale_attention_softmax_in_fp32=True, + multi_query=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32 + self.multi_query = multi_query + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py b/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py new file mode 100644 index 0000000000000000000000000000000000000000..b8334b2cbe65bb20288849eb0cf9747f48a48f0d --- /dev/null +++ b/starvector/model/gpt_bigcode/modeling_gpt_bigcode.py @@ -0,0 +1,1502 @@ +# coding=utf-8 +# Copyright 2023 The Bigcode team and HuggingFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTBigCode model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from starvector.model.gpt_bigcode.configuration_gpt_bigcode import GPTBigCodeConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder" +_CONFIG_FOR_DOC = "GPTBigCodeConfig" + + + +# Fused kernels +# Use separate functions for each case because conditionals prevent kernel fusion. +# TODO: Could have better fused kernels depending on scaling, dropout and head mask. +# Is it doable without writing 32 functions? +@torch.jit.script +def upcast_masked_softmax( + x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype +): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype): + input_dtype = x.dtype + x = x.to(softmax_dtype) * scale + x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype) + return x + + +@torch.jit.script +def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor): + x = torch.where(mask, x, mask_value) + x = torch.nn.functional.softmax(x, dim=-1) + return x + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class GPTBigCodeAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + + self.mask_value = None + self.multi_query = config.multi_query + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.kv_heads = 1 if self.multi_query else self.num_heads + self.kv_dim = self.kv_heads * self.head_dim + self.split_size = self.embed_dim + self.is_causal = True + + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + self.layer_idx = layer_idx + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + self.scale_attention_softmax_in_fp32 = ( + config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32 + ) + self.attn_pdrop = config.attn_pdrop + + if self.is_cross_attention: + if self.multi_query: + raise NotImplementedError("Multi-Query Attention not supported for cross_attention") + + self.c_attn = nn.Linear(self.embed_dim, 2 * self.embed_dim) + self.q_attn = nn.Linear(self.embed_dim, self.embed_dim) + else: + self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim) + + self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + def _get_mask_value(self, device, dtype): + # torch.where expects a tensor. We use a cache to avoid recreating it every time. + if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: + self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + return self.mask_value + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + dtype = query.dtype + softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype + upcast = dtype != softmax_dtype + + unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1 + scale_factor = unscale**-1 + if self.scale_attn_weights: + scale_factor /= self.head_dim**0.5 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key_length = key.size(-1) + if self.multi_query: + # (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length) + # -> (batch_size, query_length, num_heads, key_length) + query_length = query_shape[1] + attn_shape = (batch_size, query_length, self.num_heads, key_length) + attn_view = (batch_size, query_length * self.num_heads, key_length) + # No copy needed for MQA 2, or when layer_past is provided. + query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim) + else: + # (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length) + # -> (batch_size, num_heads, query_length, key_length) + query_length = query_shape[2] + attn_shape = (batch_size, self.num_heads, query_length, key_length) + attn_view = (batch_size * self.num_heads, query_length, key_length) + # Always copies + query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim) + # No copy when layer_past is provided. + key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length) + + attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype) + if query.device.type == "cpu": + # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588. + # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086, + # but the fix has not been released as of pytorch version 2.0.0. + attn_weights = torch.zeros_like(attn_weights) + beta = 1 + else: + beta = 0 + attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape) + + if upcast: + # Use a fused kernel to prevent a large overhead from casting and scaling. + # Sub-optimal when the key length is not a multiple of 8. + if attention_mask is None: + attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype) + else: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype) + else: + if attention_mask is not None: + mask_value = self._get_mask_value(attn_weights.device, softmax_dtype) + + # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion. + attn_weights = torch.where(attention_mask, attn_weights, mask_value) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + if self.multi_query: + head_mask = head_mask.transpose(1, 2) + attn_weights = attn_weights * head_mask + + if self.multi_query: + attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape) + else: + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPTBigCodeFlashAttention2(GPTBigCodeAttention): + """ + GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + if self.multi_query: + batch_size, query_length, _ = query.shape + query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.unsqueeze(2) + value = value.unsqueeze(2) + else: + query_length = query.shape[2] + batch_size, _, tgt, _ = key.shape + query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + + attn_dropout = self.attn_pdrop if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_attn.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) + else: + attn_weights_reshaped = None + + outputs += (attn_weights_reshaped,) + + return outputs # a, present, (attentions) + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPTBigCodeSdpaAttention(GPTBigCodeAttention): + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + if head_mask is not None: + # The super dispatch is done in the forward. + raise ValueError( + "PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository." + ) + + scale = None + if not self.scale_attn_weights: + scale = 1 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + key.shape[-2] + + if self.multi_query: + query_length = query_shape[1] + + # SDPA requires the dimension [..., sequence_length, head_dim]. + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + + # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. + key = key.unsqueeze(1) + value = value.unsqueeze(1) + + # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend + # and flash attention backend (No available kernel. Aborting execution.) from the shapes + # query = [batch_size, num_heads, query_length, head_dim] + # key = [batch_size, 1, past_length, head_dim] + # value = [batch_size, 1, past_length, head_dim] + # + # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check. + if is_torch_greater_or_equal_than_2_2: + key = key.expand(-1, self.num_heads, -1, -1) + value = value.expand(-1, self.num_heads, -1, -1) + else: + query_length = query_shape[-1] + + # See the comment above. + if query.device.type == "cuda" and attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + sdpa_result = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=self.attn_pdrop if self.training else 0.0, + # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. + is_causal=self.is_causal and attention_mask is None and query_length > 1, + scale=scale, + ) + + if self.multi_query: + # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) + sdpa_result = sdpa_result.transpose(1, 2) + + # Reshape is kind of expensive here, as it does a memory copy, + # but I did not manage to make away without it (logits do not match when using view) + # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) + sdpa_result = sdpa_result.reshape(query_shape) + + return sdpa_result, None + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + if not output_attentions and head_mask is None: + # Difference with the original implementation: there is no need to transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + else: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GPTBigCodeModel is using GPTBigCodeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` and `head_mask` not None." + ' Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + attn_output, attn_weights = super()._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs + + +class GPTBigCodeMLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = nn.Linear(embed_dim, intermediate_size) + self.c_proj = nn.Linear(intermediate_size, embed_dim) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP.forward + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPTBIGCODE_ATTENTION_CLASSES = { + "eager": GPTBigCodeAttention, + "flash_attention_2": GPTBigCodeFlashAttention2, + "sdpa": GPTBigCodeSdpaAttention, +} + + +class GPTBigCodeBlock(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.attn = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + if config.multi_query: + raise NotImplementedError("Cross-attention not implemented for MQA") + + self.crossattention = GPTBIGCODE_ATTENTION_CLASSES[config._attn_implementation]( + config, is_cross_attention=True, layer_idx=layer_idx + ) + + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPTBigCodeMLP(self.inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.Tensor]], + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPTBigCodePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTBigCodeConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["GPTBigCodeBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + module.c_proj.weight.data.normal_( + mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) + ) + module.c_proj._is_hf_initialized = True + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +GPT_BIGCODE_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPTBigCodeConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT_BIGCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.", + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeModel(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.multi_query = config.multi_query + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPTBigCodeBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False + ) + + self.gradient_checkpointing = False + + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0].size(-2) + + if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_length > 0: + position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] + elif position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Self-attention mask. + query_length = input_shape[-1] + key_length = past_length + query_length + self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None + encoder_attention_mask = ( + encoder_attention_mask.bool() + if (encoder_attention_mask is not None and 0 in encoder_attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + if attention_mask is not None: + self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to( + dtype=torch.bool, device=self_attention_mask.device + ) + + # MQA models: (batch_size, query_length, n_heads, key_length) + # MHA models: (batch_size, n_heads, query_length, key_length) + self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) + + if self._use_sdpa and head_mask is None and not output_attentions: + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + min_dtype = torch.finfo(dtype).min + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), + ) + + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if self.multi_query: + # gpt_bigcode using MQA has the bad taste to use a causal mask with shape + # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. + self_attention_mask = self_attention_mask.transpose(1, 2) + + if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda": + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + self_attention_mask = AttentionMaskConverter._unmask_unattended( + self_attention_mask, min_dtype=min_dtype + ) + + attention_mask = self_attention_mask + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if ( + self.config.add_cross_attention + and encoder_hidden_states is not None + and encoder_attention_mask is not None + ): + if encoder_attention_mask.dim() == 2: + encoder_attention_mask.unsqueeze(1) + assert encoder_attention_mask.dim() == 3 + encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = [] if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache: + presents.append(outputs[1]) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPTBigCodeModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + if self.config.multi_query: + past_length = past_key_values[0].shape[1] + else: + past_length = past_key_values[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous().to(shift_logits.device) + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) + + +@add_start_docstrings( + """ + The GPTBigCode Model transformer with a sequence classification head on top (linear layer). + + [`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPTBigCodeModel(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT_BIGCODE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + GPT_BIGCODE_START_DOCSTRING, +) +class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPTBigCodeModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/starvector/model/image_encoder/clip_model.py b/starvector/model/image_encoder/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb2349dc3a5521b0a59d896025f6a7251374897 --- /dev/null +++ b/starvector/model/image_encoder/clip_model.py @@ -0,0 +1,191 @@ +# Adapted from LAVIS-Salesforce: LAVIS/lavis/models/clip_vit.py + +from collections import OrderedDict +from itertools import repeat +import collections.abc +import math +import torch +import torch.nn.functional as F +from torch import nn +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +def convert_weights_to_precision(model: nn.Module, precision: torch.dtype): + """Convert applicable model parameters to the specified precision""" + + def _convert_weights_to_precision(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(precision) + if l.bias is not None: + l.bias.data = l.bias.data.to(precision) + + elif isinstance(l, (nn.MultiheadAttention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(precision) + else: + for _, p in l.named_parameters(): + p.data = p.data.to(precision) + + model.apply(_convert_weights_to_precision) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + layernorm_dtype = self.weight.dtype + ret = super().forward(x.type(layernorm_dtype)) + return ret.type(orig_type) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): + super().__init__() + self.input_resolution = input_resolution + self.num_features = width + self.num_heads = heads + self.num_patches = (input_resolution // patch_size) ** 2 + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + return x diff --git a/starvector/model/image_encoder/image_encoder.py b/starvector/model/image_encoder/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a96dfc001fcf194611906437174f9469b440b1 --- /dev/null +++ b/starvector/model/image_encoder/image_encoder.py @@ -0,0 +1,120 @@ +import os +import torch +import torch.nn as nn +import os +from omegaconf import OmegaConf +from starvector.model.image_encoder.clip_model import convert_weights_to_precision +from starvector.data.util import ImageTrainProcessor + +class ImageEncoder(nn.Module): + def __init__(self, config, **kwargs): + super(ImageEncoder, self).__init__() + + image_size = config.image_size + torch_dtype = kwargs.get('model_precision', config.torch_dtype) + # torch_dtype = torch.float32 + self.image_encoder_type = config.image_encoder_type + if self.image_encoder_type == 'clip': + self.visual_encoder, self.ln_vision = self.build_clip_encoder(image_size=image_size) + convert_weights_to_precision(self, torch_dtype) + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'vqgan': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif self.image_encoder_type == 'convnext': + self.visual_encoder = self.build_vqgan_encoder() + self.ln_vision = None + self.processor = ImageTrainProcessor(size=config.image_size) + + elif 'siglip' in self.image_encoder_type: + if self.image_encoder_type == 'siglip_512': + model_name = "google/siglip-base-patch16-512" + elif self.image_encoder_type == 'siglip_384': + model_name = "google/siglip-large-patch16-384" + elif self.image_encoder_type == 'siglip_256': + model_name = "google/siglip-base-patch16-256" + + from transformers import AutoProcessor, AutoModel + + self.visual_encoder = AutoModel.from_pretrained( + model_name, torch_dtype = torch_dtype + ).vision_model + + self.processor = AutoProcessor.from_pretrained( + model_name, torch_dtype = torch_dtype + ) + + def build_clip_encoder(self, image_size): + from starvector.model.image_encoder.clip_model import VisionTransformer, LayerNorm + visual_encoder = VisionTransformer( + input_resolution=image_size, + patch_size=14, + width=1024, + layers=23, + heads=16, + use_grad_checkpointing=False) + + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def build_vqgan_encoder(self): + from taming.modules.diffusionmodules.model import Encoder + VQGAN_CHECKPOINT = "/path/to/vqgan_checkpoint" # You can download the checkpoint from https://github.com/EleutherAI/vqgan-clip/blob/main/README.md + vqgan_chkp_path = VQGAN_CHECKPOINT + files_in_directory = os.listdir(vqgan_chkp_path + '/configs') + vqgan_config_file = [file for file in files_in_directory if file.endswith('project.yaml')][0] + vqgan_config = OmegaConf.load(os.path.join(vqgan_chkp_path, 'configs', vqgan_config_file)) + visual_encoder = Encoder(**vqgan_config.model.params.ddconfig) + + # Load checkpoint weights + checkpoint = torch.load(os.path.join(vqgan_chkp_path, 'checkpoints', 'last.ckpt'))['state_dict'] + + # Create a new state_dict with modified keys + new_state_dict = {} + for key, value in checkpoint.items(): + if key.startswith('encoder.'): + new_key = key[len('encoder.'):] + new_state_dict[new_key] = value + + # Load weights + visual_encoder.load_state_dict(new_state_dict) + return visual_encoder + + def build_convnext_encoder(self): + import open_clip + model, _, _ = open_clip.create_model_and_transforms('convnext_base_w', pretrained='laion2b_s13b_b82k') + return model.visual + + def forward(self, image): + if self.image_encoder_type == 'clip': + embeds = self.visual_encoder(image) + out = self.ln_vision(embeds) + elif self.image_encoder_type == 'open-clip': + out = self.visual_encoder(image)[1] + out = self.ln_vision(out) + elif self.image_encoder_type == 'vqgan': + out = self.visual_encoder(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif self.image_encoder_type == 'convnext': + out = self.visual_encoder.trunk.forward_features(image) + size = out.size() + out = out.view(size[0], size[1], -1) + out = out.permute(0, 2, 1) + elif 'siglip' in self.image_encoder_type: + out = self.visual_encoder(image)["last_hidden_state"] + return out + + def process_images(self, images): + if self.image_encoder_type == 'clip': + res = [] + for image in images: + res.append(self.processor(image).unsqueeze(0)) # B, 3, H, W + return res + else: + return self.processor(images=images, return_tensors="pt").pixel_values.unsqueeze(0) + \ No newline at end of file diff --git a/starvector/model/llm/starcoder.py b/starvector/model/llm/starcoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a09b3c6d5c1bfd04dc1ea02c6fd7f25602d16e03 --- /dev/null +++ b/starvector/model/llm/starcoder.py @@ -0,0 +1,51 @@ +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + ) + +class StarCoderModel(nn.Module): + def __init__(self, config, **kwargs): + super(StarCoderModel, self).__init__() + + self.init_tokenizer(config.starcoder_model_name) + + self.max_length = config.max_length + model_config = AutoConfig.from_pretrained(config.starcoder_model_name, trust_remote_code=True) + kwargs = {} + kwargs['trust_remote_code'] = True + kwargs['torch_dtype'] = config.torch_dtype + + # Configure special tokens for generation + model_config.eos_token_id = self.tokenizer.eos_token_id + model_config.pad_token_id = self.tokenizer.pad_token_id + model_config.bos_token_id = self.tokenizer.bos_token_id + try: + model_config.flash_attention = config.use_flash_attn + model_config._attn_implementation = "flash_attention_2" + except ImportError: + config.use_flash_attn = False + + # model = GPTBigCodeForCausalLM(config=model_config) + model = AutoModelForCausalLM.from_pretrained(config.starcoder_model_name, config=model_config, **kwargs) + model.resize_token_embeddings(len(self.tokenizer)) + self.transformer = model + + # Prompt the model after image + self.prompt = '" + end_sequence = self.svg_transformer.tokenizer("", add_special_tokens=False)['input_ids'] + stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[end_sequence])]) + return { + 'inputs_embeds': base_kwargs['inputs_embeds'], + 'attention_mask': base_kwargs['attention_mask'], + 'do_sample': base_kwargs.get('use_nucleus_sampling', True), + 'top_p': base_kwargs.get('top_p', 0.9), + 'temperature': base_kwargs.get('temperature', 1), + 'num_beams': base_kwargs.get('num_beams', 2), + 'max_length': base_kwargs.get('max_length', 30), + 'min_length': base_kwargs.get('min_length', 1), + 'repetition_penalty': base_kwargs.get('repetition_penalty', 1.0), + 'length_penalty': base_kwargs.get('length_penalty', 1.0), + 'use_cache': base_kwargs.get('use_cache', True), + 'stopping_criteria': stopping_criteria + } + + def generate_im2svg(self, batch, **kwargs): + """Base implementation of image to SVG generation""" + inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs( + batch, kwargs.get('prompt'), batch["image"].device + ) + + generation_kwargs = self._get_generation_kwargs( + {**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask} + ) + # Let subclasses override these defaults if needed + generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs)) + + outputs = self.svg_transformer.transformer.generate(**generation_kwargs) + outputs = torch.cat([prompt_tokens.input_ids, outputs], dim=1) + raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return raw_svg + + def generate_im2svg_grpo(self, batch, **kwargs): + """Base implementation of image to SVG generation""" + inputs_embeds, attention_mask, prompt_tokens = self._prepare_generation_inputs( + batch, kwargs.get('prompt'), batch["image"].device + ) + + generation_kwargs = self._get_generation_kwargs( + {**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask} + ) + # Let subclasses override these defaults if needed + generation_kwargs.update(self._get_im2svg_specific_kwargs(kwargs)) + + num_return_sequences = kwargs.get('num_return_sequences', 1) + if num_return_sequences > 1: + generation_kwargs['num_return_sequences'] = num_return_sequences + generation_kwargs['num_beams'] = 1 + + outputs = self.svg_transformer.transformer.generate(**generation_kwargs) + outputs = torch.cat([prompt_tokens.input_ids.repeat(num_return_sequences, 1), outputs], dim=1) + raw_svg = self.svg_transformer.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return { + "raw_svg": raw_svg, + "outputs": outputs, + "inputs_embeds": inputs_embeds, + } + + + def _get_im2svg_specific_kwargs(self, kwargs): + """Default implementation of im2svg specific generation kwargs. + Subclasses can override this to customize generation behavior.""" + return { + 'early_stopping': True, + 'pad_token_id': self.svg_transformer.tokenizer.pad_token_id + } + + def generate_text2svg(self, batch, **kwargs): + """Base implementation of text to SVG generation""" + device = batch["image"].device + prompt = batch["caption"] + + prompt_tokens = self._tokenize( + prompt, + max_length=kwargs.get('max_length', 30), + device=device, + add_special_tokens=False + ) + + trigger_token = self._tokenize( + [self.svg_transformer.svg_start_token for _ in batch["caption"]], + max_length=None, + device=device, + add_special_tokens=False + ) + + input_tokens = torch.cat([prompt_tokens.input_ids, trigger_token.input_ids], dim=1) + attention_mask = torch.cat([prompt_tokens.attention_mask, trigger_token.attention_mask], dim=1) + inputs_embeds = self._get_embeddings(input_tokens) + max_length = kwargs.get('max_length', 30) - input_tokens.size(1) + + generation_kwargs = self._get_generation_kwargs( + {**kwargs, 'inputs_embeds': inputs_embeds, 'attention_mask': attention_mask}, + input_tokens.size(1) + ) + # Let subclasses override these defaults if needed + generation_kwargs.update(self._get_text2svg_specific_kwargs(kwargs)) + generation_kwargs['max_length'] = max_length + + outputs = self.svg_transformer.transformer.generate(**generation_kwargs) + return outputs + + def _get_text2svg_specific_kwargs(self, kwargs): + """Default implementation of text2svg specific generation kwargs. + Subclasses can override this to customize generation behavior.""" + return { + 'eos_token_id': self.svg_transformer.tokenizer.eos_token_id, + 'early_stopping': True, + 'length_penalty': kwargs.get('length_penalty', 1.0) + } diff --git a/starvector/model/models/starvector_v1.py b/starvector/model/models/starvector_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..af35172a1fcf3cfd4de544fdc4b4a43089402f0c --- /dev/null +++ b/starvector/model/models/starvector_v1.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +from starvector.model.models.starvector_base import StarVectorBase +from transformers import AutoProcessor + +class StarVectorStarCoder(StarVectorBase): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + self.processor = AutoProcessor.from_pretrained(config._name_or_path) + + def _get_svg_transformer(self, config, **kwargs): + from starvector.model.llm.starcoder import StarCoderModel # This uses StarCoder (V1) + return StarCoderModel(config, **kwargs) + + def _get_embeddings(self, input_ids): + """V1 specific embedding method""" + return self.svg_transformer.transformer.transformer.wte(input_ids) + + def _get_svg_text(self, svg_list): + """V1 specific SVG text preparation""" + return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list] diff --git a/starvector/model/models/starvector_v2.py b/starvector/model/models/starvector_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..342fa0b02ac1fff813fdfa1a798aa4d6152d0eb7 --- /dev/null +++ b/starvector/model/models/starvector_v2.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy +from functools import partial +from starvector.model.models.starvector_base import StarVectorBase +from transformers import AutoImageProcessor + +class StarVectorStarCoder2(StarVectorBase): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + self.processor = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True) + + def _get_svg_transformer(self, config, **kwargs): + from starvector.model.llm.starcoder2 import StarCoderModel # This is a different model than V1, uses StarCoder2 + return StarCoderModel(config, **kwargs) + + + def get_fsdp_wrapping_policy(self): + """V2 specific FSDP wrapping policy""" + from starvector.model.image_encoder.image_encoder import ImageEncoder + + image_encoder_wrapping_policy = partial( + _module_wrap_policy, + module_classes={ImageEncoder}, + ) + + llm_fsdp_wrapping_policy = self.svg_transformer.get_fsdp_wrapping_policy() + from starvector.model.adapters.adapter import Adapter + + adapter_wrapping_policy = partial( + _module_wrap_policy, + module_classes={Adapter}, + ) + + return partial( + _or_policy, + policies=[ + image_encoder_wrapping_policy, + llm_fsdp_wrapping_policy, + adapter_wrapping_policy, + ], + ) + + def _get_embeddings(self, input_ids): + """V2 specific embedding method""" + return self.svg_transformer.transformer.model.embed_tokens(input_ids) + + def _get_svg_text(self, svg_list): + """V2 specific SVG text preparation""" + return [t + self.svg_transformer.svg_end_token + self.svg_transformer.tokenizer.eos_token for t in svg_list] + + def _get_im2svg_specific_kwargs(self, kwargs): + """V2 specific generation kwargs""" + return { + # 'eos_token_id': self.svg_transformer.svg_end_token_id, + } + + def _get_text2svg_specific_kwargs(self, kwargs): + """V2 specific text2svg generation kwargs""" + return { + 'eos_token_id': self.svg_transformer.tokenizer.eos_token_id, + } diff --git a/starvector/model/starvector_arch.py b/starvector/model/starvector_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..d18561f1a91e05c6582c8eebce07effe1022bccf --- /dev/null +++ b/starvector/model/starvector_arch.py @@ -0,0 +1,194 @@ +from transformers import ( + PretrainedConfig, + PreTrainedModel +) +from torch.nn import CrossEntropyLoss +from transformers.models.gpt_bigcode.modeling_gpt_bigcode import CausalLMOutputWithCrossAttentions +from typing import Optional, Tuple, Union +import torch + +from transformers.processing_utils import ProcessorMixin +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode, pad +from transformers.feature_extraction_sequence_utils import BatchFeature +from transformers import AutoProcessor + +class SimpleStarVectorProcessor(ProcessorMixin): + attributes = ["tokenizer"] # Only include tokenizer in attributes + valid_kwargs = ["size", "mean", "std"] # Add other parameters as valid kwargs + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, + tokenizer=None, # Make tokenizer the first argument + size=224, + mean=None, + std=None, + **kwargs, + ): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + # Store these as instance variables + self.mean = mean + self.std = std + self.size = size + self.normalize = transforms.Normalize(mean=mean, std=std) + + self.transform = transforms.Compose([ + transforms.Lambda(lambda img: img.convert("RGB") if img.mode == "RGBA" else img), + transforms.Lambda(lambda img: self._pad_to_square(img)), + transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + self.normalize + ]) + + # Initialize parent class with tokenizer + super().__init__(tokenizer=tokenizer) + + + def __call__(self, images=None, text=None, max_length=None, **kwargs) -> BatchFeature: + """ + Process images and/or text inputs. + + Args: + images: Optional image input(s) + text: Optional text input(s) + **kwargs: Additional arguments + """ + if images is None and text is None: + raise ValueError("You have to specify at least one of `images` or `text`.") + + image_inputs = {} + if images is not None: + if isinstance(images, (list, tuple)): + images_ = torch.stack([self.transform(img) for img in images]) + else: + images_ = self.transform(images) + image_inputs = {"pixel_values": images_} + + text_inputs = {} + if text is not None: + text_inputs = self.tokenizer( + text, truncation=True, + add_special_tokens=True, + padding='longest', + max_length=max_length, + return_tensors="pt" + ) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def _pad_to_square(self, img): + # Calculate padding to make the image square + width, height = img.size + max_dim = max(width, height) + padding = [(max_dim - width) // 2, (max_dim - height) // 2] + padding += [max_dim - width - padding[0], max_dim - height - padding[1]] + return pad(img, padding, fill=255) # Assuming white padding + + +AutoProcessor.register(SimpleStarVectorProcessor, SimpleStarVectorProcessor) + + +class StarVectorConfig(PretrainedConfig): + model_type = "starvector" + + def __init__( + self, + starcoder_model_name: str = "bigcode/starcoderbase-1b", + image_encoder_type: str = "clip", + adapter_norm: str = "layer_norm", + image_size: int = 224, + max_length: int = 8192, + max_length_train: int = 8192, + use_flash_attn: bool = True, + use_cache: bool = True, + num_attention_heads: int = 16, + num_hidden_layers: int = 24, + vocab_size: int = 49152, + hidden_size: int = 2048, + num_kv_heads: int = 4, + torch_dtype: str = "bfloat16", + **kwargs, + ): + kwargs["torch_dtype"] = torch_dtype + self.starcoder_model_name = starcoder_model_name + self.image_encoder_type = image_encoder_type + self.adapter_norm = adapter_norm + self.image_size = image_size + self.max_length = max_length + self.max_length_train = max_length_train + self.use_flash_attn = use_flash_attn + self.use_cache = use_cache + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_kv_heads = num_kv_heads + super().__init__(**kwargs) + +class StarVectorForCausalLM(PreTrainedModel): + config_class = StarVectorConfig + _no_split_modules = [] + + def __init__(self, config: StarVectorConfig, **kwargs): + super().__init__(config) + starcoder_model_name = config.starcoder_model_name + if 'starcoder2' in starcoder_model_name: + from starvector.model.models.starvector_v2 import StarVectorStarCoder2 + self.model = StarVectorStarCoder2(config=config, **kwargs) + else: + from starvector.model.models.starvector_v1 import StarVectorStarCoder + self.model = StarVectorStarCoder(config=config, **kwargs) + + + @property + def supports_gradient_checkpointing(self): + # If the underlying transformer (e.g., the one in StarCoderModel) + # supports gradient checkpointing, delegate to it. + if hasattr(self.model, 'svg_transformer'): + return getattr(self.model.svg_transformer, 'supports_gradient_checkpointing', False) + return False + + def gradient_checkpointing_enable(self): + # Optionally, forward this call to the internal transformer. + if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'): + self.model.svg_transformer.gradient_checkpointing_enable() + + def forward(self, vision_embeds, input_ids, num_generations, attention_mask, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + completion_embeds = self.model._get_embeddings(input_ids) + inputs_embeds = torch.cat([vision_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1) + + transformer_outputs = self.model.svg_transformer.transformer.transformer( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ) + hidden_states = transformer_outputs[0] + + if num_logits_to_keep > 0: + lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + else: + lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states) + + loss = None + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + def generate_im2svg(self, batch, **kwargs): + return self.model.generate_im2svg(batch, **kwargs) + + def generate_im2text(self, batch, **kwargs): + return self.model.generate_im2text(batch, **kwargs) + + def process_images(self, images): + return self.model.image_encoder.process_images(images) + diff --git a/starvector/serve/__init__.py b/starvector/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/starvector/serve/constants.py b/starvector/serve/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..1f44887d603941fde989327b32fb93aaaf84dbb6 --- /dev/null +++ b/starvector/serve/constants.py @@ -0,0 +1,16 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +IMAGE_PLACEHOLDER = "" + +CLIP_QUERY_LENGTH = 257 + diff --git a/starvector/serve/controller.py b/starvector/serve/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5d747bbd8aedeb90cd314c05c9a8fcf38ee426 --- /dev/null +++ b/starvector/serve/controller.py @@ -0,0 +1,293 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION +from starvector.serve.util import build_logger, server_error_msg + +logger = build_logger("controller", "controller.log") + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stable_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,)) + self.heart_beat_thread.start() + + logger.info("Init controller") + + def register_worker(self, worker_name: str, check_heart_beat: bool, + worker_status: dict): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], + check_heart_beat, time.time()) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stable_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + logger.info(f"no worker: {params['model']}") + ret = { + "text": server_error_msg, + "error_code": 2, + } + yield json.dumps(ret).encode() + b"\0" + + try: + response = requests.post(worker_addr + "/worker_generate_stream", + json=params, stream=True, timeout=5) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + logger.info(f"worker timeout: {worker_addr}") + ret = { + "text": server_error_msg, + "error_code": 3, + } + yield json.dumps(ret).encode() + b"\0" + + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + return { + "model_names": list(model_names), + "speed": speed, + "queue_length": queue_length, + } + + +app = FastAPI() + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], + data.get("worker_status", None)) + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat( + data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument("--dispatch-method", type=str, choices=[ + "lottery", "shortest_queue"], default="shortest_queue") + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/starvector/serve/conversation.py b/starvector/serve/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..467e240c1e27c0aeae95aa3946e6a6f1a9c70938 --- /dev/null +++ b/starvector/serve/conversation.py @@ -0,0 +1,211 @@ +import dataclasses +from typing import List +from PIL import Image +import concurrent.futures +from bs4 import BeautifulSoup +import cairosvg +from io import BytesIO + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + image_prompt: str + roles: List[str] + messages: List[List[str]] + offset: int + version: str = "Unknown" + stop_sampling: bool = False + skip_next: bool = False + display_images: bool = False + task: str = "Im2SVG" + + def set_task(self, task): + self.task = task + + def get_image_prompt(self): + return self.image_prompt + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + from PIL import Image + image, image_process_mode = msg + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(255, 255, 255)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((224, 224)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if longest_edge != max(image.size): + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def append_message(self, role, message): + self.messages.append([role, message]) + + def download_files(self): + svg_string = self.messages[-1][-1][:-1] + image = self.render_svg(svg_string) + svg_out = clean_svg(svg_string) + + return image, svg_out + + def rasterize_svg(self, svg_string, resolution=224, dpi = 128, scale=2): + try: + svg_raster_bytes = cairosvg.svg2png( + bytestring=svg_string, + background_color='white', + output_width=resolution, + output_height=resolution, + dpi=dpi, + scale=scale) + svg_raster = Image.open(BytesIO(svg_raster_bytes)) + except: + try: + svg = self.clean_svg(svg_string) + svg_raster_bytes = cairosvg.svg2png( + bytestring=svg, + background_color='white', + output_width=resolution, + output_height=resolution, + dpi=dpi, + scale=scale) + svg_raster = Image.open(BytesIO(svg_raster_bytes)) + except: + svg_raster = Image.new('RGB', (resolution, resolution), color = 'white') + return svg_raster + + def clean_svg(self, svg_text, output_width=None, output_height=None): + soup = BeautifulSoup(svg_text, 'xml') # Read as soup to parse as xml + svg_bs4 = soup.prettify() # Prettify to get a string + svg_cairo = cairosvg.svg2svg(svg_bs4, output_width=output_width, output_height=output_height).decode() + svg_clean = "\n".join([line for line in svg_cairo.split("\n") if not line.strip().startswith(" W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'' + msg = img_str + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + image_prompt=self.image_prompt, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + version=self.version + + ) + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "image_prompt": self.image_prompt, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + } + return { + "system": self.system, + "image_prompt": self.image_prompt, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + +starvector_v1 = Conversation( + system="StarVector", + # prompt='', + image_prompt=' 0 else "" + ) + return state, dropdown_update + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + +def regenerate(state, image_process_mode, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, None, None, None) + (disable_btn,) * 6 + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, None, None) + (disable_btn,) * 6 + +def send_image(state, image, image_process_mode, request: gr.Request): + logger.info(f"send_image. ip: {request.client.host}.") + state.stop_sampling = False + if image is None: + state.skip_next = True + return (state, None, None, image) + (no_change_btn,) * 6 + + if image is not None: + text = (image, image_process_mode) + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], "▌") + state.skip_next = False + msg = state.to_gradio_svg_code()[0][1] + return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 6 + +def stop_sampling(state, image, request: gr.Request): + logger.info(f"stop_sampling. ip: {request.client.host}") + state.stop_sampling = True + return (state, None, None, image) + (disable_btn,) * 6 + +def http_bot(state, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, None, None) + (no_change_btn,) * 6 + return + + # Query worker address + controller_url = args.controller_url + ret = requests.post(controller_url + "/get_worker_address", + json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + state.messages[-1][-1] = server_error_msg + yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn) + return + + # Construct prompt + prompt = state.get_prompt() + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "num_beams": int(num_beams), + "temperature": float(temperature), + "len_penalty": float(len_penalty), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH), + } + logger.info(f"==== request ====\n{pload}") + + pload['images'] = state.get_images() + + state.messages[-1][-1] = "▌" + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn) + + try: + # Stream output + if state.stop_sampling: + state.messages[1][-1] = "▌" + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn) + return + + response = requests.post(worker_addr + "/worker_generate_stream", + headers=headers, json=pload, stream=True, timeout=100) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + # output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered + output = data["text"].strip() + state.messages[-1][-1] = output + "▌" + st = state.to_gradio_svg_code() + yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn) + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + + yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn) + return + time.sleep(0.03) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn) + return + + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 6 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "svg": state.messages[-1][-1], + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +title_markdown = (""" +# 💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text +[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | 📚 [[StarVector](https://arxiv.org/abs/2312.11556)] +""") + +sub_title_markdown = (""" Throw an image and vectorize it! The model expects vector-like images to generate the corresponding svg code.""") +tos_markdown = (""" +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. +Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""") + + +learn_more_markdown = (""" +### License +The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation. +""") + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +.gradio-container{ + max-width: 1200px!important +} + +#svg_render{ + padding: 20px !important; +} + +#svg_code{ + height: 200px !important; + overflow: scroll !important; + white-space: unset !important; + flex-shrink: unset !important; +} + + +h1{display: flex;align-items: center;justify-content: center;gap: .25em} +*{transition: width 0.5s ease, flex-grow 0.5s ease} +""" + +def build_demo(embed_mode, concurrency_count=10): + with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + if not embed_mode: + gr.Markdown(title_markdown) + gr.Markdown(sub_title_markdown) + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + interactive=True, + show_label=False, + container=False) + imagebox = gr.Image(type="pil") + image_process_mode = gr.Radio( + ["Resize", "Pad", "Default"], + value="Pad", + label="Preprocess for non-square image", visible=False) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples(examples=[ + [f"{cur_dir}/examples/sample-4.png"], + [f"{cur_dir}/examples/sample-7.png"], + [f"{cur_dir}/examples/sample-16.png"], + [f"{cur_dir}/examples/sample-17.png"], + [f"{cur_dir}/examples/sample-18.png"], + [f"{cur_dir}/examples/sample-0.png"], + [f"{cur_dir}/examples/sample-1.png"], + [f"{cur_dir}/examples/sample-6.png"], + ], inputs=[imagebox]) + + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button(value="Send", variant="primary") + + with gr.Accordion("Parameters", open=True) as parameter_row: + num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,) + temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.05, interactive=True, label="Temperature",) + len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2000, step=64, interactive=True, label="Max output tokens",) + + with gr.Column(scale=8): + with gr.Row(): + svg_code = gr.Code(label="SVG Code", elem_id='svg_code', min_width=200, interactive=False, lines=5) + with gr.Row(): + gr.Image(width=50, height=256, label="Rendered SVG", elem_id='svg_render') + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, visible=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + + if not embed_mode: + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + + regenerate_btn.click( + regenerate, + [state, image_process_mode], + [state, svg_code, svg_render, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens], + [state, svg_code, svg_render] + btn_list, + concurrency_limit=concurrency_count + ) + + submit_btn.click( + send_image, + [state, imagebox, image_process_mode], + [state, svg_code, svg_render, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens], + [state, svg_code, svg_render] + btn_list, + concurrency_limit=concurrency_count + ) + + clear_btn.click( + clear_history, + None, + [state, svg_code, svg_render] + btn_list, + queue=False + ) + + stop_btn.click( + stop_sampling, + [state, imagebox], + [state, imagebox] + btn_list, + queue=False + ).then( + clear_history, + None, + [state, svg_code, svg_render] + btn_list, + queue=False + ) + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params], + [state, model_selector], + _js=get_window_url_params, + ) + elif args.model_list_mode == "reload": + demo.load( + load_demo_refresh_model_list, + None, + [state, model_selector], + queue=False + ) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=15) + parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + + logger.info(args) + demo = build_demo(args.embed, concurrency_count=args.concurrency_count) + demo.queue( + api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share + ) \ No newline at end of file diff --git a/starvector/serve/gradio_web_server.py b/starvector/serve/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..71abad36b255d87fac27b254c2e56c3a87615a4d --- /dev/null +++ b/starvector/serve/gradio_web_server.py @@ -0,0 +1,562 @@ +import argparse +import datetime +import json +import os +import time +import gradio as gr +import requests +from starvector.serve.conversation import default_conversation +from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH +from starvector.serve.util import (build_logger, server_error_msg) + +logger = build_logger("gradio_web_server", "gradio_web_server.log") +headers = {"User-Agent": "StarVector Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + +priority = { + "starvector-1b-im2svg": "aaaaaaa", +} + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + +def get_model_list(): + ret = requests.post(args.controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(args.controller_url + "/list_models") + models = ret.json()["models"] + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update( + value=model, visible=True) + + state = default_conversation.copy() + return state, dropdown_update + +mapping_model_task = { + 'Image2SVG': 'im2svg', + 'Text2SVG': 'text2svg' +} + +def get_models_dropdown_from_task(task): + models = get_model_list() + models = [model for model in models if mapping_model_task[task] in model] + dropdown_update = gr.Dropdown.update( + choices=models, + value=models[0] if len(models) > 0 else "" + ) + return dropdown_update + + +def load_demo_refresh_model_list(task, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + dropdown_update = get_models_dropdown_from_task(task) + state = default_conversation.copy() + return state, dropdown_update + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 7 + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 7 + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 7 + +def regenerate(state, image_process_mode, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, None, None, None) + (disable_btn,) * 7 + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, None, None) + (disable_btn,) * 7 + +def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request): + logger.info(f"send_data. ip: {request.client.host}.") + if task == 'Image2SVG': + if image is None: + state.skip_next = True + return (state, None, None, image) + (no_change_btn,) * 7 + + if image is not None: + image_message = (image, image_process_mode) + state.append_message(state.roles[0], image_message) + state.append_message(state.roles[1], "▌") + state.skip_next = False + msg = state.to_gradio_svg_code()[0][1] + return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7 + else: + if text_caption is None: + state.skip_next = True + return (state, None, None, image) + (no_change_btn,) * 7 + + state.append_message(state.roles[0], text_caption) + state.append_message(state.roles[1], "▌") + state.skip_next = False + msg = state.to_gradio_svg_code()[0][1] + return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7 + +def download_files(state, request: gr.Request): + logger.info(f"download_files. ip: {request.client.host}") + svg_str, image = state.download_files() + + # TODO: Figure out how to download the SVG in the users browser, idk how to do it now + +def update_task(task): + dropdown_update = get_models_dropdown_from_task(task) + + if task == "Text2SVG": + return 1.0, 0.9, 0.95, dropdown_update + else: + return 0.6, 0.9, 0.95, dropdown_update + + +def stop_sampling(state, image, request: gr.Request): + logger.info(f"stop_sampling. ip: {request.client.host}") + state.stop_sampling = True + return (state, None, None, image) + (disable_btn,) * 7 + +def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, None, None) + (no_change_btn,) * 7 + return + + # Query worker address + controller_url = args.controller_url + ret = requests.post(controller_url + "/get_worker_address", + json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + state.messages[-1][-1] = server_error_msg + yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn) + return + + # Construct prompt + if task_selector == "Image2SVG": + prompt = state.get_image_prompt() + else: + prompt = text_caption + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "num_beams": int(num_beams), + "temperature": float(temperature), + "len_penalty": float(len_penalty), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH), + } + logger.info(f"==== request ====\n{pload}") + + pload['images'] = state.get_images() + + state.messages[-1][-1] = "▌" + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + + try: + # Stream output + if state.stop_sampling: + state.messages[1][-1] = "▌" + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn) + return + + response = requests.post(worker_addr + "/worker_generate_stream", + headers=headers, json=pload, stream=True, timeout=10) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + # output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered + output = data["text"].strip() + state.messages[-1][-1] = output + "▌" + st = state.to_gradio_svg_code() + yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn) + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + + yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn) + return + time.sleep(0.03) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn) + return + + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "svg": state.messages[-1][-1], + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +title_markdown = (""" +# 💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text + +[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | 📚 [[StarVector](https://arxiv.org/abs/2312.11556)]""") + +sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. \ +**Note**: The current model works on vector-like images like icons and or vector-like designs.""") +tos_markdown = (""" +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. +Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""") + +learn_more_markdown = (""" +### License +The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation. +""") + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +.gradio-container{ + max-width: 1200px!important +} + +.ͼ1 .cm-content { + white-space: unset !important; + flex-shrink: unset !important; +} + +.ͼ2p .cm-scroller { + max-height: 200px; + overflow: scroll; +} + +#svg_render{ + padding: 20px !important; +} + +#submit_btn{ + max-height: 40px; +} + +.selector{ + max-height: 100px; +} +h1{display: flex;align-items: center;justify-content: center;gap: .25em} +*{transition: width 0.5s ease, flex-grow 0.5s ease} +""" +def build_demo(embed_mode): + svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300) + svg_code = gr.Code(label="SVG Code", elem_id='svg_code', interactive=True, lines=5) + + with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + if not embed_mode: + gr.Markdown(title_markdown) + gr.Markdown(sub_title_markdown) + with gr.Row(): + with gr.Column(scale=4): + task_selector = gr.Dropdown( + choices=["Image2SVG", "Text2SVG"], + value="Image2SVG", + label="Task", + interactive=True, + show_label=True, + container=True, + elem_id="task_selector", + elem_classes=["selector"], + ) + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + label="Model", + interactive=True, + show_label=True, + container=True, + elem_classes=["selector"], + ) + + imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox") + image_process_mode = gr.Radio( + ["Resize", "Pad", "Default"], + value="Pad", + label="Preprocess for non-square image", visible=False) + + # Text input + text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption") + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples(examples=[ + [f"{cur_dir}/examples/sample-4.png"], + [f"{cur_dir}/examples/sample-7.png"], + [f"{cur_dir}/examples/sample-16.png"], + [f"{cur_dir}/examples/sample-17.png"], + [f"{cur_dir}/examples/sample-18.png"], + [f"{cur_dir}/examples/sample-0.png"], + [f"{cur_dir}/examples/sample-1.png"], + [f"{cur_dir}/examples/sample-6.png"], + ], inputs=[imagebox], elem_id="examples") + + submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True) + + with gr.Accordion("Parameters", open=False): + num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,) + temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.9, step=0.05, interactive=True, label="Temperature",) + len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=8192, step=64, interactive=True, label="Max output tokens",) + + with gr.Column(scale=9): + with gr.Row(): + svg_code.render() + with gr.Row(): + svg_render.render() + + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, visible=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + download_btn = gr.Button(value="Download SVG", interactive=False, visible=False) + + if not embed_mode: + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + + regenerate_btn.click( + regenerate, + [state, image_process_mode], + [state, svg_code, svg_render, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens], + [state, svg_code, svg_render] + btn_list) + + submit_btn.click( + send_data, + [state, imagebox, image_process_mode, text_caption, task_selector], + [state, svg_code, svg_render, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens], + [state, svg_code, svg_render] + btn_list + ) + + clear_btn.click( + clear_history, + None, + [state, svg_code, svg_render] + btn_list, + queue=False + ) + + stop_btn.click( + stop_sampling, + [state, imagebox], + [state, imagebox] + btn_list, + queue=False + ).then( + clear_history, + None, + [state, svg_code, svg_render] + btn_list, + queue=False + ) + + download_btn.click( + download_files, + [state], + None, + queue=False + ) + task_selector.change( + update_task, + inputs=[task_selector], + outputs=[len_penalty, temperature, top_p, model_selector], + queue=False, + _js=""" + function(task) { + var imageBoxElement = document.getElementById("imagebox"); + var textCaptionElement = document.getElementById("text_caption"); + var examplesElement = document.getElementById("examples"); + if (task === "Text2SVG") { + imageBoxElement.style.display = "none"; + textCaptionElement.style.display = "block"; + examplesElement.style.display = "none"; + } else if (task === "Image2SVG") { + imageBoxElement.style.display = "block"; + textCaptionElement.style.display = "none"; + examplesElement.style.display = "block"; + } + return task; + } + """ + ) + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params, task_selector], + [state, model_selector], + _js=""" + function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + + } + """, + queue=False + ) + elif args.model_list_mode == "reload": + demo.load( + load_demo_refresh_model_list, + [task_selector], + [state, model_selector], + _js=""" + function(task) { + var textCaptionElement = document.getElementById("text_caption"); + var autoScrollBottom = true; + textCaptionElement.style.display = "none"; + function updateScroll(){ + if (autoScrollBottom) { + var element = document.getElementsByClassName("cm-scroller")[0]; + element.scrollTop = element.scrollHeight; + } + } + function handleScroll() { + var element = document.getElementsByClassName("cm-scroller")[0]; + //if (element.scrollHeight - element.scrollTop === element.clientHeight) { + if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) { + // User has scrolled to the bottom, enable auto-scrolling + autoScrollBottom = true; + console.log("bottom"); + } else { + console.log("not bottom"); + // User has scrolled away from the bottom, disable auto-scrolling + autoScrollBottom = false; + } + } + setInterval(updateScroll,500); + var element = document.getElementsByClassName("cm-scroller")[0]; + element.addEventListener("scroll", handleScroll); + + return task; + } + + """, + queue=False, + ) + + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--model-list-mode", type=str, default="once", + choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + + logger.info(args) + demo = build_demo(args.embed) + demo.queue( + concurrency_count=args.concurrency_count, + api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share + ) \ No newline at end of file diff --git a/starvector/serve/model_worker.py b/starvector/serve/model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7a45aefcc92d042b448e0c83321909a4628c47 --- /dev/null +++ b/starvector/serve/model_worker.py @@ -0,0 +1,269 @@ +""" +A model worker executes the model. +""" +import argparse +import asyncio +import json +import time +import threading +import uuid +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import torch +import uvicorn +from functools import partial +from starvector.serve.constants import WORKER_HEART_BEAT_INTERVAL, CLIP_QUERY_LENGTH +from starvector.serve.util import (build_logger, server_error_msg, + pretty_print_semaphore) +from starvector.model.builder import load_pretrained_model +from starvector.serve.util import process_images, load_image_from_base64 +from threading import Thread +from transformers import TextIteratorStreamer + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 +model_semaphore = None + +def heart_beat_worker(controller): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, + worker_id, no_register, + model_path, model_base, model_name, + load_8bit, load_4bit, device): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + if model_name is None: + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + self.model_name = model_paths[-2] + "_" + model_paths[-1] + else: + self.model_name = model_paths[-1] + else: + self.model_name = model_name + + if "text2svg" in self.model_name.lower(): + self.task = "Text2SVG" + elif "im2svg" in self.model_name.lower(): + self.task = "Image2SVG" + + self.device = device + logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( + model_path, device=self.device, load_in_8bit=load_8bit, load_in_4bit=load_4bit) + self.model.to(torch.bfloat16) + self.is_multimodal = 'starvector' in self.model_name.lower() + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,)) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status() + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}") + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length()}, timeout=5) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len( + model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + @torch.inference_mode() + def generate_stream(self, params): + tokenizer, model, image_processor, task = self.tokenizer, self.model, self.image_processor, self.task + + num_beams = int(params.get("num_beams", 1)) + temperature = float(params.get("temperature", 1.0)) + len_penalty = float(params.get("len_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = getattr(model.config, 'max_position_embeddings', 8192) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15) + prompt = params["prompt"] + + if task == "Image2SVG": + images = params.get("images", None) + for b64_image in images: + if b64_image is not None and self.is_multimodal: + image = load_image_from_base64(b64_image) + image = process_images(image, image_processor) + image = image.to(self.model.device, dtype=torch.float16) + else: + image = None + + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192) + max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH) + pre_pend = prompt + batch = {} + batch["image"] = image + generate_method = model.model.generate_im2svg + else: + max_new_tokens = min(int(params.get("max_new_tokens", 128)), 8192) + pre_pend = "" + batch = {} + batch['caption'] = [prompt] + # White PIL image + batch['image'] = torch.zeros((3, 256, 256), dtype=torch.float16).to(self.model.device) + generate_method = model.model.generate_text2svg + + if max_new_tokens < 1: + yield json.dumps({"text": prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" + return + + thread = Thread(target=generate_method, kwargs=dict( + batch=batch, + prompt=prompt, + use_nucleus_sampling=True, + num_beams=num_beams, + temperature=temperature, + length_penalty=len_penalty, + top_p=top_p, + max_length=max_new_tokens, + streamer=streamer, + )) + thread.start() + + generated_text = pre_pend + for new_text in streamer: + if new_text == " ": + continue + generated_text += new_text + # if generated_text.endswith(stop_str): + # generated_text = generated_text[:-len(stop_str)] + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" + + def generate_stream_gate(self, params): + try: + for x in self.generate_stream(params): + yield x + except ValueError as e: + print("Caught ValueError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError as e: + print("Caught torch.cuda.CudaError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + print("Caught Unknown Error", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + +app = FastAPI() + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, + default="http://localhost:21002") + parser.add_argument("--controller-address", type=str, + default="http://localhost:21001") + parser.add_argument("--model-path", type=str, default="joanrodai/starvector-1.4b") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--model-name", type=str) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.") + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.multi_modal: + logger.warning("Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.") + + worker = ModelWorker(args.controller_address, + args.worker_address, + worker_id, + args.no_register, + args.model_path, + args.model_base, + args.model_name, + args.load_8bit, + args.load_4bit, + args.device) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") \ No newline at end of file diff --git a/starvector/serve/register_worker.py b/starvector/serve/register_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2c40295e0351f25709ba25554c9329f15bf0d2 --- /dev/null +++ b/starvector/serve/register_worker.py @@ -0,0 +1,26 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/starvector/serve/util.py b/starvector/serve/util.py new file mode 100644 index 0000000000000000000000000000000000000000..c0cb9ec88957e11071fd72d5b6af624d8b3547a6 --- /dev/null +++ b/starvector/serve/util.py @@ -0,0 +1,129 @@ +import logging +import logging.handlers +import os +import sys +import requests +from starvector.serve.constants import LOGDIR +from PIL import Image +from io import BytesIO +import base64 + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True, encoding='UTF-8') + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + +def process_images(image, image_processor): + im = image_processor(image) + im = im.unsqueeze(0) + return im \ No newline at end of file diff --git a/starvector/serve/vllm_api_gradio/controller.py b/starvector/serve/vllm_api_gradio/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0636e0bbe3f65b09b79780d90589afe8c205eb --- /dev/null +++ b/starvector/serve/vllm_api_gradio/controller.py @@ -0,0 +1,292 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION +from starvector.serve.util import build_logger, server_error_msg + +logger = build_logger("controller", "controller.log") + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stable_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread( + target=heart_beat_controller, args=(self,)) + self.heart_beat_thread.start() + + logger.info("Init controller") + + def register_worker(self, worker_name: str, check_heart_beat: bool, + worker_status: dict): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo( + worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], + check_heart_beat, time.time()) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), + p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stable_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + logger.info(f"no worker: {params['model']}") + ret = { + "text": server_error_msg, + "error_code": 2, + } + yield json.dumps(ret).encode() + b"\0" + + try: + response = requests.post(worker_addr + "/worker_generate_stream", + json=params, stream=True, timeout=10) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + logger.info(f"worker timeout: {worker_addr}") + ret = { + "text": server_error_msg, + "error_code": 3, + } + yield json.dumps(ret).encode() + b"\0" + + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + return { + "model_names": list(model_names), + "speed": speed, + "queue_length": queue_length, + } + +app = FastAPI() + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker( + data["worker_name"], data["check_heart_beat"], + data.get("worker_status", None)) + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat( + data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument("--dispatch-method", type=str, choices=[ + "lottery", "shortest_queue"], default="shortest_queue") + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/starvector/serve/vllm_api_gradio/gradio_vllm.py b/starvector/serve/vllm_api_gradio/gradio_vllm.py new file mode 100644 index 0000000000000000000000000000000000000000..14953e7565e6ee8ba8b26e94e91e93d4dd0975f2 --- /dev/null +++ b/starvector/serve/vllm_api_gradio/gradio_vllm.py @@ -0,0 +1,79 @@ +import argparse + +import gradio as gr +from openai import OpenAI + +# Argument parser setup +parser = argparse.ArgumentParser( + description="Chatbot Interface with Customizable Parameters" +) +parser.add_argument( + "--model-url", type=str, default="http://localhost:8000/v1", help="Model URL" +) +parser.add_argument( + "-m", + "--model", + type=str, + default="ServiceNow/starvector-1.4b-im2svg-v6", + help="Model name for the chatbot", +) +parser.add_argument( + "--temp", type=float, default=0.8, help="Temperature for text generation" +) +parser.add_argument( + "--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs" +) +parser.add_argument("--host", type=str, default=None) +parser.add_argument("--port", type=int, default=8001) + +# Parse the arguments +args = parser.parse_args() + +# Set OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = args.model_url + +# Create an OpenAI client to interact with the API server +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + + +def predict(message, history): + # Convert chat history to OpenAI format + history_openai_format = [ + {"role": "system", "content": "You are a great ai assistant."} + ] + for human, assistant in history: + history_openai_format.append({"role": "user", "content": human}) + history_openai_format.append({"role": "assistant", "content": assistant}) + history_openai_format.append({"role": "user", "content": message}) + + # Create a chat completion request and send it to the API server + stream = client.chat.completions.create( + model=args.model, # Model name to use + messages=history_openai_format, # Chat history + temperature=args.temp, # Temperature for text generation + stream=True, # Stream response + extra_body={ + "repetition_penalty": 1, + "stop_token_ids": ( + [int(id.strip()) for id in args.stop_token_ids.split(",") if id.strip()] + if args.stop_token_ids + else [] + ), + }, + ) + + # Read and return generated text from response stream + partial_message = "" + for chunk in stream: + partial_message += chunk.choices[0].delta.content or "" + yield partial_message + + +# Create and launch a chat interface with Gradio +gr.ChatInterface(predict).queue().launch( + server_name=args.host, server_port=args.port, share=True +) diff --git a/starvector/serve/vllm_api_gradio/gradio_web_server.py b/starvector/serve/vllm_api_gradio/gradio_web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6cbc020a09bd12a4938f74fd38f282a9bb8847 --- /dev/null +++ b/starvector/serve/vllm_api_gradio/gradio_web_server.py @@ -0,0 +1,745 @@ +import argparse +import datetime +import json +import os +import time +import gradio as gr +import requests +from starvector.serve.conversation import default_conversation +from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH +from starvector.serve.util import (build_logger, server_error_msg) + +logger = build_logger("gradio_web_server", "gradio_web_server.log") +headers = {"User-Agent": "StarVector Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + +priority = { + "starvector-1b": "aaaaaaa", +} + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + +def get_model_list(): + ret = requests.post(args.controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(args.controller_url + "/list_models") + models = ret.json()["models"] + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update( + value=model, visible=True) + + state = default_conversation.copy() + return state, dropdown_update + +mapping_model_task = { + 'Image2SVG': 'im2svg', + 'Text2SVG': 'text2svg' +} + +def get_models_dropdown_from_task(task): + models = get_model_list() + models = [model for model in models if mapping_model_task[task] in model] + dropdown_update = gr.Dropdown.update( + choices=models, + value=models[0] if len(models) > 0 else "" + ) + return dropdown_update + + +def load_demo_refresh_model_list(task, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + dropdown_update = get_models_dropdown_from_task(task) + state = default_conversation.copy() + return state, dropdown_update + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 7 + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 7 + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 7 + +def regenerate(state, image_process_mode, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, None, None, None) + (disable_btn,) * 7 + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, None, None) + (disable_btn,) * 7 + +def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request): + logger.info(f"send_data. ip: {request.client.host}.") + if task == 'Image2SVG': + if image is None: + state.skip_next = True + return (state, None, None, image) + (no_change_btn,) * 7 + + # Reset the conversation state when a new image is uploaded + state = default_conversation.copy() + + if image is not None: + image_message = (image, image_process_mode) + state.append_message(state.roles[0], image_message) + state.append_message(state.roles[1], "▌") + state.skip_next = False + msg = state.to_gradio_svg_code()[0][1] + return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7 + else: + if text_caption is None: + state.skip_next = True + return (state, None, None, image) + (no_change_btn,) * 7 + + # Reset the conversation state for new text inputs too + state = default_conversation.copy() + + state.append_message(state.roles[0], text_caption) + state.append_message(state.roles[1], "▌") + state.skip_next = False + msg = state.to_gradio_svg_code()[0][1] + return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7 + +def download_files(state, request: gr.Request): + logger.info(f"download_files. ip: {request.client.host}") + svg_str, image = state.download_files() + + # TODO: Figure out how to download the SVG in the users browser, idk how to do it now + +def update_task(task): + dropdown_update = get_models_dropdown_from_task(task) + + if task == "Text2SVG": + return 1.0, 0.9, 0.95, dropdown_update + else: + return 0.6, 0.9, 0.95, dropdown_update + + +def stop_sampling(state, image, request: gr.Request): + logger.info(f"stop_sampling. ip: {request.client.host}") + state.stop_sampling = True + return (state, None, None, image) + (disable_btn,) * 7 + +def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, None, None) + (no_change_btn,) * 7 + return + + # Query worker address + controller_url = args.controller_url + ret = requests.post(controller_url + "/get_worker_address", + json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + state.messages[-1][-1] = server_error_msg + yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn) + return + + # Construct prompt + if task_selector == "Image2SVG": + prompt = state.get_image_prompt() + else: + prompt = text_caption + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "num_beams": int(num_beams), + "temperature": float(temperature), + "len_penalty": float(len_penalty), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH), + } + logger.info(f"==== request ====\n{pload}") + + pload['images'] = state.get_images() + + state.messages[-1][-1] = "▌" + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + + try: + # Stream output + if state.stop_sampling: + state.messages[1][-1] = "▌" + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn) + return + + response = requests.post(worker_addr + "/worker_generate_stream", + headers=headers, json=pload, stream=True, timeout=10) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + # output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered + output = data["text"].strip() + state.messages[-1][-1] = output + "▌" + st = state.to_gradio_svg_code() + # Explicitly set the string value without HTML escaping + yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn) + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + st = state.to_gradio_svg_code() + + yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn) + return + time.sleep(0.01) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn) + return + + yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "svg": state.messages[-1][-1], + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + # Fix: Replace 'btn_list' with (enable_btn,) * 7 + return (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7 + +title_markdown = (""" +# 💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text + + + + + 🌎 Project Page + + + + + + GitHub + + + + + 🤗 StarVector-1B + + + + + 🤗 StarVector-8B + + + + + ⭐ SVG-Stack + + + + + 🏆 SVG-Bench + + + + + 📚 arXiv + + + +""") + +sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. +**Limitations**: The current model works on vector-like images like icons and or vector-like designs. Images with low resolution may not be vectorized well.""") +tos_markdown = (""" +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. +Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""") + +learn_more_markdown = (""" +### License +The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation. +""") + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +.gradio-container{ + max-width: 1200px!important +} + +.ͼ1 .cm-content { + white-space: unset !important; + flex-shrink: unset !important; +} + +.ͼ2p .cm-scroller { + max-height: 160px; + overflow-y: auto !important; + overflow-x: auto !important; +} + +#svg_code .cm-editor { + height: 115px; + overflow: hidden; +} + +#svg_code .cm-scroller { + overflow: auto !important; +} + +/* New styles to make svg code textbox smaller but still readable */ +#svg_code textarea { + height: 115px !important; +} + +#svg_render{ + padding: 20px !important; +} + +#submit_btn{ + max-height: 40px; +} + +.selector{ + max-height: 100px; +} +h1{display: flex;align-items: center;justify-content: center;gap: .25em} +*{transition: width 0.5s ease, flex-grow 0.5s ease} + +/* Custom SVG code display area */ +#custom_svg_code { + height: 100px; + overflow: auto; + border: 1px solid #ccc; + padding: 10px; + background-color: #f5f5f5; + font-family: monospace; + white-space: pre; + font-size: 14px; + line-height: 1.4; + border-radius: 4px; +} + +/* Ensure the content is properly displayed */ +#custom_svg_code code { + display: block; + overflow-x: auto; +} +""" + +# Create a smarter version of the JavaScript for auto-scrolling +code_scroll_js = """ +console.log('SVG Auto-scroll script loaded'); + +// Track if user has manually scrolled up +let userHasScrolledUp = false; +let lastKnownScrollHeight = 0; +let lastKnownScrollTop = 0; + +function setupAutoScroll() { + // Find the SVG code textbox + const svgCodeElements = document.querySelectorAll('#svg_code textarea, #svg_code .cm-content'); + console.log('SVG code elements found:', svgCodeElements.length); + + // Add scroll event listeners to detect manual scrolling + svgCodeElements.forEach(el => { + if (el) { + el.addEventListener('scroll', () => { + const isAtBottom = Math.abs((el.scrollHeight - el.scrollTop - el.clientHeight)) < 30; + + // If user scrolls up, stop auto-scrolling + if (!isAtBottom && lastKnownScrollTop > el.scrollTop) { + console.log('User scrolled up, pausing auto-scroll'); + userHasScrolledUp = true; + } + + // If user scrolls to bottom, resume auto-scrolling + if (isAtBottom) { + console.log('User scrolled to bottom, resuming auto-scroll'); + userHasScrolledUp = false; + } + + lastKnownScrollTop = el.scrollTop; + lastKnownScrollHeight = el.scrollHeight; + }); + } + }); + + // Set up an interval to scroll to bottom only if user hasn't scrolled up + setInterval(() => { + svgCodeElements.forEach(el => { + if (el && el.scrollHeight > 0) { + // Only auto-scroll if content has changed or user hasn't scrolled up + if (!userHasScrolledUp || lastKnownScrollHeight !== el.scrollHeight) { + console.log('Auto-scrolling, scrollHeight:', el.scrollHeight); + el.scrollTop = el.scrollHeight; + lastKnownScrollHeight = el.scrollHeight; + lastKnownScrollTop = el.scrollTop; + } + } + }); + }, 500); + + console.log('Smart auto-scroll setup complete'); +} + +// Also observe for content changes to handle new content being added +function observeContentChanges() { + const svgCodeContainer = document.getElementById('svg_code'); + if (svgCodeContainer) { + const observer = new MutationObserver((mutations) => { + // If new content is added and user hasn't scrolled up, scroll to bottom + if (!userHasScrolledUp) { + const svgCodeElements = document.querySelectorAll('#svg_code textarea, #svg_code .cm-content'); + svgCodeElements.forEach(el => { + if (el && el.scrollHeight > 0) { + el.scrollTop = el.scrollHeight; + } + }); + } + }); + + observer.observe(svgCodeContainer, { + childList: true, + subtree: true, + characterData: true + }); + } +} + +// Try to run immediately +setupAutoScroll(); +observeContentChanges(); + +// Also try when DOM is fully loaded +document.addEventListener('DOMContentLoaded', () => { + setupAutoScroll(); + observeContentChanges(); +}); + +// And on window load +window.addEventListener('load', () => { + setupAutoScroll(); + observeContentChanges(); +}); +""" + +def build_demo(embed_mode): + svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300) + + # Use a Textbox instead of Code component + svg_code = gr.Textbox( + label="SVG Code", + elem_id="svg_code", + lines=9, + value="", + max_lines=9, + show_copy_button=True, + ) + + with gr.Blocks( + title="StarVector", + theme=gr.themes.Default(), + css=block_css, + head=f"" # Use head parameter instead of HTML component + ) as demo: + # Add a dummy component that we'll use to trigger our JavaScript + dummy = gr.Number(value=0, visible=False) + + state = gr.State() + if not embed_mode: + gr.Markdown(title_markdown) + gr.Markdown(sub_title_markdown) + with gr.Row(): + with gr.Column(scale=4): + task_selector = gr.Dropdown( + choices=["Image2SVG", "Text2SVG"], + value="Image2SVG", + label="Task", + interactive=True, + show_label=True, + container=True, + elem_id="task_selector", + elem_classes=["selector"], + ) + model_selector = gr.Dropdown( + choices=models, + value=models[0] if len(models) > 0 else "", + label="Model", + interactive=True, + show_label=True, + container=True, + elem_classes=["selector"], + ) + + imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox") + + # Move the submit button here - right after the imagebox + submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True) + + image_process_mode = gr.Radio( + ["Resize", "Pad", "Default"], + value="Pad", + label="Preprocess for non-square image", visible=False) + + # Text input + text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption") + + cur_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + gr.Examples(examples=[ + [f"{cur_dir}/examples/sample-4.png"], + [f"{cur_dir}/examples/sample-7.png"], + [f"{cur_dir}/examples/sample-16.png"], + [f"{cur_dir}/examples/sample-17.png"], + [f"{cur_dir}/examples/sample-18.png"], + [f"{cur_dir}/examples/sample-0.png"], + [f"{cur_dir}/examples/sample-1.png"], + [f"{cur_dir}/examples/sample-6.png"], + ], inputs=[imagebox], elem_id="examples") + + # Remove the submit button from here since we moved it above + + with gr.Accordion("Parameters", open=False): + num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,) + temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.05, interactive=True, label="Temperature",) + len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, interactive=True, label="Length Penalty",) + top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",) + max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=8192, step=64, interactive=True, label="Max output tokens",) + + with gr.Column(scale=9): + with gr.Row(): + svg_code.render() + with gr.Row(): + svg_render.render() + + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, visible=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + download_btn = gr.Button(value="Download SVG", interactive=False, visible=False) + + if not embed_mode: + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn] + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [upvote_btn, downvote_btn, flag_btn], + queue=False + ) + + regenerate_btn.click( + regenerate, + [state, image_process_mode], + [state, svg_code, svg_render, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens], + [state, svg_code, svg_render] + btn_list + ) + + submit_btn.click( + send_data, + [state, imagebox, image_process_mode, text_caption, task_selector], + [state, svg_code, svg_render, imagebox] + btn_list, + queue=False + ).then( + http_bot, + [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens], + [state, svg_code, svg_render] + btn_list + ) + + clear_btn.click( + clear_history, + None, + [state, svg_code, svg_render] + btn_list, + queue=False + ) + + stop_btn.click( + stop_sampling, + [state, imagebox], + [state, imagebox] + btn_list, + queue=False + ).then( + clear_history, + None, + [state, svg_code, svg_render] + btn_list, + queue=False + ) + + download_btn.click( + download_files, + [state], + None, + queue=False + ) + task_selector.change( + update_task, + inputs=[task_selector], + outputs=[len_penalty, temperature, top_p, model_selector], + queue=False, + _js=""" + function(task) { + var imageBoxElement = document.getElementById("imagebox"); + var textCaptionElement = document.getElementById("text_caption"); + var examplesElement = document.getElementById("examples"); + if (task === "Text2SVG") { + imageBoxElement.style.display = "none"; + textCaptionElement.style.display = "block"; + examplesElement.style.display = "none"; + } else if (task === "Image2SVG") { + imageBoxElement.style.display = "block"; + textCaptionElement.style.display = "none"; + examplesElement.style.display = "block"; + } + return task; + } + """ + ) + + if args.model_list_mode == "once": + demo.load( + load_demo, + [url_params, task_selector], + [state, model_selector], + _js=""" + function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + + } + + """, + queue=False + ) + elif args.model_list_mode == "reload": + demo.load( + load_demo_refresh_model_list, + [task_selector], + [state, model_selector], + _js=""" + function(task) { + var textCaptionElement = document.getElementById("text_caption"); + textCaptionElement.style.display = "none"; + return task; + } + """, + queue=False, + ) + + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + # Trigger our JavaScript whenever the page loads + demo.load(lambda: 0, outputs=dummy, _js=f"() => {{ {code_scroll_js}; return 0; }}") + + return demo + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--model-list-mode", type=str, default="once", + choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + + logger.info(args) + demo = build_demo(args.embed) + demo.queue( + concurrency_count=args.concurrency_count, + api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share + ) \ No newline at end of file diff --git a/starvector/serve/vllm_api_gradio/model_worker.py b/starvector/serve/vllm_api_gradio/model_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..22532310927b4a83158ac433f45bc23120175b2d --- /dev/null +++ b/starvector/serve/vllm_api_gradio/model_worker.py @@ -0,0 +1,301 @@ +""" +A model worker executes the model. +""" +import argparse +import asyncio +import json +import time +import threading +import uuid +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import torch +import uvicorn +from functools import partial +from starvector.serve.constants import WORKER_HEART_BEAT_INTERVAL, CLIP_QUERY_LENGTH +from starvector.serve.util import (build_logger, server_error_msg, + pretty_print_semaphore) +from starvector.serve.util import process_images, load_image_from_base64 +from threading import Thread +from transformers import TextIteratorStreamer +from openai import OpenAI + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 +model_semaphore = None + +def heart_beat_worker(controller): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, vllm_base_url, + worker_id, no_register, model_name, openai_api_key): + + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + self.vllm_base_url = vllm_base_url + self.model_name = model_name + self.openai_api_key = openai_api_key + + self.client = OpenAI( + api_key=openai_api_key, + base_url=vllm_base_url, + ) + + if "text2svg" in self.model_name.lower(): + self.task = "Text2SVG" + elif "im2svg" in self.model_name.lower(): + self.task = "Image2SVG" + + logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") + + self.is_multimodal = 'starvector' in self.model_name.lower() + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=heart_beat_worker, args=(self,)) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = { + "worker_name": self.worker_addr, + "check_heart_beat": True, + "worker_status": self.get_status() + } + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}") + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length()}, timeout=30) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len( + model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + def generate_stream(self, params): + + num_beams = int(params.get("num_beams", 1)) + temperature = float(params.get("temperature", 1.0)) + len_penalty = float(params.get("len_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = 1000 + + # prompt = params["prompt"] + prompt = " 0 else None + + if not image_base_64: + yield json.dumps({"text": "Error: No image provided for Image2SVG task", "error_code": 1}).encode() + b"\0" + return + + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192) + max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH) + + # Use the chat completions endpoint + vllm_endpoint = f"{self.vllm_base_url}/v1/chat/completions" + + # Use a model name that vLLM recognizes + # The full path including the organization is important + model_name_for_vllm = params['model'] + + # Format payload for the chat completions endpoint + request_payload = { + "model": model_name_for_vllm, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": ""}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base_64}"}} + ] + } + ], + "max_tokens": 7500, + "temperature": temperature, + "top_p": top_p, + "stream": True + } + + # Log the request for debugging + logger.info(f"Request to vLLM: {vllm_endpoint}") + logger.info(f"Using model: {model_name_for_vllm}") + + # Use requests instead of OpenAI client + response = requests.post( + vllm_endpoint, + json=request_payload, + stream=True, + headers={"Content-Type": "application/json"} + ) + + # Log the response status for debugging + logger.info(f"Response status: {response.status_code}") + + if response.status_code != 200: + try: + error_detail = response.json() + logger.error(f"Error from vLLM server: {error_detail}") + except json.JSONDecodeError: + logger.error(f"Error from vLLM server: {response.text}") + + yield json.dumps({"text": f"Error communicating with model server: {response.status_code}", "error_code": 1}).encode() + b"\0" + return + + # Process the streaming response + output_text = "" + for line in response.iter_lines(): + if line: + # Skip the "data: " prefix if present + if line.startswith(b"data: "): + line = line[6:] + + if line.strip() == b"[DONE]": + break + + try: + data = json.loads(line) + if "choices" in data and len(data["choices"]) > 0: + delta = data["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + output_text += content + yield json.dumps({"text": output_text, "error_code": 0}).encode() + b"\0" + except json.JSONDecodeError: + logger.error(f"Failed to parse line as JSON: {line}") + continue + + # Send final output if not already sent + if output_text: + yield json.dumps({"text": output_text, "error_code": 0}).encode() + b"\0" + + elif self.task == "Text2SVG": + # Implementation for Text2SVG task would go here + yield json.dumps({"text": "Text2SVG task not implemented yet", "error_code": 1}).encode() + b"\0" + return + + def generate_stream_gate(self, params): + try: + for x in self.generate_stream(params): + yield x + except ValueError as e: + print("Caught ValueError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError as e: + print("Caught torch.cuda.CudaError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + print("Caught Unknown Error", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + +app = FastAPI() + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, + default="http://localhost:21002") + parser.add_argument("--controller-address", type=str, + default="http://localhost:21001") + parser.add_argument("--model-name", type=str) + parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.") + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--openai-api-key", type=str, default="EMPTY") + parser.add_argument("--vllm-base-url", type=str, default="http://localhost:8000") + + + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.multi_modal: + logger.warning("Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.") + + worker = ModelWorker(args.controller_address, + args.worker_address, + args.vllm_base_url, + worker_id, + args.no_register, + args.model_name, + args.openai_api_key, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") \ No newline at end of file diff --git a/starvector/serve/vllm_api_gradio/scroll.js b/starvector/serve/vllm_api_gradio/scroll.js new file mode 100644 index 0000000000000000000000000000000000000000..790829abd9fc092e4591b705776700c502cb07b6 --- /dev/null +++ b/starvector/serve/vllm_api_gradio/scroll.js @@ -0,0 +1,24 @@ +var autoScrollBottom = true; + +function updateScroll(){ + if (autoScrollBottom) { + var element = document.getElementsByClassName("cm-scroller")[0]; + element.scrollTop = element.scrollHeight; + } +} +function handleScroll() { + var element = document.getElementsByClassName("cm-scroller")[0]; + //if (element.scrollHeight - element.scrollTop === element.clientHeight) { + if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) { + // User has scrolled to the bottom, enable auto-scrolling + autoScrollBottom = true; + console.log("bottom"); + } else { + console.log("not bottom"); + // User has scrolled away from the bottom, disable auto-scrolling + autoScrollBottom = false; + } +} +setInterval(updateScroll, 50); +var element = document.getElementsByClassName("cm-scroller")[0]; +element.addEventListener("scroll", handleScroll); diff --git a/starvector/train/train.py b/starvector/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f569762bb17b77198fa5f15066df7641e206e9fc --- /dev/null +++ b/starvector/train/train.py @@ -0,0 +1,274 @@ +import os +from starvector.util import ( + set_env_vars, + flatten_dict, + get_exp_id, + instantiate_from_config, + generate_id_name_eval, + get_last_checkpoint, + model_summary_table, + copy_code, + ) +# set_env_vars() +from starvector.train.util import ( + save_checkpoint, + get_optimizer, + init_distributed_mode, + setup_train_env_variables, + load_fsdp_plugin, + apply_gradient_checkpointing, +) +import logging +import math +from torch.utils.data import DataLoader +from transformers import get_scheduler +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from tqdm.auto import tqdm +from omegaconf import OmegaConf +import os +import time +from starvector.metrics.util import AverageMeter +from util import save_checkpoint, get_optimizer +from starvector.util import get_output_dir +from starvector.model.builder import model_builder +from safetensors.torch import load_file as load_safetensors +from starvector.util import get_config +import torch + +from starvector.train.util import load_checkpoint, is_deepspeed, consolidate_deepspeed_checkpoint +logger = get_logger(__name__, log_level="INFO") + +def validate(model, dataloader, accelerator): + loss_meter = AverageMeter() + model.eval() + pbar = tqdm(total=len(dataloader), ncols=100, desc="Processing", disable=not accelerator.is_local_main_process) + with torch.no_grad(): + for i, batch in enumerate(dataloader): + batch_size = len(batch["image"]) + loss = model(batch) + loss_meter.update(loss.detach().item(), batch_size) + pbar.update(1) + + val_loss = ( + accelerator.gather(torch.tensor(loss_meter.avg).to(accelerator.device)) + .float() + .mean() + .item() + ) + + accelerator.wait_for_everyone() + pbar.close() + + return val_loss + +def main(config=None): + print(f"Experiment config: {config}") + set_env_vars() + + exp_id = get_exp_id(config) + output_dir = get_output_dir() + logging_dir = os.path.join(output_dir, config.data.train.params.dataset_name, exp_id) + + if os.path.exists(logging_dir) and not config.training.resume_from_checkpoint: + config.training.resume_from_checkpoint = get_last_checkpoint(logging_dir) + config.training.continue_training = True + + # Flatten config dict for logging it + log_config = flatten_dict(OmegaConf.to_container(config, resolve=True)) + log_config['logging_dir'] = logging_dir # Add logging dir to config + + if config.fsdp.enable: + init_distributed_mode(config) + setup_train_env_variables(config) + + # --------------- Datasets --------------- + train_dataset = instantiate_from_config(config.data.train) + test_dataset = instantiate_from_config(config.data.test) + train_dataloader = DataLoader(train_dataset, batch_size=config.data.train.batch_size, shuffle=True, num_workers=config.data.num_workers, pin_memory=True) + test_dataloader = DataLoader(test_dataset, batch_size=config.data.test.batch_size, shuffle=False, num_workers=config.data.num_workers, pin_memory=True) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.training.gradient_accumulation_steps) + max_train_steps = config.training.n_epochs * num_update_steps_per_epoch + + global_step = 0 + first_epoch = 0 + + model = model_builder(config) + + # Instantiate the model, fsdp and accelerator + if config.training.resume_from_checkpoint: + if not config.fsdp.enable: + if is_deepspeed(config.training.resume_from_checkpoint): + if accelerator.is_main_process: + consolidate_deepspeed_checkpoint(config.training.resume_from_checkpoint) + accelerator.wait_for_everyone() + model = load_checkpoint(model, config.training.resume_from_checkpoint) + else: + model.load_state_dict(torch.load(os.path.join(config.training.resume_from_checkpoint, "pytorch_model_fsdp.bin")), strict=False) + if config.training.continue_training: + global_step = int(os.path.basename(config.training.resume_from_checkpoint).split("-")[1]) + resume_global_step = global_step * config.training.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * config.training.gradient_accumulation_steps) + else: + global_step = 0 + first_epoch = 0 + resume_step = 0 + print("Loaded checkpoint but not updating global step") + + if config.fsdp.enable: + fsdp_plugin = load_fsdp_plugin(config, model) + else: + fsdp_plugin = None + + # Define accelerator + kwargs_handler = None + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.model_precision, + log_with="wandb" if config.project.use_wandb else None, + project_dir=logging_dir, + project_config=ProjectConfiguration(logging_dir=logging_dir), + step_scheduler_with_optimizer=False, + fsdp_plugin=fsdp_plugin, + kwargs_handlers=kwargs_handler + ) + + # --------------- Logging --------------- + if accelerator.is_main_process: + if config.project.use_wandb: + import wandb + wandb.init(name=exp_id, project=config.project.project, entity=config.project.entity, config=log_config) + accelerator.init_trackers( + project_name=config.project.project, + ) + config.project.wandb_run_id = wandb.run.id + else: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + + if logging_dir is not None: + os.makedirs(logging_dir, exist_ok=True) + + # Copy code and dependency versions + if config.project.copy_code: + out_dir = os.path.join(logging_dir, "code") + copy_code(os.path.join(os.path.dirname(__file__), "..", ".."), out_dir) + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=True) + + total_batch_size = config.data.train.batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps + + if accelerator.is_main_process and config.project.use_wandb: + wandb.log({"total_batch_size": total_batch_size}) + wandb.log({"num_update_steps_per_epoch": num_update_steps_per_epoch}) + wandb.log({"max_train_steps": max_train_steps}) + + # accelerate prepare model + model = accelerator.prepare(model) + + # activation/gradient checkpointing + if config.training.use_gradient_checkpointing: + print("apply gradient checkpointing") + model = apply_gradient_checkpointing(model) + + optimizer = get_optimizer(config, model) + + if accelerator.is_main_process: + print("Train dataset length: ", len(train_dataset)) + print("Test dataset length: ", len(test_dataset)) + + # --------------- Training config --------------- + lr_scheduler = get_scheduler( + config.training.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=config.training.lr_warmup_steps * config.training.gradient_accumulation_steps, + num_training_steps= (len(train_dataloader) * config.training.n_epochs), + ) + + optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare( + optimizer, train_dataloader, test_dataloader, lr_scheduler + ) + + loss_meter = AverageMeter() + + if accelerator.is_main_process: + model_summary_table(model) + + if not os.path.exists(os.path.join(logging_dir, 'config.yaml')): + with open(os.path.join(logging_dir, 'config.yaml'), 'w') as f: + OmegaConf.save(config, f) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {config.training.n_epochs}") + logger.info(f" Instantaneous batch size per device = {config.data.train.batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + + # --------------- Generation/Validation arguments --------------- + generation_args = config.generation + + # Need to set some experiment specific arguments + generation_args.project_name = config.project.project + generation_args.use_wandb = config.project.use_wandb + generation_args.id = generate_id_name_eval(generation_args) + generation_args.out_path = os.path.join(logging_dir, generation_args.id) + generation_args.start_generation_at_step = config.generation.start_generation_at_step + generation_args.metrics = config.metrics + + os.makedirs(generation_args.out_path, exist_ok=True) + + # --------------- Training loop --------------- + total_steps = num_update_steps_per_epoch * config.training.n_epochs + progress_bar = tqdm(total=total_steps, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Training Progress") + + for epoch in range(config.training.n_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + s_time = time.time() + + if config.training.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % config.training.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(model): + loss = model(batch) + accelerator.backward(loss) + loss_meter.update(loss.detach().item(), batch['image'].shape[0]) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + if global_step % config.training.checkpointing_steps == 0: + accelerator.wait_for_everyone() + val_loss = validate(model, test_dataloader, accelerator) + accelerator.log({"val_loss": val_loss}, step=global_step) + save_checkpoint(accelerator, model, global_step, logging_dir, config.training.checkpoints_total_limit) + model.train() + logs = { + "loss": loss_meter.val, + "last_lr": lr_scheduler.get_last_lr()[0], + "step": global_step, + "step_time": time.time() - s_time, + "epoch": epoch} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + accelerator.end_training() + +if __name__ == "__main__": + main(config=get_config()) \ No newline at end of file diff --git a/starvector/train/util.py b/starvector/train/util.py new file mode 100644 index 0000000000000000000000000000000000000000..c24c5fb24ac3cfedf43d4157a2273df54ad80c28 --- /dev/null +++ b/starvector/train/util.py @@ -0,0 +1,285 @@ +import os +import torch +import transformers +import os +from starvector.util import checkpoint_key +import glob +import shutil +import builtins +import datetime +from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp import ( + MixedPrecision, + ShardingStrategy, +) +import functools +from accelerate import FullyShardedDataParallelPlugin +from accelerate.utils import PrecisionType + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, +) + + +from transformers import ( + AutoConfig, + AutoModelForCausalLM +) +from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM +from starvector.train.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + + +def is_deepspeed(checkpoint_dir): + # Check zero_to_fp32.py file (generated only in deepspeed training) + return os.path.exists(os.path.join(checkpoint_dir, 'zero_to_fp32.py')) + +def consolidate_deepspeed_checkpoint(checkpoint_dir): + path_state_dict = os.path.join(checkpoint_dir, 'weights.pt') + if not os.path.exists(path_state_dict): + convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, path_state_dict) + +def load_checkpoint(model, checkpoint_dir): + candidate_files = ['weights.pt', 'pytorch_model.bin', 'model.safetensors'] + + # Determine the correct file to load + for candidate in candidate_files: + path_state_dict = os.path.join(checkpoint_dir, candidate) + if os.path.exists(path_state_dict): + break + else: + raise FileNotFoundError(f"No checkpoint file found in {checkpoint_dir}") + + # Load the state dict based on file type + if path_state_dict.endswith('.safetensors'): + import safetensors.torch + state_dict = safetensors.torch.load_file(path_state_dict) + else: + state_dict = torch.load(path_state_dict) + + # Handle FSDP or module prefix + if list(model.state_dict().keys())[0].startswith('module'): + new_state_dict = {'module.' + key: val for key, val in state_dict.items()} + else: + new_state_dict = state_dict + + # Handle Tied Weights + if hasattr(model, 'tie_weights'): + # Remove the lm_head.weight key if it exists and tie_weights will handle it + new_state_dict.pop("model.svg_transformer.transformer.lm_head.weight", None) + + # Load the state dict into the model with strict=False to ignore missing keys + model.load_state_dict(new_state_dict, strict=False) # Allow missing keys + + # Ensure weights are tied after loading + model.tie_weights() # This method should tie the weights internally + + return model + +from transformers import ( + AutoConfig, + AutoModelForCausalLM +) +from starvector.model.starvector_arch import StarVectorConfig, StarVectorForCausalLM + +def push_model_to_hub(model, new_model_name, tokenizer, processor): + # Register the model for HF + AutoConfig.register("starvector", StarVectorConfig) + AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM) + StarVectorConfig.register_for_auto_class() + StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM") + model.push_to_hub(new_model_name, commit_message=new_model_name, private=True) + tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True) + processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True) + +# push_model_to_hub(self.model, new_model_name, self.tokenizer, self.processor) +def save_checkpoint(accelerator, model, global_step, logging_dir, checkpoint_limit): + print("Saving checkpoint! Global Step: " + str(global_step)) + save_checkpoint_dir = os.path.join(logging_dir, f"checkpoint-{global_step}") + os.makedirs(save_checkpoint_dir, exist_ok=True) + accelerator.wait_for_everyone() + accelerator.save_state(save_checkpoint_dir) + + chkp_dirs = sorted(glob.glob(os.path.join(logging_dir, "checkpoint-*")), key = checkpoint_key) + chkp_to_remove = chkp_dirs[:-checkpoint_limit] + for chkp in chkp_to_remove: + if accelerator.is_main_process: + try: + shutil.rmtree(chkp) + except: + print("could not remove checkpoint") + print(f"Saved state to {save_checkpoint_dir}") + +def push_model_to_hub(model, new_model_name, hf_token=None): + tokenizer = model.model.svg_transformer.tokenizer + # Register the model for HF + AutoConfig.register("starvector", StarVectorConfig) + AutoModelForCausalLM.register(StarVectorConfig, StarVectorForCausalLM) + StarVectorConfig.register_for_auto_class() + StarVectorForCausalLM.register_for_auto_class("AutoModelForCausalLM") + + + model.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) + tokenizer.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) + + processor = model.model.image_encoder.processor + from starvector.data.base import ImageTrainProcessor + if not isinstance(processor, ImageTrainProcessor): + processor.push_to_hub(new_model_name, commit_message=new_model_name, private=True, token=hf_token) + +def get_optimizer(config, model): + optimizer = config.training.optimizer + if optimizer == "adamw": + optimizer = torch.optim.AdamW( + model.parameters(), + lr=config.training.lr, + betas=(config.training.adam_beta1, config.training.adam_beta2), + weight_decay=config.training.adam_weight_decay, + eps=config.training.adam_epsilon, + ) + elif optimizer == "adafactor": + optimizer = transformers.Adafactor( + model.parameters(), + lr=config.training.lr, + relative_step=False, + scale_parameter=False, + ) + else: + raise ValueError(f"Optimizer {optimizer} not supported") + return optimizer + + +def init_distributed_mode(config): + """ + Initializes torch distributed + """ + assert all(var in os.environ for var in ['WORLD_SIZE', 'LOCAL_RANK', 'RANK']) + + world_size = int(os.environ['WORLD_SIZE']) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ['LOCAL_RANK']) + dist_url = 'env://' + + torch.cuda.set_device(local_rank) + dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + rank, dist_url, local_rank), flush=True) + torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url, + world_size=world_size, rank=rank) + torch.distributed.barrier() + print_only_on_master(rank == 0) + +def print_only_on_master(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + kwargs['flush'] = True + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + +def setup_train_env_variables(config): + """ + Set environment variables needed by FSDP and accelerate + """ + mixed_precision = config.training.model_precision.lower() + + try: + mixed_precision = PrecisionType(mixed_precision) + except ValueError: + raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") + + os.environ["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + + if config.fsdp.enable: + # We have to manually set some of the FSDP arguments as environment variables as these are not exposed by the FSDP Plugin API + os.environ['ACCELERATE_USE_FSDP']="true" + os.environ['FSDP_USE_ORIG_PARAMS']=str(config.fsdp.use_orig_params).lower() + os.environ['FSDP_FORWARD_PREFETCH']=str(config.fsdp.forward_prefetch).lower() + + if config.fsdp.cpu_ram_efficient_loading and not config.fsdp.sync_module_states: + raise ValueError("When using `fsdp.cpu_ram_efficient_loading` set `fsdp.sync_module_states` to `True`") + + os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING']=str(config.fsdp.cpu_ram_efficient_loading).lower() + os.environ['FSDP_SYNC_MODULE_STATES']=str(config.fsdp.sync_module_states).lower() + +def load_fsdp_plugin(config, model): + if config.fsdp.enable: + # get mixed precsion dtype + mixed_precision_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "tf32": torch.float32, + }[config.training.model_precision] + + fsdp_plugin = FullyShardedDataParallelPlugin( + state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + auto_wrap_policy=model.model.get_fsdp_wrapping_policy(), + mixed_precision_policy=MixedPrecision( + param_dtype=mixed_precision_dtype, + reduce_dtype=mixed_precision_dtype, + buffer_dtype=mixed_precision_dtype, + ), + sharding_strategy={ + "sdp": ShardingStrategy.SHARD_GRAD_OP, + "ddp": ShardingStrategy.NO_SHARD, + "fsdp": ShardingStrategy.FULL_SHARD, + "hsdp": ShardingStrategy.HYBRID_SHARD, + }[config.fsdp.sharding_strategy], + backward_prefetch=config.fsdp.backward_prefetch, + cpu_offload=config.fsdp.cpu_offload, + ) + else: + fsdp_plugin = None + + return fsdp_plugin + + +def apply_gradient_checkpointing(model): + """ Apply gradient checkpointing to Transformer cls of the LLM """ + + def check_fn(submodule): + return isinstance(submodule, model.model.svg_transformer.transformer_layer_cls) + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ), + check_fn=check_fn, + ) + + # Wait for all processes to finish + torch.distributed.barrier() + + return model + +def get_module_class_from_name(module, name): + """ + Gets a class from a module by its name. + + Args: + module (`torch.nn.Module`): The module to get the class from. + name (`str`): The name of the class. + """ + modules_children = list(module.children()) + if module.__class__.__name__ == name: + return module.__class__ + elif len(modules_children) == 0: + return + else: + for child_module in modules_children: + module_class = get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class \ No newline at end of file diff --git a/starvector/train/zero_to_fp32.py b/starvector/train/zero_to_fp32.py new file mode 100755 index 0000000000000000000000000000000000000000..49b846633d6eb1e836e34681e44033581f4edb7b --- /dev/null +++ b/starvector/train/zero_to_fp32.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag) diff --git a/starvector/util.py b/starvector/util.py new file mode 100644 index 0000000000000000000000000000000000000000..47793736486f87301baaa20fdd18826256525035 --- /dev/null +++ b/starvector/util.py @@ -0,0 +1,292 @@ +import os +import importlib +import hashlib +import re +import time +import subprocess +import logging +import shlex +import os +import shutil +import fnmatch +from huggingface_hub import login +import torch +from omegaconf import OmegaConf + +dtype_mapping = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + "no": "no" +} + +# -------------- Metrics -------------- +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def count_parameters(model): + num = sum(p.numel() for p in model.parameters() if p.requires_grad) + for unit in ['', 'K', 'M', 'B', 'T']: + if abs(num) < 1000: + return f"{num:.1f}{unit}" + num /= 1000 + return f"{num:.1f}P" + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + +def set_env_vars(): + HF_HOME = os.environ['HF_HOME'] + + if HF_HOME is None: + raise EnvironmentError("HF_HOME environment variable is not defined.") + + os.makedirs(HF_HOME, exist_ok=True) + # os.environ['TRANSFORMERS_CACHE'] = HF_HOME + os.environ['HUGGINGFACE_HUB_CACHE'] = HF_HOME + os.environ['TORCH_HOME'] = HF_HOME + os.environ['HF_HOME'] = HF_HOME + os.environ['HF_HUB_CACHE'] = HF_HOME + os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1' + os.environ['TOKENIZERS_PARALLELISM']="False" + os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' + os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + + HF_TOKEN = os.environ['HF_TOKEN'] + + if HF_TOKEN is None: + raise EnvironmentError("HF_TOKEN environment variable is not defined.") + time.sleep(1) # wait for the token to be saved + login(HF_TOKEN) + +def flatten_dict(d, parent_key='', sep='.'): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + +def hash_dict(exp_dict): + """Create a hash for an experiment. Credtts to github.com/haven-ai! + + Parameters + ---------- + exp_dict : dict + An experiment, which is a single set of hyper-parameters + + Returns + ------- + hash_id: str + A unique id defining the experiment + """ + dict2hash = "" + if not isinstance(exp_dict, dict): + raise ValueError("exp_dict is not a dict") + + for k in sorted(exp_dict.keys()): + if "." in k: + raise ValueError(". has special purpose") + elif isinstance(exp_dict[k], dict): + v = hash_dict(exp_dict[k]) + elif isinstance(exp_dict[k], tuple): + raise ValueError(f"{exp_dict[k]} tuples can't be hashed yet, consider converting tuples to lists") + elif isinstance(exp_dict[k], list) and len(exp_dict[k]) and isinstance(exp_dict[k][0], dict): + v_str = "" + for e in exp_dict[k]: + if isinstance(e, dict): + v_str += hash_dict(e) + else: + raise ValueError("all have to be dicts") + v = v_str + else: + v = exp_dict[k] + + dict2hash += str(k) + "/" + str(v) + hash_id = hashlib.md5(dict2hash.encode()).hexdigest() + + return hash_id + +def get_exp_id(config): + exp_hash_id = hash_dict(dict(config)) + if config.model.model_name is not None: + model_name = config.model.model_name.split("/")[1] + else: + model_name = config.model.starcoder_model_name.split("/")[1] + "_" + config.model.image_encoder_type + exp_id = f"{config.project.project}-{config.model.max_length}-{model_name}-{exp_hash_id}" + print("\n" + "Experiment ID: " + exp_id + "\n") + return exp_id + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("No target in config") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + +def generate_id_name_eval(args): + id_name = f"len_{args.max_length}" + if args.use_nucleus_sampling: + id_name += "_nucleus" + id_name += f"_top_p_{args.top_p:.2f}" + + if args.num_beams > 1: + id_name += "_beam_search" + id_name += f"_beams_{args.num_beams}" + else: + if not args.use_nucleus_sampling: + id_name += "_greedy" + id_name += f"_rep_pen_{args.repetition_penalty:.2f}" + id_name += f"_len_pen_{args.length_penalty:.2f}" + id_name += f"_temp_{args.temperature:.2f}" + return id_name + +def get_last_checkpoint(log_dir): + """Get the last checkpoint. + + Returns + ------- + last_checkpoint: str + The last checkpoint + """ + + pattern = re.compile(r"checkpoint-(\d+)") + files = os.listdir(log_dir) + checkpoints = [f for f in files if pattern.match(f)] + if len(checkpoints) == 0: + return None + steps = [int(pattern.match(c).group(1)) for c in checkpoints] + max_step = max(steps) + last_checkpoint = f"checkpoint-{max_step}" + + return os.path.join(log_dir, last_checkpoint) + +def model_summary_table(model): + total_params = 0 + name_col_width = 20 # set the width of the name column + print("\n") + print(f"| {'Submodel Name'.ljust(name_col_width)} | Number of Parameters |") + print("|" + "-" * name_col_width + "|---------------------|") + for name, module in model.named_children(): + num_params = sum(p.numel() for p in module.parameters()) + total_params += num_params + print(f"| {name.ljust(name_col_width)} | {num_params:>20,} |") + + print("|" + "-" * name_col_width + "|---------------------|") + print(f"| {'Total'.ljust(name_col_width)} | {total_params:>20,} |") + print("\n") + +def checkpoint_key(checkpoint_dir): + return int(checkpoint_dir.split("-")[-1]) + +def subprocess_call(cmd_string): + """Run a terminal process. + + Parameters + ---------- + cmd_string : str + Command to execute in the terminal + + Returns + ------- + [type] + Error code or 0 if no error happened + """ + return subprocess.check_output(shlex.split(cmd_string), shell=False, stderr=subprocess.STDOUT).decode("utf-8") + +def copy_code( + src_path, + dst_path, + verbose=1, + exclude_list=['__pycache__', 'wandb', '.vscode', '.ipynb_checkpoints', 'project_baselines', 'assets', 'tmp']): + time.sleep(0.5) + if verbose: + print(" > Copying code from %s to %s" % (src_path, dst_path)) + + os.makedirs(dst_path, exist_ok=True) + + rsync_avialable = len(subprocess.run(['which', 'rsync'], capture_output=True, text=True).stdout) > 0 + + if rsync_avialable: # TODO: validate this works + rsync_cmd_base = f"rsync -av -r -q --delete-before --exclude='.*' --exclude '__pycache__/'" + + exclude_options = " ".join([f"--exclude='{filename}'" for filename in exclude_list]) + + rsync_cmd = f"{rsync_cmd_base} {exclude_options} {src_path} {dst_path}" + + if os.path.exists(os.path.join(src_path, ".havenignore")): + rsync_cmd += f" --exclude-from={os.path.join(src_path, '.havenignore')}" + + copy_code_cmd = rsync_cmd + subprocess_call(copy_code_cmd) + else: + logging.warning("rsync not available. Doing a hard copy of the code folder.") + for dirpath, dirs, files in os.walk(src_path): + if any(ex in dirpath for ex in exclude_list): + continue + for filename in fnmatch.filter(files, '*'): + src_file = os.path.join(dirpath, filename) + dst_file = os.path.join(dst_path, src_file.replace(src_path+'/', '')) + if src_file == dst_file: + continue + dst_dir = os.path.dirname(dst_file) + if not os.path.exists(dst_dir): + os.makedirs(dst_dir, exist_ok=True) + if not os.path.isfile(dst_file): # check if destination is already a file + shutil.copy2(src_file, dst_file) + time.sleep(0.5) + +def get_output_dir(): + # get the environment variable if it exists + output_dir = os.environ.get("OUTPUT_DIR", None) + if output_dir is None: + output_dir = os.path.join(os.getcwd(), "logs") + return output_dir + +def get_config(): + base_conf = OmegaConf.load("configs/models/default.yaml") + cli_conf = OmegaConf.from_cli() + specific_conf = OmegaConf.load(cli_conf.pop('config')) if 'config' in cli_conf else {} + config = OmegaConf.merge(base_conf, specific_conf, cli_conf) + if config.training.resume_from_checkpoint: + if not os.path.exists(os.path.join(os.path.dirname(config.training.resume_from_checkpoint), 'config.yaml')): + config.training.resume_from_checkpoint = get_last_checkpoint(config.training.resume_from_checkpoint) + cli_conf.training.resume_from_checkpoint = config.training.resume_from_checkpoint + pretrained_conf = OmegaConf.load(os.path.join(os.path.dirname(config.training.resume_from_checkpoint), 'config.yaml')) + model_resume_conf = pretrained_conf.pop('model') + specific_conf['model'] = model_resume_conf + config = OmegaConf.merge(config, specific_conf, cli_conf) + return config \ No newline at end of file diff --git a/starvector/validation/README.md b/starvector/validation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..472fdcc863eb7a4e59c305911a32411655cd91ad --- /dev/null +++ b/starvector/validation/README.md @@ -0,0 +1,142 @@ +# StarVector Validation + +This module provides validation functionality for StarVector models, allowing evaluation of SVG generation quality across different model architectures and generation backends. + +## Overview + +The validation framework consists of: + +1. A base `SVGValidator` class that handles common validation logic +2. Specific validator implementations for different backends: + - `StarVectorHFSVGValidator`: Uses HuggingFace generation API + - `StarVectorVLLMValidator`: Uses vLLM for faster generation + - `StarVectorVLLMAPIValidator`: Uses vLLM through REST API + +## 1. Running Validation + +### Using HuggingFace Backend + +```bash +# StarVector-1B +python starvector/validation/validate.py \ +config=configs/generation/hf/starvector-1b/im2svg.yaml \ +dataset.name=starvector/svg-stack + +# StarVector-8B +python starvector/validation/validate.py \ +config=configs/generation/hf/starvector-8b/im2svg.yaml \ +dataset.name=starvector/svg-stack +``` + +### vLLM Backend + +For using the vLLM backend (StarVectorVLLMAPIValidator), first install our StarVector fork of VLLM, [here](https://github.com/starvector/vllm). + +```bash +git clone https://github.com/starvector/vllm.git +cd vllm +pip install -e . +``` + +Then, launch the using the vllm config file (it uses StarVectorVLLMValidator): + +```bash +# StarVector-1B +python starvector/validation/validate.py \ +config=configs/generation/vllm/starvector-1b/im2svg.yaml \ +dataset.name=starvector/svg-stack + +# StarVector-8B +python starvector/validation/validate.py \ +config=configs/generation/vllm/starvector-8b/im2svg.yaml \ +dataset.name=starvector/svg-stack +``` + +## 2. Creating a New SVG Validator + +To create a new validator for a different model or generation backend: + +1. Create a new class inheriting from `SVGValidator` +2. Implement required abstract methods: + - `__init__(self, config)`: Initialize the validator with the given config + - `get_dataloader(self)`: Get the dataloader for the given dataset + - `generate_svg(self, batch)`: Generate SVG from input batch +3. Add the new validator to the registry in `starvector/validation/__init__.py` + +Example: + +```python +from .svg_validator_base import SVGValidator, register_validator + +@register_validator +class MyNewValidator(SVGValidator): + def __init__(self, config): + super().__init__(config) + # Initialize your model/client here + + def generate_svg(self, batch, generate_config): + # Implement generation logic + # Return list of generated SVG strings + pass + + def get_dataloader(self): + # Implement dataloader logic + pass +``` + +## Key Features + +The validation framework provides: + +- Automatic metrics calculation and logging +- WandB integration for experiment tracking +- Temperature sweep for exploring generation parameters +- Comparison plot generation +- Batch processing with configurable settings + +## Configuration + +Validation is configured through YAML files in `configs/generation/`. Key configuration sections: + +```yaml +model: + name: "model_name" # HF model name or path + task: "im2svg" # Task type + torch_dtype: "float16" # Model precision + +dataset: + dataset_name: "svg-stack" # Dataset to validate on + batch_size: 1 + num_workers: 4 + +generation_params: + temperature: 0.2 + top_p: 0.9 + max_length: 1024 + # ... other generation parameters + +run: + report_to: "wandb" # Logging backend + out_dir: "outputs" # Output directory +``` + +## Output Structure + +The validator creates the following directory structure: + +``` +out_dir/ +├── {model}_{dataset}_{timestamp}/ +│ ├── config.yaml # Run configuration +│ ├── results/ +│ │ ├── results_avg.json # Average metrics +│ │ └── all_results.csv # Per-sample metrics +│ └── {sample_id}/ # Per-sample outputs +│ ├── metadata.json +│ ├── {sample_id}.svg +│ ├── {sample_id}_raw.svg +│ ├── {sample_id}_gt.svg +│ ├── {sample_id}_generated.png +│ ├── {sample_id}_original.png +│ └── {sample_id}_comparison.png +``` diff --git a/starvector/validation/__init__.py b/starvector/validation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bda352cfbc94ff5ac8b0f9b5648367655e3fb980 --- /dev/null +++ b/starvector/validation/__init__.py @@ -0,0 +1,11 @@ +from .svg_validator_base import SVGValidator +from .starvector_hf_validator import StarVectorHFSVGValidator +from .starvector_vllm_svg_validator import StarVectorVLLMValidator +from .starvector_vllm_api_svg_validator import StarVectorVLLMAPIValidator + +__all__ = [ + 'SVGValidator', + 'StarVectorHFSVGValidator', + 'StarVectorVLLMValidator', + 'StarVectorVLLMAPIValidator' +] \ No newline at end of file diff --git a/starvector/validation/starvector_hf_validator.py b/starvector/validation/starvector_hf_validator.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6166fe584d9718ad38503406026a5d6289e59f --- /dev/null +++ b/starvector/validation/starvector_hf_validator.py @@ -0,0 +1,89 @@ +# hf https://huggingface.co/docs/transformers/main_classes/text_generation +from starvector.validation.svg_validator_base import SVGValidator, register_validator +import torch +from transformers import AutoProcessor, AutoModelForCausalLM +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset +from starvector.data.util import rasterize_svg + +class SVGValDataset(Dataset): + def __init__(self, dataset_name, config_name, split, im_size, num_samples, processor): + self.dataset_name = dataset_name + self.config_name = config_name + self.split = split + self.im_size = im_size + self.num_samples = num_samples + self.processor = processor + + if self.config_name: + self.data = load_dataset(self.dataset_name, self.config_name, split=self.split) + else: + self.data = load_dataset(self.dataset_name, split=self.split) + + if self.num_samples != -1: + self.data = self.data.select(range(self.num_samples)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + svg_str = self.data[idx]['Svg'] + sample_id = self.data[idx]['Filename'] + image = rasterize_svg(svg_str, resolution=self.im_size) + image = self.processor(image, return_tensors="pt")['pixel_values'].squeeze(0) + caption = self.data[idx].get('Caption', "") + return { + 'Svg': svg_str, + 'image': image, + 'Filename': sample_id, + 'Caption': caption + } + + +@register_validator +class StarVectorHFSVGValidator(SVGValidator): + def __init__(self, config): + super().__init__(config) + # Initialize HuggingFace model and tokenizer here + self.torch_dtype = { + 'bfloat16': torch.bfloat16, + 'float16': torch.float16, + 'float32': torch.float32 + }[config.model.torch_dtype] + + # could also use AutoModelForCausalLM + if config.model.from_checkpoint: + self.model = AutoModelForCausalLM.from_pretrained(self.resume_from_checkpoint, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device) + else: + self.model = AutoModelForCausalLM.from_pretrained(config.model.name, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device) + + self.tokenizer = self.model.model.svg_transformer.tokenizer + self.svg_end_token_id = self.tokenizer.encode("")[0] + + def get_dataloader(self): + self.dataset = SVGValDataset(self.config.dataset.dataset_name, self.config.dataset.config_name, self.config.dataset.split, self.config.dataset.im_size, self.config.dataset.num_samples, self.processor) + self.dataloader = DataLoader(self.dataset, batch_size=self.config.dataset.batch_size, shuffle=False, num_workers=self.config.dataset.num_workers) + + def release_memory(self): + # Clear references to free GPU memory + self.model.model.svg_transformer.tokenizer = None + self.model.model.svg_transformer.model = None + + # Force CUDA garbage collection + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + def generate_svg(self, batch, generate_config): + if generate_config['temperature'] == 0: + generate_config['temperature'] = 1.0 + generate_config['do_sample'] = False + outputs = [] + batch['image'] = batch['image'].to('cuda').to(self.torch_dtype) + # for i, batch in enumerate(batch['svg']): + if self.task == 'im2svg': + outputs = self.model.model.generate_im2svg(batch = batch, **generate_config) + elif self.task == 'text2svg': + outputs = self.model.model.generate_text2svg(batch = batch, **generate_config) + return outputs + \ No newline at end of file diff --git a/starvector/validation/starvector_vllm_api_svg_validator.py b/starvector/validation/starvector_vllm_api_svg_validator.py new file mode 100644 index 0000000000000000000000000000000000000000..b1333902efe2788753ffe5edf0517d94b531f4ab --- /dev/null +++ b/starvector/validation/starvector_vllm_api_svg_validator.py @@ -0,0 +1,77 @@ +# vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html +# TODO: This is not maintained, need to update it to use the new VLLM API + +from .svg_validator_base import SVGValidator, register_validator +from starvector.data.util import rasterize_svg, clean_svg, use_placeholder +from starvector.data.util import encode_image_base64 +from svgpathtools import svgstr2paths +import os +import json +from copy import deepcopy +from openai import OpenAI + +@register_validator +class StarVectorVLLMAPIValidator(SVGValidator): + def __init__(self, config): + + super().__init__(config) + # Initialize VLLM OpenAI client here + self.client = OpenAI( + api_key=config.run.api.key, + base_url=f"{config.run.api.base_url}", + ) + if 'starvector-1b' in config.model.name: + self.svg_end_token_id = 49154 # Adjust as needed + elif 'starvector-8b' in config.model.name: + self.svg_end_token_id = 49156 # Adjust as needed + + def generate_svg(self, batch, generate_config): + outputs = [] + for i, sample in enumerate(batch['svg']): + if self.task == "im2svg": + image = rasterize_svg(sample, 512) + base64_image = encode_image_base64(image) + content = [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, + }, + {"type": "text", "text": " 1, + 'best_of': generate_config['num_beams'] + }, + stream=generate_config['stream'], + logit_bias={self.svg_end_token_id: generate_config['logit_bias']} if generate_config['logit_bias'] else None, + ) + + if generate_config['stream']: + generated_text = self._handle_stream_response(response) + else: + generated_text = "