Spaces:
Sleeping
Sleeping
import torch | |
import datetime | |
from typing import List, Dict, Optional | |
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler | |
from PIL import Image | |
import gradio as gr | |
from transformers import pipeline as hf_pipeline | |
class StableDiffusionAgent: | |
def __init__(self, config: Optional[Dict] = None): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.default_config = { | |
"model": "stabilityai/stable-diffusion-2-1", | |
"safety_checker": True, | |
"max_resolution": 1024, | |
"art_styles": ["realistic", "anime", "cyberpunk", "watercolor", "pixel-art"], | |
"default_style": "realistic", | |
"memory_size": 10, | |
"prompt_enhancer": True | |
} | |
self.config = {**self.default_config, **(config or {})} | |
self._initialize_models() | |
self.memory = [] | |
self.user_profiles = {} | |
self.current_style = self.config["default_style"] | |
def _initialize_models(self): | |
"""Load all required models""" | |
# Text-to-Image Pipeline | |
self.sd_pipeline = StableDiffusionPipeline.from_pretrained( | |
self.config["model"], | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
safety_checker=None if not self.config["safety_checker"] else None | |
).to(self.device) | |
self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
self.sd_pipeline.scheduler.config | |
) | |
if self.device == "cuda": | |
self.sd_pipeline.enable_xformers_memory_efficient_attention() | |
self.sd_pipeline.enable_attention_slicing() | |
# Prompt Enhancement Model | |
if self.config["prompt_enhancer"]: | |
self.prompt_pipeline = hf_pipeline( | |
"text2text-generation", | |
model="microsoft/Promptist" | |
) | |
def _enhance_prompt(self, prompt: str) -> str: | |
"""Improve prompt using LLM""" | |
if self.config["prompt_enhancer"]: | |
try: | |
return self.prompt_pipeline(prompt, max_length=256)[0]["generated_text"] | |
except: | |
return prompt | |
return prompt | |
def _apply_style(self, prompt: str, style: str) -> str: | |
"""Apply artistic style to prompt""" | |
style_templates = { | |
"anime": "anime style, vibrant colors, detailed line art", | |
"cyberpunk": "neon lights, cyberpunk style, rainy night, futuristic", | |
"watercolor": "watercolor painting, soft edges, artistic", | |
"pixel-art": "8-bit pixel art, retro gaming style" | |
} | |
return f"{prompt}, {style_templates.get(style, '')}" | |
def generate( | |
self, | |
user_id: str, | |
prompt: str, | |
negative_prompt: str = "", | |
style: Optional[str] = None, | |
**kwargs | |
) -> Dict: | |
"""Main generation method with user context""" | |
# Get user preferences | |
user_prefs = self.user_profiles.get(user_id, {}) | |
# Enhance prompt | |
enhanced_prompt = self._enhance_prompt(prompt) | |
# Apply style | |
style = style or user_prefs.get("style", self.current_style) | |
final_prompt = self._apply_style(enhanced_prompt, style) | |
# Generate image | |
results = self._generate_image( | |
prompt=final_prompt, | |
negative_prompt=negative_prompt, | |
**{**self._get_default_params(), **kwargs} | |
) | |
# Update memory and user profile | |
self._update_memory(user_id, prompt, results) | |
return { | |
"images": results["images"], | |
"metadata": { | |
"enhanced_prompt": enhanced_prompt, | |
"style": style, | |
"seed": results["seed"], | |
"timestamp": datetime.datetime.now().isoformat() | |
} | |
} | |
def _generate_image(self, **kwargs) -> Dict: | |
"""Low-level generation with safety checks""" | |
generator = torch.Generator(device=self.device) | |
seed = kwargs.pop("seed", None) | |
if seed is not None: | |
generator = generator.manual_seed(seed) | |
results = self.sd_pipeline(**kwargs, generator=generator) | |
# Filter NSFW content | |
safe_images = [] | |
for i, img in enumerate(results.images): | |
if results.nsfw_content_detected and results.nsfw_content_detected[i]: | |
safe_images.append(self._create_black_image(kwargs["width"], kwargs["height"])) | |
else: | |
safe_images.append(img) | |
return { | |
"images": safe_images, | |
"seed": seed or generator.initial_seed() | |
} | |
def _update_memory(self, user_id: str, prompt: str, results: Dict): | |
"""Store generation history""" | |
self.memory.append({ | |
"user_id": user_id, | |
"prompt": prompt, | |
"timestamp": datetime.datetime.now(), | |
"metadata": results["metadata"] | |
}) | |
if len(self.memory) > self.config["memory_size"]: | |
self.memory.pop(0) | |
def _get_default_params(self): | |
return { | |
"height": 512, | |
"width": 512, | |
"num_images_per_prompt": 1, | |
"num_inference_steps": 50, | |
"guidance_scale": 7.5 | |
} | |
def _create_black_image(self, width: int, height: int) -> Image.Image: | |
return Image.new("RGB", (width, height), (0, 0, 0)) | |
# ----------- User Interaction Methods ----------- | |
def set_style(self, user_id: str, style: str): | |
if style in self.config["art_styles"]: | |
self.user_profiles.setdefault(user_id, {})["style"] = style | |
return f"Style set to {style}" | |
return f"Invalid style. Available styles: {', '.join(self.config['art_styles'])}" | |
def get_history(self, user_id: str) -> List[Dict]: | |
return [entry for entry in self.memory if entry["user_id"] == user_id] | |
# Update the Gradio interface section as follows: | |
# ------------------ Gradio Interface ------------------ | |
def create_web_interface(agent: StableDiffusionAgent): | |
css = """ | |
.gradio-container {max-width: 900px!important} | |
.output-image img {box-shadow: 0 4px 8px rgba(0,0,0,0.1)} | |
""" | |
with gr.Blocks(css=css) as interface: | |
gr.Markdown("# 🎨 AI Art Generator Agent") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
user_id = gr.Textbox(label="User ID", placeholder="Enter unique identifier") | |
prompt = gr.Textbox(label="Prompt", lines=3) | |
negative_prompt = gr.Textbox(label="Negative Prompt") | |
style = gr.Dropdown(agent.config["art_styles"], label="Art Style") | |
generate_btn = gr.Button("Generate", variant="primary") | |
with gr.Column(scale=1): | |
output_image = gr.Image(label="Generated Art", elem_classes=["output-image"]) | |
meta_info = gr.JSON(label="Generation Metadata") | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
steps = gr.Slider(10, 100, value=50, label="Steps") | |
guidance = gr.Slider(1.0, 20.0, value=7.5, label="Guidance Scale") | |
seed = gr.Number(label="Seed (optional)") | |
# Modified click handler | |
generate_btn.click( | |
fn=lambda user_id, prompt, negative_prompt, style, steps, guidance, seed: | |
agent.generate( | |
user_id=user_id, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
style=style, | |
num_inference_steps=steps, | |
guidance_scale=guidance, | |
seed=seed | |
), | |
inputs=[user_id, prompt, negative_prompt, style, steps, guidance, seed], | |
outputs=[output_image, meta_info] | |
) | |
return interface | |
if __name__ == "__main__": | |
# Initialize agent | |
config = { | |
"prompt_enhancer": True, | |
"art_styles": ["realistic", "anime", "cyberpunk", "watercolor"] | |
} | |
agent = StableDiffusionAgent(config) | |
# Launch Gradio interface | |
interface = create_web_interface(agent) | |
interface.launch(server_port=7860, share=True) |