Mansuba commited on
Commit
5ae1545
·
verified ·
1 Parent(s): 6024488

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -76
app.py CHANGED
@@ -10,7 +10,6 @@ import json
10
  import logging
11
  from dataclasses import dataclass
12
  import gc
13
- import os
14
 
15
  # Configure logging
16
  logging.basicConfig(
@@ -30,10 +29,6 @@ class ModelCache:
30
  def __init__(self, cache_dir: Path):
31
  self.cache_dir = cache_dir
32
  self.cache_dir.mkdir(parents=True, exist_ok=True)
33
-
34
- # Set environment variables for better memory management
35
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
36
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
37
 
38
  def load_model(self, model_id: str, load_func: callable, cache_name: str) -> Any:
39
  try:
@@ -52,32 +47,18 @@ class EnhancedBanglaSDGenerator:
52
  ):
53
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  logger.info(f"Using device: {self.device}")
55
-
56
- # Set memory split for VRAM usage on CPU
57
- self.memory_split = 0.5 # Use 50% of available VRAM
58
- self.setup_memory_management()
59
-
60
  self.cache = ModelCache(Path(cache_dir))
61
  self._initialize_models(banglaclip_weights_path)
62
  self._load_context_data()
63
 
64
- def setup_memory_management(self):
65
- """Setup optimal memory management for CPU and VRAM"""
66
- if torch.cuda.is_available():
67
- total_memory = torch.cuda.get_device_properties(0).total_memory
68
- torch.cuda.set_per_process_memory_fraction(self.memory_split)
69
-
70
- # Optimize CPU memory
71
- torch.set_num_threads(min(8, os.cpu_count() or 4))
72
- torch.set_num_interop_threads(min(8, os.cpu_count() or 4))
73
-
74
  def _initialize_models(self, banglaclip_weights_path: str):
75
  try:
76
  # Initialize translation models
77
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
78
  self.translator = self.cache.load_model(
79
  self.bn2en_model_name,
80
- lambda x: MarianMTModel.from_pretrained(x, low_cpu_mem_usage=True),
81
  "translator"
82
  ).to(self.device)
83
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
@@ -98,52 +79,34 @@ class EnhancedBanglaSDGenerator:
98
 
99
  def _initialize_stable_diffusion(self):
100
  """Initialize Stable Diffusion pipeline with optimized settings."""
101
- try:
102
- self.pipe = self.cache.load_model(
103
- "runwayml/stable-diffusion-v1-5",
104
- lambda model_id: StableDiffusionPipeline.from_pretrained(
105
- model_id,
106
- torch_dtype=torch.float32,
107
- safety_checker=None,
108
- use_safetensors=True,
109
- low_cpu_mem_usage=True,
110
- ),
111
- "stable_diffusion"
112
- )
113
 
114
- # Optimize scheduler for speed
115
- self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
116
- self.pipe.scheduler.config,
117
- use_karras_sigmas=True,
118
- algorithm_type="dpmsolver++",
119
- solver_order=2
120
- )
121
 
122
- # Memory optimizations
123
- self.pipe.enable_attention_slicing(slice_size=1)
124
- self.pipe.enable_vae_slicing()
125
  self.pipe.enable_sequential_cpu_offload()
126
-
127
- # VRAM optimization
128
- if torch.cuda.is_available():
129
- torch.cuda.empty_cache()
130
- self.pipe.enable_model_cpu_offload()
131
-
132
- self.pipe = self.pipe.to(self.device)
133
-
134
- except Exception as e:
135
- logger.error(f"Error initializing Stable Diffusion: {str(e)}")
136
- raise
137
 
138
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
139
  try:
140
  if not Path(weights_path).exists():
141
  raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
142
 
143
- clip_model = CLIPModel.from_pretrained(
144
- self.clip_model_name,
145
- low_cpu_mem_usage=True
146
- )
147
  state_dict = torch.load(weights_path, map_location=self.device)
148
 
149
  cleaned_state_dict = {
@@ -178,12 +141,22 @@ class EnhancedBanglaSDGenerator:
178
  inputs = self.trans_tokenizer(bangla_text, return_tensors="pt", padding=True)
179
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
180
 
181
- with torch.no_grad(), torch.cpu.amp.autocast():
182
  outputs = self.translator.generate(**inputs)
183
 
184
  translated = self.trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
185
  return translated
186
 
 
 
 
 
 
 
 
 
 
 
187
  def generate_image(
188
  self,
189
  bangla_text: str,
@@ -198,29 +171,18 @@ class EnhancedBanglaSDGenerator:
198
  if config.seed is not None:
199
  torch.manual_seed(config.seed)
200
 
201
- # Clear memory before generation
202
- gc.collect()
203
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
204
-
205
  enhanced_prompt = self._enhance_prompt(bangla_text)
206
  negative_prompt = self._get_negative_prompt()
207
 
208
- # Use mixed precision for faster generation
209
- with torch.inference_mode(), torch.cpu.amp.autocast():
210
  result = self.pipe(
211
  prompt=enhanced_prompt,
212
  negative_prompt=negative_prompt,
213
  num_images_per_prompt=config.num_images,
214
  num_inference_steps=config.num_inference_steps,
215
- guidance_scale=config.guidance_scale,
216
- use_memory_efficient_attention=True,
217
- use_memory_efficient_cross_attention=True
218
  )
219
 
220
- # Clear memory after generation
221
- gc.collect()
222
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
223
-
224
  return result.images, enhanced_prompt
225
 
226
  except Exception as e:
@@ -231,10 +193,12 @@ class EnhancedBanglaSDGenerator:
231
  """Enhance prompt with context and style information."""
232
  translated_text = self._translate_text(bangla_text)
233
 
 
234
  contexts = []
235
  contexts.extend(context for loc, context in self.location_contexts.items() if loc in bangla_text)
236
  contexts.extend(context for scene, context in self.scene_contexts.items() if scene in bangla_text)
237
 
 
238
  photo_style = [
239
  "professional photography",
240
  "high resolution",
@@ -244,6 +208,7 @@ class EnhancedBanglaSDGenerator:
244
  "beautiful composition"
245
  ]
246
 
 
247
  all_parts = [translated_text] + contexts + photo_style
248
  return ", ".join(dict.fromkeys(all_parts))
249
 
@@ -352,9 +317,6 @@ def create_gradio_interface():
352
  return demo
353
 
354
  if __name__ == "__main__":
355
- # Set environment variables for better performance
356
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
357
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
358
-
359
  demo = create_gradio_interface()
 
360
  demo.queue().launch(share=True)
 
10
  import logging
11
  from dataclasses import dataclass
12
  import gc
 
13
 
14
  # Configure logging
15
  logging.basicConfig(
 
29
  def __init__(self, cache_dir: Path):
30
  self.cache_dir = cache_dir
31
  self.cache_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
32
 
33
  def load_model(self, model_id: str, load_func: callable, cache_name: str) -> Any:
34
  try:
 
47
  ):
48
  self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  logger.info(f"Using device: {self.device}")
50
+
 
 
 
 
51
  self.cache = ModelCache(Path(cache_dir))
52
  self._initialize_models(banglaclip_weights_path)
53
  self._load_context_data()
54
 
 
 
 
 
 
 
 
 
 
 
55
  def _initialize_models(self, banglaclip_weights_path: str):
56
  try:
57
  # Initialize translation models
58
  self.bn2en_model_name = "Helsinki-NLP/opus-mt-bn-en"
59
  self.translator = self.cache.load_model(
60
  self.bn2en_model_name,
61
+ MarianMTModel.from_pretrained,
62
  "translator"
63
  ).to(self.device)
64
  self.trans_tokenizer = MarianTokenizer.from_pretrained(self.bn2en_model_name)
 
79
 
80
  def _initialize_stable_diffusion(self):
81
  """Initialize Stable Diffusion pipeline with optimized settings."""
82
+ self.pipe = self.cache.load_model(
83
+ "runwayml/stable-diffusion-v1-5",
84
+ lambda model_id: StableDiffusionPipeline.from_pretrained(
85
+ model_id,
86
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
87
+ safety_checker=None
88
+ ),
89
+ "stable_diffusion"
90
+ )
 
 
 
91
 
92
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
93
+ self.pipe.scheduler.config,
94
+ use_karras_sigmas=True,
95
+ algorithm_type="dpmsolver++"
96
+ )
97
+ self.pipe = self.pipe.to(self.device)
 
98
 
99
+ # Memory optimization
100
+ self.pipe.enable_attention_slicing()
101
+ if torch.cuda.is_available():
102
  self.pipe.enable_sequential_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  def _load_banglaclip_model(self, weights_path: str) -> CLIPModel:
105
  try:
106
  if not Path(weights_path).exists():
107
  raise FileNotFoundError(f"BanglaCLIP weights not found at {weights_path}")
108
 
109
+ clip_model = CLIPModel.from_pretrained(self.clip_model_name)
 
 
 
110
  state_dict = torch.load(weights_path, map_location=self.device)
111
 
112
  cleaned_state_dict = {
 
141
  inputs = self.trans_tokenizer(bangla_text, return_tensors="pt", padding=True)
142
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
143
 
144
+ with torch.no_grad():
145
  outputs = self.translator.generate(**inputs)
146
 
147
  translated = self.trans_tokenizer.decode(outputs[0], skip_special_tokens=True)
148
  return translated
149
 
150
+ def _get_text_embedding(self, text: str):
151
+ """Get text embedding from BanglaCLIP model."""
152
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
153
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
154
+
155
+ with torch.no_grad():
156
+ outputs = self.banglaclip_model.get_text_features(**inputs)
157
+
158
+ return outputs
159
+
160
  def generate_image(
161
  self,
162
  bangla_text: str,
 
171
  if config.seed is not None:
172
  torch.manual_seed(config.seed)
173
 
 
 
 
 
174
  enhanced_prompt = self._enhance_prompt(bangla_text)
175
  negative_prompt = self._get_negative_prompt()
176
 
177
+ with torch.autocast(self.device.type):
 
178
  result = self.pipe(
179
  prompt=enhanced_prompt,
180
  negative_prompt=negative_prompt,
181
  num_images_per_prompt=config.num_images,
182
  num_inference_steps=config.num_inference_steps,
183
+ guidance_scale=config.guidance_scale
 
 
184
  )
185
 
 
 
 
 
186
  return result.images, enhanced_prompt
187
 
188
  except Exception as e:
 
193
  """Enhance prompt with context and style information."""
194
  translated_text = self._translate_text(bangla_text)
195
 
196
+ # Gather contexts
197
  contexts = []
198
  contexts.extend(context for loc, context in self.location_contexts.items() if loc in bangla_text)
199
  contexts.extend(context for scene, context in self.scene_contexts.items() if scene in bangla_text)
200
 
201
+ # Add photo style
202
  photo_style = [
203
  "professional photography",
204
  "high resolution",
 
208
  "beautiful composition"
209
  ]
210
 
211
+ # Combine all parts
212
  all_parts = [translated_text] + contexts + photo_style
213
  return ", ".join(dict.fromkeys(all_parts))
214
 
 
317
  return demo
318
 
319
  if __name__ == "__main__":
 
 
 
 
320
  demo = create_gradio_interface()
321
+ # Fixed queue configuration for newer Gradio versions
322
  demo.queue().launch(share=True)