SameerArz commited on
Commit
5021bda
·
verified ·
1 Parent(s): d8788c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -6,7 +6,7 @@ import json
6
  import torch
7
  from diffusers import StableDiffusionPipeline
8
 
9
- # Get Groq API key from environment variables (set in Space settings)
10
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
11
  if not GROQ_API_KEY:
12
  raise ValueError("Please set GROQ_API_KEY in the Space settings under 'Variables'.")
@@ -14,15 +14,20 @@ if not GROQ_API_KEY:
14
  # Initialize Groq client
15
  client = Groq(api_key=GROQ_API_KEY)
16
 
17
- # Set up device and load the exact model you specified
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
- model_id = "sd-legacy/stable-diffusion-v1-5" # Your exact model ID
20
- pipe = StableDiffusionPipeline.from_pretrained(
21
- model_id,
22
- torch_dtype=torch.float16, # Matches your snippet
23
- safety_checker=None, # Optional: disabled for flexibility
24
- ).to(device)
25
- pipe.enable_model_cpu_offload() # Optimize memory for Spaces
 
 
 
 
 
26
 
27
  # Function to generate tutor output (lesson, question, feedback)
28
  def generate_tutor_output(subject, difficulty, student_input):
@@ -51,21 +56,28 @@ def generate_tutor_output(subject, difficulty, student_input):
51
  )
52
  return completion.choices[0].message.content
53
 
54
- # Function to generate images using your model ID
55
  def generate_images(text, selected_model):
 
56
  if selected_model == "Stable Diffusion (Realism)":
57
  prompt_prefix = "realistic, detailed, vivid colors, "
58
  elif selected_model == "Stable Diffusion (Portraits)":
59
  prompt_prefix = "portrait, photorealistic, cinematic lighting, "
60
  else:
 
61
  return ["Invalid model selection."] * 3
62
 
63
  results = []
64
- for i in range(3):
65
- modified_text = f"{prompt_prefix}{text} variation {i+1}, high quality"
66
- image = pipe(modified_text, num_inference_steps=25).images[0] # Using your model
67
- results.append(image)
68
- return results
 
 
 
 
 
69
 
70
  # Gradio interface
71
  with gr.Blocks(title="AI Tutor with Visuals") as demo:
@@ -112,7 +124,7 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
112
  gr.Markdown("""
113
  ### How to Use
114
  1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question'.
115
- 2. **Visual Section**: Select a style and click 'Generate Visuals' to see 3 images from sd-legacy/stable-diffusion-v1-5.
116
  3. Review the AI-generated content!
117
 
118
  *Example*: Try "a photo of an astronaut riding a horse on mars" for visuals.
@@ -131,7 +143,8 @@ with gr.Blocks(title="AI Tutor with Visuals") as demo:
131
  images = generate_images(text, selected_model)
132
  return images[0], images[1], images[2]
133
  except Exception as e:
134
- return None, None, None
 
135
 
136
  submit_button_text.click(
137
  fn=process_output_text,
 
6
  import torch
7
  from diffusers import StableDiffusionPipeline
8
 
9
+ # Get Groq API key from environment variables
10
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
11
  if not GROQ_API_KEY:
12
  raise ValueError("Please set GROQ_API_KEY in the Space settings under 'Variables'.")
 
14
  # Initialize Groq client
15
  client = Groq(api_key=GROQ_API_KEY)
16
 
17
+ # Set up device and load the model
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Device available: {device}, CUDA: {torch.cuda.is_available()}")
20
+ model_id = "runwayml/stable-diffusion-v1-5" # Updated to a widely supported model
21
+ try:
22
+ pipe = StableDiffusionPipeline.from_pretrained(
23
+ model_id,
24
+ torch_dtype=torch.float16,
25
+ safety_checker=None, # Optional: disabled for flexibility
26
+ ).to(device)
27
+ pipe.enable_model_cpu_offload() # Optimize memory usage
28
+ print(f"Model loaded successfully on {device}: {model_id}")
29
+ except Exception as e:
30
+ raise ValueError(f"Failed to load model {model_id}: {str(e)}")
31
 
32
  # Function to generate tutor output (lesson, question, feedback)
33
  def generate_tutor_output(subject, difficulty, student_input):
 
56
  )
57
  return completion.choices[0].message.content
58
 
59
+ # Function to generate images
60
  def generate_images(text, selected_model):
61
+ print(f"Generating images for text: {text}, model: {selected_model}")
62
  if selected_model == "Stable Diffusion (Realism)":
63
  prompt_prefix = "realistic, detailed, vivid colors, "
64
  elif selected_model == "Stable Diffusion (Portraits)":
65
  prompt_prefix = "portrait, photorealistic, cinematic lighting, "
66
  else:
67
+ print("Invalid model selection")
68
  return ["Invalid model selection."] * 3
69
 
70
  results = []
71
+ try:
72
+ for i in range(3):
73
+ modified_text = f"{prompt_prefix}{text} variation {i+1}, high quality"
74
+ print(f"Generating image {i+1} with prompt: {modified_text}")
75
+ image = pipe(modified_text, num_inference_steps=25).images[0]
76
+ results.append(image)
77
+ return results
78
+ except Exception as e:
79
+ print(f"Error in image generation: {str(e)}")
80
+ return [f"Error: {str(e)}"] * 3
81
 
82
  # Gradio interface
83
  with gr.Blocks(title="AI Tutor with Visuals") as demo:
 
124
  gr.Markdown("""
125
  ### How to Use
126
  1. **Text Section**: Select a subject and difficulty, type your query, and click 'Generate Lesson & Question'.
127
+ 2. **Visual Section**: Select a style and click 'Generate Visuals' to see 3 images.
128
  3. Review the AI-generated content!
129
 
130
  *Example*: Try "a photo of an astronaut riding a horse on mars" for visuals.
 
143
  images = generate_images(text, selected_model)
144
  return images[0], images[1], images[2]
145
  except Exception as e:
146
+ error_msg = f"Visual generation failed: {str(e)}"
147
+ return error_msg, error_msg, error_msg
148
 
149
  submit_button_text.click(
150
  fn=process_output_text,