Mansuba commited on
Commit
6024488
·
verified ·
1 Parent(s): d47dd8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -54
app.py CHANGED
@@ -1,5 +1,4 @@
1
 
2
-
3
  import torch
4
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
5
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
@@ -11,6 +10,7 @@ import json
11
  import logging
12
  from dataclasses import dataclass
13
  import gc
 
14
 
15
  # Configure logging
16
  logging.basicConfig(
@@ -30,6 +30,10 @@ class ModelCache:
30
  def __init__(self, cache_dir: Path):
31
  self.cache_dir = cache_dir
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
33
 
34
  def load_model(self, model_id: str, load_func: callable, cache_name: str) -> Any:
35
  try:
@@ -48,18 +52,32 @@ class EnhancedBanglaSDGenerator:
48
  ):
49
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
  logger.info(f"Using device: {self.device}")
51
-
 
 
 
 
52
  self.cache = ModelCache(Path(cache_dir))
53
  self._initialize_models(banglaclip_weights_path)
54
  self._load_context_data()
55
 
 
 
 
 
 
 
 
 
 
 
56
  def _initialize_models(self, banglaclip_weights_path: str):
57
  try:
58
  # Initialize translation models
59
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
60
  self.translator = self.cache.load_model(
61
  self.bn2en_model_name,
62
- MarianMTModel.from_pretrained,
63
  "translator"
64
  ).to(self.device)
65
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
@@ -71,7 +89,7 @@ class EnhancedBanglaSDGenerator:
71
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
72
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
73
 
74
- # Initialize Stable Diffusion with optimizations
75
  self._initialize_stable_diffusion()
76
 
77
  except Exception as e:
@@ -79,45 +97,53 @@ class EnhancedBanglaSDGenerator:
79
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
80
 
81
  def _initialize_stable_diffusion(self):
82
- """Initialize Stable Diffusion pipeline with CPU performance optimizations."""
83
- self.pipe = self.cache.load_model(
84
- "runwayml/stable-diffusion-v1-5",
85
- lambda model_id: StableDiffusionPipeline.from_pretrained(
86
- model_id,
87
- torch_dtype=torch.float32,
88
- safety_checker=None,
89
- use_safetensors=True,
90
- use_memory_efficient_attention=True,
91
- local_files_only=True
92
- ),
93
- "stable_diffusion"
94
- )
95
 
96
- # Optimize scheduler
97
- self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
98
- self.pipe.scheduler.config,
99
- use_karras_sigmas=True,
100
- algorithm_type="dpmsolver++"
101
- )
 
102
 
103
- # CPU optimizations
104
- self.pipe.enable_attention_slicing(slice_size=1)
105
- self.pipe.enable_vae_slicing()
106
- self.pipe.enable_sequential_cpu_offload()
107
-
108
- # Component-level optimizations
109
- for component in [self.pipe.text_encoder, self.pipe.vae, self.pipe.unet]:
110
- if hasattr(component, 'enable_model_cpu_offload'):
111
- component.enable_model_cpu_offload()
112
-
113
- self.pipe = self.pipe.to(self.device)
 
 
 
 
114
 
115
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
116
  try:
117
  if not Path(weights_path).exists():
118
  raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
119
 
120
- clip_model = CLIPModel.from_pretrained(self.clip_model_name)
 
 
 
121
  state_dict = torch.load(weights_path, map_location=self.device)
122
 
123
  cleaned_state_dict = {
@@ -152,22 +178,12 @@ class EnhancedBanglaSDGenerator:
152
  inputs = self.trans_tokenizer(bangla_text, return_tensors="pt", padding=True)
153
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
154
 
155
- with torch.no_grad():
156
  outputs = self.translator.generate(**inputs)
157
 
158
  translated = self.trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
159
  return translated
160
 
161
- def _get_text_embedding(self, text: str):
162
- """Get text embedding from BanglaCLIP model."""
163
- inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
164
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
165
-
166
- with torch.no_grad():
167
- outputs = self.banglaclip_model.get_text_features(**inputs)
168
-
169
- return outputs
170
-
171
  def generate_image(
172
  self,
173
  bangla_text: str,
@@ -182,16 +198,15 @@ class EnhancedBanglaSDGenerator:
182
  if config.seed is not None:
183
  torch.manual_seed(config.seed)
184
 
185
- enhanced_prompt = self._enhance_prompt(bangla_text)
186
- negative_prompt = self._get_negative_prompt()
187
-
188
- # Pre-generation optimization
189
- torch.set_num_threads(max(4, torch.get_num_threads()))
190
  gc.collect()
191
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
192
 
193
- # Memory-optimized generation
194
- with torch.inference_mode():
 
 
 
195
  result = self.pipe(
196
  prompt=enhanced_prompt,
197
  negative_prompt=negative_prompt,
@@ -202,7 +217,7 @@ class EnhancedBanglaSDGenerator:
202
  use_memory_efficient_cross_attention=True
203
  )
204
 
205
- # Post-generation cleanup
206
  gc.collect()
207
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
208
 
@@ -337,5 +352,9 @@ def create_gradio_interface():
337
  return demo
338
 
339
  if __name__ == "__main__":
 
 
 
 
340
  demo = create_gradio_interface()
341
  demo.queue().launch(share=True)
 
1
 
 
2
  import torch
3
  from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
4
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
10
  import logging
11
  from dataclasses import dataclass
12
  import gc
13
+ import os
14
 
15
  # Configure logging
16
  logging.basicConfig(
 
30
  def __init__(self, cache_dir: Path):
31
  self.cache_dir = cache_dir
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
33
+
34
+ # Set environment variables for better memory management
35
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
36
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
37
 
38
  def load_model(self, model_id: str, load_func: callable, cache_name: str) -> Any:
39
  try:
 
52
  ):
53
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  logger.info(f"Using device: {self.device}")
55
+
56
+ # Set memory split for VRAM usage on CPU
57
+ self.memory_split = 0.5 # Use 50% of available VRAM
58
+ self.setup_memory_management()
59
+
60
  self.cache = ModelCache(Path(cache_dir))
61
  self._initialize_models(banglaclip_weights_path)
62
  self._load_context_data()
63
 
64
+ def setup_memory_management(self):
65
+ """Setup optimal memory management for CPU and VRAM"""
66
+ if torch.cuda.is_available():
67
+ total_memory = torch.cuda.get_device_properties(0).total_memory
68
+ torch.cuda.set_per_process_memory_fraction(self.memory_split)
69
+
70
+ # Optimize CPU memory
71
+ torch.set_num_threads(min(8, os.cpu_count() or 4))
72
+ torch.set_num_interop_threads(min(8, os.cpu_count() or 4))
73
+
74
  def _initialize_models(self, banglaclip_weights_path: str):
75
  try:
76
  # Initialize translation models
77
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
78
  self.translator = self.cache.load_model(
79
  self.bn2en_model_name,
80
+ lambda x: MarianMTModel.from_pretrained(x, low_cpu_mem_usage=True),
81
  "translator"
82
  ).to(self.device)
83
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
 
89
  self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
90
  self.tokenizer = AutoTokenizer.from_pretrained(self.bangla_text_model)
91
 
92
+ # Initialize Stable Diffusion
93
  self._initialize_stable_diffusion()
94
 
95
  except Exception as e:
 
97
  raise RuntimeError(f"Failed to initialize models: {str(e)}")
98
 
99
  def _initialize_stable_diffusion(self):
100
+ """Initialize Stable Diffusion pipeline with optimized settings."""
101
+ try:
102
+ self.pipe = self.cache.load_model(
103
+ "runwayml/stable-diffusion-v1-5",
104
+ lambda model_id: StableDiffusionPipeline.from_pretrained(
105
+ model_id,
106
+ torch_dtype=torch.float32,
107
+ safety_checker=None,
108
+ use_safetensors=True,
109
+ low_cpu_mem_usage=True,
110
+ ),
111
+ "stable_diffusion"
112
+ )
113
 
114
+ # Optimize scheduler for speed
115
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
116
+ self.pipe.scheduler.config,
117
+ use_karras_sigmas=True,
118
+ algorithm_type="dpmsolver++",
119
+ solver_order=2
120
+ )
121
 
122
+ # Memory optimizations
123
+ self.pipe.enable_attention_slicing(slice_size=1)
124
+ self.pipe.enable_vae_slicing()
125
+ self.pipe.enable_sequential_cpu_offload()
126
+
127
+ # VRAM optimization
128
+ if torch.cuda.is_available():
129
+ torch.cuda.empty_cache()
130
+ self.pipe.enable_model_cpu_offload()
131
+
132
+ self.pipe = self.pipe.to(self.device)
133
+
134
+ except Exception as e:
135
+ logger.error(f"Error initializing Stable Diffusion: {str(e)}")
136
+ raise
137
 
138
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
139
  try:
140
  if not Path(weights_path).exists():
141
  raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
142
 
143
+ clip_model = CLIPModel.from_pretrained(
144
+ self.clip_model_name,
145
+ low_cpu_mem_usage=True
146
+ )
147
  state_dict = torch.load(weights_path, map_location=self.device)
148
 
149
  cleaned_state_dict = {
 
178
  inputs = self.trans_tokenizer(bangla_text, return_tensors="pt", padding=True)
179
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
180
 
181
+ with torch.no_grad(), torch.cpu.amp.autocast():
182
  outputs = self.translator.generate(**inputs)
183
 
184
  translated = self.trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
185
  return translated
186
 
 
 
 
 
 
 
 
 
 
 
187
  def generate_image(
188
  self,
189
  bangla_text: str,
 
198
  if config.seed is not None:
199
  torch.manual_seed(config.seed)
200
 
201
+ # Clear memory before generation
 
 
 
 
202
  gc.collect()
203
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
204
 
205
+ enhanced_prompt = self._enhance_prompt(bangla_text)
206
+ negative_prompt = self._get_negative_prompt()
207
+
208
+ # Use mixed precision for faster generation
209
+ with torch.inference_mode(), torch.cpu.amp.autocast():
210
  result = self.pipe(
211
  prompt=enhanced_prompt,
212
  negative_prompt=negative_prompt,
 
217
  use_memory_efficient_cross_attention=True
218
  )
219
 
220
+ # Clear memory after generation
221
  gc.collect()
222
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
223
 
 
352
  return demo
353
 
354
  if __name__ == "__main__":
355
+ # Set environment variables for better performance
356
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
357
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
358
+
359
  demo = create_gradio_interface()
360
  demo.queue().launch(share=True)