jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
700 Bytes
import random
from typing import List, Union
import torch
def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]:
if random.random() >= dropout_p:
return caption
if isinstance(caption, str):
return ""
return [""] * len(caption)
def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor:
if random.random() >= dropout_p:
return embed
embed = torch.zeros_like(embed)
return embed
def remove_prefix(text: str, prefixes: List[str]) -> str:
for prefix in prefixes:
if text.startswith(prefix):
return text.removeprefix(prefix).strip()
return text