SameerArz commited on
Commit
f7c50b9
·
verified ·
1 Parent(s): 51f5fc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -25
app.py CHANGED
@@ -4,7 +4,7 @@ from groq import Groq
4
  import os
5
  import json
6
  import torch
7
- from diffusers import AutoPipelineForText2Image
8
 
9
  # Get Groq API key from environment variables (set in Space settings)
10
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
@@ -14,23 +14,14 @@ if not GROQ_API_KEY:
14
  # Initialize Groq client
15
  client = Groq(api_key=GROQ_API_KEY)
16
 
17
- # Set up device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- # Load two open-access models
21
- realism_pipeline = AutoPipelineForText2Image.from_pretrained(
22
- "runwayml/stable-diffusion-v1-5",
23
- torch_dtype=torch.float16, # Faster on GPU
24
- safety_checker=None, # Disable NSFW filter (optional; comment out if unwanted)
25
- ).to(device)
26
- realism_pipeline.enable_model_cpu_offload() # Optimize memory
27
-
28
- photo_pipeline = AutoPipelineForText2Image.from_pretrained(
29
- "dreamlike-art/dreamlike-photoreal-2.0",
30
  torch_dtype=torch.float16,
31
- safety_checker=None, # Disable NSFW filter (optional)
32
  ).to(device)
33
- photo_pipeline.enable_model_cpu_offload() # Optimize memory
34
 
35
  # Function to generate tutor output (lesson, question, feedback)
36
  def generate_tutor_output(subject, difficulty, student_input):
@@ -59,21 +50,19 @@ def generate_tutor_output(subject, difficulty, student_input):
59
  )
60
  return completion.choices[0].message.content
61
 
62
- # Function to generate images
63
  def generate_images(text, selected_model):
64
  if selected_model == "Stable Diffusion (Realism)":
65
- pipeline = realism_pipeline
66
  prompt_prefix = "realistic, detailed, vivid colors, "
67
- elif selected_model == "Dreamlike Photoreal (Portraits)":
68
- pipeline = photo_pipeline
69
- prompt_prefix = "photorealistic portrait, cinematic lighting, "
70
  else:
71
  return ["Invalid model selection."] * 3
72
 
73
  results = []
74
  for i in range(3):
75
  modified_text = f"{prompt_prefix}{text} variation {i+1}, high quality"
76
- image = pipeline(modified_text, num_inference_steps=25).images[0]
77
  results.append(image)
78
  return results
79
 
@@ -108,8 +97,8 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
108
  with gr.Row():
109
  with gr.Column(scale=2):
110
  model_selector = gr.Radio(
111
- ["Stable Diffusion (Realism)", "Dreamlike Photoreal (Portraits)"],
112
- label="Select Image Generation Model",
113
  value="Stable Diffusion (Realism)"
114
  )
115
  submit_button_visual = gr.Button("Generate Visuals", variant="primary")
@@ -122,10 +111,10 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
122
  gr.Markdown("""
123
  ### How to Use
124
  1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question' to get your personalized lesson, question, and feedback.
125
- 2. **Visual Section**: Select a model and click 'Generate Visuals' to see 3 free, open-source image variations based on your input.
126
  3. Review the AI-generated content to enhance your learning!
127
 
128
- *Note*: These use free, open-access models (Stable Diffusion & Dreamlike Photoreal). GPU recommended for speed.
129
  """)
130
 
131
  def process_output_text(subject, difficulty, student_input):
 
4
  import os
5
  import json
6
  import torch
7
+ from diffusers import StableDiffusionPipeline
8
 
9
  # Get Groq API key from environment variables (set in Space settings)
10
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
14
  # Initialize Groq client
15
  client = Groq(api_key=GROQ_API_KEY)
16
 
17
+ # Set up device and load Stable Diffusion model
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ pipe = StableDiffusionPipeline.from_pretrained(
20
+ "sd-legacy/stable-diffusion-v1-5",
 
 
 
 
 
 
 
 
 
21
  torch_dtype=torch.float16,
22
+ safety_checker=None, # Optional: disable NSFW filter (comment out if unwanted)
23
  ).to(device)
24
+ pipe.enable_model_cpu_offload() # Optimize memory usage
25
 
26
  # Function to generate tutor output (lesson, question, feedback)
27
  def generate_tutor_output(subject, difficulty, student_input):
 
50
  )
51
  return completion.choices[0].message.content
52
 
53
+ # Function to generate images using sd-legacy/stable-diffusion-v1-5
54
  def generate_images(text, selected_model):
55
  if selected_model == "Stable Diffusion (Realism)":
 
56
  prompt_prefix = "realistic, detailed, vivid colors, "
57
+ elif selected_model == "Stable Diffusion (Portraits)":
58
+ prompt_prefix = "portrait, photorealistic, cinematic lighting, "
 
59
  else:
60
  return ["Invalid model selection."] * 3
61
 
62
  results = []
63
  for i in range(3):
64
  modified_text = f"{prompt_prefix}{text} variation {i+1}, high quality"
65
+ image = pipe(modified_text, num_inference_steps=25).images[0]
66
  results.append(image)
67
  return results
68
 
 
97
  with gr.Row():
98
  with gr.Column(scale=2):
99
  model_selector = gr.Radio(
100
+ ["Stable Diffusion (Realism)", "Stable Diffusion (Portraits)"],
101
+ label="Select Image Style",
102
  value="Stable Diffusion (Realism)"
103
  )
104
  submit_button_visual = gr.Button("Generate Visuals", variant="primary")
 
111
  gr.Markdown("""
112
  ### How to Use
113
  1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question' to get your personalized lesson, question, and feedback.
114
+ 2. **Visual Section**: Select an image style and click 'Generate Visuals' to see 3 variations using Stable Diffusion v1.5 (free & open-source).
115
  3. Review the AI-generated content to enhance your learning!
116
 
117
+ *Note*: Powered by sd-legacy/stable-diffusion-v1-5. GPU recommended for faster results.
118
  """)
119
 
120
  def process_output_text(subject, difficulty, student_input):