Spaces:
Running
on
Zero
Running
on
Zero
wenxiang guo
commited on
Update ldm/modules/encoders/modules.py
Browse files
ldm/modules/encoders/modules.py
CHANGED
|
@@ -56,6 +56,7 @@ class FrozenFLANEmbedder(AbstractEncoder):
|
|
| 56 |
|
| 57 |
def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77,
|
| 58 |
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
|
|
|
| 59 |
super().__init__()
|
| 60 |
|
| 61 |
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
|
@@ -88,8 +89,8 @@ class FrozenCLAPEmbedder(AbstractEncoder):
|
|
| 88 |
"""Uses the CLAP transformer encoder for text from microsoft"""
|
| 89 |
|
| 90 |
def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
|
|
|
|
| 91 |
super().__init__()
|
| 92 |
-
|
| 93 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|
| 94 |
match_params = dict()
|
| 95 |
for key in list(model_state_dict.keys()):
|
|
@@ -103,7 +104,7 @@ class FrozenCLAPEmbedder(AbstractEncoder):
|
|
| 103 |
self.caption_encoder = TextEncoder(
|
| 104 |
args.d_proj, args.text_model, args.transformer_embed_dim
|
| 105 |
)
|
| 106 |
-
|
| 107 |
self.max_length = max_length
|
| 108 |
self.device = device
|
| 109 |
if freeze: self.freeze()
|
|
@@ -130,6 +131,7 @@ class FrozenCLAPFLANEmbedder(AbstractEncoder):
|
|
| 130 |
|
| 131 |
def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
|
| 132 |
max_length=77): # clip-vit-base-patch32
|
|
|
|
| 133 |
super().__init__()
|
| 134 |
|
| 135 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|
|
|
|
| 56 |
|
| 57 |
def __init__(self, version="google/flan-t5-large", device="cuda", max_length=77,
|
| 58 |
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
| 59 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 60 |
super().__init__()
|
| 61 |
|
| 62 |
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
|
|
|
| 89 |
"""Uses the CLAP transformer encoder for text from microsoft"""
|
| 90 |
|
| 91 |
def __init__(self, weights_path, freeze=True, device="cuda", max_length=77): # clip-vit-base-patch32
|
| 92 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 93 |
super().__init__()
|
|
|
|
| 94 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|
| 95 |
match_params = dict()
|
| 96 |
for key in list(model_state_dict.keys()):
|
|
|
|
| 104 |
self.caption_encoder = TextEncoder(
|
| 105 |
args.d_proj, args.text_model, args.transformer_embed_dim
|
| 106 |
)
|
| 107 |
+
|
| 108 |
self.max_length = max_length
|
| 109 |
self.device = device
|
| 110 |
if freeze: self.freeze()
|
|
|
|
| 131 |
|
| 132 |
def __init__(self, weights_path, t5version="google/t5-v1_1-large", freeze=True, device="cuda",
|
| 133 |
max_length=77): # clip-vit-base-patch32
|
| 134 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 135 |
super().__init__()
|
| 136 |
|
| 137 |
model_state_dict = torch.load(weights_path, map_location=torch.device('cpu'))['model']
|