Spaces:
Build error
Build error
File size: 1,917 Bytes
0874cc4 9fb6c68 2af0bcd 0874cc4 394be61 2af0bcd 0874cc4 2af0bcd 0874cc4 9fb6c68 0874cc4 2af0bcd 0874cc4 2af0bcd 0874cc4 2af0bcd 0874cc4 4c72205 0874cc4 2af0bcd 0874cc4 2af0bcd 0874cc4 2af0bcd 0874cc4 2af0bcd 0874cc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Load the model and tokenizer
model_name = "mrcuddle/SD-Prompter"
logging.info(f"Loading model and tokenizer for {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
logging.info("Model and tokenizer loaded successfully")
@spaces.GPU
# Function to generate a response
def chat(message, history):
logging.info(f"Received message: {message}")
logging.info(f"Chat history: {history}")
# Combine the message and history into a single input
input_text = " ".join([f"{user}: {msg}" for user, msg in history] + [f"User: {message}"])
logging.info(f"Input text: {input_text}")
inputs = tokenizer(input_text, return_tensors="pt")
logging.info(f"Tokenized input: {inputs}")
# Generate a response
with torch.no_grad():
outputs = model.generate(inputs.input_ids, max_length=300, num_return_sequences=1)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
logging.info(f"Generated response: {response}")
# Extract only the new response part
response = response.replace(input_text, "").strip()
logging.info(f"Extracted response: {response}")
# Append the new message and response to the history
history.append(("User", message))
history.append(("Assistant", response))
logging.info(f"Updated chat history: {history}")
return history, history
# Create the Gradio chat interface
iface = gr.ChatInterface(
fn=chat,
title="Llama3.2 1B Stable Diffusion Prompter",
description="Generate Stable Diffusion Prompt with Llama3.2"
)
# Launch the interface
logging.info("Launching Gradio interface")
iface.launch() |