Ron1006 commited on
Commit
43cda70
·
1 Parent(s): 6012149
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -8,22 +8,28 @@ import os
8
  import torch
9
  from transformers import utils
10
 
11
- # 清理缓存
12
- utils.move_cache()
13
-
14
  # 从环境变量中获取访问令牌
15
  hf_token = os.getenv("hf_token")
16
 
17
  # 使用令牌进行登录(如果不用Hugging Face的模型,可以跳过这一步)
18
  if hf_token:
19
  login(token=hf_token)
20
-
21
- # 加载预训练的图像到图像的 Stable Diffusion 模型
22
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
23
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
24
- )
25
-
26
- pipe.to("cuda") # 使用 CPU 来处理,如果有 GPU,可修改为 "cuda"
 
 
 
 
 
 
 
 
 
27
 
28
  # 定义 Gradio 接口中调用的函数
29
  def image_to_image(input_image, prompt):
@@ -34,12 +40,16 @@ def image_to_image(input_image, prompt):
34
  # 将用户上传的输入图像转化为 RGB 格式
35
  input_image = input_image.convert("RGB")
36
 
37
- # 使用预训练模型生成新图像
38
- generated_image = pipe(
39
- prompt=prompt, image=input_image, strength=0.75, guidance_scale=7.5
40
- ).images[0]
 
41
 
42
- return generated_image # 返回生成的图像
 
 
 
43
 
44
  # 创建 Gradio 接口
45
  demo = gr.Interface(
@@ -49,4 +59,4 @@ demo = gr.Interface(
49
  )
50
 
51
  if __name__ == "__main__":
52
- demo.launch(share=True) # 如果需要公开链接,可以设置 share=True
 
8
  import torch
9
  from transformers import utils
10
 
 
 
 
11
  # 从环境变量中获取访问令牌
12
  hf_token = os.getenv("hf_token")
13
 
14
  # 使用令牌进行登录(如果不用Hugging Face的模型,可以跳过这一步)
15
  if hf_token:
16
  login(token=hf_token)
17
+ else:
18
+ print("Hugging Face token not found. Please set the 'hf_token' environment variable.")
19
+
20
+ # 检查是否有GPU可用,并选择相应的设备和数据类型
21
+ if torch.cuda.is_available():
22
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
23
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
24
+ )
25
+ pipe.to("cuda")
26
+ print("Using CUDA")
27
+ else:
28
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
29
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32
30
+ )
31
+ pipe.to("cpu")
32
+ print("CUDA not available, using CPU")
33
 
34
  # 定义 Gradio 接口中调用的函数
35
  def image_to_image(input_image, prompt):
 
40
  # 将用户上传的输入图像转化为 RGB 格式
41
  input_image = input_image.convert("RGB")
42
 
43
+ try:
44
+ # 使用预训练模型生成新图像
45
+ generated_image = pipe(
46
+ prompt=prompt, image=input_image, strength=0.75, guidance_scale=7.5
47
+ ).images[0]
48
 
49
+ return generated_image # 返回生成的图像
50
+ except Exception as e:
51
+ print(f"Error during image generation: {e}")
52
+ return None
53
 
54
  # 创建 Gradio 接口
55
  demo = gr.Interface(
 
59
  )
60
 
61
  if __name__ == "__main__":
62
+ demo.launch(share=False) # 如果需要公开链接,可以设置 share=True