Update app.py
Browse files
app.py
CHANGED
@@ -6,26 +6,31 @@ import json
|
|
6 |
import torch
|
7 |
from diffusers import AutoPipelineForText2Image
|
8 |
|
9 |
-
# Get API
|
10 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
11 |
-
|
12 |
-
|
13 |
-
# Check if keys are provided
|
14 |
-
if not GROQ_API_KEY or not HF_TOKEN:
|
15 |
-
raise ValueError("Please set GROQ_API_KEY and HF_TOKEN in the Space settings under 'Variables'.")
|
16 |
|
17 |
# Initialize Groq client
|
18 |
client = Groq(api_key=GROQ_API_KEY)
|
19 |
|
20 |
-
# Set up device
|
21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
26 |
).to(device)
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
# Function to generate tutor output (lesson, question, feedback)
|
31 |
def generate_tutor_output(subject, difficulty, student_input):
|
@@ -56,17 +61,19 @@ def generate_tutor_output(subject, difficulty, student_input):
|
|
56 |
|
57 |
# Function to generate images
|
58 |
def generate_images(text, selected_model):
|
59 |
-
if selected_model == "
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
else:
|
64 |
return ["Invalid model selection."] * 3
|
65 |
|
66 |
results = []
|
67 |
for i in range(3):
|
68 |
-
modified_text = f"{prompt_prefix}{text} variation {i+1}"
|
69 |
-
image = pipeline(modified_text, num_inference_steps=
|
70 |
results.append(image)
|
71 |
return results
|
72 |
|
@@ -101,9 +108,9 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
|
|
101 |
with gr.Row():
|
102 |
with gr.Column(scale=2):
|
103 |
model_selector = gr.Radio(
|
104 |
-
["
|
105 |
label="Select Image Generation Model",
|
106 |
-
value="
|
107 |
)
|
108 |
submit_button_visual = gr.Button("Generate Visuals", variant="primary")
|
109 |
|
@@ -115,8 +122,10 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
|
|
115 |
gr.Markdown("""
|
116 |
### How to Use
|
117 |
1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question' to get your personalized lesson, question, and feedback.
|
118 |
-
2. **Visual Section**: Select a model and click 'Generate Visuals' to see 3 image variations based on your input.
|
119 |
3. Review the AI-generated content to enhance your learning!
|
|
|
|
|
120 |
""")
|
121 |
|
122 |
def process_output_text(subject, difficulty, student_input):
|
|
|
6 |
import torch
|
7 |
from diffusers import AutoPipelineForText2Image
|
8 |
|
9 |
+
# Get Groq API key from environment variables (set in Space settings)
|
10 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
11 |
+
if not GROQ_API_KEY:
|
12 |
+
raise ValueError("Please set GROQ_API_KEY in the Space settings under 'Variables'.")
|
|
|
|
|
|
|
13 |
|
14 |
# Initialize Groq client
|
15 |
client = Groq(api_key=GROQ_API_KEY)
|
16 |
|
17 |
+
# Set up device
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
# Load two open-access models
|
21 |
+
realism_pipeline = AutoPipelineForText2Image.from_pretrained(
|
22 |
+
"runwayml/stable-diffusion-v1-5",
|
23 |
+
torch_dtype=torch.float16, # Faster on GPU
|
24 |
+
safety_checker=None, # Disable NSFW filter (optional; comment out if unwanted)
|
25 |
).to(device)
|
26 |
+
realism_pipeline.enable_model_cpu_offload() # Optimize memory
|
27 |
+
|
28 |
+
photo_pipeline = AutoPipelineForText2Image.from_pretrained(
|
29 |
+
"dreamlike-art/dreamlike-photoreal-2.0",
|
30 |
+
torch_dtype=torch.float16,
|
31 |
+
safety_checker=None, # Disable NSFW filter (optional)
|
32 |
+
).to(device)
|
33 |
+
photo_pipeline.enable_model_cpu_offload() # Optimize memory
|
34 |
|
35 |
# Function to generate tutor output (lesson, question, feedback)
|
36 |
def generate_tutor_output(subject, difficulty, student_input):
|
|
|
61 |
|
62 |
# Function to generate images
|
63 |
def generate_images(text, selected_model):
|
64 |
+
if selected_model == "Stable Diffusion (Realism)":
|
65 |
+
pipeline = realism_pipeline
|
66 |
+
prompt_prefix = "realistic, detailed, vivid colors, "
|
67 |
+
elif selected_model == "Dreamlike Photoreal (Portraits)":
|
68 |
+
pipeline = photo_pipeline
|
69 |
+
prompt_prefix = "photorealistic portrait, cinematic lighting, "
|
70 |
else:
|
71 |
return ["Invalid model selection."] * 3
|
72 |
|
73 |
results = []
|
74 |
for i in range(3):
|
75 |
+
modified_text = f"{prompt_prefix}{text} variation {i+1}, high quality"
|
76 |
+
image = pipeline(modified_text, num_inference_steps=25).images[0]
|
77 |
results.append(image)
|
78 |
return results
|
79 |
|
|
|
108 |
with gr.Row():
|
109 |
with gr.Column(scale=2):
|
110 |
model_selector = gr.Radio(
|
111 |
+
["Stable Diffusion (Realism)", "Dreamlike Photoreal (Portraits)"],
|
112 |
label="Select Image Generation Model",
|
113 |
+
value="Stable Diffusion (Realism)"
|
114 |
)
|
115 |
submit_button_visual = gr.Button("Generate Visuals", variant="primary")
|
116 |
|
|
|
122 |
gr.Markdown("""
|
123 |
### How to Use
|
124 |
1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question' to get your personalized lesson, question, and feedback.
|
125 |
+
2. **Visual Section**: Select a model and click 'Generate Visuals' to see 3 free, open-source image variations based on your input.
|
126 |
3. Review the AI-generated content to enhance your learning!
|
127 |
+
|
128 |
+
*Note*: These use free, open-access models (Stable Diffusion & Dreamlike Photoreal). GPU recommended for speed.
|
129 |
""")
|
130 |
|
131 |
def process_output_text(subject, difficulty, student_input):
|