AMfeta99's picture
Update app.py
84abbea verified
raw
history blame
2.37 kB
import torch
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from diffusers import StableDiffusionPipeline
# Load Stable Diffusion model
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda" if torch.cuda.is_available() else "cpu")
# Function to add label
def add_label_to_image(image, label):
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 30)
except:
font = ImageFont.load_default()
position = (20, image.height - 50)
draw.rectangle([position, (position[0]+400, position[1]+40)], fill=(0, 0, 0, 180))
draw.text(position, label, font=font, fill="white")
return image
# Generate prompt images
def generate_object_history(object_name):
prompts = {
"past": f"An old version of a {object_name}, vintage, old-fashioned",
"present": f"A modern {object_name}, realistic, current design",
"future": f"A futuristic {object_name}, sci-fi, advanced design"
}
images = []
pil_images = []
for period, prompt in prompts.items():
image = pipe(prompt).images[0]
labeled_image = add_label_to_image(image, f"{object_name.title()} - {period.title()}")
filename = f"{object_name}_{period}.png"
labeled_image.save(filename)
images.append((filename, f"{object_name.title()} - {period.title()}"))
pil_images.append(labeled_image)
gif_path = f"{object_name}_evolution.gif"
pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], duration=1000, loop=0)
return images, gif_path
# Gradio Interface
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# TimeMetamorphy: Object Evolution Visualizer")
object_name_input = gr.Textbox(label="Enter an object name (e.g., bicycle, phone)")
generate_button = gr.Button("Generate Evolution")
image_gallery = gr.Gallery(label="Generated Images", columns=3, rows=1)
gif_output = gr.Image(label="Generated GIF")
generate_button.click(fn=generate_object_history,
inputs=[object_name_input],
outputs=[image_gallery, gif_output])
return demo
demo = create_gradio_interface()
demo.launch(share=True)