Sd-Prompter / app.py
mrcuddle's picture
Update app.py
2af0bcd verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Load the model and tokenizer
model_name = "mrcuddle/SD-Prompter"
logging.info(f"Loading model and tokenizer for {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
logging.info("Model and tokenizer loaded successfully")
@spaces.GPU
# Function to generate a response
def chat(message, history):
logging.info(f"Received message: {message}")
logging.info(f"Chat history: {history}")
# Combine the message and history into a single input
input_text = " ".join([f"{user}: {msg}" for user, msg in history] + [f"User: {message}"])
logging.info(f"Input text: {input_text}")
inputs = tokenizer(input_text, return_tensors="pt")
logging.info(f"Tokenized input: {inputs}")
# Generate a response
with torch.no_grad():
outputs = model.generate(inputs.input_ids, max_length=300, num_return_sequences=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.info(f"Generated response: {response}")
# Extract only the new response part
response = response.replace(input_text, "").strip()
logging.info(f"Extracted response: {response}")
# Append the new message and response to the history
history.append(("User", message))
history.append(("Assistant", response))
logging.info(f"Updated chat history: {history}")
return history, history
# Create the Gradio chat interface
iface = gr.ChatInterface(
fn=chat,
title="Llama3.2 1B Stable Diffusion Prompter",
description="Generate Stable Diffusion Prompt with Llama3.2"
)
# Launch the interface
logging.info("Launching Gradio interface")
iface.launch()