tori29umai commited on
Commit
5826348
1 Parent(s): 2b32e3d
Files changed (2) hide show
  1. app.py +2 -11
  2. utils/prompt_analysis.py +1 -1
app.py CHANGED
@@ -2,7 +2,6 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
5
- from compel import Compel, ReturnedEmbeddingsType
6
  from PIL import Image
7
  import os
8
  import time
@@ -14,7 +13,6 @@ class Img2Img:
14
  def __init__(self):
15
  self.setup_paths()
16
  self.setup_models()
17
- self.compel = self.setup_compel()
18
  self.demo = self.layout()
19
 
20
  def setup_paths(self):
@@ -46,13 +44,6 @@ class Img2Img:
46
  self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors")
47
  self.pipe = self.pipe.to(self.device)
48
 
49
- def setup_compel(self):
50
- return Compel(
51
- tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
52
- text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
53
- returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
54
- requires_pooled=[False, True],
55
- )
56
 
57
  def layout(self):
58
  css = """
@@ -65,13 +56,13 @@ class Img2Img:
65
  with gr.Blocks(css=css) as demo:
66
  with gr.Row():
67
  with gr.Column():
68
- self.input_image_path = gr.Image(label="入力画像", type='filepath')
69
  self.prompt_analysis = PromptAnalysis(self.tagger_dir)
70
  self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path)
71
  self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
72
  generate_button = gr.Button("生成")
73
  with gr.Column():
74
- self.output_image = gr.Image(type="pil", label="生成画像")
75
 
76
  generate_button.click(
77
  fn=self.predict,
 
2
  import gradio as gr
3
  import torch
4
  from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler
 
5
  from PIL import Image
6
  import os
7
  import time
 
13
  def __init__(self):
14
  self.setup_paths()
15
  self.setup_models()
 
16
  self.demo = self.layout()
17
 
18
  def setup_paths(self):
 
44
  self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors")
45
  self.pipe = self.pipe.to(self.device)
46
 
 
 
 
 
 
 
 
47
 
48
  def layout(self):
49
  css = """
 
56
  with gr.Blocks(css=css) as demo:
57
  with gr.Row():
58
  with gr.Column():
59
+ self.input_image_path = gr.Image(label="input_image", type='filepath')
60
  self.prompt_analysis = PromptAnalysis(self.tagger_dir)
61
  self.prompt, self.negative_prompt = self.prompt_analysis.layout(self.input_image_path)
62
  self.controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="線画忠実度")
63
  generate_button = gr.Button("生成")
64
  with gr.Column():
65
+ self.output_image = gr.Image(type="pil", label="出力画像")
66
 
67
  generate_button.click(
68
  fn=self.predict,
utils/prompt_analysis.py CHANGED
@@ -22,7 +22,7 @@ class PromptAnalysis:
22
  with gr.Row():
23
  self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value=self.default_nagative_prompt)
24
  with gr.Row():
25
- self.prompt_analysis_button = gr.Button()
26
 
27
  self.prompt_analysis_button.click(
28
  self.process_prompt_analysis,
 
22
  with gr.Row():
23
  self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value=self.default_nagative_prompt)
24
  with gr.Row():
25
+ self.prompt_analysis_button = gr.Button("prompt解析")
26
 
27
  self.prompt_analysis_button.click(
28
  self.process_prompt_analysis,