import torch import datetime from typing import List, Dict, Optional from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler from PIL import Image import gradio as gr from transformers import pipeline as hf_pipeline class StableDiffusionAgent: def __init__(self, config: Optional[Dict] = None): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.default_config = { "model": "stabilityai/stable-diffusion-2-1", "safety_checker": True, "max_resolution": 1024, "art_styles": ["realistic", "anime", "cyberpunk", "watercolor", "pixel-art"], "default_style": "realistic", "memory_size": 10, "prompt_enhancer": True } self.config = {**self.default_config, **(config or {})} self._initialize_models() self.memory = [] self.user_profiles = {} self.current_style = self.config["default_style"] def _initialize_models(self): """Load all required models""" # Text-to-Image Pipeline self.sd_pipeline = StableDiffusionPipeline.from_pretrained( self.config["model"], torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, safety_checker=None if not self.config["safety_checker"] else None ).to(self.device) self.sd_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( self.sd_pipeline.scheduler.config ) if self.device == "cuda": self.sd_pipeline.enable_xformers_memory_efficient_attention() self.sd_pipeline.enable_attention_slicing() # Prompt Enhancement Model if self.config["prompt_enhancer"]: self.prompt_pipeline = hf_pipeline( "text2text-generation", model="microsoft/Promptist" ) def _enhance_prompt(self, prompt: str) -> str: """Improve prompt using LLM""" if self.config["prompt_enhancer"]: try: return self.prompt_pipeline(prompt, max_length=256)[0]["generated_text"] except: return prompt return prompt def _apply_style(self, prompt: str, style: str) -> str: """Apply artistic style to prompt""" style_templates = { "anime": "anime style, vibrant colors, detailed line art", "cyberpunk": "neon lights, cyberpunk style, rainy night, futuristic", "watercolor": "watercolor painting, soft edges, artistic", "pixel-art": "8-bit pixel art, retro gaming style" } return f"{prompt}, {style_templates.get(style, '')}" def generate( self, user_id: str, prompt: str, negative_prompt: str = "", style: Optional[str] = None, **kwargs ) -> Dict: """Main generation method with user context""" # Get user preferences user_prefs = self.user_profiles.get(user_id, {}) # Enhance prompt enhanced_prompt = self._enhance_prompt(prompt) # Apply style style = style or user_prefs.get("style", self.current_style) final_prompt = self._apply_style(enhanced_prompt, style) # Generate image results = self._generate_image( prompt=final_prompt, negative_prompt=negative_prompt, **{**self._get_default_params(), **kwargs} ) # Update memory and user profile self._update_memory(user_id, prompt, results) return { "images": results["images"], "metadata": { "enhanced_prompt": enhanced_prompt, "style": style, "seed": results["seed"], "timestamp": datetime.datetime.now().isoformat() } } def _generate_image(self, **kwargs) -> Dict: """Low-level generation with safety checks""" generator = torch.Generator(device=self.device) seed = kwargs.pop("seed", None) if seed is not None: generator = generator.manual_seed(seed) results = self.sd_pipeline(**kwargs, generator=generator) # Filter NSFW content safe_images = [] for i, img in enumerate(results.images): if results.nsfw_content_detected and results.nsfw_content_detected[i]: safe_images.append(self._create_black_image(kwargs["width"], kwargs["height"])) else: safe_images.append(img) return { "images": safe_images, "seed": seed or generator.initial_seed() } def _update_memory(self, user_id: str, prompt: str, results: Dict): """Store generation history""" self.memory.append({ "user_id": user_id, "prompt": prompt, "timestamp": datetime.datetime.now(), "metadata": results["metadata"] }) if len(self.memory) > self.config["memory_size"]: self.memory.pop(0) def _get_default_params(self): return { "height": 512, "width": 512, "num_images_per_prompt": 1, "num_inference_steps": 50, "guidance_scale": 7.5 } def _create_black_image(self, width: int, height: int) -> Image.Image: return Image.new("RGB", (width, height), (0, 0, 0)) # ----------- User Interaction Methods ----------- def set_style(self, user_id: str, style: str): if style in self.config["art_styles"]: self.user_profiles.setdefault(user_id, {})["style"] = style return f"Style set to {style}" return f"Invalid style. Available styles: {', '.join(self.config['art_styles'])}" def get_history(self, user_id: str) -> List[Dict]: return [entry for entry in self.memory if entry["user_id"] == user_id] # Update the Gradio interface section as follows: # ------------------ Gradio Interface ------------------ def create_web_interface(agent: StableDiffusionAgent): css = """ .gradio-container {max-width: 900px!important} .output-image img {box-shadow: 0 4px 8px rgba(0,0,0,0.1)} """ with gr.Blocks(css=css) as interface: gr.Markdown("# 🎨 AI Art Generator Agent") with gr.Row(): with gr.Column(scale=1): user_id = gr.Textbox(label="User ID", placeholder="Enter unique identifier") prompt = gr.Textbox(label="Prompt", lines=3) negative_prompt = gr.Textbox(label="Negative Prompt") style = gr.Dropdown(agent.config["art_styles"], label="Art Style") generate_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="Generated Art", elem_classes=["output-image"]) meta_info = gr.JSON(label="Generation Metadata") with gr.Accordion("Advanced Settings", open=False): with gr.Row(): steps = gr.Slider(10, 100, value=50, label="Steps") guidance = gr.Slider(1.0, 20.0, value=7.5, label="Guidance Scale") seed = gr.Number(label="Seed (optional)") # Modified click handler generate_btn.click( fn=lambda user_id, prompt, negative_prompt, style, steps, guidance, seed: agent.generate( user_id=user_id, prompt=prompt, negative_prompt=negative_prompt, style=style, num_inference_steps=steps, guidance_scale=guidance, seed=seed ), inputs=[user_id, prompt, negative_prompt, style, steps, guidance, seed], outputs=[output_image, meta_info] ) return interface if __name__ == "__main__": # Initialize agent config = { "prompt_enhancer": True, "art_styles": ["realistic", "anime", "cyberpunk", "watercolor"] } agent = StableDiffusionAgent(config) # Launch Gradio interface interface = create_web_interface(agent) interface.launch(server_port=7860, share=True)