vdcapriles commited on
Commit
12c51be
·
verified ·
1 Parent(s): 9df3f95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -26
app.py CHANGED
@@ -10,31 +10,6 @@ from io import BytesIO
10
  import base64
11
  from Gradio_UI import GradioUI
12
 
13
- class ImageGenerator:
14
- def __init__(self, model_id="runwayml/stable-diffusion-v1-5", device="cuda" if torch.cuda.is_available() else "cpu"):
15
- self.pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
16
- self.device = device
17
-
18
- def generate_image(self, prompt, num_inference_steps=25, guidance_scale=7.5):
19
- """Generates an image from a text prompt."""
20
- image = self.pipeline(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
21
- return image
22
-
23
- def generate_base64_image(self, prompt, num_inference_steps=25, guidance_scale=7.5):
24
- """Generates a base64 encoded image from a text prompt."""
25
- image = self.generate_image(prompt, num_inference_steps, guidance_scale)
26
- buffered = BytesIO()
27
- image.save(buffered, format="PNG")
28
- img_str = base64.b64encode(buffered.getvalue()).decode()
29
- return img_str
30
-
31
- def generate_image_tool(image_generator):
32
- """Creates a tool function for image generation."""
33
- def image_generation_tool(prompt):
34
- """Generates an image from a prompt."""
35
- return image_generator.generate_base64_image(prompt)
36
- return image_generation_tool
37
-
38
  # Initialize the ImageGenerator and tool
39
  image_generator = ImageGenerator()
40
  image_generation_tool_function = generate_image_tool(image_generator)
@@ -86,7 +61,7 @@ with open("prompts.yaml", 'r') as stream:
86
 
87
  agent = CodeAgent(
88
  model=model,
89
- tools=[final_answer, generate_image_from_prompt, get_current_time_in_timezone], ## add your tools here (don't remove final answer)
90
  max_steps=6,
91
  verbosity_level=1,
92
  grammar=None,
 
10
  import base64
11
  from Gradio_UI import GradioUI
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Initialize the ImageGenerator and tool
14
  image_generator = ImageGenerator()
15
  image_generation_tool_function = generate_image_tool(image_generator)
 
61
 
62
  agent = CodeAgent(
63
  model=model,
64
+ tools=[final_answer, get_current_time_in_timezone], ## add your tools here (don't remove final answer)
65
  max_steps=6,
66
  verbosity_level=1,
67
  grammar=None,