codermert commited on
Commit
562b89a
·
verified ·
1 Parent(s): 34466eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -19
app.py CHANGED
@@ -2,13 +2,19 @@ import gradio as gr
2
  import torch
3
  import os
4
  from diffusers import AutoPipelineForText2Image
5
- from huggingface_hub import snapshot_download
6
  import logging
 
7
 
8
  # Logging ayarları
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
 
 
12
  class ModelHandler:
13
  def __init__(self):
14
  self.pipeline = None
@@ -20,37 +26,45 @@ class ModelHandler:
20
  if self.pipeline is not None:
21
  return "Model zaten yüklü."
22
 
 
 
 
 
 
23
  progress(0, desc="Base model indiriliyor...")
24
  # Base modeli indir
25
- base_model_path = snapshot_download(
26
- repo_id="black-forest-labs/FLUX.1-dev",
27
- local_dir="./models/base_model",
28
- ignore_patterns=["*.bin", "*.onnx"] if os.path.exists("./models/base_model") else None,
29
- token=os.getenv("HF_TOKEN")
30
- )
 
31
 
32
  progress(0.5, desc="LoRA modeli indiriliyor...")
33
  # LoRA modelini indir
34
- lora_model_path = snapshot_download(
35
- repo_id="codermert/ezelll_flux",
36
- local_dir="./models/lora_model",
37
- ignore_patterns=["*.bin", "*.onnx"] if os.path.exists("./models/lora_model") else None,
38
- token=os.getenv("HF_TOKEN")
39
- )
 
40
 
41
  progress(0.7, desc="Pipeline oluşturuluyor...")
42
  # Pipeline'ı oluştur
43
  self.pipeline = AutoPipelineForText2Image.from_pretrained(
44
- base_model_path,
45
  torch_dtype=self.dtype,
46
- use_safetensors=True
 
47
  ).to(self.device)
48
 
49
  progress(0.9, desc="LoRA yükleniyor...")
50
  # LoRA'yı yükle
51
- lora_path = os.path.join(lora_model_path, "lora.safetensors")
52
- if os.path.exists(lora_path):
53
- self.pipeline.load_lora_weights(lora_path)
54
  else:
55
  return "LoRA dosyası bulunamadı!"
56
 
@@ -74,7 +88,9 @@ class ModelHandler:
74
  image = self.pipeline(
75
  prompt,
76
  num_inference_steps=30,
77
- guidance_scale=7.5
 
 
78
  ).images[0]
79
 
80
  progress(1.0, desc="Tamamlandı!")
 
2
  import torch
3
  import os
4
  from diffusers import AutoPipelineForText2Image
5
+ from huggingface_hub import hf_hub_download
6
  import logging
7
+ from pathlib import Path
8
 
9
  # Logging ayarları
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Model dosyaları için sabit yollar
14
+ MODEL_CACHE = Path("./model_cache")
15
+ BASE_MODEL_PATH = MODEL_CACHE / "base_model"
16
+ LORA_MODEL_PATH = MODEL_CACHE / "lora_model"
17
+
18
  class ModelHandler:
19
  def __init__(self):
20
  self.pipeline = None
 
26
  if self.pipeline is not None:
27
  return "Model zaten yüklü."
28
 
29
+ # Model cache dizinlerini oluştur
30
+ MODEL_CACHE.mkdir(exist_ok=True)
31
+ BASE_MODEL_PATH.mkdir(exist_ok=True)
32
+ LORA_MODEL_PATH.mkdir(exist_ok=True)
33
+
34
  progress(0, desc="Base model indiriliyor...")
35
  # Base modeli indir
36
+ if not (BASE_MODEL_PATH / "model_index.json").exists():
37
+ hf_hub_download(
38
+ repo_id="black-forest-labs/FLUX.1-dev",
39
+ filename="model_index.json",
40
+ local_dir=BASE_MODEL_PATH,
41
+ token=os.getenv("HF_TOKEN")
42
+ )
43
 
44
  progress(0.5, desc="LoRA modeli indiriliyor...")
45
  # LoRA modelini indir
46
+ if not (LORA_MODEL_PATH / "lora.safetensors").exists():
47
+ hf_hub_download(
48
+ repo_id="codermert/ezelll_flux",
49
+ filename="lora.safetensors",
50
+ local_dir=LORA_MODEL_PATH,
51
+ token=os.getenv("HF_TOKEN")
52
+ )
53
 
54
  progress(0.7, desc="Pipeline oluşturuluyor...")
55
  # Pipeline'ı oluştur
56
  self.pipeline = AutoPipelineForText2Image.from_pretrained(
57
+ str(BASE_MODEL_PATH),
58
  torch_dtype=self.dtype,
59
+ use_safetensors=True,
60
+ cache_dir=MODEL_CACHE
61
  ).to(self.device)
62
 
63
  progress(0.9, desc="LoRA yükleniyor...")
64
  # LoRA'yı yükle
65
+ lora_path = LORA_MODEL_PATH / "lora.safetensors"
66
+ if lora_path.exists():
67
+ self.pipeline.load_lora_weights(str(lora_path))
68
  else:
69
  return "LoRA dosyası bulunamadı!"
70
 
 
88
  image = self.pipeline(
89
  prompt,
90
  num_inference_steps=30,
91
+ guidance_scale=7.5,
92
+ width=512,
93
+ height=512
94
  ).images[0]
95
 
96
  progress(1.0, desc="Tamamlandı!")