Spaces:
Runtime error
Runtime error
bipin
commited on
Commit
·
8985863
1
Parent(s):
0843a80
bug fixes
Browse files
app.py
CHANGED
@@ -3,9 +3,17 @@ import gradio as gr
|
|
3 |
from prefix_clip import download_pretrained_model, generate_caption
|
4 |
from gpt2_story_gen import generate_story
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def main(pil_image, genre, model, use_beam_search=False):
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
download_pretrained_model(model.lower(), file_to_save=model_file)
|
11 |
|
@@ -14,7 +22,7 @@ def main(pil_image, genre, model, use_beam_search=False):
|
|
14 |
pil_image=pil_image,
|
15 |
use_beam_search=use_beam_search,
|
16 |
)
|
17 |
-
story = generate_story(image_caption,
|
18 |
return story
|
19 |
|
20 |
|
|
|
3 |
from prefix_clip import download_pretrained_model, generate_caption
|
4 |
from gpt2_story_gen import generate_story
|
5 |
|
6 |
+
coco_weights = 'coco_weights.pt'
|
7 |
+
conceptual_weights = 'conceptual_weights.pt'
|
8 |
+
download_pretrained_model('coco', file_to_save=coco_weights)
|
9 |
+
download_pretrained_model('conceptual', file_to_save=conceptual_weights)
|
10 |
+
|
11 |
|
12 |
def main(pil_image, genre, model, use_beam_search=False):
|
13 |
+
if model.lower()=='coco':
|
14 |
+
model_file = coco_weights
|
15 |
+
elif model.lower()=='conceptual':
|
16 |
+
model_file = conceptual_weights
|
17 |
|
18 |
download_pretrained_model(model.lower(), file_to_save=model_file)
|
19 |
|
|
|
22 |
pil_image=pil_image,
|
23 |
use_beam_search=use_beam_search,
|
24 |
)
|
25 |
+
story = generate_story(image_caption, pil_image, genre.lower())
|
26 |
return story
|
27 |
|
28 |
|