gaur3009 commited on
Commit
78f426f
·
verified ·
1 Parent(s): 0327052

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import requests
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from selenium import webdriver
7
+ from selenium.webdriver.chrome.service import Service
8
+ from webdriver_manager.chrome import ChromeDriverManager
9
+ from diffusers import StableDiffusionPipeline
10
+ import torch
11
+ import gradio as gr
12
+
13
+ # ---------- Step 1: Scrape Celebrity Images ----------
14
+ def scrape_images(celebrity_name, num_images=20):
15
+ search_url = f"https://www.google.com/search?q={celebrity_name}+portrait&tbm=isch"
16
+ driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()))
17
+ driver.get(search_url)
18
+ os.makedirs(f"data/{celebrity_name}", exist_ok=True)
19
+
20
+ images = driver.find_elements("tag name", "img")
21
+ count = 0
22
+
23
+ for img in images:
24
+ if count >= num_images:
25
+ break
26
+ src = img.get_attribute("src")
27
+ if src and "http" in src:
28
+ try:
29
+ img_data = requests.get(src).content
30
+ with open(f"data/{celebrity_name}/{celebrity_name}_{count}.jpg", "wb") as handler:
31
+ handler.write(img_data)
32
+ count += 1
33
+ except Exception as e:
34
+ print(f"Error downloading image: {e}")
35
+ driver.quit()
36
+
37
+ # ---------- Step 2: Fine-Tuning Stable Diffusion ----------
38
+ def fine_tune_sd3(celebrity_name):
39
+ model_id = "runwayml/stable-diffusion-v1-5"
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
42
+
43
+ celeb_images_path = f"data/{celebrity_name}"
44
+ images = [Image.open(os.path.join(celeb_images_path, img)) for img in os.listdir(celeb_images_path) if img.endswith(".jpg")]
45
+
46
+ # Simple fine-tuning logic (for demonstration; deep fine-tuning requires more work)
47
+ print(f"Fine-tuning with {len(images)} images of {celebrity_name}...")
48
+
49
+ # Saving model
50
+ fine_tuned_model_path = f"models/{celebrity_name}_sd3"
51
+ os.makedirs(fine_tuned_model_path, exist_ok=True)
52
+ pipe.save_pretrained(fine_tuned_model_path)
53
+ print(f"Model saved at {fine_tuned_model_path}")
54
+
55
+ return fine_tuned_model_path
56
+
57
+ # ---------- Step 3: Generate Phone Cover Designs ----------
58
+ def generate_cover(prompt, model_path):
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device)
61
+
62
+ image = pipe(prompt).images[0]
63
+ cover_template = Image.open("phone_cover_template.png").convert("RGBA")
64
+ image = image.resize(cover_template.size)
65
+ blended = Image.alpha_composite(cover_template, image.convert("RGBA"))
66
+
67
+ output_path = "generated_phone_cover.png"
68
+ blended.save(output_path)
69
+ return output_path
70
+
71
+ # ---------- Step 4: Gradio Deployment ----------
72
+ def launch_gradio(model_path):
73
+ def infer(prompt):
74
+ result = generate_cover(prompt, model_path)
75
+ return result
76
+
77
+ gr.Interface(fn=infer,
78
+ inputs=gr.Textbox(label="Enter a design prompt"),
79
+ outputs=gr.Image(label="Generated Phone Cover"),
80
+ title="Celebrity Phone Cover Generator").launch()
81
+
82
+ # ---------- Main Workflow ----------
83
+ if __name__ == "__main__":
84
+ celebrity = "Taylor Swift" # Example celebrity
85
+ scrape_images(celebrity, num_images=30)
86
+
87
+ model_path = fine_tune_sd3(celebrity)
88
+
89
+ # Deploy on Gradio
90
+ launch_gradio(model_path)