Spaces:
Running
on
Zero
Running
on
Zero
yeq6x
commited on
Commit
•
c9cc441
0
Parent(s):
init
Browse files- .gitignore +5 -0
- Dockerfile.backend +35 -0
- anime.py +152 -0
- app.py +187 -0
- data.py +97 -0
- generate_prompt.py +154 -0
- lineart_util.py +109 -0
- model.py +186 -0
- process_utils.py +345 -0
- requirements.txt +19 -0
- templates/index.html +60 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models/
|
2 |
+
__pycache__/
|
3 |
+
venv/
|
4 |
+
output/
|
5 |
+
hf_gradio/
|
Dockerfile.backend
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA 12.1ベースのUbuntuイメージを使用
|
2 |
+
FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
|
3 |
+
|
4 |
+
RUN ln -sf /usr/share/zoneinfo/Asia/Tokyo /etc/localtime
|
5 |
+
|
6 |
+
# 必要なパッケージをインストール
|
7 |
+
RUN apt-get update && apt-get install -y \
|
8 |
+
software-properties-common \
|
9 |
+
&& add-apt-repository ppa:deadsnakes/ppa \
|
10 |
+
&& apt-get update && apt-get install -y \
|
11 |
+
python3.10 \
|
12 |
+
python3.10-dev \
|
13 |
+
python3.10-distutils \
|
14 |
+
wget \
|
15 |
+
&& rm -rf /var/lib/apt/lists/*
|
16 |
+
# pipのインストール
|
17 |
+
RUN wget https://bootstrap.pypa.io/get-pip.py \
|
18 |
+
&& python3.10 get-pip.py \
|
19 |
+
&& rm get-pip.py
|
20 |
+
# デフォルトのpythonとpipコマンドをpython3.10とpip3.10にリンク
|
21 |
+
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 \
|
22 |
+
&& update-alternatives --install /usr/bin/pip pip /usr/local/bin/pip3.10 1
|
23 |
+
|
24 |
+
WORKDIR /app
|
25 |
+
|
26 |
+
# 依存関係をインストール
|
27 |
+
COPY requirements.txt /app/
|
28 |
+
RUN apt -y update && apt -y upgrade
|
29 |
+
RUN apt -y install libopencv-dev
|
30 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
31 |
+
RUN pip install --no-dependencies transformers
|
32 |
+
|
33 |
+
EXPOSE 5000
|
34 |
+
|
35 |
+
CMD ["python", "app.py"]
|
anime.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test script for anime-to-sketch translation
|
2 |
+
Example:
|
3 |
+
python3 test.py --dataroot /your_path/dir --load_size 512
|
4 |
+
python3 test.py --dataroot /your_path/img.jpg --load_size 512
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
from torchvision import transforms
|
10 |
+
from data import get_image_list, get_transform
|
11 |
+
from model import create_model
|
12 |
+
from data import read_img_path, tensor_to_img, save_image
|
13 |
+
import argparse
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
from kornia.enhance import equalize_clahe
|
16 |
+
from PIL import Image
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
|
20 |
+
# numpy配列の画像を受け取り、線画を生成してnumpy配列で返す
|
21 |
+
def generate_sketch(image, clahe_clip=-1, load_size=512):
|
22 |
+
"""
|
23 |
+
Generate sketch image from input image
|
24 |
+
Args:
|
25 |
+
image (np.ndarray): input image
|
26 |
+
clahe_clip (float): clip threshold for CLAHE
|
27 |
+
load_size (int): image size to load
|
28 |
+
Returns:
|
29 |
+
np.ndarray: output image
|
30 |
+
"""
|
31 |
+
# create model
|
32 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
33 |
+
model_opt = "default"
|
34 |
+
model = create_model(model_opt).to(device)
|
35 |
+
model.eval()
|
36 |
+
|
37 |
+
aus_resize = None
|
38 |
+
if load_size > 0:
|
39 |
+
aus_resize = (image.shape[0], image.shape[1])
|
40 |
+
transform = get_transform(load_size=load_size)
|
41 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float()
|
42 |
+
# [0,255] to [-1,1]
|
43 |
+
image = transform(image)
|
44 |
+
if image.max() > 1:
|
45 |
+
image = (image-image.min())/(image.max()-image.min())*2-1
|
46 |
+
|
47 |
+
img, aus_resize = image.unsqueeze(0), aus_resize
|
48 |
+
if clahe_clip > 0:
|
49 |
+
img = (img + 1) / 2 # [-1,1] to [0,1]
|
50 |
+
img = equalize_clahe(img, clip_limit=clahe_clip)
|
51 |
+
img = (img - .5) / .5 # [0,1] to [-1,1]
|
52 |
+
|
53 |
+
aus_tensor = model(img.to(device))
|
54 |
+
|
55 |
+
# resize to original size
|
56 |
+
if aus_resize is not None:
|
57 |
+
aus_tensor = torch.nn.functional.interpolate(aus_tensor, aus_resize, mode='bilinear', align_corners=False)
|
58 |
+
|
59 |
+
aus_img = tensor_to_img(aus_tensor)
|
60 |
+
return aus_img
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
os.chdir(os.path.dirname("Anime2Sketch/"))
|
65 |
+
parser = argparse.ArgumentParser(description='Anime-to-sketch test options.')
|
66 |
+
parser.add_argument('--dataroot','-i', default='test_samples/', type=str)
|
67 |
+
parser.add_argument('--load_size','-s', default=512, type=int)
|
68 |
+
parser.add_argument('--output_dir','-o', default='results/', type=str)
|
69 |
+
parser.add_argument('--gpu_ids', '-g', default=[], help="gpu ids: e.g. 0 0,1,2 0,2.")
|
70 |
+
parser.add_argument('--model', default="default", type=str, help="variant of model to use. you can choose from ['default','improved']")
|
71 |
+
parser.add_argument('--clahe_clip', default=-1, type=float, help="clip threshold for CLAHE set to -1 to disable")
|
72 |
+
opt = parser.parse_args()
|
73 |
+
|
74 |
+
# # generate sketchで線画生成
|
75 |
+
# for test_path in tqdm(get_image_list(opt.dataroot)):
|
76 |
+
# basename = os.path.basename(test_path)
|
77 |
+
# aus_path = os.path.join(opt.output_dir, basename)
|
78 |
+
# # numpy配列で画像を読み込む
|
79 |
+
# img = Image.open(test_path)
|
80 |
+
# img = np.array(img)
|
81 |
+
# aus_img = generate_sketch(img, opt.clahe_clip)
|
82 |
+
# # 画像を保存
|
83 |
+
# save_image(aus_img, aus_path, (512, 512))
|
84 |
+
|
85 |
+
|
86 |
+
# create model
|
87 |
+
gpu_list = ','.join(str(x) for x in opt.gpu_ids)
|
88 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
89 |
+
model = create_model(opt.model).to(device) # create a model given opt.model and other options
|
90 |
+
model.eval()
|
91 |
+
|
92 |
+
for test_path in tqdm(get_image_list(opt.dataroot)):
|
93 |
+
basename = os.path.basename(test_path)
|
94 |
+
aus_path = os.path.join(opt.output_dir, basename)
|
95 |
+
|
96 |
+
img = Image.open(test_path).convert('RGB')
|
97 |
+
img = np.array(img)
|
98 |
+
|
99 |
+
load_size = 512
|
100 |
+
aus_resize = None
|
101 |
+
if load_size > 0:
|
102 |
+
aus_resize = (img.shape[1], img.shape[0])
|
103 |
+
transform = get_transform(load_size=load_size)
|
104 |
+
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
105 |
+
# [0,255] to [-1,1]
|
106 |
+
image = transform(img)
|
107 |
+
if image.max() > 1:
|
108 |
+
image = (image-image.min())/(image.max()-image.min())*2-1
|
109 |
+
print(image.min(), image.max())
|
110 |
+
|
111 |
+
img, aus_resize = image.unsqueeze(0), aus_resize
|
112 |
+
if opt.clahe_clip > 0:
|
113 |
+
img = (img + 1) / 2 # [-1,1] to [0,1]
|
114 |
+
img = equalize_clahe(img, clip_limit=opt.clahe_clip)
|
115 |
+
img = (img - .5) / .5 # [0,1] to [-1,1]
|
116 |
+
|
117 |
+
aus_tensor = model(img.to(device))
|
118 |
+
aus_img = tensor_to_img(aus_tensor)
|
119 |
+
save_image(aus_img, aus_path, aus_resize)
|
120 |
+
"""
|
121 |
+
# create model
|
122 |
+
gpu_list = ','.join(str(x) for x in opt.gpu_ids)
|
123 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
124 |
+
device = torch.device('cuda' if len(opt.gpu_ids)>0 else 'cpu')
|
125 |
+
model = create_model(opt.model).to(device) # create a model given opt.model and other options
|
126 |
+
model.eval()
|
127 |
+
# get input data
|
128 |
+
if os.path.isdir(opt.dataroot):
|
129 |
+
test_list = get_image_list(opt.dataroot)
|
130 |
+
elif os.path.isfile(opt.dataroot):
|
131 |
+
test_list = [opt.dataroot]
|
132 |
+
else:
|
133 |
+
raise Exception("{} is not a valid directory or image file.".format(opt.dataroot))
|
134 |
+
# save outputs
|
135 |
+
save_dir = opt.output_dir
|
136 |
+
os.makedirs(save_dir, exist_ok=True)
|
137 |
+
|
138 |
+
for test_path in tqdm(test_list):
|
139 |
+
basename = os.path.basename(test_path)
|
140 |
+
aus_path = os.path.join(save_dir, basename)
|
141 |
+
img, aus_resize = read_img_path(test_path, opt.load_size)
|
142 |
+
|
143 |
+
if opt.clahe_clip > 0:
|
144 |
+
img = (img + 1) / 2 # [-1,1] to [0,1]
|
145 |
+
img = equalize_clahe(img, clip_limit=opt.clahe_clip)
|
146 |
+
img = (img - .5) / .5 # [0,1] to [-1,1]
|
147 |
+
|
148 |
+
aus_tensor = model(img.to(device))
|
149 |
+
print(aus_tensor.shape)
|
150 |
+
aus_img = tensor_to_img(aus_tensor)
|
151 |
+
save_image(aus_img, aus_path, aus_resize)
|
152 |
+
"""
|
app.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, render_template, send_file, jsonify, send_from_directory
|
2 |
+
from flask_socketio import SocketIO, emit
|
3 |
+
from flask_cors import CORS
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
import gc
|
9 |
+
from peft import PeftModel
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import threading
|
13 |
+
import uuid
|
14 |
+
import concurrent.futures
|
15 |
+
from process_utils import *
|
16 |
+
|
17 |
+
app = Flask(__name__)
|
18 |
+
# app.secret_key = 'super_secret_key'
|
19 |
+
CORS(app)
|
20 |
+
socketio = SocketIO(app, cors_allowed_origins="*")
|
21 |
+
|
22 |
+
# タスクキューの作成
|
23 |
+
task_queue = queue.Queue()
|
24 |
+
active_tasks = {}
|
25 |
+
task_futures = {}
|
26 |
+
|
27 |
+
# ThreadPoolExecutorの作成
|
28 |
+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
29 |
+
|
30 |
+
class Task:
|
31 |
+
def __init__(self, task_id, mode, weight1, weight2, file_data):
|
32 |
+
self.task_id = task_id
|
33 |
+
self.mode = mode
|
34 |
+
self.weight1 = weight1
|
35 |
+
self.weight2 = weight2
|
36 |
+
self.file_data = file_data
|
37 |
+
self.cancel_flag = False
|
38 |
+
|
39 |
+
def update_queue_status(message=None):
|
40 |
+
socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
|
41 |
+
|
42 |
+
def process_task(task):
|
43 |
+
try:
|
44 |
+
# ファイルデータをPIL Imageに変換
|
45 |
+
image = Image.open(io.BytesIO(task.file_data))
|
46 |
+
image = ensure_rgb(image)
|
47 |
+
|
48 |
+
# キャンセルチェック
|
49 |
+
if task.cancel_flag:
|
50 |
+
return
|
51 |
+
|
52 |
+
# 画像処理ロジックを呼び出す
|
53 |
+
sotai_image, sketch_image = process_image_as_base64(image, task.mode, task.weight1, task.weight2)
|
54 |
+
|
55 |
+
# キャンセルチェック
|
56 |
+
if task.cancel_flag:
|
57 |
+
return
|
58 |
+
|
59 |
+
socketio.emit('task_complete', {
|
60 |
+
'task_id': task.task_id,
|
61 |
+
'sotai_image': sotai_image,
|
62 |
+
'sketch_image': sketch_image
|
63 |
+
})
|
64 |
+
except Exception as e:
|
65 |
+
if not task.cancel_flag:
|
66 |
+
socketio.emit('task_error', {'task_id': task.task_id, 'error': str(e)})
|
67 |
+
finally:
|
68 |
+
if task.task_id in active_tasks:
|
69 |
+
del active_tasks[task.task_id]
|
70 |
+
if task.task_id in task_futures:
|
71 |
+
del task_futures[task.task_id]
|
72 |
+
update_queue_status('Task completed or cancelled')
|
73 |
+
|
74 |
+
def worker():
|
75 |
+
while True:
|
76 |
+
try:
|
77 |
+
task = task_queue.get()
|
78 |
+
if task.task_id in active_tasks:
|
79 |
+
future = executor.submit(process_task, task)
|
80 |
+
task_futures[task.task_id] = future
|
81 |
+
update_queue_status(f'Task started: {task.task_id}')
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Worker error: {str(e)}")
|
84 |
+
finally:
|
85 |
+
# Ensure the task is always removed from the queue
|
86 |
+
task_queue.task_done()
|
87 |
+
|
88 |
+
# ワーカースレッドの開始
|
89 |
+
threading.Thread(target=worker, daemon=True).start()
|
90 |
+
|
91 |
+
@app.route('/submit_task', methods=['POST'])
|
92 |
+
def submit_task():
|
93 |
+
task_id = str(uuid.uuid4())
|
94 |
+
file = request.files['file']
|
95 |
+
mode = request.form.get('mode', 'refine')
|
96 |
+
weight1 = float(request.form.get('weight1', 0.4))
|
97 |
+
weight2 = float(request.form.get('weight2', 0.3))
|
98 |
+
|
99 |
+
# ファイルデータをバイト列として保存
|
100 |
+
file_data = file.read()
|
101 |
+
|
102 |
+
task = Task(task_id, mode, weight1, weight2, file_data)
|
103 |
+
task_queue.put(task)
|
104 |
+
active_tasks[task_id] = task
|
105 |
+
|
106 |
+
update_queue_status(f'Task submitted: {task_id}')
|
107 |
+
|
108 |
+
queue_size = task_queue.qsize()
|
109 |
+
return jsonify({'task_id': task_id, 'queue_size': queue_size})
|
110 |
+
|
111 |
+
@app.route('/cancel_task/<task_id>', methods=['POST'])
|
112 |
+
def cancel_task(task_id):
|
113 |
+
if task_id in active_tasks:
|
114 |
+
task = active_tasks[task_id]
|
115 |
+
task.cancel_flag = True
|
116 |
+
if task_id in task_futures:
|
117 |
+
task_futures[task_id].cancel()
|
118 |
+
del task_futures[task_id]
|
119 |
+
del active_tasks[task_id]
|
120 |
+
update_queue_status('Task cancelled')
|
121 |
+
return jsonify({'message': 'Task cancellation requested'})
|
122 |
+
else:
|
123 |
+
return jsonify({'message': 'Task not found or already completed'}), 404
|
124 |
+
|
125 |
+
def get_active_task_order(task_id):
|
126 |
+
return list(active_tasks.keys()).index(task_id) if task_id in active_tasks else None
|
127 |
+
|
128 |
+
# get_task_orderイベントハンドラー
|
129 |
+
@app.route('/get_task_order/<task_id>', methods=['GET'])
|
130 |
+
def handle_get_task_order(task_id):
|
131 |
+
task_order = get_active_task_order(task_id)
|
132 |
+
return jsonify({'task_order': task_order})
|
133 |
+
|
134 |
+
@socketio.on('connect')
|
135 |
+
def handle_connect():
|
136 |
+
emit('queue_update', {'active_tasks': len(active_tasks), 'active_task_order': None})
|
137 |
+
|
138 |
+
# Flaskルート
|
139 |
+
@app.route('/', methods=['GET', 'POST'])
|
140 |
+
def process_refined():
|
141 |
+
if request.method == 'POST':
|
142 |
+
file = request.files['file']
|
143 |
+
weight1 = float(request.form.get('weight1', 0.4))
|
144 |
+
weight2 = float(request.form.get('weight2', 0.3))
|
145 |
+
|
146 |
+
image = ensure_rgb(Image.open(file.stream))
|
147 |
+
sotai_image, sketch_image = process_image_as_base64(image, "refine", weight1, weight2)
|
148 |
+
|
149 |
+
return jsonify({
|
150 |
+
'sotai_image': sotai_image,
|
151 |
+
'sketch_image': sketch_image
|
152 |
+
})
|
153 |
+
|
154 |
+
@app.route('/process_original', methods=['GET', 'POST'])
|
155 |
+
def process_original():
|
156 |
+
if request.method == 'POST':
|
157 |
+
file = request.files['file']
|
158 |
+
|
159 |
+
image = ensure_rgb(Image.open(file.stream))
|
160 |
+
sotai_image, sketch_image = process_image_as_base64(image, "original")
|
161 |
+
|
162 |
+
return jsonify({
|
163 |
+
'sotai_image': sotai_image,
|
164 |
+
'sketch_image': sketch_image
|
165 |
+
})
|
166 |
+
|
167 |
+
@app.route('/process_sketch', methods=['GET', 'POST'])
|
168 |
+
def process_sketch():
|
169 |
+
if request.method == 'POST':
|
170 |
+
file = request.files['file']
|
171 |
+
|
172 |
+
image = ensure_rgb(Image.open(file.stream))
|
173 |
+
sotai_image, sketch_image = process_image_as_base64(image, "sketch")
|
174 |
+
|
175 |
+
return jsonify({
|
176 |
+
'sotai_image': sotai_image,
|
177 |
+
'sketch_image': sketch_image
|
178 |
+
})
|
179 |
+
|
180 |
+
# エラーハンドラー
|
181 |
+
@app.errorhandler(500)
|
182 |
+
def server_error(e):
|
183 |
+
return jsonify(error=str(e)), 500
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
initialize(local_model=True)
|
187 |
+
socketio.run(app, debug=True, host='0.0.0.0', port=5000)
|
data.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
try:
|
5 |
+
from transforms import InterpolationMode
|
6 |
+
bic = InterpolationMode.BICUBIC
|
7 |
+
except ImportError:
|
8 |
+
bic = Image.BICUBIC
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
|
14 |
+
|
15 |
+
def is_image_file(filename):
|
16 |
+
"""if a given filename is a valid image
|
17 |
+
Parameters:
|
18 |
+
filename (str) -- image filename
|
19 |
+
"""
|
20 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
21 |
+
|
22 |
+
def get_image_list(path):
|
23 |
+
"""read the paths of valid images from the given directory path
|
24 |
+
Parameters:
|
25 |
+
path (str) -- input directory path
|
26 |
+
"""
|
27 |
+
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
28 |
+
images = []
|
29 |
+
for dirpath, _, fnames in sorted(os.walk(path)):
|
30 |
+
for fname in sorted(fnames):
|
31 |
+
if is_image_file(fname):
|
32 |
+
img_path = os.path.join(dirpath, fname)
|
33 |
+
images.append(img_path)
|
34 |
+
assert images, '{:s} has no valid image file'.format(path)
|
35 |
+
return images
|
36 |
+
|
37 |
+
def get_transform(load_size=0, grayscale=False, method=bic, convert=True):
|
38 |
+
transform_list = []
|
39 |
+
if grayscale:
|
40 |
+
transform_list.append(transforms.Grayscale(1))
|
41 |
+
if load_size > 0:
|
42 |
+
osize = [load_size, load_size]
|
43 |
+
transform_list.append(transforms.Resize(osize, method))
|
44 |
+
if convert:
|
45 |
+
# transform_list += [transforms.ToTensor()]
|
46 |
+
if grayscale:
|
47 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
48 |
+
else:
|
49 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
50 |
+
return transforms.Compose(transform_list)
|
51 |
+
|
52 |
+
def read_img_path(path, load_size):
|
53 |
+
"""read tensors from a given image path
|
54 |
+
Parameters:
|
55 |
+
path (str) -- input image path
|
56 |
+
load_size(int) -- the input size. If <= 0, don't resize
|
57 |
+
"""
|
58 |
+
img = Image.open(path).convert('RGB')
|
59 |
+
aus_resize = None
|
60 |
+
if load_size > 0:
|
61 |
+
aus_resize = img.size
|
62 |
+
transform = get_transform(load_size=load_size)
|
63 |
+
image = transform(img)
|
64 |
+
return image.unsqueeze(0), aus_resize
|
65 |
+
|
66 |
+
def tensor_to_img(input_image, imtype=np.uint8):
|
67 |
+
""""Converts a Tensor array into a numpy image array.
|
68 |
+
Parameters:
|
69 |
+
input_image (tensor) -- the input image tensor array
|
70 |
+
imtype (type) -- the desired type of the converted numpy array
|
71 |
+
"""
|
72 |
+
|
73 |
+
if not isinstance(input_image, np.ndarray):
|
74 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
75 |
+
image_tensor = input_image.data
|
76 |
+
else:
|
77 |
+
return input_image
|
78 |
+
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
79 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
80 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
81 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
82 |
+
else: # if it is a numpy array, do nothing
|
83 |
+
image_numpy = input_image
|
84 |
+
return image_numpy.astype(imtype)
|
85 |
+
|
86 |
+
def save_image(image_numpy, image_path, output_resize=None):
|
87 |
+
"""Save a numpy image to the disk
|
88 |
+
Parameters:
|
89 |
+
image_numpy (numpy array) -- input numpy array
|
90 |
+
image_path (str) -- the path of the image
|
91 |
+
output_resize(None or tuple) -- the output size. If None, don't resize
|
92 |
+
"""
|
93 |
+
|
94 |
+
image_pil = Image.fromarray(image_numpy)
|
95 |
+
if output_resize:
|
96 |
+
image_pil = image_pil.resize(output_resize, bic)
|
97 |
+
image_pil.save(image_path)
|
generate_prompt.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import csv
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from tensorflow.keras.layers import TFSMLayer
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
# from wd14 tagger
|
14 |
+
IMAGE_SIZE = 448
|
15 |
+
|
16 |
+
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
17 |
+
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
18 |
+
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
19 |
+
SUB_DIR = "variables"
|
20 |
+
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
21 |
+
CSV_FILE = FILES[-1]
|
22 |
+
|
23 |
+
def preprocess_image(image):
|
24 |
+
image = np.array(image)
|
25 |
+
image = image[:, :, ::-1] # RGB->BGR
|
26 |
+
|
27 |
+
# pad to square
|
28 |
+
size = max(image.shape[0:2])
|
29 |
+
pad_x = size - image.shape[1]
|
30 |
+
pad_y = size - image.shape[0]
|
31 |
+
pad_l = pad_x // 2
|
32 |
+
pad_t = pad_y // 2
|
33 |
+
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
34 |
+
|
35 |
+
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
|
36 |
+
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
|
37 |
+
|
38 |
+
image = image.astype(np.float32)
|
39 |
+
return image
|
40 |
+
|
41 |
+
|
42 |
+
def load_wd14_tagger_model():
|
43 |
+
model_dir = "wd14_tagger_model"
|
44 |
+
repo_id = DEFAULT_WD14_TAGGER_REPO
|
45 |
+
|
46 |
+
if not os.path.exists(model_dir):
|
47 |
+
print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}")
|
48 |
+
for file in FILES:
|
49 |
+
hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
|
50 |
+
for file in SUB_DIR_FILES:
|
51 |
+
hf_hub_download(
|
52 |
+
repo_id,
|
53 |
+
file,
|
54 |
+
subfolder=SUB_DIR,
|
55 |
+
cache_dir=os.path.join(model_dir, SUB_DIR),
|
56 |
+
force_download=True,
|
57 |
+
force_filename=file,
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
print("using existing wd14 tagger model")
|
61 |
+
|
62 |
+
# モデルを読み込む
|
63 |
+
model = TFSMLayer(model_dir, call_endpoint='serving_default')
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def generate_tags(images, model_dir, model):
|
68 |
+
with open(os.path.join(model_dir, CSV_FILE), "r", encoding="utf-8") as f:
|
69 |
+
reader = csv.reader(f)
|
70 |
+
l = [row for row in reader]
|
71 |
+
header = l[0] # tag_id,name,category,count
|
72 |
+
rows = l[1:]
|
73 |
+
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
|
74 |
+
|
75 |
+
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
|
76 |
+
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
|
77 |
+
|
78 |
+
tag_freq = {}
|
79 |
+
undesired_tags = ['one-piece_swimsuit',
|
80 |
+
'swimsuit',
|
81 |
+
'leotard',
|
82 |
+
'saitama_(one-punch_man)',
|
83 |
+
'1boy',
|
84 |
+
]
|
85 |
+
|
86 |
+
probs = model(images, training=False)
|
87 |
+
probs = probs['predictions_sigmoid'].numpy()
|
88 |
+
|
89 |
+
tag_text_list = []
|
90 |
+
for prob in probs:
|
91 |
+
combined_tags = []
|
92 |
+
general_tag_text = ""
|
93 |
+
character_tag_text = ""
|
94 |
+
thresh = 0.35
|
95 |
+
for i, p in enumerate(prob[4:]):
|
96 |
+
if i < len(general_tags) and p >= thresh:
|
97 |
+
tag_name = general_tags[i]
|
98 |
+
if tag_name not in undesired_tags:
|
99 |
+
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
100 |
+
general_tag_text += ", " + tag_name
|
101 |
+
combined_tags.append(tag_name)
|
102 |
+
elif i >= len(general_tags) and p >= thresh:
|
103 |
+
tag_name = character_tags[i - len(general_tags)]
|
104 |
+
if tag_name not in undesired_tags:
|
105 |
+
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
106 |
+
character_tag_text += ", " + tag_name
|
107 |
+
combined_tags.append(tag_name)
|
108 |
+
|
109 |
+
if len(general_tag_text) > 0:
|
110 |
+
general_tag_text = general_tag_text[2:]
|
111 |
+
if len(character_tag_text) > 0:
|
112 |
+
character_tag_text = character_tag_text[2:]
|
113 |
+
|
114 |
+
tag_text = ", ".join(combined_tags)
|
115 |
+
tag_text_list.append(tag_text)
|
116 |
+
return tag_text_list
|
117 |
+
|
118 |
+
|
119 |
+
def generate_prompt_json(target_folder, prompt_file, model_dir, model):
|
120 |
+
image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
|
121 |
+
image_count = len(image_files)
|
122 |
+
|
123 |
+
prompt_list = []
|
124 |
+
|
125 |
+
for i, filename in enumerate(image_files, 1):
|
126 |
+
source_path = "source/" + filename
|
127 |
+
target_path = os.path.join(target_folder, filename) # Use absolute path
|
128 |
+
target_path2 = "target/" + filename
|
129 |
+
|
130 |
+
prompt = generate_tags(target_path, model_dir, model)
|
131 |
+
|
132 |
+
for j in range(4):
|
133 |
+
prompt_data = {
|
134 |
+
"source": f"{source_path.split('.')[0]}_{j}.jpg",
|
135 |
+
"target": f"{target_path2.split('.')[0]}_{j}.jpg",
|
136 |
+
"prompt": prompt
|
137 |
+
}
|
138 |
+
|
139 |
+
prompt_list.append(prompt_data)
|
140 |
+
|
141 |
+
print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)
|
142 |
+
|
143 |
+
with open(prompt_file, "w") as file:
|
144 |
+
for prompt_data in prompt_list:
|
145 |
+
json.dump(prompt_data, file)
|
146 |
+
file.write("\n")
|
147 |
+
|
148 |
+
print(f"Processing completed. Total Images: {image_count}")
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
model_dir = "wd14_tagger_model"
|
153 |
+
model = load_wd14_tagger_model()
|
154 |
+
prompt = generate_tags(target_path, model_dir, model)
|
lineart_util.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from anime import generate_sketch
|
5 |
+
|
6 |
+
def pad64(x):
|
7 |
+
return int(np.ceil(float(x) / 64.0) * 64 - x)
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
def safer_memory(x):
|
28 |
+
# Fix many MAC/AMD problems
|
29 |
+
return np.ascontiguousarray(x.copy()).copy()
|
30 |
+
|
31 |
+
def resize_image_with_pad(input_image, resolution, skip_hwc3=False):
|
32 |
+
if skip_hwc3:
|
33 |
+
img = input_image
|
34 |
+
else:
|
35 |
+
img = HWC3(input_image)
|
36 |
+
H_raw, W_raw, _ = img.shape
|
37 |
+
k = float(resolution) / float(min(H_raw, W_raw))
|
38 |
+
interpolation = cv2.INTER_CUBIC if k > 1 else cv2.INTER_AREA
|
39 |
+
H_target = int(np.round(float(H_raw) * k))
|
40 |
+
W_target = int(np.round(float(W_raw) * k))
|
41 |
+
img = cv2.resize(img, (W_target, H_target), interpolation=interpolation)
|
42 |
+
H_pad, W_pad = pad64(H_target), pad64(W_target)
|
43 |
+
img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
|
44 |
+
|
45 |
+
def remove_pad(x):
|
46 |
+
return safer_memory(x[:H_target, :W_target])
|
47 |
+
|
48 |
+
return safer_memory(img_padded), remove_pad
|
49 |
+
|
50 |
+
def scribble_xdog(img, res=512, thr_a=32, **kwargs):
|
51 |
+
"""
|
52 |
+
XDoGを使ってスケッチ画像を生成する
|
53 |
+
:param img: np.ndarray, 入力画像
|
54 |
+
:param res: int, 出力画像の解像度
|
55 |
+
:param thr_a: int, 閾値
|
56 |
+
|
57 |
+
Returns
|
58 |
+
-------
|
59 |
+
Image : PIL.Image
|
60 |
+
"""
|
61 |
+
img, remove_pad = resize_image_with_pad(img, res)
|
62 |
+
g1 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 0.5)
|
63 |
+
g2 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 5.0)
|
64 |
+
dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8)
|
65 |
+
result = np.zeros_like(img, dtype=np.uint8)
|
66 |
+
result[2 * (255 - dog) > thr_a] = 255
|
67 |
+
result = Image.fromarray(remove_pad(result))
|
68 |
+
return result, True
|
69 |
+
|
70 |
+
def canny(img, res=512, thr_a=100, thr_b=200, **kwargs):
|
71 |
+
img, remove_pad = resize_image_with_pad(img, res)
|
72 |
+
result = cv2.Canny(img, thr_a, thr_b)
|
73 |
+
result = Image.fromarray(remove_pad(result))
|
74 |
+
return result, True
|
75 |
+
|
76 |
+
def get_sketch(image, method='scribble_xdog', res=2048, thr=20, **kwargs):
|
77 |
+
# image: np.ndarray
|
78 |
+
input_height = image.shape[0]
|
79 |
+
input_width = image.shape[1]
|
80 |
+
|
81 |
+
if method == 'scribble_xdog':
|
82 |
+
processed_image, _ = scribble_xdog(image, res, thr) # PIL.Image
|
83 |
+
processed_image = processed_image.resize((input_width, input_height))
|
84 |
+
# make PIL.Image to cv2 and INVERSE
|
85 |
+
processed_image = cv2.cvtColor(np.array(processed_image), cv2.COLOR_RGB2BGR)
|
86 |
+
processed_image = 255 - processed_image
|
87 |
+
processed_image = Image.fromarray(processed_image)
|
88 |
+
elif method == 'anime2sketch':
|
89 |
+
clahe = 1.0
|
90 |
+
processed_image = generate_sketch(image, clahe_clip=clahe, load_size=1024) # output: numpy.ndarray
|
91 |
+
processed_image = Image.fromarray(processed_image)
|
92 |
+
# processed_image.save(output_path.split('.')[0] + f'_{clahe}.png')
|
93 |
+
elif method == 'both':
|
94 |
+
alpha = 0.5
|
95 |
+
# 2枚をalphaの重みで合成
|
96 |
+
scribble_xdog_processed_image, _ = scribble_xdog(image, res, thr)
|
97 |
+
scribble_xdog_processed_image = scribble_xdog_processed_image.resize((input_width, input_height))
|
98 |
+
scribble_xdog_processed_image = cv2.cvtColor(np.array(scribble_xdog_processed_image), cv2.COLOR_RGB2BGR)
|
99 |
+
scribble_xdog_processed_image = 255 - scribble_xdog_processed_image
|
100 |
+
|
101 |
+
anime2sketch_processed_image = generate_sketch(image, clahe_clip=1.0, load_size=1024)
|
102 |
+
anime2sketch_processed_image = Image.fromarray(anime2sketch_processed_image)
|
103 |
+
anime2sketch_processed_image = anime2sketch_processed_image.resize((input_width, input_height))
|
104 |
+
anime2sketch_processed_image = cv2.cvtColor(np.array(anime2sketch_processed_image), cv2.COLOR_RGB2BGR)
|
105 |
+
|
106 |
+
processed_image = cv2.addWeighted(scribble_xdog_processed_image, alpha, anime2sketch_processed_image, 1-alpha, 0)
|
107 |
+
processed_image = Image.fromarray(processed_image)
|
108 |
+
|
109 |
+
return processed_image
|
model.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
from app import download_file
|
6 |
+
|
7 |
+
class UnetGenerator(nn.Module):
|
8 |
+
"""Create a Unet-based generator"""
|
9 |
+
|
10 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
11 |
+
"""Construct a Unet generator
|
12 |
+
Parameters:
|
13 |
+
input_nc (int) -- the number of channels in input images
|
14 |
+
output_nc (int) -- the number of channels in output images
|
15 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
16 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
17 |
+
ngf (int) -- the number of filters in the last conv layer
|
18 |
+
norm_layer -- normalization layer
|
19 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
20 |
+
It is a recursive process.
|
21 |
+
"""
|
22 |
+
super(UnetGenerator, self).__init__()
|
23 |
+
# construct unet structure
|
24 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
25 |
+
for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
26 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
27 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
28 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
29 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
30 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
31 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
32 |
+
|
33 |
+
def forward(self, input):
|
34 |
+
"""Standard forward"""
|
35 |
+
return self.model(input)
|
36 |
+
|
37 |
+
class UnetSkipConnectionBlock(nn.Module):
|
38 |
+
"""Defines the Unet submodule with skip connection.
|
39 |
+
X -------------------identity----------------------
|
40 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
44 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
45 |
+
"""Construct a Unet submodule with skip connections.
|
46 |
+
Parameters:
|
47 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
48 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
49 |
+
input_nc (int) -- the number of channels in input images/features
|
50 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
51 |
+
outermost (bool) -- if this module is the outermost module
|
52 |
+
innermost (bool) -- if this module is the innermost module
|
53 |
+
norm_layer -- normalization layer
|
54 |
+
use_dropout (bool) -- if use dropout layers.
|
55 |
+
"""
|
56 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
57 |
+
self.outermost = outermost
|
58 |
+
if type(norm_layer) == functools.partial:
|
59 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
60 |
+
else:
|
61 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
62 |
+
if input_nc is None:
|
63 |
+
input_nc = outer_nc
|
64 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
65 |
+
stride=2, padding=1, bias=use_bias)
|
66 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
67 |
+
downnorm = norm_layer(inner_nc)
|
68 |
+
uprelu = nn.ReLU(True)
|
69 |
+
upnorm = norm_layer(outer_nc)
|
70 |
+
|
71 |
+
if outermost:
|
72 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
73 |
+
kernel_size=4, stride=2,
|
74 |
+
padding=1)
|
75 |
+
down = [downconv]
|
76 |
+
up = [uprelu, upconv, nn.Tanh()]
|
77 |
+
model = down + [submodule] + up
|
78 |
+
elif innermost:
|
79 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
80 |
+
kernel_size=4, stride=2,
|
81 |
+
padding=1, bias=use_bias)
|
82 |
+
down = [downrelu, downconv]
|
83 |
+
up = [uprelu, upconv, upnorm]
|
84 |
+
model = down + up
|
85 |
+
else:
|
86 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
87 |
+
kernel_size=4, stride=2,
|
88 |
+
padding=1, bias=use_bias)
|
89 |
+
down = [downrelu, downconv, downnorm]
|
90 |
+
up = [uprelu, upconv, upnorm]
|
91 |
+
|
92 |
+
if use_dropout:
|
93 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
94 |
+
else:
|
95 |
+
model = down + [submodule] + up
|
96 |
+
|
97 |
+
self.model = nn.Sequential(*model)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
if self.outermost:
|
101 |
+
return self.model(x)
|
102 |
+
else: # add skip connections
|
103 |
+
return torch.cat([x, self.model(x)], 1)
|
104 |
+
|
105 |
+
|
106 |
+
class Smooth(nn.Module):
|
107 |
+
def __init__(self):
|
108 |
+
super().__init__()
|
109 |
+
kernel = [
|
110 |
+
[1, 2, 1],
|
111 |
+
[2, 4, 2],
|
112 |
+
[1, 2, 1]
|
113 |
+
]
|
114 |
+
kernel = torch.tensor([[kernel]], dtype=torch.float)
|
115 |
+
kernel /= kernel.sum()
|
116 |
+
self.register_buffer('kernel', kernel)
|
117 |
+
self.pad = nn.ReplicationPad2d(1)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
b, c, h, w = x.shape
|
121 |
+
x = x.view(-1, 1, h, w)
|
122 |
+
x = self.pad(x)
|
123 |
+
x = F.conv2d(x, self.kernel)
|
124 |
+
return x.view(b, c, h, w)
|
125 |
+
|
126 |
+
|
127 |
+
class Upsample(nn.Module):
|
128 |
+
def __init__(self, inc, outc, scale_factor=2):
|
129 |
+
super().__init__()
|
130 |
+
self.scale_factor = scale_factor
|
131 |
+
self.up = nn.Upsample(scale_factor=scale_factor, mode='bilinear')
|
132 |
+
self.smooth = Smooth()
|
133 |
+
self.conv = nn.Conv2d(inc, outc, kernel_size=3, stride=1, padding=1)
|
134 |
+
self.mlp = nn.Sequential(
|
135 |
+
nn.Conv2d(outc, 4 * outc, kernel_size=1, stride=1, padding=0),
|
136 |
+
nn.GELU(),
|
137 |
+
nn.Conv2d(4 * outc, outc, kernel_size=1, stride=1, padding=0),
|
138 |
+
)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
x = self.smooth(self.up(x))
|
142 |
+
x = self.conv(x)
|
143 |
+
x = self.mlp(x) + x
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
def create_model(model):
|
148 |
+
"""Create a model for anime2sketch
|
149 |
+
hardcoding the options for simplicity
|
150 |
+
"""
|
151 |
+
|
152 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
153 |
+
net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
154 |
+
|
155 |
+
import os
|
156 |
+
cwd = os.getcwd() # 現在のディレクトリを保存
|
157 |
+
os.chdir(os.path.dirname(__file__)) # このファイルのディレクトリに移動
|
158 |
+
if model == 'default':
|
159 |
+
model_path = download_file("netG.pth", subfolder="models/Anime2Sketch")
|
160 |
+
ckpt = torch.load(model_path)
|
161 |
+
for key in list(ckpt.keys()):
|
162 |
+
if 'module.' in key:
|
163 |
+
ckpt[key.replace('module.', '')] = ckpt[key]
|
164 |
+
del ckpt[key]
|
165 |
+
net.load_state_dict(ckpt)
|
166 |
+
|
167 |
+
os.chdir(cwd) # 元のディレクトリに戻る
|
168 |
+
|
169 |
+
elif model == 'improved':
|
170 |
+
ckpt = torch.load('weights/improved.bin', map_location=torch.device('cpu'))
|
171 |
+
base = net.model.model[1]
|
172 |
+
|
173 |
+
# swap deconvolution layers with reszie + conv layers for 2x upsampling
|
174 |
+
for _ in range(6):
|
175 |
+
inc, outc = base.model[5].in_channels, base.model[5].out_channels
|
176 |
+
base.model[5] = Upsample(inc, outc)
|
177 |
+
base = base.model[3]
|
178 |
+
|
179 |
+
net.load_state_dict(ckpt)
|
180 |
+
|
181 |
+
os.chdir(cwd) # 元のディレクトリに戻る
|
182 |
+
|
183 |
+
else:
|
184 |
+
raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")
|
185 |
+
|
186 |
+
return net
|
process_utils.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import base64
|
4 |
+
from PIL import Image
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from generate_prompt import load_wd14_tagger_model, generate_tags, preprocess_image as wd14_preprocess_image
|
8 |
+
from lineart_util import scribble_xdog, get_sketch, canny
|
9 |
+
import torch
|
10 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, AutoencoderKL
|
11 |
+
import gc
|
12 |
+
from peft import PeftModel
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
# グローバル変数
|
19 |
+
local_model = False
|
20 |
+
model = None
|
21 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
device = "cpu"
|
23 |
+
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
24 |
+
sotai_gen_pipe = None
|
25 |
+
refine_gen_pipe = None
|
26 |
+
|
27 |
+
def download_file(filename, subfolder=None):
|
28 |
+
return hf_hub_download(
|
29 |
+
repo_id=os.environ['REPO_ID'],
|
30 |
+
filename=filename,
|
31 |
+
subfolder=subfolder,
|
32 |
+
token=os.environ['HF_TOKEN'],
|
33 |
+
cache_dir=os.environ['CACHE_DIR']
|
34 |
+
)
|
35 |
+
|
36 |
+
def get_file_path(filename, subfolder=None):
|
37 |
+
if local_model:
|
38 |
+
return os.path.join(subfolder, filename)
|
39 |
+
else:
|
40 |
+
return download_file(filename, subfolder)
|
41 |
+
|
42 |
+
def ensure_rgb(image):
|
43 |
+
if image.mode != 'RGB':
|
44 |
+
return image.convert('RGB')
|
45 |
+
return image
|
46 |
+
|
47 |
+
def initialize(_local_model=False):
|
48 |
+
global model, sotai_gen_pipe, refine_gen_pipe, local_model
|
49 |
+
|
50 |
+
local_model = _local_model
|
51 |
+
model = load_wd14_tagger_model()
|
52 |
+
sotai_gen_pipe = initialize_sotai_model()
|
53 |
+
refine_gen_pipe = initialize_refine_model()
|
54 |
+
|
55 |
+
def load_lora(pipeline, lora_path, alpha=0.75):
|
56 |
+
pipeline.load_lora_weights(lora_path)
|
57 |
+
pipeline.fuse_lora(lora_scale=alpha)
|
58 |
+
|
59 |
+
def initialize_sotai_model():
|
60 |
+
global device, torch_dtype
|
61 |
+
|
62 |
+
sotai_sd_model_path = get_file_path(os.environ["sotai_sd_model_name"], subfolder=os.environ["sd_models_dir"])
|
63 |
+
controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
|
64 |
+
controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
|
65 |
+
|
66 |
+
# Load the Stable Diffusion model
|
67 |
+
sd_pipe = StableDiffusionPipeline.from_single_file(
|
68 |
+
sotai_sd_model_path,
|
69 |
+
torch_dtype=torch_dtype,
|
70 |
+
use_safetensors=True
|
71 |
+
).to(device)
|
72 |
+
|
73 |
+
# Load the ControlNet model
|
74 |
+
controlnet1 = ControlNetModel.from_single_file(
|
75 |
+
controlnet_path1,
|
76 |
+
torch_dtype=torch_dtype
|
77 |
+
).to(device)
|
78 |
+
|
79 |
+
# Load the ControlNet model
|
80 |
+
controlnet2 = ControlNetModel.from_single_file(
|
81 |
+
controlnet_path2,
|
82 |
+
torch_dtype=torch_dtype
|
83 |
+
).to(device)
|
84 |
+
|
85 |
+
# Create the ControlNet pipeline
|
86 |
+
sotai_gen_pipe = StableDiffusionControlNetPipeline(
|
87 |
+
vae=sd_pipe.vae,
|
88 |
+
text_encoder=sd_pipe.text_encoder,
|
89 |
+
tokenizer=sd_pipe.tokenizer,
|
90 |
+
unet=sd_pipe.unet,
|
91 |
+
scheduler=sd_pipe.scheduler,
|
92 |
+
safety_checker=sd_pipe.safety_checker,
|
93 |
+
feature_extractor=sd_pipe.feature_extractor,
|
94 |
+
controlnet=[controlnet1, controlnet2]
|
95 |
+
).to(device)
|
96 |
+
|
97 |
+
# LoRAの適用
|
98 |
+
lora_names = [
|
99 |
+
(os.environ["lora_name1"], 1.0),
|
100 |
+
# (os.environ["lora_name2"], 0.3),
|
101 |
+
]
|
102 |
+
|
103 |
+
for lora_name, alpha in lora_names:
|
104 |
+
lora_path = get_file_path(lora_name, subfolder=os.environ["lora_dir"])
|
105 |
+
load_lora(sotai_gen_pipe, lora_path, alpha)
|
106 |
+
|
107 |
+
# スケジューラーの設定
|
108 |
+
sotai_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(sotai_gen_pipe.scheduler.config)
|
109 |
+
|
110 |
+
return sotai_gen_pipe
|
111 |
+
|
112 |
+
def initialize_refine_model():
|
113 |
+
global device, torch_dtype
|
114 |
+
|
115 |
+
refine_sd_model_path = get_file_path(os.environ["refine_sd_model_name"], subfolder=os.environ["sd_models_dir"])
|
116 |
+
controlnet_path3 = get_file_path(os.environ["controlnet_name3"], subfolder=os.environ["controlnet_dir1"])
|
117 |
+
controlnet_path4 = get_file_path(os.environ["controlnet_name4"], subfolder=os.environ["controlnet_dir1"])
|
118 |
+
vae_path = get_file_path(os.environ["vae_name"], subfolder=os.environ["vae_dir"])
|
119 |
+
|
120 |
+
# Load the Stable Diffusion model
|
121 |
+
sd_pipe = StableDiffusionPipeline.from_single_file(
|
122 |
+
refine_sd_model_path,
|
123 |
+
torch_dtype=torch_dtype,
|
124 |
+
use_safetensors=True
|
125 |
+
).to(device)
|
126 |
+
|
127 |
+
# controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
|
128 |
+
controlnet1 = ControlNetModel.from_single_file(
|
129 |
+
controlnet_path3,
|
130 |
+
torch_dtype=torch_dtype
|
131 |
+
).to(device)
|
132 |
+
|
133 |
+
# Load the ControlNet model
|
134 |
+
controlnet2 = ControlNetModel.from_single_file(
|
135 |
+
controlnet_path4,
|
136 |
+
torch_dtype=torch_dtype
|
137 |
+
).to(device)
|
138 |
+
|
139 |
+
# Create the ControlNet pipeline
|
140 |
+
refine_gen_pipe = StableDiffusionControlNetPipeline(
|
141 |
+
vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype).to(device),
|
142 |
+
text_encoder=sd_pipe.text_encoder,
|
143 |
+
tokenizer=sd_pipe.tokenizer,
|
144 |
+
unet=sd_pipe.unet,
|
145 |
+
scheduler=sd_pipe.scheduler,
|
146 |
+
safety_checker=sd_pipe.safety_checker,
|
147 |
+
feature_extractor=sd_pipe.feature_extractor,
|
148 |
+
controlnet=[controlnet1, controlnet2], # 複数のControlNetを指定
|
149 |
+
).to(device)
|
150 |
+
|
151 |
+
# スケジューラーの設定
|
152 |
+
refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)
|
153 |
+
|
154 |
+
return refine_gen_pipe
|
155 |
+
|
156 |
+
def get_wd_tags(images: list) -> list:
|
157 |
+
global model
|
158 |
+
if model is None:
|
159 |
+
initialize()
|
160 |
+
preprocessed_images = [wd14_preprocess_image(img) for img in images]
|
161 |
+
preprocessed_images = np.array(preprocessed_images)
|
162 |
+
return generate_tags(preprocessed_images, os.environ["wd_model_name"], model)
|
163 |
+
|
164 |
+
def preprocess_image_for_generation(image):
|
165 |
+
if isinstance(image, str): # base64文字列の場合
|
166 |
+
image = Image.open(io.BytesIO(base64.b64decode(image)))
|
167 |
+
elif isinstance(image, np.ndarray): # numpy配列の場合
|
168 |
+
image = Image.fromarray(image)
|
169 |
+
elif not isinstance(image, Image.Image):
|
170 |
+
raise ValueError("Unsupported image type")
|
171 |
+
|
172 |
+
# 画像サイズの計算
|
173 |
+
input_width, input_height = image.size
|
174 |
+
max_size = 736
|
175 |
+
output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
|
176 |
+
output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)
|
177 |
+
|
178 |
+
image = image.resize((output_width, output_height))
|
179 |
+
return image, output_width, output_height
|
180 |
+
|
181 |
+
def binarize_image(image: Image.Image) -> np.ndarray:
|
182 |
+
image = np.array(image.convert('L'))
|
183 |
+
# 色反転
|
184 |
+
image = 255 - image
|
185 |
+
|
186 |
+
# ヒストグラム平坦化
|
187 |
+
clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8, 8))
|
188 |
+
image = clahe.apply(image)
|
189 |
+
|
190 |
+
# ガウシアンブラー適用
|
191 |
+
image = cv2.GaussianBlur(image, (5, 5), 0)
|
192 |
+
|
193 |
+
# 適応的二値化
|
194 |
+
binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 9, -8)
|
195 |
+
|
196 |
+
return binary_image
|
197 |
+
|
198 |
+
def create_rgba_image(binary_image: np.ndarray, color: list) -> Image.Image:
|
199 |
+
rgba_image = np.zeros((binary_image.shape[0], binary_image.shape[1], 4), dtype=np.uint8)
|
200 |
+
rgba_image[:, :, 0] = color[0]
|
201 |
+
rgba_image[:, :, 1] = color[1]
|
202 |
+
rgba_image[:, :, 2] = color[2]
|
203 |
+
rgba_image[:, :, 3] = binary_image
|
204 |
+
return Image.fromarray(rgba_image, 'RGBA')
|
205 |
+
|
206 |
+
def generate_sotai_image(input_image: Image.Image, output_width: int, output_height: int) -> Image.Image:
|
207 |
+
input_image = ensure_rgb(input_image)
|
208 |
+
global sotai_gen_pipe
|
209 |
+
if sotai_gen_pipe is None:
|
210 |
+
initialize()
|
211 |
+
|
212 |
+
prompt = "anime pose, girl, (white background:1.5), (monochrome:1.5), full body, sketch, eyes, breasts, (slim legs, skinny legs:1.2)"
|
213 |
+
try:
|
214 |
+
# 入力画像のリサイズ
|
215 |
+
if input_image.size[0] > input_image.size[1]:
|
216 |
+
input_image = input_image.resize((512, int(512 * input_image.size[1] / input_image.size[0])))
|
217 |
+
else:
|
218 |
+
input_image = input_image.resize((int(512 * input_image.size[0] / input_image.size[1]), 512))
|
219 |
+
|
220 |
+
# EasyNegativeV2の内容
|
221 |
+
easy_negative_v2 = "(worst quality, low quality, normal quality:1.4), lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry, artist name, (bad_prompt_version2:0.8)"
|
222 |
+
|
223 |
+
output = sotai_gen_pipe(
|
224 |
+
prompt,
|
225 |
+
image=[input_image, input_image],
|
226 |
+
negative_prompt=f"(wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
|
227 |
+
# negative_prompt=f"{easy_negative_v2}, (wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
|
228 |
+
num_inference_steps=40,
|
229 |
+
guidance_scale=8,
|
230 |
+
width=output_width,
|
231 |
+
height=output_height,
|
232 |
+
denoising_strength=0.13,
|
233 |
+
num_images_per_prompt=1, # Equivalent to batch_size
|
234 |
+
guess_mode=[True, True], # Equivalent to pixel_perfect
|
235 |
+
controlnet_conditioning_scale=[1.2, 1.3], # 各ControlNetの重み
|
236 |
+
guidance_start=[0.0, 0.0],
|
237 |
+
guidance_end=[1.0, 1.0],
|
238 |
+
)
|
239 |
+
generated_image = output.images[0]
|
240 |
+
|
241 |
+
return generated_image
|
242 |
+
|
243 |
+
finally:
|
244 |
+
# メモリ解放
|
245 |
+
if device == "cuda":
|
246 |
+
torch.cuda.empty_cache()
|
247 |
+
gc.collect()
|
248 |
+
|
249 |
+
def generate_refined_image(prompt: str, original_image: Image.Image, output_width: int, output_height: int, weight1: float, weight2: float) -> Image.Image:
|
250 |
+
original_image = ensure_rgb(original_image)
|
251 |
+
global refine_gen_pipe
|
252 |
+
if refine_gen_pipe is None:
|
253 |
+
initialize()
|
254 |
+
|
255 |
+
try:
|
256 |
+
original_image_np = np.array(original_image)
|
257 |
+
# scribble_xdog
|
258 |
+
scribble_image, _ = scribble_xdog(original_image_np, 2048, 20)
|
259 |
+
|
260 |
+
original_image = original_image.resize((output_width, output_height))
|
261 |
+
output = refine_gen_pipe(
|
262 |
+
prompt,
|
263 |
+
image=[scribble_image, original_image], # 2つのControlNetに対応する入力画像
|
264 |
+
negative_prompt="extra limb, monochrome, black and white",
|
265 |
+
num_inference_steps=20,
|
266 |
+
width=output_width,
|
267 |
+
height=output_height,
|
268 |
+
controlnet_conditioning_scale=[weight1, weight2], # 各ControlNetの重み
|
269 |
+
control_guidance_start=[0.0, 0.0],
|
270 |
+
control_guidance_end=[1.0, 1.0],
|
271 |
+
guess_mode=[False, False], # pixel_perfect
|
272 |
+
)
|
273 |
+
generated_image = output.images[0]
|
274 |
+
|
275 |
+
return generated_image
|
276 |
+
|
277 |
+
finally:
|
278 |
+
# メモリ解放
|
279 |
+
if device == "cuda":
|
280 |
+
torch.cuda.empty_cache()
|
281 |
+
gc.collect()
|
282 |
+
|
283 |
+
def process_image(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
|
284 |
+
input_image = ensure_rgb(input_image)
|
285 |
+
# サイズを取得
|
286 |
+
input_width, input_height = input_image.size
|
287 |
+
max_size = 736
|
288 |
+
output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
|
289 |
+
output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)
|
290 |
+
|
291 |
+
if mode == "refine":
|
292 |
+
# WD-14 taggerを使用してプロンプトを生成
|
293 |
+
image_np = np.array(ensure_rgb(input_image))
|
294 |
+
prompt = get_wd_tags([image_np])[0]
|
295 |
+
prompt = f"{prompt}"
|
296 |
+
print(prompt)
|
297 |
+
|
298 |
+
refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
|
299 |
+
refined_image = refined_image.convert('RGB')
|
300 |
+
|
301 |
+
# スケッチ画像を生成
|
302 |
+
refined_image_np = np.array(refined_image)
|
303 |
+
sketch_image = get_sketch(refined_image_np, "both", 2048, 10)
|
304 |
+
sketch_image = sketch_image.resize((output_width, output_height)) # 画像サイズを合わせる
|
305 |
+
# スケッチ画像の二値化
|
306 |
+
sketch_binary = binarize_image(sketch_image)
|
307 |
+
# RGBAに変換(透明なベース画像を作成)して、青い線を設定
|
308 |
+
sketch_image = create_rgba_image(sketch_binary, [0, 0, 255])
|
309 |
+
|
310 |
+
# 素体画像の生成
|
311 |
+
sotai_image = generate_sotai_image(refined_image, output_width, output_height)
|
312 |
+
|
313 |
+
elif mode == "original":
|
314 |
+
sotai_image = generate_sotai_image(input_image, output_width, output_height)
|
315 |
+
|
316 |
+
# スケッチ画像の生成
|
317 |
+
input_image_np = np.array(input_image)
|
318 |
+
sketch_image = get_sketch(input_image_np, "both", 2048, 16)
|
319 |
+
|
320 |
+
elif mode == "sketch":
|
321 |
+
# スケッチ画像の生成
|
322 |
+
input_image_np = np.array(input_image)
|
323 |
+
sketch_image = get_sketch(input_image_np, "both", 2048, 16)
|
324 |
+
|
325 |
+
# 素体画像の生成
|
326 |
+
sotai_image = generate_sotai_image(sketch_image, output_width, output_height)
|
327 |
+
|
328 |
+
else:
|
329 |
+
raise ValueError("Invalid mode")
|
330 |
+
|
331 |
+
# 素体画像の二値化
|
332 |
+
sotai_binary = binarize_image(sotai_image)
|
333 |
+
# RGBAに変換(透明なベース画像を作成)して、赤い線を設定
|
334 |
+
sotai_image = create_rgba_image(sotai_binary, [255, 0, 0])
|
335 |
+
|
336 |
+
return sotai_image, sketch_image
|
337 |
+
|
338 |
+
def image_to_base64(img_array):
|
339 |
+
buffered = io.BytesIO()
|
340 |
+
img_array.save(buffered, format="PNG")
|
341 |
+
return base64.b64encode(buffered.getvalue()).decode()
|
342 |
+
|
343 |
+
def process_image_as_base64(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
|
344 |
+
sotai_image, sketch_image = process_image(input_image, mode, weight1, weight2)
|
345 |
+
return image_to_base64(sotai_image), image_to_base64(sketch_image)
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|
2 |
+
torch==2.2.0
|
3 |
+
torchvision==0.17.0
|
4 |
+
torchaudio==2.2.0
|
5 |
+
diffusers==0.29.1
|
6 |
+
Flask==3.0.3
|
7 |
+
Flask-Cors==4.0.0
|
8 |
+
gradio==4.36.1
|
9 |
+
huggingface_hub==0.23.2
|
10 |
+
kornia==0.7.1
|
11 |
+
numpy==1.23.5
|
12 |
+
opencv-python==4.9.0.80
|
13 |
+
Pillow==10.3.0
|
14 |
+
Requests==2.32.3
|
15 |
+
tensorflow==2.16.1
|
16 |
+
transforms==0.2.1
|
17 |
+
tokenizers
|
18 |
+
pytorch_lightning
|
19 |
+
python-dotenv
|
templates/index.html
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Scribble Image Generator</title>
|
7 |
+
<style>
|
8 |
+
body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
|
9 |
+
.image-container { display: flex; justify-content: space-between; margin-top: 20px; }
|
10 |
+
.image-container img { max-width: 48%; height: auto; }
|
11 |
+
#loading { display: none; }
|
12 |
+
</style>
|
13 |
+
</head>
|
14 |
+
<body>
|
15 |
+
<h1>Scribble Image Generator</h1>
|
16 |
+
<form id="upload-form">
|
17 |
+
<input type="file" id="file-input" accept="image/*" required>
|
18 |
+
<br><br>
|
19 |
+
<label for="threshold">Threshold:</label>
|
20 |
+
<input type="number" id="threshold" name="threshold" value="20" min="1" max="64">
|
21 |
+
<br><br>
|
22 |
+
<label for="processor_res">Processor Resolution:</label>
|
23 |
+
<input type="number" id="processor_res" name="processor_res" value="2048" min="64" max="2048">
|
24 |
+
<br><br>
|
25 |
+
<button type="submit">Generate Scribble</button>
|
26 |
+
</form>
|
27 |
+
<div id="loading">Processing...</div>
|
28 |
+
<div class="image-container">
|
29 |
+
<img id="original-image" alt="Original Image">
|
30 |
+
<img id="scribble-image" alt="Scribble Image">
|
31 |
+
</div>
|
32 |
+
|
33 |
+
<script>
|
34 |
+
document.getElementById('upload-form').addEventListener('submit', function(e) {
|
35 |
+
e.preventDefault();
|
36 |
+
var formData = new FormData();
|
37 |
+
formData.append('file', document.getElementById('file-input').files[0]);
|
38 |
+
formData.append('threshold', document.getElementById('threshold').value);
|
39 |
+
formData.append('processor_res', document.getElementById('processor_res').value);
|
40 |
+
|
41 |
+
document.getElementById('loading').style.display = 'block';
|
42 |
+
|
43 |
+
fetch('/process', {
|
44 |
+
method: 'POST',
|
45 |
+
body: formData
|
46 |
+
})
|
47 |
+
.then(response => response.json())
|
48 |
+
.then(data => {
|
49 |
+
document.getElementById('original-image').src = 'data:image/png;base64,' + data.original_image;
|
50 |
+
document.getElementById('scribble-image').src = 'data:image/png;base64,' + data.scribble_image;
|
51 |
+
document.getElementById('loading').style.display = 'none';
|
52 |
+
})
|
53 |
+
.catch(error => {
|
54 |
+
console.error('Error:', error);
|
55 |
+
document.getElementById('loading').style.display = 'none';
|
56 |
+
});
|
57 |
+
});
|
58 |
+
</script>
|
59 |
+
</body>
|
60 |
+
</html>
|