gaur3009 commited on
Commit
43ce2e9
·
verified ·
1 Parent(s): 75433b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import datetime
3
+ from typing import List, Dict, Optional
4
+ from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler
5
+ from PIL import Image
6
+ import gradio as gr
7
+ from transformers import pipeline as hf_pipeline
8
+
9
+ class StableDiffusionAgent:
10
+ def __init__(self, config: Optional[Dict] = None):
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.default_config = {
13
+ "model": "stabilityai/stable-diffusion-2-1",
14
+ "safety_checker": True,
15
+ "max_resolution": 1024,
16
+ "art_styles": ["realistic", "anime", "cyberpunk", "watercolor", "pixel-art"],
17
+ "default_style": "realistic",
18
+ "memory_size": 10,
19
+ "prompt_enhancer": True
20
+ }
21
+ self.config = {**self.default_config, **(config or {})}
22
+
23
+ self._initialize_models()
24
+ self.memory = []
25
+ self.user_profiles = {}
26
+ self.current_style = self.config["default_style"]
27
+
28
+ def _initialize_models(self):
29
+ """Load all required models"""
30
+ # Text-to-Image Pipeline
31
+ self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
32
+ self.config["model"],
33
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
34
+ safety_checker=None if not self.config["safety_checker"] else None
35
+ ).to(self.device)
36
+
37
+ self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
38
+ self.sd_pipeline.scheduler.config
39
+ )
40
+
41
+ if self.device == "cuda":
42
+ self.sd_pipeline.enable_xformers_memory_efficient_attention()
43
+ self.sd_pipeline.enable_attention_slicing()
44
+
45
+ # Prompt Enhancement Model
46
+ if self.config["prompt_enhancer"]:
47
+ self.prompt_pipeline = hf_pipeline(
48
+ "text2text-generation",
49
+ model="microsoft/Promptist"
50
+ )
51
+
52
+ def _enhance_prompt(self, prompt: str) -> str:
53
+ """Improve prompt using LLM"""
54
+ if self.config["prompt_enhancer"]:
55
+ try:
56
+ return self.prompt_pipeline(prompt, max_length=256)[0]["generated_text"]
57
+ except:
58
+ return prompt
59
+ return prompt
60
+
61
+ def _apply_style(self, prompt: str, style: str) -> str:
62
+ """Apply artistic style to prompt"""
63
+ style_templates = {
64
+ "anime": "anime style, vibrant colors, detailed line art",
65
+ "cyberpunk": "neon lights, cyberpunk style, rainy night, futuristic",
66
+ "watercolor": "watercolor painting, soft edges, artistic",
67
+ "pixel-art": "8-bit pixel art, retro gaming style"
68
+ }
69
+ return f"{prompt}, {style_templates.get(style, '')}"
70
+
71
+ def generate(
72
+ self,
73
+ user_id: str,
74
+ prompt: str,
75
+ negative_prompt: str = "",
76
+ style: Optional[str] = None,
77
+ **kwargs
78
+ ) -> Dict:
79
+ """Main generation method with user context"""
80
+ # Get user preferences
81
+ user_prefs = self.user_profiles.get(user_id, {})
82
+
83
+ # Enhance prompt
84
+ enhanced_prompt = self._enhance_prompt(prompt)
85
+
86
+ # Apply style
87
+ style = style or user_prefs.get("style", self.current_style)
88
+ final_prompt = self._apply_style(enhanced_prompt, style)
89
+
90
+ # Generate image
91
+ results = self._generate_image(
92
+ prompt=final_prompt,
93
+ negative_prompt=negative_prompt,
94
+ **{**self._get_default_params(), **kwargs}
95
+ )
96
+
97
+ # Update memory and user profile
98
+ self._update_memory(user_id, prompt, results)
99
+ return {
100
+ "images": results["images"],
101
+ "metadata": {
102
+ "enhanced_prompt": enhanced_prompt,
103
+ "style": style,
104
+ "seed": results["seed"],
105
+ "timestamp": datetime.datetime.now().isoformat()
106
+ }
107
+ }
108
+
109
+ def _generate_image(self, **kwargs) -> Dict:
110
+ """Low-level generation with safety checks"""
111
+ generator = torch.Generator(device=self.device)
112
+ seed = kwargs.pop("seed", None)
113
+ if seed is not None:
114
+ generator = generator.manual_seed(seed)
115
+
116
+ results = self.sd_pipeline(**kwargs, generator=generator)
117
+
118
+ # Filter NSFW content
119
+ safe_images = []
120
+ for i, img in enumerate(results.images):
121
+ if results.nsfw_content_detected and results.nsfw_content_detected[i]:
122
+ safe_images.append(self._create_black_image(kwargs["width"], kwargs["height"]))
123
+ else:
124
+ safe_images.append(img)
125
+
126
+ return {
127
+ "images": safe_images,
128
+ "seed": seed or generator.initial_seed()
129
+ }
130
+
131
+ def _update_memory(self, user_id: str, prompt: str, results: Dict):
132
+ """Store generation history"""
133
+ self.memory.append({
134
+ "user_id": user_id,
135
+ "prompt": prompt,
136
+ "timestamp": datetime.datetime.now(),
137
+ "metadata": results["metadata"]
138
+ })
139
+ if len(self.memory) > self.config["memory_size"]:
140
+ self.memory.pop(0)
141
+
142
+ def _get_default_params(self):
143
+ return {
144
+ "height": 512,
145
+ "width": 512,
146
+ "num_images_per_prompt": 1,
147
+ "num_inference_steps": 50,
148
+ "guidance_scale": 7.5
149
+ }
150
+
151
+ def _create_black_image(self, width: int, height: int) -> Image.Image:
152
+ return Image.new("RGB", (width, height), (0, 0, 0))
153
+
154
+ # ----------- User Interaction Methods -----------
155
+ def set_style(self, user_id: str, style: str):
156
+ if style in self.config["art_styles"]:
157
+ self.user_profiles.setdefault(user_id, {})["style"] = style
158
+ return f"Style set to {style}"
159
+ return f"Invalid style. Available styles: {', '.join(self.config['art_styles'])}"
160
+
161
+ def get_history(self, user_id: str) -> List[Dict]:
162
+ return [entry for entry in self.memory if entry["user_id"] == user_id]
163
+
164
+ # ------------------ Gradio Interface ------------------
165
+ def create_web_interface(agent: StableDiffusionAgent):
166
+ css = """
167
+ .gradio-container {max-width: 900px!important}
168
+ .output-image img {box-shadow: 0 4px 8px rgba(0,0,0,0.1)}
169
+ """
170
+
171
+ with gr.Blocks(css=css) as interface:
172
+ gr.Markdown("# 🎨 AI Art Generator Agent")
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=1):
176
+ user_id = gr.Textbox(label="User ID", placeholder="Enter unique identifier")
177
+ prompt = gr.Textbox(label="Prompt", lines=3)
178
+ negative_prompt = gr.Textbox(label="Negative Prompt")
179
+ style = gr.Dropdown(agent.config["art_styles"], label="Art Style")
180
+ generate_btn = gr.Button("Generate", variant="primary")
181
+
182
+ with gr.Column(scale=1):
183
+ output_image = gr.Image(label="Generated Art", elem_classes=["output-image"])
184
+ meta_info = gr.JSON(label="Generation Metadata")
185
+
186
+ with gr.Accordion("Advanced Settings", open=False):
187
+ with gr.Row():
188
+ steps = gr.Slider(10, 100, value=50, label="Steps")
189
+ guidance = gr.Slider(1.0, 20.0, value=7.5, label="Guidance Scale")
190
+ seed = gr.Number(label="Seed (optional)")
191
+
192
+ generate_btn.click(
193
+ fn=lambda *args: agent.generate(*args),
194
+ inputs=[user_id, prompt, negative_prompt, style, steps, guidance, seed],
195
+ outputs=[output_image, meta_info]
196
+ )
197
+
198
+ return interface
199
+
200
+ if __name__ == "__main__":
201
+ # Initialize agent
202
+ config = {
203
+ "prompt_enhancer": True,
204
+ "art_styles": ["realistic", "anime", "cyberpunk", "watercolor"]
205
+ }
206
+ agent = StableDiffusionAgent(config)
207
+
208
+ # Launch Gradio interface
209
+ interface = create_web_interface(agent)
210
+ interface.launch(server_port=7860, share=True)