DarkMuse / app.py
Nevaehni's picture
testing
ba5e9de
raw
history blame
11.9 kB
import asyncio
import logging
import os
import re
import sys
import datetime # For scheduling
from huggingface_hub import HfApi # For Hugging Face API
import discord
import requests
from aiohttp import web # Added missing import
from discord import Embed
from discord.ext import commands
from gradio_client import Client
from gradio_client.exceptions import AppError
# **Fetch Discord Bot Token and Hugging Face Token from Environment Variables**
DISCORD_BOT_TOKEN = os.environ.get('DISCORD_BOT_TOKEN')
HF_TOKEN = os.environ.get('HF_TOKEN') # Fetch the HF_TOKEN
if not DISCORD_BOT_TOKEN:
print("Error: The environment variable 'DISCORD_BOT_TOKEN' is not set.")
sys.exit(1)
if not HF_TOKEN:
print("Error: The environment variable 'HF_TOKEN' is not set.")
sys.exit(1)
# Configure logging
logging.basicConfig(
level=logging.INFO, # Change to DEBUG for more detailed logs
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Intents are required for accessing certain Discord gateway events
intents = discord.Intents.default()
intents.message_content = True # Enable access to message content
# Initialize the bot with command prefix '!' and the specified intents
bot = commands.Bot(command_prefix='!', intents=intents, help_command=None)
# Regular expression to parse the prompt parameter with DOTALL flag to include newlines
PROMPT_REGEX = re.compile(r'prompt\s*=\s*"(.*?)"', re.DOTALL)
# Initialize the Gradio client with hf_token
GRADIO_CLIENT = Client("Nevaehni/FLUX.1-schnell", hf_token=HF_TOKEN)
# Initialize Hugging Face API client
hf_api = HfApi()
@bot.event
async def on_ready():
"""Event handler triggered when the bot is ready."""
logger.info(f'Logged in as {bot.user} (ID: {bot.user.id})')
logger.info('------')
def parse_prompt(command: str) -> str:
"""
Parse the prompt from the command string.
Args:
command (str): The command message content.
Returns:
str: The extracted prompt or an empty string if not found.
"""
match = PROMPT_REGEX.search(command)
if match:
return match.group(1).strip()
return ''
def create_example_embed() -> Embed:
"""
Create an embed message with an example !generate command.
Returns:
Embed: The Discord embed object.
"""
# Example command with a newline character
example_command = '!generate prompt="High resolution serene landscape\nwith text \'cucolina\'. seed:1"'
embed = Embed(
description=f"```\n{example_command}\n```",
color=discord.Color.blue()
)
return embed
@bot.command(name='generate')
async def generate(ctx: commands.Context, *, args: str = None):
"""
Command handler for !generate. Generates content based on the provided prompt.
Args:
ctx (commands.Context): The context in which the command was invoked.
args (str, optional): The arguments passed with the command.
"""
if not args:
# No parameters provided, send example command without copy button
embed = create_example_embed()
await ctx.send(embed=embed)
return
# Parse the prompt from the arguments
prompt = parse_prompt(args)
if not prompt:
# Prompt parameter not found or empty
await ctx.send("❌ **Error:** Prompt cannot be empty. Please provide a valid input.")
return
# Acknowledge the command and indicate processing
processing_message = await ctx.send("πŸ”„ Generating your content, please wait...")
try:
logger.info(f"Received prompt: {prompt}")
# Non-blocking call to predict using asyncio.to_thread
response = await asyncio.wait_for(
asyncio.to_thread(GRADIO_CLIENT.predict, param_0=prompt, api_name="/predict"),
timeout=60
)
logger.info(f"API response: {response}")
# **Debugging: Log the actual response**
logger.debug(f"API Response: {response}")
# Reconstruct the exact command used by the user
command_used = ctx.message.content.strip()
# Handle different response structures
if isinstance(response, dict):
# Check if 'url' key exists
url = response.get('url')
if url:
if isinstance(url, str):
# Embed the image if it's an image URL
if url.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp')):
embed = Embed(title="🎨 Generated Image", color=discord.Color.green())
embed.set_image(url=url)
# Add prompt as a field with code block
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
# Send the embed and mention the user
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
# If not an image, send the URL directly
embed = Embed(title="πŸ”— Generated Content", description=url, color=discord.Color.green())
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
# 'url' exists but is not a string
await ctx.send("❌ **Error:** Received an invalid URL format from the API.")
else:
# 'url' key does not exist
await ctx.send("❌ **Error:** The API response does not contain a 'url' key.")
elif isinstance(response, str):
# Assume the response is a file path
file_path = response
if os.path.isfile(file_path):
try:
# Extract the file name
file_name = os.path.basename(file_path)
# Create a Discord File object
discord_file = discord.File(file_path, filename=file_name)
# Create an embed with the image
embed = Embed(title="🎨 Generated Image", color=discord.Color.green())
embed.set_image(url=f"attachment://{file_name}")
# Add prompt as a field with code block
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
# Send the embed with the file and mention the user
await ctx.send(content=f"{ctx.author.mention}", embed=embed, file=discord_file)
logger.info(f"Sent image from {file_path} to Discord.")
except Exception as e:
logger.error(f"Failed to send image to Discord: {e}")
await ctx.send("❌ **Error:** Failed to send the generated image to Discord.")
else:
await ctx.send("❌ **Error:** The API returned an invalid file path.")
elif isinstance(response, list):
# Handle list responses if applicable
if len(response) > 0 and isinstance(response[0], dict):
first_item = response[0]
url = first_item.get('url')
if url and isinstance(url, str):
if url.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp')):
embed = Embed(title="🎨 Generated Image", color=discord.Color.green())
embed.set_image(url=url)
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
embed = Embed(title="πŸ”— Generated Content", description=url, color=discord.Color.green())
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
await ctx.send("❌ **Error:** Received an invalid URL format from the API.")
else:
await ctx.send("❌ **Error:** The API response is an unexpected list structure.")
else:
# Response is neither dict, str, nor list
await ctx.send("❌ **Error:** Unexpected response format from the API.")
except asyncio.TimeoutError:
logger.error("API request timed out.")
await ctx.send("⏰ **Error:** The request to the API timed out. Please try again later.")
except AppError as e:
logger.error(f"API Error: {str(e)}")
await ctx.send(f"❌ **API Error:** {str(e)}")
except requests.exceptions.ConnectionError:
logger.error("Failed to connect to the API.")
await ctx.send("⚠️ **Error:** Failed to connect to the API. Please check your network connection.")
except Exception as e:
logger.exception("An unexpected error occurred.")
await ctx.send(f"❌ **Error:** An unexpected error occurred: {str(e)}")
finally:
# Delete the processing message
await processing_message.delete()
@bot.event
async def on_command_error(ctx: commands.Context, error):
"""
Global error handler for command errors.
Args:
ctx (commands.Context): The context in which the error occurred.
error (Exception): The exception that was raised.
"""
if isinstance(error, commands.CommandNotFound):
await ctx.send("❓ **Error:** Unknown command. Please use `!generate` to generate content.")
elif isinstance(error, commands.CommandOnCooldown):
await ctx.send(f"⏳ **Please wait {error.retry_after:.2f} seconds before using this command again.**")
else:
await ctx.send(f"❌ **Error:** {str(error)}")
logger.error(f"Unhandled command error: {str(error)}")
async def handle_root(request):
return web.Response(text="DarkMuse GOES VROOOOM", status=200)
async def start_web_server():
app = web.Application()
app.router.add_get('/', handle_root)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, '0.0.0.0', 7860)
await site.start()
logger.info("Web server started on port 7860")
async def start_bot():
await bot.start(DISCORD_BOT_TOKEN)
async def schedule_restart():
"""
Schedule the bot to restart the Hugging Face Space at a specified interval.
For testing purposes, this function restarts the Space every minute.
"""
space_id = "Nevaehni/FLUX.1-schnell" # Replace with your actual space ID
restart_interval_seconds = 60 # Restart every 60 seconds (1 minute) for testing
while True:
logger.info(f"Scheduled space restart in {restart_interval_seconds} seconds.")
await asyncio.sleep(restart_interval_seconds)
try:
logger.info("Attempting to restart the Hugging Face Space...")
hf_api.restart_space(space_id=space_id, token=HF_TOKEN)
logger.info("Space restarted successfully.")
except Exception as e:
logger.error(f"Failed to restart space: {e}")
# Optional: Add a short delay to prevent rapid retries in case of failure
await asyncio.sleep(5) # Wait 5 seconds before the next restart cycle
async def main():
# Start the scheduler as a background task
asyncio.create_task(schedule_restart())
# Run the bot and web server concurrently
await asyncio.gather(
start_bot(),
start_web_server()
)
# Run the bot and web server concurrently
if __name__ == '__main__':
try:
asyncio.run(main())
except Exception as e:
logger.exception(f"Failed to run the bot: {e}")