Fiqa commited on
Commit
f667084
·
verified ·
1 Parent(s): 8772806

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from PIL import Image
4
+ import streamlit as st
5
+ import torch
6
+ from huggingface_hub import login
7
+ from transformers import AutoProcessor, AutoModelForCausalLM
8
+ from diffusers import DiffusionPipeline
9
+
10
+ # Hugging Face token setup
11
+ hf_token = os.getenv('HF_AUTH_TOKEN')
12
+ if not hf_token:
13
+ raise ValueError("Hugging Face token is not set in the environment variables.")
14
+ login(token=hf_token)
15
+
16
+ # Initialize Stable Diffusion pipeline
17
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium")
18
+
19
+ # Initialize captioning model and processor
20
+ caption_model_name = "pretrained-caption-model" # Replace with the actual model name
21
+ processor = AutoProcessor.from_pretrained(caption_model_name)
22
+ model = AutoModelForCausalLM.from_pretrained(caption_model_name)
23
+
24
+ # Move models to GPU if available
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ pipe.to(device)
27
+ model.to(device)
28
+
29
+ # Streamlit UI
30
+ st.title("Image Caption and Design Generator")
31
+ st.write("Upload an image or provide an image URL to generate a caption and use it to create a similar design.")
32
+
33
+ # Image upload or URL input
34
+ img_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
35
+ img_url = st.text_input("Or provide an image URL:")
36
+
37
+ # Process the image
38
+ raw_image = None
39
+ if img_file:
40
+ raw_image = Image.open(img_file).convert("RGB")
41
+ st.image(raw_image, caption="Uploaded Image", use_column_width=True)
42
+ elif img_url:
43
+ try:
44
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
45
+ st.image(raw_image, caption="Image from URL", use_column_width=True)
46
+ except Exception as e:
47
+ st.error(f"Error loading image from URL: {e}")
48
+
49
+ # Generate caption and design
50
+ if raw_image and st.button("Generate Caption and Design"):
51
+ with st.spinner("Generating caption..."):
52
+ # Generate caption
53
+ inputs = processor(raw_image, return_tensors="pt", padding=True, truncation=True, max_length=250)
54
+ inputs = {key: val.to(device) for key, val in inputs.items()}
55
+ out = model.generate(**inputs)
56
+ caption = processor.decode(out[0], skip_special_tokens=True)
57
+ st.success("Generated Caption:")
58
+ st.write(caption)
59
+
60
+ with st.spinner("Generating similar design..."):
61
+ # Generate similar design using the caption as a prompt
62
+ generated_image = pipe(caption).images[0]
63
+ st.success("Generated Design:")
64
+ st.image(generated_image, caption="Design Generated from Caption", use_column_width=True)