Spaces:
Running
Running
import torch | |
from diffusers import StableDiffusionPipeline | |
import gradio as gr | |
# Load the model | |
model_id = "SG161222/RealVisXL_V4.0" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
pipe.to("cpu") # Use "cuda" if GPU is available | |
# Define placeholder functions and variables | |
DEFAULT_STYLE_NAME = "default" | |
default_negative = "" | |
NUM_IMAGES_PER_PROMPT = 1 | |
def check_text(prompt, negative_prompt): | |
# Implement your text check logic here | |
return False | |
def apply_style(style, prompt, negative_prompt): | |
# Implement your style application logic here | |
return prompt, negative_prompt | |
def randomize_seed_fn(seed, randomize_seed): | |
# Implement your seed randomization logic here | |
return seed | |
def save_image(image): | |
# Implement your image saving logic here | |
return image | |
def generate_image(prompt, negative_prompt="", use_negative_prompt=False, style=DEFAULT_STYLE_NAME, seed=0, width=1024, height=1024, guidance_scale=3, randomize_seed=False, use_resolution_binning=True, progress=gr.Progress(track_tqdm=True)): | |
if check_text(prompt, negative_prompt): | |
raise ValueError("Prompt contains restricted words.") | |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt) | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
generator = torch.Generator().manual_seed(seed) | |
if not use_negative_prompt: | |
negative_prompt = "" | |
negative_prompt += default_negative | |
options = { | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"width": width, | |
"height": height, | |
"guidance_scale": guidance_scale, | |
"num_inference_steps": 25, | |
"generator": generator, | |
"num_images_per_prompt": NUM_IMAGES_PER_PROMPT, | |
"use_resolution_binning": use_resolution_binning, | |
"output_type": "pil", | |
} | |
images = pipe(**options).images | |
image_paths = [save_image(img) for img in images] | |
return image_paths, seed | |
def chatbot(prompt): | |
# Generate the image based on the user's input | |
image = generate_image(prompt) | |
return image | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=chatbot, | |
inputs="text", | |
outputs="image", | |
title="RealVisXL V4.0 Text-to-Image Chatbot", | |
description="Enter a text prompt and get an AI-generated image." | |
) | |
# Launch the interface | |
interface.launch() | |