Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,93 +1,97 @@
|
|
1 |
-
import requests
|
2 |
-
from PIL import Image
|
3 |
-
import gradio as gr
|
4 |
-
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline
|
5 |
-
import torch
|
6 |
-
import tempfile
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
# Load the
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
#
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
#
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from PIL import Image
|
3 |
+
import gradio as gr
|
4 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel, TextGenerationPipeline
|
5 |
+
import torch
|
6 |
+
import tempfile
|
7 |
+
import os
|
8 |
+
|
9 |
+
os.system("pip uninstall -y gradio")
|
10 |
+
os.system("pip install gradio==2.6.4")
|
11 |
+
|
12 |
+
# Initialize the device
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
|
15 |
+
# Load the image captioning model and processor
|
16 |
+
image_processor = BlipProcessor.from_pretrained("AdhamEhab/ImageCaptioning")
|
17 |
+
image_model = BlipForConditionalGeneration.from_pretrained("AdhamEhab/ImageCaptioning").to(device)
|
18 |
+
|
19 |
+
# Load the story generation model and tokenizer
|
20 |
+
story_generation_model = GPT2LMHeadModel.from_pretrained("AdhamEhab/StoryGen")
|
21 |
+
story_generation_tokenizer = AutoTokenizer.from_pretrained("AdhamEhab/StoryGen")
|
22 |
+
|
23 |
+
generator = TextGenerationPipeline(model=story_generation_model, tokenizer=story_generation_tokenizer)
|
24 |
+
|
25 |
+
# Define a function to generate caption from an image
|
26 |
+
def generate_caption(image):
|
27 |
+
try:
|
28 |
+
# If image is a file object, extract the file path
|
29 |
+
if isinstance(image, tempfile._TemporaryFileWrapper):
|
30 |
+
image_path = image.name
|
31 |
+
else:
|
32 |
+
image_path = image
|
33 |
+
|
34 |
+
# Load and preprocess the image
|
35 |
+
image = Image.open(image_path)
|
36 |
+
inputs = image_processor(image, return_tensors="pt", padding="max_length", truncation=True)
|
37 |
+
|
38 |
+
# Generate caption
|
39 |
+
with torch.no_grad():
|
40 |
+
caption_ids = image_model.generate(**inputs.to(device))
|
41 |
+
|
42 |
+
# Decode the caption
|
43 |
+
caption = image_processor.decode(caption_ids[0], skip_special_tokens=True)
|
44 |
+
return caption
|
45 |
+
except Exception as e:
|
46 |
+
return f"An error occurred: {str(e)}"
|
47 |
+
|
48 |
+
# Define a function to generate a story based on a prompt
|
49 |
+
def generate_story(prompt):
|
50 |
+
try:
|
51 |
+
input_prompt = prompt
|
52 |
+
story = generator(input_prompt, max_length=200, do_sample=True)[0]['generated_text']
|
53 |
+
return story
|
54 |
+
except Exception as e:
|
55 |
+
return f"An error occurred: {str(e)}"
|
56 |
+
|
57 |
+
# Create Gradio interfaces
|
58 |
+
image_caption_interface = gr.Interface(
|
59 |
+
fn=generate_caption,
|
60 |
+
inputs=gr.inputs.Image(type="filepath", label="Upload Image"),
|
61 |
+
outputs="text",
|
62 |
+
title="Image Captioning",
|
63 |
+
description="Generate a caption for the provided image."
|
64 |
+
)
|
65 |
+
|
66 |
+
story_generation_interface = gr.Interface(
|
67 |
+
fn=generate_story,
|
68 |
+
inputs="text",
|
69 |
+
outputs="text",
|
70 |
+
title="Story Generation",
|
71 |
+
description="Generate a story based on the provided prompt."
|
72 |
+
)
|
73 |
+
|
74 |
+
# Create Gradio interfaces
|
75 |
+
with gr.Blocks(css="footer{display:none !important}") as combined_interface:
|
76 |
+
gr.Markdown(
|
77 |
+
"""
|
78 |
+
# The Seer's Legacy
|
79 |
+
Crafting tales from the fabric of imagination.
|
80 |
+
""")
|
81 |
+
with gr.Row():
|
82 |
+
with gr.Column():
|
83 |
+
image_input = gr.inputs.Image(type="filepath", label="Upload Image")
|
84 |
+
image_output = gr.Text(label="Image Caption")
|
85 |
+
image_btn = gr.Button("Generate Image Caption")
|
86 |
+
|
87 |
+
with gr.Column():
|
88 |
+
text_input = gr.inputs.Textbox(label="Story Prompt")
|
89 |
+
story_output = gr.Text(label="Generated Story")
|
90 |
+
story_btn = gr.Button("Generate Story")
|
91 |
+
|
92 |
+
image_btn.click(generate_caption, inputs=[image_input], outputs=[image_output])
|
93 |
+
story_btn.click(generate_story, inputs=[text_input], outputs=[story_output])
|
94 |
+
|
95 |
+
# Launch the combined interface
|
96 |
+
combined_interface.launch()
|
97 |
+
|