gaur3009 commited on
Commit
5e4062b
·
verified ·
1 Parent(s): e94aa2a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from diffusers import StableDiffusionXLPipeline
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import os
7
+ import requests
8
+ import time
9
+ from tqdm import tqdm
10
+
11
+ # Load local Stable Diffusion XL model
12
+ model_path = "networks/TShirtDesignRedmondV2-Tshirtdesign-TshirtDesignAF.safetensors"
13
+ pipe = StableDiffusionXLPipeline.from_single_file(
14
+ model_path,
15
+ torch_dtype=torch.float16,
16
+ use_safetensors=True,
17
+ )
18
+ pipe = pipe.to("cuda")
19
+
20
+ repo = "artificialguybr/TshirtDesignRedmond-V2"
21
+ api_url = f"https://api-inference.huggingface.co/models/{repo}"
22
+
23
+ def infer(color_prompt, dress_type_prompt, design_prompt, text):
24
+ prompt = (
25
+ f"A single {color_prompt} colored {dress_type_prompt} featuring a bold {design_prompt} design printed on the {dress_type_prompt},"
26
+ " hanging on a plain wall. The soft light and shadows create a striking contrast against the minimal background, evoking modern sophistication."
27
+ )
28
+
29
+ print("Generating image locally with prompt:", prompt)
30
+ try:
31
+ image = pipe(prompt).images[0]
32
+ return image
33
+ except Exception as e:
34
+ print("Local generation failed. Switching to API.", str(e))
35
+
36
+ # API fallback
37
+ headers = {}
38
+ payload = {
39
+ "inputs": prompt,
40
+ "parameters": {
41
+ "negative_prompt": "(worst quality, low quality, lowres, bad details, watermark, text, blurry, cartoon, 3D, bad anatomy, outdated fashion, cheap look, unreal details, unwanted features)",
42
+ "num_inference_steps": 30,
43
+ "scheduler": "DPMSolverMultistepScheduler"
44
+ },
45
+ }
46
+
47
+ error_count = 0
48
+ pbar = tqdm(total=None, desc="Loading model")
49
+ while True:
50
+ response = requests.post(api_url, headers=headers, json=payload)
51
+ if response.status_code == 200:
52
+ return Image.open(BytesIO(response.content))
53
+ elif response.status_code == 503:
54
+ time.sleep(1)
55
+ pbar.update(1)
56
+ elif response.status_code == 500 and error_count < 5:
57
+ time.sleep(1)
58
+ error_count += 1
59
+ else:
60
+ raise Exception(f"API Error: {response.status_code}")
61
+
62
+ # Gradio Interface
63
+ iface = gr.Interface(
64
+ fn=infer,
65
+ inputs=[
66
+ gr.Textbox(lines=1, placeholder="Color"),
67
+ gr.Textbox(lines=1, placeholder="Dress Type"),
68
+ gr.Textbox(lines=2, placeholder="Design"),
69
+ gr.Textbox(lines=1, placeholder="Text (Optional)")
70
+ ],
71
+ outputs="image",
72
+ title="AI-Generated T-Shirt Designs",
73
+ description="Generate custom t-shirt designs using AI!",
74
+ examples=[["Red", "T-shirt", "Minimalistic logo", "Brand Name"]]
75
+ )
76
+
77
+ print("Launching Gradio interface...")
78
+ iface.launch()