Update app.py
Browse files
app.py
CHANGED
@@ -2,13 +2,11 @@ import torch
|
|
2 |
import os
|
3 |
import requests
|
4 |
import logging
|
5 |
-
import gc
|
6 |
from pathlib import Path
|
7 |
-
|
8 |
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
|
9 |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
10 |
import gradio as gr
|
11 |
-
from typing import List, Tuple, Optional,
|
12 |
from dataclasses import dataclass
|
13 |
|
14 |
# Configure logging
|
@@ -25,7 +23,7 @@ def download_model(model_url: str, model_path: str):
|
|
25 |
logger.info(f"Downloading model from {model_url}...")
|
26 |
response = requests.get(model_url, stream=True)
|
27 |
response.raise_for_status()
|
28 |
-
|
29 |
total_size = int(response.headers.get('content-length', 0))
|
30 |
block_size = 1024 * 1024 # 1 MB chunks
|
31 |
downloaded_size = 0
|
@@ -78,7 +76,6 @@ class EnhancedBanglaSDGenerator:
|
|
78 |
def _load_banglaclip_model(self):
|
79 |
"""Load BanglaCLIP model from Hugging Face directly"""
|
80 |
try:
|
81 |
-
# Use the Hugging Face model directly instead of local weight file
|
82 |
model = CLIPModel.from_pretrained("Mansuba/BanglaCLIP13")
|
83 |
return model.to(self.device)
|
84 |
except Exception as e:
|
@@ -113,7 +110,9 @@ class EnhancedBanglaSDGenerator:
|
|
113 |
logger.error(f"Error initializing models: {str(e)}")
|
114 |
raise RuntimeError(f"Failed to initialize models: {str(e)}")
|
115 |
|
116 |
-
|
|
|
|
|
117 |
|
118 |
def create_gradio_interface():
|
119 |
"""Create and configure the Gradio interface."""
|
@@ -156,8 +155,20 @@ def create_gradio_interface():
|
|
156 |
cleanup_generator()
|
157 |
return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
|
158 |
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
if __name__ == "__main__":
|
162 |
demo = create_gradio_interface()
|
163 |
-
demo.queue().launch(share=True, debug=True)
|
|
|
2 |
import os
|
3 |
import requests
|
4 |
import logging
|
|
|
5 |
from pathlib import Path
|
|
|
6 |
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, MarianMTModel, MarianTokenizer
|
7 |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
8 |
import gradio as gr
|
9 |
+
from typing import List, Tuple, Optional, Any
|
10 |
from dataclasses import dataclass
|
11 |
|
12 |
# Configure logging
|
|
|
23 |
logger.info(f"Downloading model from {model_url}...")
|
24 |
response = requests.get(model_url, stream=True)
|
25 |
response.raise_for_status()
|
26 |
+
|
27 |
total_size = int(response.headers.get('content-length', 0))
|
28 |
block_size = 1024 * 1024 # 1 MB chunks
|
29 |
downloaded_size = 0
|
|
|
76 |
def _load_banglaclip_model(self):
|
77 |
"""Load BanglaCLIP model from Hugging Face directly"""
|
78 |
try:
|
|
|
79 |
model = CLIPModel.from_pretrained("Mansuba/BanglaCLIP13")
|
80 |
return model.to(self.device)
|
81 |
except Exception as e:
|
|
|
110 |
logger.error(f"Error initializing models: {str(e)}")
|
111 |
raise RuntimeError(f"Failed to initialize models: {str(e)}")
|
112 |
|
113 |
+
def _initialize_stable_diffusion(self):
|
114 |
+
"""Load and initialize Stable Diffusion pipeline"""
|
115 |
+
pass # Your existing code for initializing Stable Diffusion
|
116 |
|
117 |
def create_gradio_interface():
|
118 |
"""Create and configure the Gradio interface."""
|
|
|
155 |
cleanup_generator()
|
156 |
return None, f"ছবি তৈরি ব্যর্থ হয়েছে: {str(e)}"
|
157 |
|
158 |
+
with gr.Blocks() as demo:
|
159 |
+
text_input = gr.Textbox(label="Text", placeholder="Enter your prompt here...")
|
160 |
+
num_images_input = gr.Slider(minimum=1, maximum=5, value=1, label="Number of Images")
|
161 |
+
steps_input = gr.Slider(minimum=1, maximum=100, value=50, label="Steps")
|
162 |
+
guidance_scale_input = gr.Slider(minimum=1, maximum=20, value=7.5, label="Guidance Scale")
|
163 |
+
seed_input = gr.Number(label="Seed", optional=True)
|
164 |
+
|
165 |
+
output_images = gr.Gallery(label="Generated Images")
|
166 |
+
|
167 |
+
generate_button = gr.Button("Generate Images")
|
168 |
+
generate_button.click(generate_images, inputs=[text_input, num_images_input, steps_input, guidance_scale_input, seed_input], outputs=[output_images])
|
169 |
+
|
170 |
+
return demo
|
171 |
|
172 |
if __name__ == "__main__":
|
173 |
demo = create_gradio_interface()
|
174 |
+
demo.queue().launch(share=True, debug=True)
|