arif670 commited on
Commit
edb58da
·
verified ·
1 Parent(s): 6c5f599

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +21 -41
models.py CHANGED
@@ -5,70 +5,50 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
5
  from huggingface_hub import login
6
  from typing import Tuple
7
 
8
- # Configure logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- def load_models():
13
- # Existing model loading logic
14
- # Add this safety check
15
- if not hasattr(torch, 'cuda') or not torch.cuda.is_available():
16
- torch.set_flush_denormal(True)
17
  try:
18
- # Authentication setup
 
 
 
 
19
  hf_token = os.getenv("HF_TOKEN")
20
  if hf_token:
21
  login(token=hf_token)
22
- logger.info("HF authentication successful")
23
- else:
24
- logger.warning("Proceeding without HF authentication")
25
-
26
- # Configure Torch for optimal performance
27
- torch.set_grad_enabled(False)
28
- if torch.cuda.is_available():
29
- torch.backends.cuda.matmul.allow_tf32 = True
30
- torch.backends.cudnn.benchmark = True
31
 
32
- # Load text-to-image model with optimizations
33
- logger.info("Loading text-to-image model...")
34
  text_to_image = StableDiffusionPipeline.from_pretrained(
35
  "runwayml/stable-diffusion-v1-5",
36
- torch_dtype=torch.float16,
37
  use_safetensors=True,
38
- safety_checker=None,
39
- variant="fp16",
40
- use_auth_token=hf_token if hf_token else None
41
  )
 
42
 
43
- # Enable memory optimizations
44
- if torch.cuda.is_available():
45
- text_to_image = text_to_image.to("cuda")
46
  text_to_image.enable_xformers_memory_efficient_attention()
47
- text_to_image.enable_model_cpu_offload()
48
  else:
49
- text_to_image = text_to_image.to("cpu")
50
  text_to_image.enable_attention_slicing()
51
 
52
- # Load image-to-video model
53
- logger.info("Loading image-to-video model...")
54
  image_to_video = DiffusionPipeline.from_pretrained(
55
  "cerspense/zeroscope_v2_576w",
56
- torch_dtype=torch.float16,
57
- use_auth_token=hf_token if hf_token else None
58
  )
59
-
60
- # Video model optimizations
61
- if torch.cuda.is_available():
62
- image_to_video = image_to_video.to("cuda")
63
  image_to_video.enable_xformers_memory_efficient_attention()
64
- image_to_video.enable_model_cpu_offload()
65
  else:
66
- image_to_video = image_to_video.to("cpu")
67
  image_to_video.enable_attention_slicing()
68
 
69
- logger.info("All models loaded successfully")
70
- return text_to_image, image_to_video, None # TTS placeholder
71
 
72
  except Exception as e:
73
- logger.error(f"Model loading failed: {str(e)}")
74
- raise RuntimeError("Model initialization error - check logs") from e
 
5
  from huggingface_hub import login
6
  from typing import Tuple
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ def load_models() -> Tuple[StableDiffusionPipeline, DiffusionPipeline, None]:
 
 
 
 
12
  try:
13
+ # Device and precision configuration
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
16
+
17
+ # Authentication
18
  hf_token = os.getenv("HF_TOKEN")
19
  if hf_token:
20
  login(token=hf_token)
 
 
 
 
 
 
 
 
 
21
 
22
+ # Text-to-image model
23
+ logger.info(f"Loading text-to-image model on {device} with {dtype}")
24
  text_to_image = StableDiffusionPipeline.from_pretrained(
25
  "runwayml/stable-diffusion-v1-5",
26
+ torch_dtype=dtype,
27
  use_safetensors=True,
28
+ safety_checker=None
 
 
29
  )
30
+ text_to_image = text_to_image.to(device)
31
 
32
+ if device.type == "cuda":
 
 
33
  text_to_image.enable_xformers_memory_efficient_attention()
 
34
  else:
 
35
  text_to_image.enable_attention_slicing()
36
 
37
+ # Image-to-video model
38
+ logger.info(f"Loading video model on {device} with {dtype}")
39
  image_to_video = DiffusionPipeline.from_pretrained(
40
  "cerspense/zeroscope_v2_576w",
41
+ torch_dtype=dtype
 
42
  )
43
+ image_to_video = image_to_video.to(device)
44
+
45
+ if device.type == "cuda":
 
46
  image_to_video.enable_xformers_memory_efficient_attention()
 
47
  else:
 
48
  image_to_video.enable_attention_slicing()
49
 
50
+ return text_to_image, image_to_video, None
 
51
 
52
  except Exception as e:
53
+ logger.error(f"Model load failed: {str(e)}")
54
+ raise