File size: 4,978 Bytes
8604677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
system_message = """
You are a prompt creation assistant for FLUX, an AI image generation model. Your mission is to help the user craft a detailed and optimized prompt by following these steps:
1. **Understanding the User's Needs**:
- The user provides a basic idea, concept, or description.
- Analyze their input to determine essential details and nuances.
2. **Enhancing Details**:
- Enrich the basic idea with vivid, specific, and descriptive elements.
- Include factors such as lighting, mood, style, perspective, and specific objects or elements the user wants in the scene.
3. **Formatting the Prompt**:
- Structure the enriched description into a clear, precise, and effective prompt.
- Ensure the prompt is tailored for high-quality output from the FLUX model, considering its strengths (e.g., photorealistic details, fine anatomy, or artistic styles).
Use this process to compose a detailed and coherent prompt. Ensure the final prompt is clear and complete, and write your response in English.
Ensure that the final part is a synthesized version of the prompt.
"""
def load_model_and_tokenizer(model_name):
"""
Load the model and tokenizer using Hugging Face's Auto classes.
Args:
model_name (str): Hugging Face model name.
token (str): Hugging Face token.
Returns:
tuple: model, tokenizer, device
"""
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load model and tokenizer using Auto classes
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Set or add padding token
if tokenizer.pad_token is None:
logger.info("Adding pad_token to the tokenizer.")
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model.resize_token_embeddings(len(tokenizer)) # Adjust model embeddings for new token
return model, tokenizer, device
except Exception as e:
logger.error(f"Error loading model or tokenizer: {e}")
raise
def chatbot_fn(prompt, chatbot_history=[]):
"""
Chatbot function to handle user prompts and generate responses.
Args:
prompt (str): User input prompt.
chatbot_history (list): History of the conversation.
Returns:
tuple: Assistant's response, updated conversation history.
"""
if not prompt.strip():
return "Please enter a valid prompt.", chatbot_history
try:
# Initialize conversation with system message if empty
if not chatbot_history:
chatbot_history.append({"role": "system", "content": system_message})
# Build the conversation context
conversation = [item['content'] for item in chatbot_history]
input_text = "\n".join(conversation) + f"\nUser: {prompt}\nAssistant:"
# Tokenize input
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=1024,
padding=True
).to(device)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=2000, pad_token_id=tokenizer.pad_token_id
)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
assistant_response = response_text.split("Assistant:")[-1].strip()
except Exception as e:
logger.error(f"Error generating response: {e}")
return f"An error occurred: {e}", chatbot_history
# Update history
chatbot_history.append({"role": "user", "content": prompt})
chatbot_history.append({"role": "assistant", "content": assistant_response})
return assistant_response, chatbot_history
# Initialize Hugging Face model and tokenizer
try:
model_name = "VincentGOURBIN/Llama-3.2-3B-Fluxed" # Model name
model, tokenizer, device = load_model_and_tokenizer(model_name)
except Exception as e:
logger.critical("Failed to initialize the model. Exiting.")
raise
# Define Gradio interface
iface = gr.Interface(
fn=chatbot_fn,
inputs=["text", "state"],
outputs=["text", "state"],
title="Prompt Crafting Assistant for FLUX",
description=(
"This assistant helps you create detailed and optimized prompts for FLUX, "
"an AI image generation model. Provide a basic idea, and it will enhance it "
"with vivid details for high-quality results."
),
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch(share=True)
|