Sd-Prompter / app.py
mrcuddle's picture
Create app.py
0874cc4 verified
raw
history blame
1.26 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the model and tokenizer
model_name = "mrcuddle/SD-Prompter-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Function to generate a response
def chat(message, history):
# Combine the message and history into a single input
input_text = " ".join([f"{user}: {msg}" for user, msg in history] + [f"User: {message}"])
inputs = tokenizer(input_text, return_tensors="pt")
# Generate a response
with torch.no_grad():
outputs = model.generate(inputs.input_ids, max_length=50, num_return_sequences=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new response part
response = response.replace(input_text, "").strip()
# Append the new message and response to the history
history.append(("User", message))
history.append(("Assistant", response))
return history, history
# Create the Gradio chat interface
iface = gr.ChatInterface(
fn=chat,
title="Llama3.2 1B Stable Diffusion Prompter",
description="Generate Stable Diffusion Prompt with Llama3.2"
)
# Launch the interface
iface.launch()