Fiqa commited on
Commit
ea2971c
·
verified ·
1 Parent(s): 391fb98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import requests
3
  from PIL import Image
4
- import streamlit as st
5
  import torch
 
6
  from huggingface_hub import login
7
  from transformers import AutoProcessor, AutoModelForCausalLM
8
  from diffusers import DiffusionPipeline
@@ -21,44 +21,34 @@ caption_model_name = "pretrained-caption-model" # Replace with the actual model
21
  processor = AutoProcessor.from_pretrained(caption_model_name)
22
  model = AutoModelForCausalLM.from_pretrained(caption_model_name)
23
 
24
- # Move models to GPU if available
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  pipe.to(device)
27
  model.to(device)
28
 
29
- # Streamlit UI
30
- st.title("Image Caption and Design Generator")
31
- st.write("Upload an image or provide an image URL to generate a caption and use it to create a similar design.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Image upload or URL input
34
- img_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
35
- img_url = st.text_input("Or provide an image URL:")
36
-
37
- # Process the image
38
- raw_image = None
39
- if img_file:
40
- raw_image = Image.open(img_file).convert("RGB")
41
- st.image(raw_image, caption="Uploaded Image", use_column_width=True)
42
- elif img_url:
43
- try:
44
- raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
45
- st.image(raw_image, caption="Image from URL", use_column_width=True)
46
- except Exception as e:
47
- st.error(f"Error loading image from URL: {e}")
48
-
49
- # Generate caption and design
50
- if raw_image and st.button("Generate Caption and Design"):
51
- with st.spinner("Generating caption..."):
52
- # Generate caption
53
- inputs = processor(raw_image, return_tensors="pt", padding=True, truncation=True, max_length=250)
54
- inputs = {key: val.to(device) for key, val in inputs.items()}
55
- out = model.generate(**inputs)
56
- caption = processor.decode(out[0], skip_special_tokens=True)
57
- st.success("Generated Caption:")
58
- st.write(caption)
59
-
60
- with st.spinner("Generating similar design..."):
61
- # Generate similar design using the caption as a prompt
62
- generated_image = pipe(caption).images[0]
63
- st.success("Generated Design:")
64
- st.image(generated_image, caption="Design Generated from Caption", use_column_width=True)
 
1
  import os
2
  import requests
3
  from PIL import Image
 
4
  import torch
5
+ import gradio as gr
6
  from huggingface_hub import login
7
  from transformers import AutoProcessor, AutoModelForCausalLM
8
  from diffusers import DiffusionPipeline
 
21
  processor = AutoProcessor.from_pretrained(caption_model_name)
22
  model = AutoModelForCausalLM.from_pretrained(caption_model_name)
23
 
24
+ # Check for GPU availability (handled automatically by Hugging Face Spaces)
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
  pipe.to(device)
27
  model.to(device)
28
 
29
+ # Function to process the image and generate caption and design
30
+ @spaces.GPU
31
+ def generate_caption_and_design(image):
32
+ # Generate caption
33
+ inputs = processor(image, return_tensors="pt", padding=True, truncation=True, max_length=250)
34
+ inputs = {key: val.to(device) for key, val in inputs.items()}
35
+ out = model.generate(**inputs)
36
+ caption = processor.decode(out[0], skip_special_tokens=True)
37
+
38
+ # Generate design based on caption
39
+ generated_image = pipe(caption).images[0]
40
+
41
+ return caption, generated_image
42
+
43
+ # Gradio Interface
44
+ interface = gr.Interface(
45
+ fn=generate_caption_and_design,
46
+ inputs=gr.Image(type="pil", label="Upload an Image"),
47
+ outputs=[gr.Textbox(label="Generated Caption"), gr.Image(label="Generated Design")],
48
+ title="Image Caption and Design Generator",
49
+ description="Upload an image or provide an image URL to generate a caption and use it to create a similar design.",
50
+ )
51
+
52
+ # Launch Gradio app
53
+ interface.launch()
54