|
import asyncio |
|
import logging |
|
import os |
|
import re |
|
import sys |
|
import datetime |
|
from huggingface_hub import HfApi |
|
|
|
import discord |
|
import requests |
|
from aiohttp import web |
|
from discord import Embed |
|
from discord.ext import commands |
|
from gradio_client import Client |
|
from gradio_client.exceptions import AppError |
|
|
|
|
|
DISCORD_BOT_TOKEN = os.environ.get('DISCORD_BOT_TOKEN') |
|
HF_TOKEN = os.environ.get('HF_TOKEN') |
|
|
|
if not DISCORD_BOT_TOKEN: |
|
print("Error: The environment variable 'DISCORD_BOT_TOKEN' is not set.") |
|
sys.exit(1) |
|
|
|
if not HF_TOKEN: |
|
print("Error: The environment variable 'HF_TOKEN' is not set.") |
|
sys.exit(1) |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
intents = discord.Intents.default() |
|
intents.message_content = True |
|
|
|
|
|
bot = commands.Bot(command_prefix='!', intents=intents, help_command=None) |
|
|
|
|
|
PROMPT_REGEX = re.compile(r'prompt\s*=\s*"(.*?)"', re.DOTALL) |
|
|
|
|
|
GRADIO_CLIENT = Client("Nevaehni/FLUX.1-schnell", hf_token=HF_TOKEN) |
|
|
|
|
|
hf_api = HfApi() |
|
|
|
@bot.event |
|
async def on_ready(): |
|
"""Event handler triggered when the bot is ready.""" |
|
logger.info(f'Logged in as {bot.user} (ID: {bot.user.id})') |
|
logger.info('------') |
|
|
|
def parse_prompt(command: str) -> str: |
|
""" |
|
Parse the prompt from the command string. |
|
|
|
Args: |
|
command (str): The command message content. |
|
|
|
Returns: |
|
str: The extracted prompt or an empty string if not found. |
|
""" |
|
match = PROMPT_REGEX.search(command) |
|
if match: |
|
return match.group(1).strip() |
|
return '' |
|
|
|
def create_example_embed() -> Embed: |
|
""" |
|
Create an embed message with an example !generate command. |
|
|
|
Returns: |
|
Embed: The Discord embed object. |
|
""" |
|
|
|
example_command = '!generate prompt="High resolution serene landscape\nwith text \'cucolina\'. seed:1"' |
|
embed = Embed( |
|
description=f"```\n{example_command}\n```", |
|
color=discord.Color.blue() |
|
) |
|
return embed |
|
|
|
@bot.command(name='generate') |
|
async def generate(ctx: commands.Context, *, args: str = None): |
|
""" |
|
Command handler for !generate. Generates content based on the provided prompt. |
|
|
|
Args: |
|
ctx (commands.Context): The context in which the command was invoked. |
|
args (str, optional): The arguments passed with the command. |
|
""" |
|
if not args: |
|
|
|
embed = create_example_embed() |
|
await ctx.send(embed=embed) |
|
return |
|
|
|
|
|
prompt = parse_prompt(args) |
|
|
|
if not prompt: |
|
|
|
await ctx.send("β **Error:** Prompt cannot be empty. Please provide a valid input.") |
|
return |
|
|
|
|
|
processing_message = await ctx.send("π Generating your content, please wait...") |
|
|
|
try: |
|
logger.info(f"Received prompt: {prompt}") |
|
|
|
|
|
response = await asyncio.wait_for( |
|
asyncio.to_thread(GRADIO_CLIENT.predict, param_0=prompt, api_name="/predict"), |
|
timeout=60 |
|
) |
|
|
|
logger.info(f"API response: {response}") |
|
|
|
|
|
logger.debug(f"API Response: {response}") |
|
|
|
|
|
command_used = ctx.message.content.strip() |
|
|
|
|
|
if isinstance(response, dict): |
|
|
|
url = response.get('url') |
|
if url: |
|
if isinstance(url, str): |
|
|
|
if url.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp')): |
|
embed = Embed(title="π¨ Generated Image", color=discord.Color.green()) |
|
embed.set_image(url=url) |
|
|
|
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False) |
|
|
|
await ctx.send(content=f"{ctx.author.mention}", embed=embed) |
|
else: |
|
|
|
embed = Embed(title="π Generated Content", description=url, color=discord.Color.green()) |
|
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False) |
|
await ctx.send(content=f"{ctx.author.mention}", embed=embed) |
|
else: |
|
|
|
await ctx.send("β **Error:** Received an invalid URL format from the API.") |
|
else: |
|
|
|
await ctx.send("β **Error:** The API response does not contain a 'url' key.") |
|
elif isinstance(response, str): |
|
|
|
file_path = response |
|
if os.path.isfile(file_path): |
|
try: |
|
|
|
file_name = os.path.basename(file_path) |
|
|
|
|
|
discord_file = discord.File(file_path, filename=file_name) |
|
|
|
|
|
embed = Embed(title="π¨ Generated Image", color=discord.Color.green()) |
|
embed.set_image(url=f"attachment://{file_name}") |
|
|
|
|
|
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False) |
|
|
|
|
|
await ctx.send(content=f"{ctx.author.mention}", embed=embed, file=discord_file) |
|
|
|
logger.info(f"Sent image from {file_path} to Discord.") |
|
except Exception as e: |
|
logger.error(f"Failed to send image to Discord: {e}") |
|
await ctx.send("β **Error:** Failed to send the generated image to Discord.") |
|
else: |
|
await ctx.send("β **Error:** The API returned an invalid file path.") |
|
elif isinstance(response, list): |
|
|
|
if len(response) > 0 and isinstance(response[0], dict): |
|
first_item = response[0] |
|
url = first_item.get('url') |
|
if url and isinstance(url, str): |
|
if url.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp')): |
|
embed = Embed(title="π¨ Generated Image", color=discord.Color.green()) |
|
embed.set_image(url=url) |
|
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False) |
|
await ctx.send(content=f"{ctx.author.mention}", embed=embed) |
|
else: |
|
embed = Embed(title="π Generated Content", description=url, color=discord.Color.green()) |
|
embed.add_field(name="Prompt", value=f"```\n{command_used}\n```", inline=False) |
|
await ctx.send(content=f"{ctx.author.mention}", embed=embed) |
|
else: |
|
await ctx.send("β **Error:** Received an invalid URL format from the API.") |
|
else: |
|
await ctx.send("β **Error:** The API response is an unexpected list structure.") |
|
else: |
|
|
|
await ctx.send("β **Error:** Unexpected response format from the API.") |
|
except asyncio.TimeoutError: |
|
logger.error("API request timed out.") |
|
await ctx.send("β° **Error:** The request to the API timed out. Please try again later.") |
|
except AppError as e: |
|
logger.error(f"API Error: {str(e)}") |
|
await ctx.send(f"β **API Error:** {str(e)}") |
|
except requests.exceptions.ConnectionError: |
|
logger.error("Failed to connect to the API.") |
|
await ctx.send("β οΈ **Error:** Failed to connect to the API. Please check your network connection.") |
|
except Exception as e: |
|
logger.exception("An unexpected error occurred.") |
|
await ctx.send(f"β **Error:** An unexpected error occurred: {str(e)}") |
|
finally: |
|
|
|
await processing_message.delete() |
|
|
|
@bot.event |
|
async def on_command_error(ctx: commands.Context, error): |
|
""" |
|
Global error handler for command errors. |
|
|
|
Args: |
|
ctx (commands.Context): The context in which the error occurred. |
|
error (Exception): The exception that was raised. |
|
""" |
|
if isinstance(error, commands.CommandNotFound): |
|
await ctx.send("β **Error:** Unknown command. Please use `!generate` to generate content.") |
|
elif isinstance(error, commands.CommandOnCooldown): |
|
await ctx.send(f"β³ **Please wait {error.retry_after:.2f} seconds before using this command again.**") |
|
else: |
|
await ctx.send(f"β **Error:** {str(error)}") |
|
logger.error(f"Unhandled command error: {str(error)}") |
|
|
|
async def handle_root(request): |
|
return web.Response(text="DarkMuse GOES VROOOOM", status=200) |
|
|
|
async def start_web_server(): |
|
app = web.Application() |
|
app.router.add_get('/', handle_root) |
|
runner = web.AppRunner(app) |
|
await runner.setup() |
|
site = web.TCPSite(runner, '0.0.0.0', 7860) |
|
await site.start() |
|
logger.info("Web server started on port 7860") |
|
|
|
async def start_bot(): |
|
await bot.start(DISCORD_BOT_TOKEN) |
|
|
|
async def schedule_restart(): |
|
""" |
|
Schedule the bot to restart the Hugging Face Space at a specified interval. |
|
For testing purposes, this function restarts the Space every minute. |
|
""" |
|
space_id = "Nevaehni/FLUX.1-schnell" |
|
restart_interval_seconds = 60 |
|
|
|
while True: |
|
logger.info(f"Scheduled space restart in {restart_interval_seconds} seconds.") |
|
await asyncio.sleep(restart_interval_seconds) |
|
try: |
|
logger.info("Attempting to restart the Hugging Face Space...") |
|
hf_api.restart_space(space_id=space_id, token=HF_TOKEN) |
|
logger.info("Space restarted successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to restart space: {e}") |
|
|
|
await asyncio.sleep(5) |
|
|
|
async def main(): |
|
|
|
asyncio.create_task(schedule_restart()) |
|
|
|
await asyncio.gather( |
|
start_bot(), |
|
start_web_server() |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
try: |
|
asyncio.run(main()) |
|
except Exception as e: |
|
logger.exception(f"Failed to run the bot: {e}") |
|
|