mrcuddle commited on
Commit
0874cc4
·
verified ·
1 Parent(s): d5e5fe2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Load the model and tokenizer
6
+ model_name = "mrcuddle/SD-Prompter-1B"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+
10
+ # Function to generate a response
11
+ def chat(message, history):
12
+ # Combine the message and history into a single input
13
+ input_text = " ".join([f"{user}: {msg}" for user, msg in history] + [f"User: {message}"])
14
+ inputs = tokenizer(input_text, return_tensors="pt")
15
+
16
+ # Generate a response
17
+ with torch.no_grad():
18
+ outputs = model.generate(inputs.input_ids, max_length=50, num_return_sequences=1)
19
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
+
21
+ # Extract only the new response part
22
+ response = response.replace(input_text, "").strip()
23
+
24
+ # Append the new message and response to the history
25
+ history.append(("User", message))
26
+ history.append(("Assistant", response))
27
+
28
+ return history, history
29
+
30
+ # Create the Gradio chat interface
31
+ iface = gr.ChatInterface(
32
+ fn=chat,
33
+ title="Llama3.2 1B Stable Diffusion Prompter",
34
+ description="Generate Stable Diffusion Prompt with Llama3.2"
35
+ )
36
+
37
+ # Launch the interface
38
+ iface.launch()