VincentGOURBIN's picture
Upload folder using huggingface_hub
8604677 verified
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
system_message = """
You are a prompt creation assistant for FLUX, an AI image generation model. Your mission is to help the user craft a detailed and optimized prompt by following these steps:
1. **Understanding the User's Needs**:
- The user provides a basic idea, concept, or description.
- Analyze their input to determine essential details and nuances.
2. **Enhancing Details**:
- Enrich the basic idea with vivid, specific, and descriptive elements.
- Include factors such as lighting, mood, style, perspective, and specific objects or elements the user wants in the scene.
3. **Formatting the Prompt**:
- Structure the enriched description into a clear, precise, and effective prompt.
- Ensure the prompt is tailored for high-quality output from the FLUX model, considering its strengths (e.g., photorealistic details, fine anatomy, or artistic styles).
Use this process to compose a detailed and coherent prompt. Ensure the final prompt is clear and complete, and write your response in English.
Ensure that the final part is a synthesized version of the prompt.
"""
def load_model_and_tokenizer(model_name):
"""
Load the model and tokenizer using Hugging Face's Auto classes.
Args:
model_name (str): Hugging Face model name.
token (str): Hugging Face token.
Returns:
tuple: model, tokenizer, device
"""
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load model and tokenizer using Auto classes
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Set or add padding token
if tokenizer.pad_token is None:
logger.info("Adding pad_token to the tokenizer.")
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model.resize_token_embeddings(len(tokenizer)) # Adjust model embeddings for new token
return model, tokenizer, device
except Exception as e:
logger.error(f"Error loading model or tokenizer: {e}")
raise
def chatbot_fn(prompt, chatbot_history=[]):
"""
Chatbot function to handle user prompts and generate responses.
Args:
prompt (str): User input prompt.
chatbot_history (list): History of the conversation.
Returns:
tuple: Assistant's response, updated conversation history.
"""
if not prompt.strip():
return "Please enter a valid prompt.", chatbot_history
try:
# Initialize conversation with system message if empty
if not chatbot_history:
chatbot_history.append({"role": "system", "content": system_message})
# Build the conversation context
conversation = [item['content'] for item in chatbot_history]
input_text = "\n".join(conversation) + f"\nUser: {prompt}\nAssistant:"
# Tokenize input
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=1024,
padding=True
).to(device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=2000, pad_token_id=tokenizer.pad_token_id
)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
assistant_response = response_text.split("Assistant:")[-1].strip()
except Exception as e:
logger.error(f"Error generating response: {e}")
return f"An error occurred: {e}", chatbot_history
# Update history
chatbot_history.append({"role": "user", "content": prompt})
chatbot_history.append({"role": "assistant", "content": assistant_response})
return assistant_response, chatbot_history
# Initialize Hugging Face model and tokenizer
try:
model_name = "VincentGOURBIN/Llama-3.2-3B-Fluxed" # Model name
model, tokenizer, device = load_model_and_tokenizer(model_name)
except Exception as e:
logger.critical("Failed to initialize the model. Exiting.")
raise
# Define Gradio interface
iface = gr.Interface(
fn=chatbot_fn,
inputs=["text", "state"],
outputs=["text", "state"],
title="Prompt Crafting Assistant for FLUX",
description=(
"This assistant helps you create detailed and optimized prompts for FLUX, "
"an AI image generation model. Provide a basic idea, and it will enhance it "
"with vivid details for high-quality results."
),
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch(share=True)