soiz commited on
Commit
6b25aaf
1 Parent(s): 9327b29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -62
app.py CHANGED
@@ -1,17 +1,11 @@
1
  import os
2
-
3
- # Install Flask if not already installed
4
- return_code = os.system('pip install flask')
5
- if return_code != 0:
6
- raise RuntimeError("Failed to install Flask")
7
-
8
- import gradio as gr
9
- from random import randint
10
- from all_models import models
11
  from flask import Flask, request, send_file
12
  from io import BytesIO
13
  from PIL import Image, ImageChops
 
 
14
 
 
15
  app = Flask(__name__)
16
 
17
  # グローバルなモデル辞書
@@ -21,60 +15,35 @@ def load_model(model_name):
21
  global models_load
22
  if model_name not in models_load:
23
  try:
24
- m = gr.load(f'models/{model_name}')
 
 
 
25
  print(f"Model {model_name} loaded successfully.")
26
- models_load[model_name] = m
27
  except Exception as error:
28
  print(f"Error loading model {model_name}: {error}")
29
- models_load[model_name] = gr.Interface(lambda txt: None, ['text'], ['image'])
30
 
31
- def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None, num_inference_steps=None, sampler=None):
32
  if model_str not in models_load:
33
  load_model(model_str)
34
 
35
- if model_str in models_load:
36
  if noise == "random":
37
  noise = str(randint(0, 99999999999))
38
  full_prompt = f'{prompt} {noise}' if noise else prompt
39
 
40
- # ネガティブプロンプトとその他のパラメータをログに出力
41
- print(f"Prompt: {full_prompt}, Negative Prompt: {negative_prompt}, CFG Scale: {cfg_scale}, Steps: {num_inference_steps}, Sampler: {sampler}")
42
-
43
- # Construct the function call parameters dynamically
44
- inputs = [full_prompt]
45
- if negative_prompt:
46
- inputs.append(negative_prompt)
47
- if cfg_scale is not None:
48
- inputs.append(cfg_scale)
49
- if num_inference_steps is not None:
50
- inputs.append(num_inference_steps)
51
- if sampler:
52
- inputs.append(sampler)
53
 
54
  try:
55
  # モデル呼び出し
56
- result = models_load[model_str](*inputs)
57
-
58
- # Debugging result type
59
- print(f"Result type: {type(result)}, Result: {result}")
60
 
61
- # Check if result is an image or a file path
62
- if isinstance(result, str): # Assuming result might be a file path
63
- if os.path.exists(result):
64
- image = Image.open(result)
65
- else:
66
- print(f"File path not found: {result}")
67
- return None, 'File path not found'
68
- elif isinstance(result, Image.Image):
69
- image = result
70
- else:
71
- print("Result is not an image:", type(result))
72
- return None, f"Unexpected result type: {type(result)}"
73
-
74
  # Check if the image is completely black
75
  black = Image.new('RGB', image.size, (0, 0, 0))
76
  if ImageChops.difference(image, black).getbbox() is None:
77
- return None, 'The image is completely black. There may be a parameter that cannot be specified, or an error may have occurred internally.'
78
 
79
  return image, None
80
 
@@ -89,21 +58,11 @@ def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None,
89
  def home():
90
  prompt = request.args.get('prompt', '')
91
  model = request.args.get('model', '')
92
- negative_prompt = request.args.get('Nprompt', None)
93
  noise = request.args.get('noise', None)
94
- cfg_scale = request.args.get('cfg_scale', None)
95
- num_inference_steps = request.args.get('steps', None)
96
- sampler = request.args.get('sampler', None)
97
-
98
- try:
99
- if cfg_scale is not None:
100
- cfg_scale = float(cfg_scale)
101
- except ValueError:
102
- return 'Invalid "cfg_scale" parameter. It should be a number.', 400
103
 
104
  try:
105
- if num_inference_steps is not None:
106
- num_inference_steps = int(num_inference_steps)
107
  except ValueError:
108
  return 'Invalid "steps" parameter. It should be an integer.', 400
109
 
@@ -114,12 +73,11 @@ def home():
114
  return 'Please provide a "prompt" query parameter in the URL.', 400
115
 
116
  # Generate the image
117
- image, error_message = gen_fn(model, prompt, negative_prompt, noise, cfg_scale, num_inference_steps, sampler)
118
  if error_message:
119
  return error_message, 400
120
 
121
- if isinstance(image, Image.Image): # Ensure the result is a PIL image
122
- # Save image to BytesIO object
123
  img_io = BytesIO()
124
  image.save(img_io, format='PNG')
125
  img_io.seek(0)
@@ -128,5 +86,4 @@ def home():
128
  return 'Failed to generate image.', 500
129
 
130
  if __name__ == '__main__':
131
- # Launch Flask app
132
- app.run(host='0.0.0.0', port=7860) # Run Flask app
 
1
  import os
 
 
 
 
 
 
 
 
 
2
  from flask import Flask, request, send_file
3
  from io import BytesIO
4
  from PIL import Image, ImageChops
5
+ from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
6
+ import torch
7
 
8
+ # Flaskアプリケーションの初期化
9
  app = Flask(__name__)
10
 
11
  # グローバルなモデル辞書
 
15
  global models_load
16
  if model_name not in models_load:
17
  try:
18
+ scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
19
+ pipe = StableDiffusionPipeline.from_pretrained(model_name, scheduler=scheduler, torch_dtype=torch.float16)
20
+ pipe = pipe.to("cuda")
21
+ models_load[model_name] = pipe
22
  print(f"Model {model_name} loaded successfully.")
 
23
  except Exception as error:
24
  print(f"Error loading model {model_name}: {error}")
25
+ models_load[model_name] = None
26
 
27
+ def gen_fn(model_str, prompt, negative_prompt=None, noise=None, cfg_scale=None, num_inference_steps=None):
28
  if model_str not in models_load:
29
  load_model(model_str)
30
 
31
+ if model_str in models_load and models_load[model_str] is not None:
32
  if noise == "random":
33
  noise = str(randint(0, 99999999999))
34
  full_prompt = f'{prompt} {noise}' if noise else prompt
35
 
36
+ print(f"Prompt: {full_prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  try:
39
  # モデル呼び出し
40
+ result = models_load[model_str](full_prompt, num_inference_steps=num_inference_steps)
41
+ image = result.images[0] # 生成された画像を取得
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Check if the image is completely black
44
  black = Image.new('RGB', image.size, (0, 0, 0))
45
  if ImageChops.difference(image, black).getbbox() is None:
46
+ return None, 'The image is completely black.'
47
 
48
  return image, None
49
 
 
58
  def home():
59
  prompt = request.args.get('prompt', '')
60
  model = request.args.get('model', '')
 
61
  noise = request.args.get('noise', None)
62
+ num_inference_steps = request.args.get('steps', 50) # デフォルト値を設定
 
 
 
 
 
 
 
 
63
 
64
  try:
65
+ num_inference_steps = int(num_inference_steps)
 
66
  except ValueError:
67
  return 'Invalid "steps" parameter. It should be an integer.', 400
68
 
 
73
  return 'Please provide a "prompt" query parameter in the URL.', 400
74
 
75
  # Generate the image
76
+ image, error_message = gen_fn(model, prompt, noise=noise, num_inference_steps=num_inference_steps)
77
  if error_message:
78
  return error_message, 400
79
 
80
+ if isinstance(image, Image.Image):
 
81
  img_io = BytesIO()
82
  image.save(img_io, format='PNG')
83
  img_io.seek(0)
 
86
  return 'Failed to generate image.', 500
87
 
88
  if __name__ == '__main__':
89
+ app.run(host='0.0.0.0', port=7860)