Spaces:
Runtime error
Runtime error
import os | |
import random | |
import uuid | |
import base64 | |
import json | |
import re | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import torch | |
from PIL import Image | |
from datetime import datetime | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
import anthropic | |
# ============================================================ | |
# === GLOBALS & DATA STORAGE FILES | |
# ============================================================ | |
LIKES_CACHE_FILE = "likes_cache.json" | |
LOG_CACHE_FILE = "log_cache.json" | |
QUOTE_CACHE_FILE = "quotes_cache.json" | |
STATIC_URL_PREFIX = "https://huggingface.co/spaces/awacke1/dalle-3-xl-lora-v2/file=" | |
# Initialize caches / load from JSON | |
def load_json(file): | |
if os.path.exists(file): | |
with open(file, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
return {} | |
def save_json(file, data): | |
with open(file, 'w', encoding='utf-8') as f: | |
json.dump(data, f, indent=4) | |
likes_cache = load_json(LIKES_CACHE_FILE) or {} | |
chat_logs = load_json(LOG_CACHE_FILE) if os.path.exists(LOG_CACHE_FILE) else [] | |
quotes = load_json(QUOTE_CACHE_FILE) if os.path.exists(QUOTE_CACHE_FILE) else [] | |
# DataFrame for images | |
image_metadata = pd.DataFrame(columns=['Filename','Prompt','Likes','Dislikes','Hearts','Created']) | |
# ============================================================ | |
# === ANTHROPIC CLIENT (Claude) | |
# ============================================================ | |
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", None) | |
claude_client = anthropic.Anthropic(api_key=anthropic_api_key) if anthropic_api_key else None | |
# ============================================================ | |
# === IMAGE PIPELINE | |
# ============================================================ | |
pipe = None | |
if torch.cuda.is_available(): | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"fluently/Fluently-XL-v4", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
) | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle") | |
pipe.set_adapters("dalle") | |
pipe.to("cuda") | |
MAX_SEED = np.iinfo(np.int32).max | |
# ============================================================ | |
# === HELPER FUNCTIONS | |
# ============================================================ | |
def randomize_seed_fn(seed: int, randomize_seed: bool): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return int(seed) | |
def sanitize_prompt(prompt): | |
return re.sub(r'[^\w\s-]', '', prompt.lower())[:50] | |
def save_image_locally(img, prompt): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
safe_prompt = sanitize_prompt(prompt) | |
filename = f"{timestamp}_{safe_prompt}.png" | |
img.save(filename) | |
if filename not in likes_cache: | |
likes_cache[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0} | |
save_json(LIKES_CACHE_FILE, likes_cache) | |
global image_metadata | |
new_row = { | |
'Filename': filename, | |
'Prompt': prompt, | |
'Likes': 0, | |
'Dislikes': 0, | |
'Hearts': 0, | |
'Created': str(datetime.now()) | |
} | |
image_metadata = pd.concat([image_metadata, pd.DataFrame([new_row])], ignore_index=True) | |
return filename | |
def log_input_output(user_input, model_output, link=""): | |
global chat_logs | |
chat_logs.append({ | |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"input": user_input, | |
"output": model_output, | |
"file_link": link | |
}) | |
save_json(LOG_CACHE_FILE, chat_logs) | |
def generate_image( | |
prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, randomize_seed | |
): | |
if pipe is None: | |
return ["No GPU available, cannot generate images."], 0, [], [], [] | |
seed = randomize_seed_fn(seed, randomize_seed) | |
if not use_negative_prompt: | |
negative_prompt = "" | |
images = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=20, | |
num_images_per_prompt=1, | |
cross_attention_kwargs={"scale": 0.65}, | |
output_type="pil", | |
).images | |
filenames = [] | |
for img in images: | |
fname = save_image_locally(img, prompt) | |
filenames.append(fname) | |
links = [f"{STATIC_URL_PREFIX}{f}" for f in filenames] | |
# Log the generation | |
log_input_output(user_input=prompt, model_output="(image generated)", link=", ".join(links)) | |
# Return Gradio objects | |
return filenames, seed, links, get_image_gallery(), image_metadata.values.tolist() | |
def get_image_gallery(): | |
return [ | |
(row["Filename"], f"{row['Filename']}\nPrompt: {row['Prompt']}\n👍 {row['Likes']} 👎 {row['Dislikes']} ❤️ {row['Hearts']}") | |
for _, row in image_metadata.iterrows() | |
if os.path.exists(row["Filename"]) | |
] | |
def vote_image(filename, vote_type): | |
if filename and filename in likes_cache: | |
likes_cache[filename][vote_type] += 1 | |
save_json(LIKES_CACHE_FILE, likes_cache) | |
idx = image_metadata.index[image_metadata['Filename'] == filename] | |
if not idx.empty: | |
image_metadata.at[idx, vote_type.capitalize()] = image_metadata.at[idx, vote_type.capitalize()] + 1 | |
return get_image_gallery(), image_metadata.values.tolist() | |
def delete_image(filename): | |
if filename and os.path.exists(filename): | |
os.remove(filename) | |
if filename in likes_cache: | |
del likes_cache[filename] | |
save_json(LIKES_CACHE_FILE, likes_cache) | |
global image_metadata | |
image_metadata = image_metadata[image_metadata['Filename'] != filename] | |
return get_image_gallery(), image_metadata.values.tolist() | |
def delete_all_images(): | |
global image_metadata, likes_cache | |
for f in image_metadata["Filename"].tolist(): | |
if os.path.exists(f): | |
os.remove(f) | |
image_metadata = pd.DataFrame(columns=['Filename','Prompt','Likes','Dislikes','Hearts','Created']) | |
likes_cache.clear() | |
save_json(LIKES_CACHE_FILE, likes_cache) | |
return get_image_gallery(), image_metadata.values.tolist() | |
# === QUOTES Demo (Optional) === | |
def add_quote(q): | |
if q.strip(): | |
quotes.append({ | |
"text": q, | |
"likes": 0, | |
"created": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
}) | |
save_json(QUOTE_CACHE_FILE, quotes) | |
return [[idx, itm["text"], itm["likes"], itm["created"]] for idx, itm in enumerate(quotes)] | |
def like_quote(idx): | |
if 0 <= idx < len(quotes): | |
quotes[idx]["likes"] += 1 | |
save_json(QUOTE_CACHE_FILE, quotes) | |
return [[i, itm["text"], itm["likes"], itm["created"]] for i, itm in enumerate(quotes)] | |
# === CLAUDE Chat === | |
def chat_claude(user_message): | |
if not claude_client: | |
return "No Anthropic API key configured." | |
if not user_message.strip(): | |
return "Empty message." | |
resp = claude_client.messages.create( | |
model="claude-3-sonnet-20240229", | |
max_tokens=1000, | |
messages=[{"role": "user", "content": user_message}], | |
) | |
text = resp.content[0].text | |
log_input_output(user_input=user_message, model_output=text, link="") | |
return text | |
# === Refresh gallery + DF | |
def refresh_gallery_and_df(): | |
return gr.update(value=get_image_gallery()), gr.update(value=image_metadata.values.tolist()) | |
# ============================================================ | |
# === BUILD GRADIO UI | |
# ============================================================ | |
DESCRIPTION = """# 🎨 ArtForge & Claude Chat | |
Generate AI art, chat with Claude, log everything, and vote on images. | |
""" | |
examples = [ | |
"Futuristic cityscape in neon lighting", | |
"Cute cat wearing a wizard hat", | |
"Surreal landscape with floating islands", | |
] | |
with gr.Blocks(css=".gradio-container {max-width: 1024px !important}") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Tab("Generate Images"): | |
with gr.Row(): | |
prompt = gr.Text(label="Prompt", max_lines=1) | |
run_button = gr.Button("Run") | |
result = gr.Gallery(label="Result", columns=1, preview=True) | |
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True) | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
lines=3, | |
value="(deformed, distorted:1.3), poorly drawn, bad anatomy", | |
) | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
width = gr.Slider(label="Width", minimum=512, maximum=2048, step=64, value=1024) | |
height = gr.Slider(label="Height", minimum=512, maximum=2048, step=64, value=1024) | |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, step=0.5, value=7) | |
run_button.click( | |
fn=generate_image, | |
inputs=[prompt, negative_prompt, use_negative_prompt, seed, width, height, guidance_scale, randomize_seed], | |
outputs=[result, seed, gr.HTML(visible=False), gr.Gallery(), gr.Dataframe()], | |
api_name="run" | |
) | |
gr.Examples(examples=examples, inputs=prompt) | |
with gr.Tab("Chat with Claude"): | |
claude_input = gr.Textbox(label="Your Message") | |
claude_output = gr.Textbox(label="Claude's Reply", lines=4) | |
send_claude = gr.Button("Send to Claude") | |
send_claude.click(chat_claude, inputs=claude_input, outputs=claude_output) | |
with gr.Tab("Logs & Management"): | |
with gr.Accordion("All Logs", open=False): | |
logs_data = gr.Dataframe( | |
value=pd.DataFrame(chat_logs), | |
label="Input/Output Logs", | |
interactive=False, | |
wrap=True | |
) | |
with gr.Tab("Gallery & Voting"): | |
image_gallery = gr.Gallery(label="Generated Images", columns=4) | |
metadata_df = gr.Dataframe( | |
label="Image Metadata", | |
headers=["Filename", "Prompt", "Likes", "Dislikes", "Hearts", "Created"], | |
interactive=False | |
) | |
selected_image = gr.State() | |
with gr.Row(): | |
like_button = gr.Button("👍 Like") | |
dislike_button = gr.Button("👎 Dislike") | |
heart_button = gr.Button("❤️ Heart") | |
delete_image_button = gr.Button("🗑️ Delete Image") | |
delete_all_button = gr.Button("🗑️ Delete All") | |
image_gallery.select(fn=lambda evt: evt, inputs=[], outputs=[selected_image]) | |
like_button.click(fn=lambda x: vote_image(x, 'likes'), inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
dislike_button.click(fn=lambda x: vote_image(x, 'dislikes'), inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
heart_button.click(fn=lambda x: vote_image(x, 'hearts'), inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
delete_image_button.click(fn=delete_image, inputs=selected_image, outputs=[image_gallery, metadata_df]) | |
delete_all_button.click(fn=delete_all_images, outputs=[image_gallery, metadata_df]) | |
with gr.Tab("Quotes (Optional)"): | |
quote_input = gr.Textbox(label="Enter a quote") | |
add_q_button = gr.Button("Add Quote") | |
quote_df = gr.Dataframe(value=[(idx, q['text'], q['likes'], q['created']) for idx,q in enumerate(quotes)], | |
headers=["Index","Text","Likes","Created"], interactive=False) | |
selected_quote = gr.Number(label="Index to Like") | |
like_q_button = gr.Button("Like Quote") | |
add_q_button.click(fn=add_quote, inputs=quote_input, outputs=quote_df) | |
like_q_button.click(fn=like_quote, inputs=selected_quote, outputs=quote_df) | |
demo.load(fn=refresh_gallery_and_df, outputs=[image_gallery, metadata_df]) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |