Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -88,10 +88,14 @@ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, ver
|
|
88 |
# Load SentenceTransformer model
|
89 |
clip_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1').to(device)
|
90 |
|
|
|
91 |
model_name = "EleutherAI/gpt-neo-1.3B"
|
92 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
93 |
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
94 |
|
|
|
95 |
def generate_description(image):
|
96 |
image_resnet = data_transforms(image).unsqueeze(0).to(device)
|
97 |
|
@@ -123,13 +127,14 @@ def generate_description(image):
|
|
123 |
top_p=0.9,
|
124 |
repetition_penalty=1.2,
|
125 |
do_sample=True,
|
126 |
-
pad_token_id=tokenizer.
|
127 |
)
|
128 |
|
129 |
description_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
130 |
|
131 |
return predicted_style, predicted_artist, description_text
|
132 |
|
|
|
133 |
# Gradio interface
|
134 |
def gradio_interface(image):
|
135 |
if image is None:
|
|
|
88 |
# Load SentenceTransformer model
|
89 |
clip_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1').to(device)
|
90 |
|
91 |
+
# Load GPT-Neo and set padding token
|
92 |
model_name = "EleutherAI/gpt-neo-1.3B"
|
93 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
94 |
+
if tokenizer.pad_token is None:
|
95 |
+
tokenizer.pad_token = tokenizer.eos_token # Set pad_token to eos_token
|
96 |
model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
97 |
|
98 |
+
|
99 |
def generate_description(image):
|
100 |
image_resnet = data_transforms(image).unsqueeze(0).to(device)
|
101 |
|
|
|
127 |
top_p=0.9,
|
128 |
repetition_penalty=1.2,
|
129 |
do_sample=True,
|
130 |
+
pad_token_id=tokenizer.pad_token_id
|
131 |
)
|
132 |
|
133 |
description_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
134 |
|
135 |
return predicted_style, predicted_artist, description_text
|
136 |
|
137 |
+
|
138 |
# Gradio interface
|
139 |
def gradio_interface(image):
|
140 |
if image is None:
|