SameerArz commited on
Commit
5496590
·
verified ·
1 Parent(s): 0dd22f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -6,26 +6,31 @@ import json
6
  import torch
7
  from diffusers import AutoPipelineForText2Image
8
 
9
- # Get API keys from environment variables (set in Space settings)
10
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
11
- HF_TOKEN = os.getenv("HF_TOKEN")
12
-
13
- # Check if keys are provided
14
- if not GROQ_API_KEY or not HF_TOKEN:
15
- raise ValueError("Please set GROQ_API_KEY and HF_TOKEN in the Space settings under 'Variables'.")
16
 
17
  # Initialize Groq client
18
  client = Groq(api_key=GROQ_API_KEY)
19
 
20
- # Set up device and image generation pipeline
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
- pipeline = AutoPipelineForText2Image.from_pretrained(
23
- "black-forest-labs/FLUX.1-dev",
24
- torch_dtype=torch.bfloat16,
25
- use_auth_token=HF_TOKEN
 
 
26
  ).to(device)
27
- pipeline.load_lora_weights("Purz/face-projection", weight_name="purz-f4c3_p40j3ct10n.safetensors")
28
- pipeline.enable_model_cpu_offload() # Optimize memory usage
 
 
 
 
 
 
29
 
30
  # Function to generate tutor output (lesson, question, feedback)
31
  def generate_tutor_output(subject, difficulty, student_input):
@@ -56,17 +61,19 @@ def generate_tutor_output(subject, difficulty, student_input):
56
 
57
  # Function to generate images
58
  def generate_images(text, selected_model):
59
- if selected_model == "Model 1 (Turbo Realism)":
60
- prompt_prefix = "realistic, high detail, "
61
- elif selected_model == "Model 2 (Face Projection)":
62
- prompt_prefix = "f4c3_p40j3ct10n, projection on a face, "
 
 
63
  else:
64
  return ["Invalid model selection."] * 3
65
 
66
  results = []
67
  for i in range(3):
68
- modified_text = f"{prompt_prefix}{text} variation {i+1}"
69
- image = pipeline(modified_text, num_inference_steps=20).images[0]
70
  results.append(image)
71
  return results
72
 
@@ -101,9 +108,9 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
101
  with gr.Row():
102
  with gr.Column(scale=2):
103
  model_selector = gr.Radio(
104
- ["Model 1 (Turbo Realism)", "Model 2 (Face Projection)"],
105
  label="Select Image Generation Model",
106
- value="Model 1 (Turbo Realism)"
107
  )
108
  submit_button_visual = gr.Button("Generate Visuals", variant="primary")
109
 
@@ -115,8 +122,10 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
115
  gr.Markdown("""
116
  ### How to Use
117
  1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question' to get your personalized lesson, question, and feedback.
118
- 2. **Visual Section**: Select a model and click 'Generate Visuals' to see 3 image variations based on your input.
119
  3. Review the AI-generated content to enhance your learning!
 
 
120
  """)
121
 
122
  def process_output_text(subject, difficulty, student_input):
 
6
  import torch
7
  from diffusers import AutoPipelineForText2Image
8
 
9
+ # Get Groq API key from environment variables (set in Space settings)
10
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
11
+ if not GROQ_API_KEY:
12
+ raise ValueError("Please set GROQ_API_KEY in the Space settings under 'Variables'.")
 
 
 
13
 
14
  # Initialize Groq client
15
  client = Groq(api_key=GROQ_API_KEY)
16
 
17
+ # Set up device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # Load two open-access models
21
+ realism_pipeline = AutoPipelineForText2Image.from_pretrained(
22
+ "runwayml/stable-diffusion-v1-5",
23
+ torch_dtype=torch.float16, # Faster on GPU
24
+ safety_checker=None, # Disable NSFW filter (optional; comment out if unwanted)
25
  ).to(device)
26
+ realism_pipeline.enable_model_cpu_offload() # Optimize memory
27
+
28
+ photo_pipeline = AutoPipelineForText2Image.from_pretrained(
29
+ "dreamlike-art/dreamlike-photoreal-2.0",
30
+ torch_dtype=torch.float16,
31
+ safety_checker=None, # Disable NSFW filter (optional)
32
+ ).to(device)
33
+ photo_pipeline.enable_model_cpu_offload() # Optimize memory
34
 
35
  # Function to generate tutor output (lesson, question, feedback)
36
  def generate_tutor_output(subject, difficulty, student_input):
 
61
 
62
  # Function to generate images
63
  def generate_images(text, selected_model):
64
+ if selected_model == "Stable Diffusion (Realism)":
65
+ pipeline = realism_pipeline
66
+ prompt_prefix = "realistic, detailed, vivid colors, "
67
+ elif selected_model == "Dreamlike Photoreal (Portraits)":
68
+ pipeline = photo_pipeline
69
+ prompt_prefix = "photorealistic portrait, cinematic lighting, "
70
  else:
71
  return ["Invalid model selection."] * 3
72
 
73
  results = []
74
  for i in range(3):
75
+ modified_text = f"{prompt_prefix}{text} variation {i+1}, high quality"
76
+ image = pipeline(modified_text, num_inference_steps=25).images[0]
77
  results.append(image)
78
  return results
79
 
 
108
  with gr.Row():
109
  with gr.Column(scale=2):
110
  model_selector = gr.Radio(
111
+ ["Stable Diffusion (Realism)", "Dreamlike Photoreal (Portraits)"],
112
  label="Select Image Generation Model",
113
+ value="Stable Diffusion (Realism)"
114
  )
115
  submit_button_visual = gr.Button("Generate Visuals", variant="primary")
116
 
 
122
  gr.Markdown("""
123
  ### How to Use
124
  1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question' to get your personalized lesson, question, and feedback.
125
+ 2. **Visual Section**: Select a model and click 'Generate Visuals' to see 3 free, open-source image variations based on your input.
126
  3. Review the AI-generated content to enhance your learning!
127
+
128
+ *Note*: These use free, open-access models (Stable Diffusion & Dreamlike Photoreal). GPU recommended for speed.
129
  """)
130
 
131
  def process_output_text(subject, difficulty, student_input):