trashchenkov commited on
Commit
2e24ab7
·
verified ·
1 Parent(s): 1ae054c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -23,6 +23,7 @@ MAX_IMAGE_SIZE = 1024
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
 
26
  prompt,
27
  negative_prompt,
28
  seed,
@@ -33,7 +34,10 @@ def infer(
33
  progress=gr.Progress(track_tqdm=True),
34
  ):
35
 
36
-
 
 
 
37
  generator = torch.Generator().manual_seed(seed)
38
 
39
  image = pipe(
@@ -61,10 +65,22 @@ css = """
61
  max-width: 640px;
62
  }
63
  """
64
-
 
 
 
 
 
 
 
65
  with gr.Blocks(css=css) as demo:
66
  with gr.Column(elem_id="col-container"):
67
  gr.Markdown(" # Text-to-Image Gradio Template")
 
 
 
 
 
68
  prompt = gr.Text(
69
  label="Prompt",
70
  show_label=False,
 
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
+ model,
27
  prompt,
28
  negative_prompt,
29
  seed,
 
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
 
37
+ global model_repo_id
38
+ if model != model_repo_id:
39
+ pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype)
40
+ pipe = pipe.to(device)
41
  generator = torch.Generator().manual_seed(seed)
42
 
43
  image = pipe(
 
65
  max-width: 640px;
66
  }
67
  """
68
+ available_models = [
69
+ "CompVis/stable-diffusion-v1-4",
70
+ "stabilityai/sdxl-turbo"
71
+ "runwayml/stable-diffusion-v1-5",
72
+ "stabilityai/stable-diffusion-2-1",
73
+ "prompthero/openjourney",
74
+
75
+ ]
76
  with gr.Blocks(css=css) as demo:
77
  with gr.Column(elem_id="col-container"):
78
  gr.Markdown(" # Text-to-Image Gradio Template")
79
+ model = gr.Dropdown(
80
+ label="Model Selection",
81
+ choices=available_models,
82
+ value="CompVis/stable-diffusion-v1-4",
83
+ )
84
  prompt = gr.Text(
85
  label="Prompt",
86
  show_label=False,