DarkMuse / app.py
Nevaehni's picture
initial
b0d8cff
raw
history blame
10.4 kB
# discord_bot.py
import asyncio
import logging
import os
import re
import sys
import discord
import requests # Ensure 'requests' is installed
from discord import Embed
from discord.ext import commands
from gradio_client import Client
from gradio_client.exceptions import AppError # Updated import
# **Fetch Discord Bot Token from Environment Variable**
DISCORD_BOT_TOKEN = os.environ.get('DISCORD_BOT_TOKEN')
HF_TOKEN = os.environ.get('HF_TOKEN') # Fetch the HF_TOKEN
if not DISCORD_BOT_TOKEN:
print("Error: The environment variable 'DISCORD_BOT_TOKEN' is not set.")
sys.exit(1)
if not HF_TOKEN:
print("Error: The environment variable 'HF_TOKEN' is not set.")
sys.exit(1)
# Configure logging
logging.basicConfig(
level=logging.INFO, # Change to DEBUG for more detailed logs
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Intents are required for accessing certain Discord gateway events
intents = discord.Intents.default()
intents.message_content = True # Enable access to message content
# Initialize the bot with command prefix '!' and the specified intents
bot = commands.Bot(command_prefix='!', intents=intents, help_command=None)
# Regular expression to parse the prompt parameter
PROMPT_REGEX = re.compile(r'prompt\s*=\s*"(.*?)"')
# Initialize the Gradio client with hf_token
GRADIO_CLIENT = Client("Nevaehni/FLUX.1-schnell", hf_token=HF_TOKEN)
@bot.event
async def on_ready():
"""Event handler triggered when the bot is ready."""
logger.info(f'Logged in as {bot.user} (ID: {bot.user.id})')
logger.info('------')
def parse_prompt(command: str) -> str:
"""
Parse the prompt from the command string.
Args:
command (str): The command message content.
Returns:
str: The extracted prompt or an empty string if not found.
"""
match = PROMPT_REGEX.search(command)
if match:
return match.group(1).strip()
return ''
def create_example_embed() -> Embed:
"""
Create an embed message with an example !generate command.
Returns:
Embed: The Discord embed object.
"""
example_command = '!generate prompt="High resolution serene landscape with text \'cucolina\'. seed:1"'
embed = Embed(
description=f"```\n{example_command}\n```",
color=discord.Color.blue()
)
return embed
@bot.command(name='generate')
async def generate(ctx: commands.Context, *, args: str = None):
"""
Command handler for !generate. Generates content based on the provided prompt.
Args:
ctx (commands.Context): The context in which the command was invoked.
args (str, optional): The arguments passed with the command.
"""
if not args:
# No parameters provided, send example command without copy button
embed = create_example_embed()
await ctx.send(embed=embed)
return
# Parse the prompt from the arguments
prompt = parse_prompt(args)
if not prompt:
# Prompt parameter not found or empty
await ctx.send("❌ **Error:** Prompt cannot be empty. Please provide a valid input.")
return
# Acknowledge the command and indicate processing
processing_message = await ctx.send("πŸ”„ Generating your content, please wait...")
try:
logger.info(f"Received prompt: {prompt}")
# Define an asynchronous wrapper for the predict call
async def call_predict():
return GRADIO_CLIENT.predict(param_0=prompt, api_name="/predict")
# Set a timeout for the predict call (e.g., 60 seconds)
response = await asyncio.wait_for(call_predict(), timeout=60)
logger.info(f"API response: {response}")
# **Debugging: Log the actual response**
logger.debug(f"API Response: {response}")
# Reconstruct the exact command used by the user
command_used = ctx.message.content.strip()
# Handle different response structures
if isinstance(response, dict):
# Check if 'url' key exists
url = response.get('url')
if url:
if isinstance(url, str):
# Embed the image if it's an image URL
if url.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp')):
embed = Embed(title="🎨 Generated Image", color=discord.Color.green())
embed.set_image(url=url)
# Add prompt as a field with code block
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
# Send the embed and mention the user
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
# If not an image, send the URL directly
embed = Embed(title="πŸ”— Generated Content", description=url, color=discord.Color.green())
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
# 'url' exists but is not a string
await ctx.send("❌ **Error:** Received an invalid URL format from the API.")
else:
# 'url' key does not exist
await ctx.send("❌ **Error:** The API response does not contain a 'url' key.")
elif isinstance(response, str):
# Assume the response is a file path
file_path = response
if os.path.isfile(file_path):
try:
# Extract the file name
file_name = os.path.basename(file_path)
# Create a Discord File object
discord_file = discord.File(file_path, filename=file_name)
# Create an embed with the image
embed = Embed(title="🎨 Generated Image", color=discord.Color.green())
embed.set_image(url=f"attachment://{file_name}")
# Add prompt as a field with code block
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
# Send the embed with the file and mention the user
await ctx.send(content=f"{ctx.author.mention}", embed=embed, file=discord_file)
logger.info(f"Sent image from {file_path} to Discord.")
except Exception as e:
logger.error(f"Failed to send image to Discord: {e}")
await ctx.send("❌ **Error:** Failed to send the generated image to Discord.")
else:
await ctx.send("❌ **Error:** The API returned an invalid file path.")
elif isinstance(response, list):
# Handle list responses if applicable
if len(response) > 0 and isinstance(response[0], dict):
first_item = response[0]
url = first_item.get('url')
if url and isinstance(url, str):
if url.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp')):
embed = Embed(title="🎨 Generated Image", color=discord.Color.green())
embed.set_image(url=url)
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
embed = Embed(title="πŸ”— Generated Content", description=url, color=discord.Color.green())
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False)
await ctx.send(content=f"{ctx.author.mention}", embed=embed)
else:
await ctx.send("❌ **Error:** Received an invalid URL format from the API.")
else:
await ctx.send("❌ **Error:** The API response is an unexpected list structure.")
else:
# Response is neither dict, str, nor list
await ctx.send("❌ **Error:** Unexpected response format from the API.")
except asyncio.TimeoutError:
logger.error("API request timed out.")
await ctx.send("⏰ **Error:** The request to the API timed out. Please try again later.")
except AppError as e:
logger.error(f"API Error: {str(e)}")
await ctx.send(f"❌ **API Error:** {str(e)}")
except requests.exceptions.ConnectionError:
logger.error("Failed to connect to the API.")
await ctx.send("⚠️ **Error:** Failed to connect to the API. Please check your network connection.")
except Exception as e:
logger.exception("An unexpected error occurred.")
await ctx.send(f"❌ **Error:** An unexpected error occurred: {str(e)}")
finally:
# Delete the processing message
await processing_message.delete()
@bot.event
async def on_command_error(ctx: commands.Context, error):
"""
Global error handler for command errors.
Args:
ctx (commands.Context): The context in which the error occurred.
error (Exception): The exception that was raised.
"""
if isinstance(error, commands.CommandNotFound):
await ctx.send("❓ **Error:** Unknown command. Please use `!generate` to generate content.")
elif isinstance(error, commands.CommandOnCooldown):
await ctx.send(f"⏳ **Please wait {error.retry_after:.2f} seconds before using this command again.**")
else:
await ctx.send(f"❌ **Error:** {str(error)}")
logger.error(f"Unhandled command error: {str(error)}")
async def handle_root(request):
return web.Response(text="DarkMuse GOES VROOOOM", status=200)
async def start_web_server():
app = web.Application()
app.router.add_get('/', handle_root)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, '0.0.0.0', 7860)
await site.start()
# Run the bot
if __name__ == '__main__':
try:
bot.run(DISCORD_BOT_TOKEN)
bot.loop.create_task(start_web_server())
except Exception as e:
logger.exception(f"Failed to run the bot: {e}")