vdcapriles commited on
Commit
7e8503c
·
verified ·
1 Parent(s): 9e0fb74

Update app.py

Browse files

Added image gen

Files changed (1) hide show
  1. app.py +37 -7
app.py CHANGED
@@ -4,19 +4,49 @@ import requests
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
7
-
 
 
 
8
  from Gradio_UI import GradioUI
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Below is an example of a tool that does nothing. Amaze us with your creativity !
11
  @tool
12
- def my_custom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return type
13
- #Keep this format for the description / args / args description but feel free to modify the tool
14
- """A tool that does nothing yet
15
  Args:
16
- arg1: the first argument
17
- arg2: the second argument
18
  """
19
- return "What magic will you build ?"
20
 
21
  @tool
22
  def get_current_time_in_timezone(timezone: str) -> str:
 
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
7
+ from diffusers import StableDiffusionPipeline
8
+ import torch
9
+ from io import BytesIO
10
+ import base64
11
  from Gradio_UI import GradioUI
12
 
13
+ class ImageGenerator:
14
+ def __init__(self, model_id="runwayml/stable-diffusion-v1-5", device="cuda" if torch.cuda.is_available() else "cpu"):
15
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
16
+ self.device = device
17
+
18
+ def generate_image(self, prompt, num_inference_steps=25, guidance_scale=7.5):
19
+ """Generates an image from a text prompt."""
20
+ image = self.pipeline(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
21
+ return image
22
+
23
+ def generate_base64_image(self, prompt, num_inference_steps=25, guidance_scale=7.5):
24
+ """Generates a base64 encoded image from a text prompt."""
25
+ image = self.generate_image(prompt, num_inference_steps, guidance_scale)
26
+ buffered = BytesIO()
27
+ image.save(buffered, format="PNG")
28
+ img_str = base64.b64encode(buffered.getvalue()).decode()
29
+ return img_str
30
+
31
+ def generate_image_tool(image_generator):
32
+ """Creates a tool function for image generation."""
33
+ def image_generation_tool(prompt):
34
+ """Generates an image from a prompt."""
35
+ return image_generator.generate_base64_image(prompt)
36
+ return image_generation_tool
37
+
38
+ # Initialize the ImageGenerator and tool
39
+ image_generator = ImageGenerator()
40
+ image_generation_tool_function = generate_image_tool(image_generator)
41
+
42
  # Below is an example of a tool that does nothing. Amaze us with your creativity !
43
  @tool
44
+ def generate_image_from_prompt(prompt: str) -> str:
45
+ """Generates an image from a text prompt.
 
46
  Args:
47
+ prompt: The text prompt to generate the image from.
 
48
  """
49
+ return image_generation_tool_function(prompt)
50
 
51
  @tool
52
  def get_current_time_in_timezone(timezone: str) -> str: