yeq6x commited on
Commit
c9cc441
0 Parent(s):
Files changed (11) hide show
  1. .gitignore +5 -0
  2. Dockerfile.backend +35 -0
  3. anime.py +152 -0
  4. app.py +187 -0
  5. data.py +97 -0
  6. generate_prompt.py +154 -0
  7. lineart_util.py +109 -0
  8. model.py +186 -0
  9. process_utils.py +345 -0
  10. requirements.txt +19 -0
  11. templates/index.html +60 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ models/
2
+ __pycache__/
3
+ venv/
4
+ output/
5
+ hf_gradio/
Dockerfile.backend ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA 12.1ベースのUbuntuイメージを使用
2
+ FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
3
+
4
+ RUN ln -sf /usr/share/zoneinfo/Asia/Tokyo /etc/localtime
5
+
6
+ # 必要なパッケージをインストール
7
+ RUN apt-get update && apt-get install -y \
8
+ software-properties-common \
9
+ && add-apt-repository ppa:deadsnakes/ppa \
10
+ && apt-get update && apt-get install -y \
11
+ python3.10 \
12
+ python3.10-dev \
13
+ python3.10-distutils \
14
+ wget \
15
+ && rm -rf /var/lib/apt/lists/*
16
+ # pipのインストール
17
+ RUN wget https://bootstrap.pypa.io/get-pip.py \
18
+ && python3.10 get-pip.py \
19
+ && rm get-pip.py
20
+ # デフォルトのpythonとpipコマンドをpython3.10とpip3.10にリンク
21
+ RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 \
22
+ && update-alternatives --install /usr/bin/pip pip /usr/local/bin/pip3.10 1
23
+
24
+ WORKDIR /app
25
+
26
+ # 依存関係をインストール
27
+ COPY requirements.txt /app/
28
+ RUN apt -y update && apt -y upgrade
29
+ RUN apt -y install libopencv-dev
30
+ RUN pip install --no-cache-dir -r requirements.txt
31
+ RUN pip install --no-dependencies transformers
32
+
33
+ EXPOSE 5000
34
+
35
+ CMD ["python", "app.py"]
anime.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test script for anime-to-sketch translation
2
+ Example:
3
+ python3 test.py --dataroot /your_path/dir --load_size 512
4
+ python3 test.py --dataroot /your_path/img.jpg --load_size 512
5
+ """
6
+
7
+ import os
8
+ import torch
9
+ from torchvision import transforms
10
+ from data import get_image_list, get_transform
11
+ from model import create_model
12
+ from data import read_img_path, tensor_to_img, save_image
13
+ import argparse
14
+ from tqdm.auto import tqdm
15
+ from kornia.enhance import equalize_clahe
16
+ from PIL import Image
17
+ import numpy as np
18
+
19
+
20
+ # numpy配列の画像を受け取り、線画を生成してnumpy配列で返す
21
+ def generate_sketch(image, clahe_clip=-1, load_size=512):
22
+ """
23
+ Generate sketch image from input image
24
+ Args:
25
+ image (np.ndarray): input image
26
+ clahe_clip (float): clip threshold for CLAHE
27
+ load_size (int): image size to load
28
+ Returns:
29
+ np.ndarray: output image
30
+ """
31
+ # create model
32
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ model_opt = "default"
34
+ model = create_model(model_opt).to(device)
35
+ model.eval()
36
+
37
+ aus_resize = None
38
+ if load_size > 0:
39
+ aus_resize = (image.shape[0], image.shape[1])
40
+ transform = get_transform(load_size=load_size)
41
+ image = torch.from_numpy(image).permute(2, 0, 1).float()
42
+ # [0,255] to [-1,1]
43
+ image = transform(image)
44
+ if image.max() > 1:
45
+ image = (image-image.min())/(image.max()-image.min())*2-1
46
+
47
+ img, aus_resize = image.unsqueeze(0), aus_resize
48
+ if clahe_clip > 0:
49
+ img = (img + 1) / 2 # [-1,1] to [0,1]
50
+ img = equalize_clahe(img, clip_limit=clahe_clip)
51
+ img = (img - .5) / .5 # [0,1] to [-1,1]
52
+
53
+ aus_tensor = model(img.to(device))
54
+
55
+ # resize to original size
56
+ if aus_resize is not None:
57
+ aus_tensor = torch.nn.functional.interpolate(aus_tensor, aus_resize, mode='bilinear', align_corners=False)
58
+
59
+ aus_img = tensor_to_img(aus_tensor)
60
+ return aus_img
61
+
62
+
63
+ if __name__ == '__main__':
64
+ os.chdir(os.path.dirname("Anime2Sketch/"))
65
+ parser = argparse.ArgumentParser(description='Anime-to-sketch test options.')
66
+ parser.add_argument('--dataroot','-i', default='test_samples/', type=str)
67
+ parser.add_argument('--load_size','-s', default=512, type=int)
68
+ parser.add_argument('--output_dir','-o', default='results/', type=str)
69
+ parser.add_argument('--gpu_ids', '-g', default=[], help="gpu ids: e.g. 0 0,1,2 0,2.")
70
+ parser.add_argument('--model', default="default", type=str, help="variant of model to use. you can choose from ['default','improved']")
71
+ parser.add_argument('--clahe_clip', default=-1, type=float, help="clip threshold for CLAHE set to -1 to disable")
72
+ opt = parser.parse_args()
73
+
74
+ # # generate sketchで線画生成
75
+ # for test_path in tqdm(get_image_list(opt.dataroot)):
76
+ # basename = os.path.basename(test_path)
77
+ # aus_path = os.path.join(opt.output_dir, basename)
78
+ # # numpy配列で画像を読み込む
79
+ # img = Image.open(test_path)
80
+ # img = np.array(img)
81
+ # aus_img = generate_sketch(img, opt.clahe_clip)
82
+ # # 画像を保存
83
+ # save_image(aus_img, aus_path, (512, 512))
84
+
85
+
86
+ # create model
87
+ gpu_list = ','.join(str(x) for x in opt.gpu_ids)
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+ model = create_model(opt.model).to(device) # create a model given opt.model and other options
90
+ model.eval()
91
+
92
+ for test_path in tqdm(get_image_list(opt.dataroot)):
93
+ basename = os.path.basename(test_path)
94
+ aus_path = os.path.join(opt.output_dir, basename)
95
+
96
+ img = Image.open(test_path).convert('RGB')
97
+ img = np.array(img)
98
+
99
+ load_size = 512
100
+ aus_resize = None
101
+ if load_size > 0:
102
+ aus_resize = (img.shape[1], img.shape[0])
103
+ transform = get_transform(load_size=load_size)
104
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
105
+ # [0,255] to [-1,1]
106
+ image = transform(img)
107
+ if image.max() > 1:
108
+ image = (image-image.min())/(image.max()-image.min())*2-1
109
+ print(image.min(), image.max())
110
+
111
+ img, aus_resize = image.unsqueeze(0), aus_resize
112
+ if opt.clahe_clip > 0:
113
+ img = (img + 1) / 2 # [-1,1] to [0,1]
114
+ img = equalize_clahe(img, clip_limit=opt.clahe_clip)
115
+ img = (img - .5) / .5 # [0,1] to [-1,1]
116
+
117
+ aus_tensor = model(img.to(device))
118
+ aus_img = tensor_to_img(aus_tensor)
119
+ save_image(aus_img, aus_path, aus_resize)
120
+ """
121
+ # create model
122
+ gpu_list = ','.join(str(x) for x in opt.gpu_ids)
123
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
124
+ device = torch.device('cuda' if len(opt.gpu_ids)>0 else 'cpu')
125
+ model = create_model(opt.model).to(device) # create a model given opt.model and other options
126
+ model.eval()
127
+ # get input data
128
+ if os.path.isdir(opt.dataroot):
129
+ test_list = get_image_list(opt.dataroot)
130
+ elif os.path.isfile(opt.dataroot):
131
+ test_list = [opt.dataroot]
132
+ else:
133
+ raise Exception("{} is not a valid directory or image file.".format(opt.dataroot))
134
+ # save outputs
135
+ save_dir = opt.output_dir
136
+ os.makedirs(save_dir, exist_ok=True)
137
+
138
+ for test_path in tqdm(test_list):
139
+ basename = os.path.basename(test_path)
140
+ aus_path = os.path.join(save_dir, basename)
141
+ img, aus_resize = read_img_path(test_path, opt.load_size)
142
+
143
+ if opt.clahe_clip > 0:
144
+ img = (img + 1) / 2 # [-1,1] to [0,1]
145
+ img = equalize_clahe(img, clip_limit=opt.clahe_clip)
146
+ img = (img - .5) / .5 # [0,1] to [-1,1]
147
+
148
+ aus_tensor = model(img.to(device))
149
+ print(aus_tensor.shape)
150
+ aus_img = tensor_to_img(aus_tensor)
151
+ save_image(aus_img, aus_path, aus_resize)
152
+ """
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, render_template, send_file, jsonify, send_from_directory
2
+ from flask_socketio import SocketIO, emit
3
+ from flask_cors import CORS
4
+ import io
5
+ import os
6
+ from PIL import Image
7
+ import torch
8
+ import gc
9
+ from peft import PeftModel
10
+
11
+ import queue
12
+ import threading
13
+ import uuid
14
+ import concurrent.futures
15
+ from process_utils import *
16
+
17
+ app = Flask(__name__)
18
+ # app.secret_key = 'super_secret_key'
19
+ CORS(app)
20
+ socketio = SocketIO(app, cors_allowed_origins="*")
21
+
22
+ # タスクキューの作成
23
+ task_queue = queue.Queue()
24
+ active_tasks = {}
25
+ task_futures = {}
26
+
27
+ # ThreadPoolExecutorの作成
28
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
29
+
30
+ class Task:
31
+ def __init__(self, task_id, mode, weight1, weight2, file_data):
32
+ self.task_id = task_id
33
+ self.mode = mode
34
+ self.weight1 = weight1
35
+ self.weight2 = weight2
36
+ self.file_data = file_data
37
+ self.cancel_flag = False
38
+
39
+ def update_queue_status(message=None):
40
+ socketio.emit('queue_update', {'active_tasks': len(active_tasks), 'message': message})
41
+
42
+ def process_task(task):
43
+ try:
44
+ # ファイルデータをPIL Imageに変換
45
+ image = Image.open(io.BytesIO(task.file_data))
46
+ image = ensure_rgb(image)
47
+
48
+ # キャンセルチェック
49
+ if task.cancel_flag:
50
+ return
51
+
52
+ # 画像処理ロジックを呼び出す
53
+ sotai_image, sketch_image = process_image_as_base64(image, task.mode, task.weight1, task.weight2)
54
+
55
+ # キャンセルチェック
56
+ if task.cancel_flag:
57
+ return
58
+
59
+ socketio.emit('task_complete', {
60
+ 'task_id': task.task_id,
61
+ 'sotai_image': sotai_image,
62
+ 'sketch_image': sketch_image
63
+ })
64
+ except Exception as e:
65
+ if not task.cancel_flag:
66
+ socketio.emit('task_error', {'task_id': task.task_id, 'error': str(e)})
67
+ finally:
68
+ if task.task_id in active_tasks:
69
+ del active_tasks[task.task_id]
70
+ if task.task_id in task_futures:
71
+ del task_futures[task.task_id]
72
+ update_queue_status('Task completed or cancelled')
73
+
74
+ def worker():
75
+ while True:
76
+ try:
77
+ task = task_queue.get()
78
+ if task.task_id in active_tasks:
79
+ future = executor.submit(process_task, task)
80
+ task_futures[task.task_id] = future
81
+ update_queue_status(f'Task started: {task.task_id}')
82
+ except Exception as e:
83
+ print(f"Worker error: {str(e)}")
84
+ finally:
85
+ # Ensure the task is always removed from the queue
86
+ task_queue.task_done()
87
+
88
+ # ワーカースレッドの開始
89
+ threading.Thread(target=worker, daemon=True).start()
90
+
91
+ @app.route('/submit_task', methods=['POST'])
92
+ def submit_task():
93
+ task_id = str(uuid.uuid4())
94
+ file = request.files['file']
95
+ mode = request.form.get('mode', 'refine')
96
+ weight1 = float(request.form.get('weight1', 0.4))
97
+ weight2 = float(request.form.get('weight2', 0.3))
98
+
99
+ # ファイルデータをバイト列として保存
100
+ file_data = file.read()
101
+
102
+ task = Task(task_id, mode, weight1, weight2, file_data)
103
+ task_queue.put(task)
104
+ active_tasks[task_id] = task
105
+
106
+ update_queue_status(f'Task submitted: {task_id}')
107
+
108
+ queue_size = task_queue.qsize()
109
+ return jsonify({'task_id': task_id, 'queue_size': queue_size})
110
+
111
+ @app.route('/cancel_task/<task_id>', methods=['POST'])
112
+ def cancel_task(task_id):
113
+ if task_id in active_tasks:
114
+ task = active_tasks[task_id]
115
+ task.cancel_flag = True
116
+ if task_id in task_futures:
117
+ task_futures[task_id].cancel()
118
+ del task_futures[task_id]
119
+ del active_tasks[task_id]
120
+ update_queue_status('Task cancelled')
121
+ return jsonify({'message': 'Task cancellation requested'})
122
+ else:
123
+ return jsonify({'message': 'Task not found or already completed'}), 404
124
+
125
+ def get_active_task_order(task_id):
126
+ return list(active_tasks.keys()).index(task_id) if task_id in active_tasks else None
127
+
128
+ # get_task_orderイベントハンドラー
129
+ @app.route('/get_task_order/<task_id>', methods=['GET'])
130
+ def handle_get_task_order(task_id):
131
+ task_order = get_active_task_order(task_id)
132
+ return jsonify({'task_order': task_order})
133
+
134
+ @socketio.on('connect')
135
+ def handle_connect():
136
+ emit('queue_update', {'active_tasks': len(active_tasks), 'active_task_order': None})
137
+
138
+ # Flaskルート
139
+ @app.route('/', methods=['GET', 'POST'])
140
+ def process_refined():
141
+ if request.method == 'POST':
142
+ file = request.files['file']
143
+ weight1 = float(request.form.get('weight1', 0.4))
144
+ weight2 = float(request.form.get('weight2', 0.3))
145
+
146
+ image = ensure_rgb(Image.open(file.stream))
147
+ sotai_image, sketch_image = process_image_as_base64(image, "refine", weight1, weight2)
148
+
149
+ return jsonify({
150
+ 'sotai_image': sotai_image,
151
+ 'sketch_image': sketch_image
152
+ })
153
+
154
+ @app.route('/process_original', methods=['GET', 'POST'])
155
+ def process_original():
156
+ if request.method == 'POST':
157
+ file = request.files['file']
158
+
159
+ image = ensure_rgb(Image.open(file.stream))
160
+ sotai_image, sketch_image = process_image_as_base64(image, "original")
161
+
162
+ return jsonify({
163
+ 'sotai_image': sotai_image,
164
+ 'sketch_image': sketch_image
165
+ })
166
+
167
+ @app.route('/process_sketch', methods=['GET', 'POST'])
168
+ def process_sketch():
169
+ if request.method == 'POST':
170
+ file = request.files['file']
171
+
172
+ image = ensure_rgb(Image.open(file.stream))
173
+ sotai_image, sketch_image = process_image_as_base64(image, "sketch")
174
+
175
+ return jsonify({
176
+ 'sotai_image': sotai_image,
177
+ 'sketch_image': sketch_image
178
+ })
179
+
180
+ # エラーハンドラー
181
+ @app.errorhandler(500)
182
+ def server_error(e):
183
+ return jsonify(error=str(e)), 500
184
+
185
+ if __name__ == '__main__':
186
+ initialize(local_model=True)
187
+ socketio.run(app, debug=True, host='0.0.0.0', port=5000)
data.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+ try:
5
+ from transforms import InterpolationMode
6
+ bic = InterpolationMode.BICUBIC
7
+ except ImportError:
8
+ bic = Image.BICUBIC
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
14
+
15
+ def is_image_file(filename):
16
+ """if a given filename is a valid image
17
+ Parameters:
18
+ filename (str) -- image filename
19
+ """
20
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21
+
22
+ def get_image_list(path):
23
+ """read the paths of valid images from the given directory path
24
+ Parameters:
25
+ path (str) -- input directory path
26
+ """
27
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
28
+ images = []
29
+ for dirpath, _, fnames in sorted(os.walk(path)):
30
+ for fname in sorted(fnames):
31
+ if is_image_file(fname):
32
+ img_path = os.path.join(dirpath, fname)
33
+ images.append(img_path)
34
+ assert images, '{:s} has no valid image file'.format(path)
35
+ return images
36
+
37
+ def get_transform(load_size=0, grayscale=False, method=bic, convert=True):
38
+ transform_list = []
39
+ if grayscale:
40
+ transform_list.append(transforms.Grayscale(1))
41
+ if load_size > 0:
42
+ osize = [load_size, load_size]
43
+ transform_list.append(transforms.Resize(osize, method))
44
+ if convert:
45
+ # transform_list += [transforms.ToTensor()]
46
+ if grayscale:
47
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
48
+ else:
49
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
50
+ return transforms.Compose(transform_list)
51
+
52
+ def read_img_path(path, load_size):
53
+ """read tensors from a given image path
54
+ Parameters:
55
+ path (str) -- input image path
56
+ load_size(int) -- the input size. If <= 0, don't resize
57
+ """
58
+ img = Image.open(path).convert('RGB')
59
+ aus_resize = None
60
+ if load_size > 0:
61
+ aus_resize = img.size
62
+ transform = get_transform(load_size=load_size)
63
+ image = transform(img)
64
+ return image.unsqueeze(0), aus_resize
65
+
66
+ def tensor_to_img(input_image, imtype=np.uint8):
67
+ """"Converts a Tensor array into a numpy image array.
68
+ Parameters:
69
+ input_image (tensor) -- the input image tensor array
70
+ imtype (type) -- the desired type of the converted numpy array
71
+ """
72
+
73
+ if not isinstance(input_image, np.ndarray):
74
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
75
+ image_tensor = input_image.data
76
+ else:
77
+ return input_image
78
+ image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
79
+ if image_numpy.shape[0] == 1: # grayscale to RGB
80
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
81
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
82
+ else: # if it is a numpy array, do nothing
83
+ image_numpy = input_image
84
+ return image_numpy.astype(imtype)
85
+
86
+ def save_image(image_numpy, image_path, output_resize=None):
87
+ """Save a numpy image to the disk
88
+ Parameters:
89
+ image_numpy (numpy array) -- input numpy array
90
+ image_path (str) -- the path of the image
91
+ output_resize(None or tuple) -- the output size. If None, don't resize
92
+ """
93
+
94
+ image_pil = Image.fromarray(image_numpy)
95
+ if output_resize:
96
+ image_pil = image_pil.resize(output_resize, bic)
97
+ image_pil.save(image_path)
generate_prompt.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import os
4
+ import json
5
+
6
+ from PIL import Image
7
+ import cv2
8
+ import numpy as np
9
+ from tensorflow.keras.layers import TFSMLayer
10
+ from huggingface_hub import hf_hub_download
11
+ from pathlib import Path
12
+
13
+ # from wd14 tagger
14
+ IMAGE_SIZE = 448
15
+
16
+ # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
17
+ DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
18
+ FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
19
+ SUB_DIR = "variables"
20
+ SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
21
+ CSV_FILE = FILES[-1]
22
+
23
+ def preprocess_image(image):
24
+ image = np.array(image)
25
+ image = image[:, :, ::-1] # RGB->BGR
26
+
27
+ # pad to square
28
+ size = max(image.shape[0:2])
29
+ pad_x = size - image.shape[1]
30
+ pad_y = size - image.shape[0]
31
+ pad_l = pad_x // 2
32
+ pad_t = pad_y // 2
33
+ image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
34
+
35
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
36
+ image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
37
+
38
+ image = image.astype(np.float32)
39
+ return image
40
+
41
+
42
+ def load_wd14_tagger_model():
43
+ model_dir = "wd14_tagger_model"
44
+ repo_id = DEFAULT_WD14_TAGGER_REPO
45
+
46
+ if not os.path.exists(model_dir):
47
+ print(f"downloading wd14 tagger model from hf_hub. id: {repo_id}")
48
+ for file in FILES:
49
+ hf_hub_download(repo_id, file, cache_dir=model_dir, force_download=True, force_filename=file)
50
+ for file in SUB_DIR_FILES:
51
+ hf_hub_download(
52
+ repo_id,
53
+ file,
54
+ subfolder=SUB_DIR,
55
+ cache_dir=os.path.join(model_dir, SUB_DIR),
56
+ force_download=True,
57
+ force_filename=file,
58
+ )
59
+ else:
60
+ print("using existing wd14 tagger model")
61
+
62
+ # モデルを読み込む
63
+ model = TFSMLayer(model_dir, call_endpoint='serving_default')
64
+ return model
65
+
66
+
67
+ def generate_tags(images, model_dir, model):
68
+ with open(os.path.join(model_dir, CSV_FILE), "r", encoding="utf-8") as f:
69
+ reader = csv.reader(f)
70
+ l = [row for row in reader]
71
+ header = l[0] # tag_id,name,category,count
72
+ rows = l[1:]
73
+ assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
74
+
75
+ general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
76
+ character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
77
+
78
+ tag_freq = {}
79
+ undesired_tags = ['one-piece_swimsuit',
80
+ 'swimsuit',
81
+ 'leotard',
82
+ 'saitama_(one-punch_man)',
83
+ '1boy',
84
+ ]
85
+
86
+ probs = model(images, training=False)
87
+ probs = probs['predictions_sigmoid'].numpy()
88
+
89
+ tag_text_list = []
90
+ for prob in probs:
91
+ combined_tags = []
92
+ general_tag_text = ""
93
+ character_tag_text = ""
94
+ thresh = 0.35
95
+ for i, p in enumerate(prob[4:]):
96
+ if i < len(general_tags) and p >= thresh:
97
+ tag_name = general_tags[i]
98
+ if tag_name not in undesired_tags:
99
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
100
+ general_tag_text += ", " + tag_name
101
+ combined_tags.append(tag_name)
102
+ elif i >= len(general_tags) and p >= thresh:
103
+ tag_name = character_tags[i - len(general_tags)]
104
+ if tag_name not in undesired_tags:
105
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
106
+ character_tag_text += ", " + tag_name
107
+ combined_tags.append(tag_name)
108
+
109
+ if len(general_tag_text) > 0:
110
+ general_tag_text = general_tag_text[2:]
111
+ if len(character_tag_text) > 0:
112
+ character_tag_text = character_tag_text[2:]
113
+
114
+ tag_text = ", ".join(combined_tags)
115
+ tag_text_list.append(tag_text)
116
+ return tag_text_list
117
+
118
+
119
+ def generate_prompt_json(target_folder, prompt_file, model_dir, model):
120
+ image_files = [f for f in os.listdir(target_folder) if os.path.isfile(os.path.join(target_folder, f))]
121
+ image_count = len(image_files)
122
+
123
+ prompt_list = []
124
+
125
+ for i, filename in enumerate(image_files, 1):
126
+ source_path = "source/" + filename
127
+ target_path = os.path.join(target_folder, filename) # Use absolute path
128
+ target_path2 = "target/" + filename
129
+
130
+ prompt = generate_tags(target_path, model_dir, model)
131
+
132
+ for j in range(4):
133
+ prompt_data = {
134
+ "source": f"{source_path.split('.')[0]}_{j}.jpg",
135
+ "target": f"{target_path2.split('.')[0]}_{j}.jpg",
136
+ "prompt": prompt
137
+ }
138
+
139
+ prompt_list.append(prompt_data)
140
+
141
+ print(f"Processed Images: {i}/{image_count}", end="\r", flush=True)
142
+
143
+ with open(prompt_file, "w") as file:
144
+ for prompt_data in prompt_list:
145
+ json.dump(prompt_data, file)
146
+ file.write("\n")
147
+
148
+ print(f"Processing completed. Total Images: {image_count}")
149
+
150
+
151
+ if __name__ == '__main__':
152
+ model_dir = "wd14_tagger_model"
153
+ model = load_wd14_tagger_model()
154
+ prompt = generate_tags(target_path, model_dir, model)
lineart_util.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ from anime import generate_sketch
5
+
6
+ def pad64(x):
7
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
8
+
9
+ def HWC3(x):
10
+ assert x.dtype == np.uint8
11
+ if x.ndim == 2:
12
+ x = x[:, :, None]
13
+ assert x.ndim == 3
14
+ H, W, C = x.shape
15
+ assert C == 1 or C == 3 or C == 4
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+ def safer_memory(x):
28
+ # Fix many MAC/AMD problems
29
+ return np.ascontiguousarray(x.copy()).copy()
30
+
31
+ def resize_image_with_pad(input_image, resolution, skip_hwc3=False):
32
+ if skip_hwc3:
33
+ img = input_image
34
+ else:
35
+ img = HWC3(input_image)
36
+ H_raw, W_raw, _ = img.shape
37
+ k = float(resolution) / float(min(H_raw, W_raw))
38
+ interpolation = cv2.INTER_CUBIC if k > 1 else cv2.INTER_AREA
39
+ H_target = int(np.round(float(H_raw) * k))
40
+ W_target = int(np.round(float(W_raw) * k))
41
+ img = cv2.resize(img, (W_target, H_target), interpolation=interpolation)
42
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
43
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode='edge')
44
+
45
+ def remove_pad(x):
46
+ return safer_memory(x[:H_target, :W_target])
47
+
48
+ return safer_memory(img_padded), remove_pad
49
+
50
+ def scribble_xdog(img, res=512, thr_a=32, **kwargs):
51
+ """
52
+ XDoGを使ってスケッチ画像を生成する
53
+ :param img: np.ndarray, 入力画像
54
+ :param res: int, 出力画像の解像度
55
+ :param thr_a: int, 閾値
56
+
57
+ Returns
58
+ -------
59
+ Image : PIL.Image
60
+ """
61
+ img, remove_pad = resize_image_with_pad(img, res)
62
+ g1 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 0.5)
63
+ g2 = cv2.GaussianBlur(img.astype(np.float32), (0, 0), 5.0)
64
+ dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8)
65
+ result = np.zeros_like(img, dtype=np.uint8)
66
+ result[2 * (255 - dog) > thr_a] = 255
67
+ result = Image.fromarray(remove_pad(result))
68
+ return result, True
69
+
70
+ def canny(img, res=512, thr_a=100, thr_b=200, **kwargs):
71
+ img, remove_pad = resize_image_with_pad(img, res)
72
+ result = cv2.Canny(img, thr_a, thr_b)
73
+ result = Image.fromarray(remove_pad(result))
74
+ return result, True
75
+
76
+ def get_sketch(image, method='scribble_xdog', res=2048, thr=20, **kwargs):
77
+ # image: np.ndarray
78
+ input_height = image.shape[0]
79
+ input_width = image.shape[1]
80
+
81
+ if method == 'scribble_xdog':
82
+ processed_image, _ = scribble_xdog(image, res, thr) # PIL.Image
83
+ processed_image = processed_image.resize((input_width, input_height))
84
+ # make PIL.Image to cv2 and INVERSE
85
+ processed_image = cv2.cvtColor(np.array(processed_image), cv2.COLOR_RGB2BGR)
86
+ processed_image = 255 - processed_image
87
+ processed_image = Image.fromarray(processed_image)
88
+ elif method == 'anime2sketch':
89
+ clahe = 1.0
90
+ processed_image = generate_sketch(image, clahe_clip=clahe, load_size=1024) # output: numpy.ndarray
91
+ processed_image = Image.fromarray(processed_image)
92
+ # processed_image.save(output_path.split('.')[0] + f'_{clahe}.png')
93
+ elif method == 'both':
94
+ alpha = 0.5
95
+ # 2枚をalphaの重みで合成
96
+ scribble_xdog_processed_image, _ = scribble_xdog(image, res, thr)
97
+ scribble_xdog_processed_image = scribble_xdog_processed_image.resize((input_width, input_height))
98
+ scribble_xdog_processed_image = cv2.cvtColor(np.array(scribble_xdog_processed_image), cv2.COLOR_RGB2BGR)
99
+ scribble_xdog_processed_image = 255 - scribble_xdog_processed_image
100
+
101
+ anime2sketch_processed_image = generate_sketch(image, clahe_clip=1.0, load_size=1024)
102
+ anime2sketch_processed_image = Image.fromarray(anime2sketch_processed_image)
103
+ anime2sketch_processed_image = anime2sketch_processed_image.resize((input_width, input_height))
104
+ anime2sketch_processed_image = cv2.cvtColor(np.array(anime2sketch_processed_image), cv2.COLOR_RGB2BGR)
105
+
106
+ processed_image = cv2.addWeighted(scribble_xdog_processed_image, alpha, anime2sketch_processed_image, 1-alpha, 0)
107
+ processed_image = Image.fromarray(processed_image)
108
+
109
+ return processed_image
model.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+ from app import download_file
6
+
7
+ class UnetGenerator(nn.Module):
8
+ """Create a Unet-based generator"""
9
+
10
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
11
+ """Construct a Unet generator
12
+ Parameters:
13
+ input_nc (int) -- the number of channels in input images
14
+ output_nc (int) -- the number of channels in output images
15
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
16
+ image of size 128x128 will become of size 1x1 # at the bottleneck
17
+ ngf (int) -- the number of filters in the last conv layer
18
+ norm_layer -- normalization layer
19
+ We construct the U-Net from the innermost layer to the outermost layer.
20
+ It is a recursive process.
21
+ """
22
+ super(UnetGenerator, self).__init__()
23
+ # construct unet structure
24
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
25
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
26
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
27
+ # gradually reduce the number of filters from ngf * 8 to ngf
28
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
29
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
30
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
31
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
32
+
33
+ def forward(self, input):
34
+ """Standard forward"""
35
+ return self.model(input)
36
+
37
+ class UnetSkipConnectionBlock(nn.Module):
38
+ """Defines the Unet submodule with skip connection.
39
+ X -------------------identity----------------------
40
+ |-- downsampling -- |submodule| -- upsampling --|
41
+ """
42
+
43
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
44
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
45
+ """Construct a Unet submodule with skip connections.
46
+ Parameters:
47
+ outer_nc (int) -- the number of filters in the outer conv layer
48
+ inner_nc (int) -- the number of filters in the inner conv layer
49
+ input_nc (int) -- the number of channels in input images/features
50
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
51
+ outermost (bool) -- if this module is the outermost module
52
+ innermost (bool) -- if this module is the innermost module
53
+ norm_layer -- normalization layer
54
+ use_dropout (bool) -- if use dropout layers.
55
+ """
56
+ super(UnetSkipConnectionBlock, self).__init__()
57
+ self.outermost = outermost
58
+ if type(norm_layer) == functools.partial:
59
+ use_bias = norm_layer.func == nn.InstanceNorm2d
60
+ else:
61
+ use_bias = norm_layer == nn.InstanceNorm2d
62
+ if input_nc is None:
63
+ input_nc = outer_nc
64
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
65
+ stride=2, padding=1, bias=use_bias)
66
+ downrelu = nn.LeakyReLU(0.2, True)
67
+ downnorm = norm_layer(inner_nc)
68
+ uprelu = nn.ReLU(True)
69
+ upnorm = norm_layer(outer_nc)
70
+
71
+ if outermost:
72
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
73
+ kernel_size=4, stride=2,
74
+ padding=1)
75
+ down = [downconv]
76
+ up = [uprelu, upconv, nn.Tanh()]
77
+ model = down + [submodule] + up
78
+ elif innermost:
79
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
80
+ kernel_size=4, stride=2,
81
+ padding=1, bias=use_bias)
82
+ down = [downrelu, downconv]
83
+ up = [uprelu, upconv, upnorm]
84
+ model = down + up
85
+ else:
86
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
87
+ kernel_size=4, stride=2,
88
+ padding=1, bias=use_bias)
89
+ down = [downrelu, downconv, downnorm]
90
+ up = [uprelu, upconv, upnorm]
91
+
92
+ if use_dropout:
93
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
94
+ else:
95
+ model = down + [submodule] + up
96
+
97
+ self.model = nn.Sequential(*model)
98
+
99
+ def forward(self, x):
100
+ if self.outermost:
101
+ return self.model(x)
102
+ else: # add skip connections
103
+ return torch.cat([x, self.model(x)], 1)
104
+
105
+
106
+ class Smooth(nn.Module):
107
+ def __init__(self):
108
+ super().__init__()
109
+ kernel = [
110
+ [1, 2, 1],
111
+ [2, 4, 2],
112
+ [1, 2, 1]
113
+ ]
114
+ kernel = torch.tensor([[kernel]], dtype=torch.float)
115
+ kernel /= kernel.sum()
116
+ self.register_buffer('kernel', kernel)
117
+ self.pad = nn.ReplicationPad2d(1)
118
+
119
+ def forward(self, x):
120
+ b, c, h, w = x.shape
121
+ x = x.view(-1, 1, h, w)
122
+ x = self.pad(x)
123
+ x = F.conv2d(x, self.kernel)
124
+ return x.view(b, c, h, w)
125
+
126
+
127
+ class Upsample(nn.Module):
128
+ def __init__(self, inc, outc, scale_factor=2):
129
+ super().__init__()
130
+ self.scale_factor = scale_factor
131
+ self.up = nn.Upsample(scale_factor=scale_factor, mode='bilinear')
132
+ self.smooth = Smooth()
133
+ self.conv = nn.Conv2d(inc, outc, kernel_size=3, stride=1, padding=1)
134
+ self.mlp = nn.Sequential(
135
+ nn.Conv2d(outc, 4 * outc, kernel_size=1, stride=1, padding=0),
136
+ nn.GELU(),
137
+ nn.Conv2d(4 * outc, outc, kernel_size=1, stride=1, padding=0),
138
+ )
139
+
140
+ def forward(self, x):
141
+ x = self.smooth(self.up(x))
142
+ x = self.conv(x)
143
+ x = self.mlp(x) + x
144
+ return x
145
+
146
+
147
+ def create_model(model):
148
+ """Create a model for anime2sketch
149
+ hardcoding the options for simplicity
150
+ """
151
+
152
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
153
+ net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
154
+
155
+ import os
156
+ cwd = os.getcwd() # 現在のディレクトリを保存
157
+ os.chdir(os.path.dirname(__file__)) # このファイルのディレクトリに移動
158
+ if model == 'default':
159
+ model_path = download_file("netG.pth", subfolder="models/Anime2Sketch")
160
+ ckpt = torch.load(model_path)
161
+ for key in list(ckpt.keys()):
162
+ if 'module.' in key:
163
+ ckpt[key.replace('module.', '')] = ckpt[key]
164
+ del ckpt[key]
165
+ net.load_state_dict(ckpt)
166
+
167
+ os.chdir(cwd) # 元のディレクトリに戻る
168
+
169
+ elif model == 'improved':
170
+ ckpt = torch.load('weights/improved.bin', map_location=torch.device('cpu'))
171
+ base = net.model.model[1]
172
+
173
+ # swap deconvolution layers with reszie + conv layers for 2x upsampling
174
+ for _ in range(6):
175
+ inc, outc = base.model[5].in_channels, base.model[5].out_channels
176
+ base.model[5] = Upsample(inc, outc)
177
+ base = base.model[3]
178
+
179
+ net.load_state_dict(ckpt)
180
+
181
+ os.chdir(cwd) # 元のディレクトリに戻る
182
+
183
+ else:
184
+ raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")
185
+
186
+ return net
process_utils.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import base64
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ from generate_prompt import load_wd14_tagger_model, generate_tags, preprocess_image as wd14_preprocess_image
8
+ from lineart_util import scribble_xdog, get_sketch, canny
9
+ import torch
10
+ from diffusers import StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, AutoencoderKL
11
+ import gc
12
+ from peft import PeftModel
13
+ from huggingface_hub import hf_hub_download
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ # グローバル変数
19
+ local_model = False
20
+ model = None
21
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ device = "cpu"
23
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
24
+ sotai_gen_pipe = None
25
+ refine_gen_pipe = None
26
+
27
+ def download_file(filename, subfolder=None):
28
+ return hf_hub_download(
29
+ repo_id=os.environ['REPO_ID'],
30
+ filename=filename,
31
+ subfolder=subfolder,
32
+ token=os.environ['HF_TOKEN'],
33
+ cache_dir=os.environ['CACHE_DIR']
34
+ )
35
+
36
+ def get_file_path(filename, subfolder=None):
37
+ if local_model:
38
+ return os.path.join(subfolder, filename)
39
+ else:
40
+ return download_file(filename, subfolder)
41
+
42
+ def ensure_rgb(image):
43
+ if image.mode != 'RGB':
44
+ return image.convert('RGB')
45
+ return image
46
+
47
+ def initialize(_local_model=False):
48
+ global model, sotai_gen_pipe, refine_gen_pipe, local_model
49
+
50
+ local_model = _local_model
51
+ model = load_wd14_tagger_model()
52
+ sotai_gen_pipe = initialize_sotai_model()
53
+ refine_gen_pipe = initialize_refine_model()
54
+
55
+ def load_lora(pipeline, lora_path, alpha=0.75):
56
+ pipeline.load_lora_weights(lora_path)
57
+ pipeline.fuse_lora(lora_scale=alpha)
58
+
59
+ def initialize_sotai_model():
60
+ global device, torch_dtype
61
+
62
+ sotai_sd_model_path = get_file_path(os.environ["sotai_sd_model_name"], subfolder=os.environ["sd_models_dir"])
63
+ controlnet_path1 = get_file_path(os.environ["controlnet_name1"], subfolder=os.environ["controlnet_dir2"])
64
+ controlnet_path2 = get_file_path(os.environ["controlnet_name2"], subfolder=os.environ["controlnet_dir1"])
65
+
66
+ # Load the Stable Diffusion model
67
+ sd_pipe = StableDiffusionPipeline.from_single_file(
68
+ sotai_sd_model_path,
69
+ torch_dtype=torch_dtype,
70
+ use_safetensors=True
71
+ ).to(device)
72
+
73
+ # Load the ControlNet model
74
+ controlnet1 = ControlNetModel.from_single_file(
75
+ controlnet_path1,
76
+ torch_dtype=torch_dtype
77
+ ).to(device)
78
+
79
+ # Load the ControlNet model
80
+ controlnet2 = ControlNetModel.from_single_file(
81
+ controlnet_path2,
82
+ torch_dtype=torch_dtype
83
+ ).to(device)
84
+
85
+ # Create the ControlNet pipeline
86
+ sotai_gen_pipe = StableDiffusionControlNetPipeline(
87
+ vae=sd_pipe.vae,
88
+ text_encoder=sd_pipe.text_encoder,
89
+ tokenizer=sd_pipe.tokenizer,
90
+ unet=sd_pipe.unet,
91
+ scheduler=sd_pipe.scheduler,
92
+ safety_checker=sd_pipe.safety_checker,
93
+ feature_extractor=sd_pipe.feature_extractor,
94
+ controlnet=[controlnet1, controlnet2]
95
+ ).to(device)
96
+
97
+ # LoRAの適用
98
+ lora_names = [
99
+ (os.environ["lora_name1"], 1.0),
100
+ # (os.environ["lora_name2"], 0.3),
101
+ ]
102
+
103
+ for lora_name, alpha in lora_names:
104
+ lora_path = get_file_path(lora_name, subfolder=os.environ["lora_dir"])
105
+ load_lora(sotai_gen_pipe, lora_path, alpha)
106
+
107
+ # スケジューラーの設定
108
+ sotai_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(sotai_gen_pipe.scheduler.config)
109
+
110
+ return sotai_gen_pipe
111
+
112
+ def initialize_refine_model():
113
+ global device, torch_dtype
114
+
115
+ refine_sd_model_path = get_file_path(os.environ["refine_sd_model_name"], subfolder=os.environ["sd_models_dir"])
116
+ controlnet_path3 = get_file_path(os.environ["controlnet_name3"], subfolder=os.environ["controlnet_dir1"])
117
+ controlnet_path4 = get_file_path(os.environ["controlnet_name4"], subfolder=os.environ["controlnet_dir1"])
118
+ vae_path = get_file_path(os.environ["vae_name"], subfolder=os.environ["vae_dir"])
119
+
120
+ # Load the Stable Diffusion model
121
+ sd_pipe = StableDiffusionPipeline.from_single_file(
122
+ refine_sd_model_path,
123
+ torch_dtype=torch_dtype,
124
+ use_safetensors=True
125
+ ).to(device)
126
+
127
+ # controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
128
+ controlnet1 = ControlNetModel.from_single_file(
129
+ controlnet_path3,
130
+ torch_dtype=torch_dtype
131
+ ).to(device)
132
+
133
+ # Load the ControlNet model
134
+ controlnet2 = ControlNetModel.from_single_file(
135
+ controlnet_path4,
136
+ torch_dtype=torch_dtype
137
+ ).to(device)
138
+
139
+ # Create the ControlNet pipeline
140
+ refine_gen_pipe = StableDiffusionControlNetPipeline(
141
+ vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype).to(device),
142
+ text_encoder=sd_pipe.text_encoder,
143
+ tokenizer=sd_pipe.tokenizer,
144
+ unet=sd_pipe.unet,
145
+ scheduler=sd_pipe.scheduler,
146
+ safety_checker=sd_pipe.safety_checker,
147
+ feature_extractor=sd_pipe.feature_extractor,
148
+ controlnet=[controlnet1, controlnet2], # 複数のControlNetを指定
149
+ ).to(device)
150
+
151
+ # スケジューラーの設定
152
+ refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)
153
+
154
+ return refine_gen_pipe
155
+
156
+ def get_wd_tags(images: list) -> list:
157
+ global model
158
+ if model is None:
159
+ initialize()
160
+ preprocessed_images = [wd14_preprocess_image(img) for img in images]
161
+ preprocessed_images = np.array(preprocessed_images)
162
+ return generate_tags(preprocessed_images, os.environ["wd_model_name"], model)
163
+
164
+ def preprocess_image_for_generation(image):
165
+ if isinstance(image, str): # base64文字列の場合
166
+ image = Image.open(io.BytesIO(base64.b64decode(image)))
167
+ elif isinstance(image, np.ndarray): # numpy配列の場合
168
+ image = Image.fromarray(image)
169
+ elif not isinstance(image, Image.Image):
170
+ raise ValueError("Unsupported image type")
171
+
172
+ # 画像サイズの計算
173
+ input_width, input_height = image.size
174
+ max_size = 736
175
+ output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
176
+ output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)
177
+
178
+ image = image.resize((output_width, output_height))
179
+ return image, output_width, output_height
180
+
181
+ def binarize_image(image: Image.Image) -> np.ndarray:
182
+ image = np.array(image.convert('L'))
183
+ # 色反転
184
+ image = 255 - image
185
+
186
+ # ヒストグラム平坦化
187
+ clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8, 8))
188
+ image = clahe.apply(image)
189
+
190
+ # ガウシアンブラー適用
191
+ image = cv2.GaussianBlur(image, (5, 5), 0)
192
+
193
+ # 適応的二値化
194
+ binary_image = cv2.adaptiveThreshold(image, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 9, -8)
195
+
196
+ return binary_image
197
+
198
+ def create_rgba_image(binary_image: np.ndarray, color: list) -> Image.Image:
199
+ rgba_image = np.zeros((binary_image.shape[0], binary_image.shape[1], 4), dtype=np.uint8)
200
+ rgba_image[:, :, 0] = color[0]
201
+ rgba_image[:, :, 1] = color[1]
202
+ rgba_image[:, :, 2] = color[2]
203
+ rgba_image[:, :, 3] = binary_image
204
+ return Image.fromarray(rgba_image, 'RGBA')
205
+
206
+ def generate_sotai_image(input_image: Image.Image, output_width: int, output_height: int) -> Image.Image:
207
+ input_image = ensure_rgb(input_image)
208
+ global sotai_gen_pipe
209
+ if sotai_gen_pipe is None:
210
+ initialize()
211
+
212
+ prompt = "anime pose, girl, (white background:1.5), (monochrome:1.5), full body, sketch, eyes, breasts, (slim legs, skinny legs:1.2)"
213
+ try:
214
+ # 入力画像のリサイズ
215
+ if input_image.size[0] > input_image.size[1]:
216
+ input_image = input_image.resize((512, int(512 * input_image.size[1] / input_image.size[0])))
217
+ else:
218
+ input_image = input_image.resize((int(512 * input_image.size[0] / input_image.size[1]), 512))
219
+
220
+ # EasyNegativeV2の内容
221
+ easy_negative_v2 = "(worst quality, low quality, normal quality:1.4), lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry, artist name, (bad_prompt_version2:0.8)"
222
+
223
+ output = sotai_gen_pipe(
224
+ prompt,
225
+ image=[input_image, input_image],
226
+ negative_prompt=f"(wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
227
+ # negative_prompt=f"{easy_negative_v2}, (wings:1.6), (clothes, garment, lighting, gray, missing limb, extra line, extra limb, extra arm, extra legs, hair, bangs, fringe, forelock, front hair, fill:1.4), (ink pool:1.6)",
228
+ num_inference_steps=40,
229
+ guidance_scale=8,
230
+ width=output_width,
231
+ height=output_height,
232
+ denoising_strength=0.13,
233
+ num_images_per_prompt=1, # Equivalent to batch_size
234
+ guess_mode=[True, True], # Equivalent to pixel_perfect
235
+ controlnet_conditioning_scale=[1.2, 1.3], # 各ControlNetの重み
236
+ guidance_start=[0.0, 0.0],
237
+ guidance_end=[1.0, 1.0],
238
+ )
239
+ generated_image = output.images[0]
240
+
241
+ return generated_image
242
+
243
+ finally:
244
+ # メモリ解放
245
+ if device == "cuda":
246
+ torch.cuda.empty_cache()
247
+ gc.collect()
248
+
249
+ def generate_refined_image(prompt: str, original_image: Image.Image, output_width: int, output_height: int, weight1: float, weight2: float) -> Image.Image:
250
+ original_image = ensure_rgb(original_image)
251
+ global refine_gen_pipe
252
+ if refine_gen_pipe is None:
253
+ initialize()
254
+
255
+ try:
256
+ original_image_np = np.array(original_image)
257
+ # scribble_xdog
258
+ scribble_image, _ = scribble_xdog(original_image_np, 2048, 20)
259
+
260
+ original_image = original_image.resize((output_width, output_height))
261
+ output = refine_gen_pipe(
262
+ prompt,
263
+ image=[scribble_image, original_image], # 2つのControlNetに対応する入力画像
264
+ negative_prompt="extra limb, monochrome, black and white",
265
+ num_inference_steps=20,
266
+ width=output_width,
267
+ height=output_height,
268
+ controlnet_conditioning_scale=[weight1, weight2], # 各ControlNetの重み
269
+ control_guidance_start=[0.0, 0.0],
270
+ control_guidance_end=[1.0, 1.0],
271
+ guess_mode=[False, False], # pixel_perfect
272
+ )
273
+ generated_image = output.images[0]
274
+
275
+ return generated_image
276
+
277
+ finally:
278
+ # メモリ解放
279
+ if device == "cuda":
280
+ torch.cuda.empty_cache()
281
+ gc.collect()
282
+
283
+ def process_image(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
284
+ input_image = ensure_rgb(input_image)
285
+ # サイズを取得
286
+ input_width, input_height = input_image.size
287
+ max_size = 736
288
+ output_width = max_size if input_height < input_width else int(input_width / input_height * max_size)
289
+ output_height = max_size if input_height > input_width else int(input_height / input_width * max_size)
290
+
291
+ if mode == "refine":
292
+ # WD-14 taggerを使用してプロンプトを生成
293
+ image_np = np.array(ensure_rgb(input_image))
294
+ prompt = get_wd_tags([image_np])[0]
295
+ prompt = f"{prompt}"
296
+ print(prompt)
297
+
298
+ refined_image = generate_refined_image(prompt, input_image, output_width, output_height, weight1, weight2)
299
+ refined_image = refined_image.convert('RGB')
300
+
301
+ # スケッチ画像を生成
302
+ refined_image_np = np.array(refined_image)
303
+ sketch_image = get_sketch(refined_image_np, "both", 2048, 10)
304
+ sketch_image = sketch_image.resize((output_width, output_height)) # 画像サイズを合わせる
305
+ # スケッチ画像の二値化
306
+ sketch_binary = binarize_image(sketch_image)
307
+ # RGBAに変換(透明なベース画像を作成)して、青い線を設定
308
+ sketch_image = create_rgba_image(sketch_binary, [0, 0, 255])
309
+
310
+ # 素体画像の生成
311
+ sotai_image = generate_sotai_image(refined_image, output_width, output_height)
312
+
313
+ elif mode == "original":
314
+ sotai_image = generate_sotai_image(input_image, output_width, output_height)
315
+
316
+ # スケッチ画像の生成
317
+ input_image_np = np.array(input_image)
318
+ sketch_image = get_sketch(input_image_np, "both", 2048, 16)
319
+
320
+ elif mode == "sketch":
321
+ # スケッチ画像の生成
322
+ input_image_np = np.array(input_image)
323
+ sketch_image = get_sketch(input_image_np, "both", 2048, 16)
324
+
325
+ # 素体画像の生成
326
+ sotai_image = generate_sotai_image(sketch_image, output_width, output_height)
327
+
328
+ else:
329
+ raise ValueError("Invalid mode")
330
+
331
+ # 素体画像の二値化
332
+ sotai_binary = binarize_image(sotai_image)
333
+ # RGBAに変換(透明なベース画像を作成)して、赤い線を設定
334
+ sotai_image = create_rgba_image(sotai_binary, [255, 0, 0])
335
+
336
+ return sotai_image, sketch_image
337
+
338
+ def image_to_base64(img_array):
339
+ buffered = io.BytesIO()
340
+ img_array.save(buffered, format="PNG")
341
+ return base64.b64encode(buffered.getvalue()).decode()
342
+
343
+ def process_image_as_base64(input_image, mode: str, weight1: float = 0.4, weight2: float = 0.3):
344
+ sotai_image, sketch_image = process_image(input_image, mode, weight1, weight2)
345
+ return image_to_base64(sotai_image), image_to_base64(sketch_image)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu116
2
+ torch==2.2.0
3
+ torchvision==0.17.0
4
+ torchaudio==2.2.0
5
+ diffusers==0.29.1
6
+ Flask==3.0.3
7
+ Flask-Cors==4.0.0
8
+ gradio==4.36.1
9
+ huggingface_hub==0.23.2
10
+ kornia==0.7.1
11
+ numpy==1.23.5
12
+ opencv-python==4.9.0.80
13
+ Pillow==10.3.0
14
+ Requests==2.32.3
15
+ tensorflow==2.16.1
16
+ transforms==0.2.1
17
+ tokenizers
18
+ pytorch_lightning
19
+ python-dotenv
templates/index.html ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Scribble Image Generator</title>
7
+ <style>
8
+ body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
9
+ .image-container { display: flex; justify-content: space-between; margin-top: 20px; }
10
+ .image-container img { max-width: 48%; height: auto; }
11
+ #loading { display: none; }
12
+ </style>
13
+ </head>
14
+ <body>
15
+ <h1>Scribble Image Generator</h1>
16
+ <form id="upload-form">
17
+ <input type="file" id="file-input" accept="image/*" required>
18
+ <br><br>
19
+ <label for="threshold">Threshold:</label>
20
+ <input type="number" id="threshold" name="threshold" value="20" min="1" max="64">
21
+ <br><br>
22
+ <label for="processor_res">Processor Resolution:</label>
23
+ <input type="number" id="processor_res" name="processor_res" value="2048" min="64" max="2048">
24
+ <br><br>
25
+ <button type="submit">Generate Scribble</button>
26
+ </form>
27
+ <div id="loading">Processing...</div>
28
+ <div class="image-container">
29
+ <img id="original-image" alt="Original Image">
30
+ <img id="scribble-image" alt="Scribble Image">
31
+ </div>
32
+
33
+ <script>
34
+ document.getElementById('upload-form').addEventListener('submit', function(e) {
35
+ e.preventDefault();
36
+ var formData = new FormData();
37
+ formData.append('file', document.getElementById('file-input').files[0]);
38
+ formData.append('threshold', document.getElementById('threshold').value);
39
+ formData.append('processor_res', document.getElementById('processor_res').value);
40
+
41
+ document.getElementById('loading').style.display = 'block';
42
+
43
+ fetch('/process', {
44
+ method: 'POST',
45
+ body: formData
46
+ })
47
+ .then(response => response.json())
48
+ .then(data => {
49
+ document.getElementById('original-image').src = 'data:image/png;base64,' + data.original_image;
50
+ document.getElementById('scribble-image').src = 'data:image/png;base64,' + data.scribble_image;
51
+ document.getElementById('loading').style.display = 'none';
52
+ })
53
+ .catch(error => {
54
+ console.error('Error:', error);
55
+ document.getElementById('loading').style.display = 'none';
56
+ });
57
+ });
58
+ </script>
59
+ </body>
60
+ </html>