|
import os |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
system_message = """ |
|
You are a prompt creation assistant for FLUX, an AI image generation model. Your mission is to help the user craft a detailed and optimized prompt by following these steps: |
|
|
|
1. **Understanding the User's Needs**: |
|
- The user provides a basic idea, concept, or description. |
|
- Analyze their input to determine essential details and nuances. |
|
|
|
2. **Enhancing Details**: |
|
- Enrich the basic idea with vivid, specific, and descriptive elements. |
|
- Include factors such as lighting, mood, style, perspective, and specific objects or elements the user wants in the scene. |
|
|
|
3. **Formatting the Prompt**: |
|
- Structure the enriched description into a clear, precise, and effective prompt. |
|
- Ensure the prompt is tailored for high-quality output from the FLUX model, considering its strengths (e.g., photorealistic details, fine anatomy, or artistic styles). |
|
|
|
Use this process to compose a detailed and coherent prompt. Ensure the final prompt is clear and complete, and write your response in English. |
|
|
|
Ensure that the final part is a synthesized version of the prompt. |
|
""" |
|
|
|
def load_model_and_tokenizer(model_name): |
|
""" |
|
Load the model and tokenizer using Hugging Face's Auto classes. |
|
Args: |
|
model_name (str): Hugging Face model name. |
|
token (str): Hugging Face token. |
|
Returns: |
|
tuple: model, tokenizer, device |
|
""" |
|
try: |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
logger.info("Adding pad_token to the tokenizer.") |
|
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
return model, tokenizer, device |
|
except Exception as e: |
|
logger.error(f"Error loading model or tokenizer: {e}") |
|
raise |
|
|
|
def chatbot_fn(prompt, chatbot_history=[]): |
|
""" |
|
Chatbot function to handle user prompts and generate responses. |
|
Args: |
|
prompt (str): User input prompt. |
|
chatbot_history (list): History of the conversation. |
|
Returns: |
|
tuple: Assistant's response, updated conversation history. |
|
""" |
|
if not prompt.strip(): |
|
return "Please enter a valid prompt.", chatbot_history |
|
|
|
try: |
|
|
|
if not chatbot_history: |
|
chatbot_history.append({"role": "system", "content": system_message}) |
|
|
|
|
|
conversation = [item['content'] for item in chatbot_history] |
|
input_text = "\n".join(conversation) + f"\nUser: {prompt}\nAssistant:" |
|
|
|
|
|
inputs = tokenizer( |
|
input_text, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=1024, |
|
padding=True |
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, max_new_tokens=2000, pad_token_id=tokenizer.pad_token_id |
|
) |
|
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
assistant_response = response_text.split("Assistant:")[-1].strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {e}") |
|
return f"An error occurred: {e}", chatbot_history |
|
|
|
|
|
chatbot_history.append({"role": "user", "content": prompt}) |
|
chatbot_history.append({"role": "assistant", "content": assistant_response}) |
|
|
|
return assistant_response, chatbot_history |
|
|
|
|
|
try: |
|
model_name = "VincentGOURBIN/Llama-3.2-3B-Fluxed" |
|
model, tokenizer, device = load_model_and_tokenizer(model_name) |
|
except Exception as e: |
|
logger.critical("Failed to initialize the model. Exiting.") |
|
raise |
|
|
|
|
|
iface = gr.Interface( |
|
fn=chatbot_fn, |
|
inputs=["text", "state"], |
|
outputs=["text", "state"], |
|
title="Prompt Crafting Assistant for FLUX", |
|
description=( |
|
"This assistant helps you create detailed and optimized prompts for FLUX, " |
|
"an AI image generation model. Provide a basic idea, and it will enhance it " |
|
"with vivid details for high-quality results." |
|
), |
|
allow_flagging="never", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |
|
|