arif670 commited on
Commit
bedec2e
·
verified ·
1 Parent(s): 9f82f30

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +41 -21
models.py CHANGED
@@ -1,51 +1,71 @@
1
  import torch
2
  import logging
 
3
  from diffusers import StableDiffusionPipeline, DiffusionPipeline
4
  from huggingface_hub import login
5
- import os
6
 
7
  # Configure logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
- def load_models():
 
12
  try:
13
- # Authentication
14
- hf_token = os.getenv("HF_TOKEN", "")
15
  if hf_token:
16
  login(token=hf_token)
 
17
  else:
18
- logger.warning("No HF_TOKEN found, using anonymous access")
19
 
20
- # Memory optimization config
21
- torch.backends.cudnn.benchmark = True
22
- torch.set_flush_denormal(True)
 
 
23
 
24
- # Load text-to-image model with memory optimization
25
  logger.info("Loading text-to-image model...")
26
  text_to_image = StableDiffusionPipeline.from_pretrained(
27
- "runwayml/stable-diffusion-v1-5", # Lighter variant
28
  torch_dtype=torch.float16,
29
  use_safetensors=True,
30
  safety_checker=None,
31
- requires_safety_checker=False
 
32
  )
33
- text_to_image = text_to_image.to("cpu")
34
- text_to_image.enable_attention_slicing()
35
- text_to_image.enable_sequential_cpu_offload()
 
 
 
 
 
 
36
 
37
- # Load video model with reduced parameters
38
  logger.info("Loading image-to-video model...")
39
  image_to_video = DiffusionPipeline.from_pretrained(
40
- "cerspense/zeroscope_v2_576w", # Smaller video model
41
- torch_dtype=torch.float16
 
42
  )
43
- image_to_video = image_to_video.to("cpu")
44
- image_to_video.enable_attention_slicing(1)
 
 
 
 
 
 
 
45
 
46
  logger.info("All models loaded successfully")
47
- return text_to_image, image_to_video, None
48
 
49
  except Exception as e:
50
  logger.error(f"Model loading failed: {str(e)}")
51
- raise RuntimeError("Could not initialize models") from e
 
1
  import torch
2
  import logging
3
+ import os
4
  from diffusers import StableDiffusionPipeline, DiffusionPipeline
5
  from huggingface_hub import login
6
+ from typing import Tuple
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ def load_models() -> Tuple[StableDiffusionPipeline, DiffusionPipeline, None]:
13
+ """Load and configure AI models with memory optimizations."""
14
  try:
15
+ # Authentication setup
16
+ hf_token = os.getenv("HF_TOKEN")
17
  if hf_token:
18
  login(token=hf_token)
19
+ logger.info("HF authentication successful")
20
  else:
21
+ logger.warning("Proceeding without HF authentication")
22
 
23
+ # Configure Torch for optimal performance
24
+ torch.set_grad_enabled(False)
25
+ if torch.cuda.is_available():
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.benchmark = True
28
 
29
+ # Load text-to-image model with optimizations
30
  logger.info("Loading text-to-image model...")
31
  text_to_image = StableDiffusionPipeline.from_pretrained(
32
+ "runwayml/stable-diffusion-v1-5",
33
  torch_dtype=torch.float16,
34
  use_safetensors=True,
35
  safety_checker=None,
36
+ variant="fp16",
37
+ use_auth_token=hf_token if hf_token else None
38
  )
39
+
40
+ # Enable memory optimizations
41
+ if torch.cuda.is_available():
42
+ text_to_image = text_to_image.to("cuda")
43
+ text_to_image.enable_xformers_memory_efficient_attention()
44
+ text_to_image.enable_model_cpu_offload()
45
+ else:
46
+ text_to_image = text_to_image.to("cpu")
47
+ text_to_image.enable_attention_slicing()
48
 
49
+ # Load image-to-video model
50
  logger.info("Loading image-to-video model...")
51
  image_to_video = DiffusionPipeline.from_pretrained(
52
+ "cerspense/zeroscope_v2_576w",
53
+ torch_dtype=torch.float16,
54
+ use_auth_token=hf_token if hf_token else None
55
  )
56
+
57
+ # Video model optimizations
58
+ if torch.cuda.is_available():
59
+ image_to_video = image_to_video.to("cuda")
60
+ image_to_video.enable_xformers_memory_efficient_attention()
61
+ image_to_video.enable_model_cpu_offload()
62
+ else:
63
+ image_to_video = image_to_video.to("cpu")
64
+ image_to_video.enable_attention_slicing()
65
 
66
  logger.info("All models loaded successfully")
67
+ return text_to_image, image_to_video, None # TTS placeholder
68
 
69
  except Exception as e:
70
  logger.error(f"Model loading failed: {str(e)}")
71
+ raise RuntimeError("Model initialization error - check logs") from e