AMfeta99 commited on
Commit
e38cd3d
·
verified ·
1 Parent(s): 11cd28e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -54
app.py CHANGED
@@ -1,22 +1,40 @@
1
  from huggingface_hub import InferenceClient
2
- from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.tools import DuckDuckGoSearchResults
4
  from langchain.agents import create_react_agent, AgentExecutor
5
  from langchain_core.tools import BaseTool
6
  from pydantic import Field
7
  from PIL import Image, ImageDraw, ImageFont
8
- import tempfile
9
  import gradio as gr
10
  from io import BytesIO
11
- from typing import Optional
12
- from langchain_core.language_models.llms import LLM
13
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # === Image generation tool ===
16
  class TextToImageTool(BaseTool):
17
  name: str = "text_to_image"
18
  description: str = "Generate an image from a text prompt."
19
- client: InferenceClient = Field(exclude=True)
20
 
21
  def _run(self, prompt: str) -> Image.Image:
22
  print(f"[Tool] Generating image for prompt: {prompt}")
@@ -26,16 +44,23 @@ class TextToImageTool(BaseTool):
26
  def _arun(self, prompt: str):
27
  raise NotImplementedError("This tool does not support async.")
28
 
 
 
 
 
 
 
 
29
 
30
- # === Labeling Function ===
31
  def add_label_to_image(image, label):
32
  draw = ImageDraw.Draw(image)
33
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
34
- font_size = 30
35
  try:
36
- font = ImageFont.truetype(font_path, font_size)
37
  except:
38
  font = ImageFont.load_default()
 
39
  text_width, text_height = draw.textsize(label, font=font)
40
  position = (image.width - text_width - 20, image.height - text_height - 20)
41
  rect_position = [position[0] - 10, position[1] - 10, position[0] + text_width + 10, position[1] + text_height + 10]
@@ -43,8 +68,8 @@ def add_label_to_image(image, label):
43
  draw.text(position, label, fill="white", font=font)
44
  return image
45
 
46
-
47
  # === Prompt Generator ===
 
48
  def generate_prompts_for_object(object_name):
49
  return {
50
  "past": f"Show an old version of a {object_name} from its early days.",
@@ -52,54 +77,32 @@ def generate_prompts_for_object(object_name):
52
  "future": f"Show a futuristic version of a {object_name}, predicting future features/designs.",
53
  }
54
 
55
-
56
- # === Agent Setup ===
57
- # Set up the tools
58
- text_to_image_client = InferenceClient("m-ric/text-to-image")
59
- text_to_image_tool = TextToImageTool(client=text_to_image_client)
60
- search_tool = DuckDuckGoSearchResults()
61
-
62
- # Load a public, token-free model locally via transformers pipeline
63
- text_gen_pipeline = pipeline("text-generation", model="Qwen/Qwen2.5-72B-Instruct", max_new_tokens=512)
64
- #tiiuae/falcon-7b-instruct
65
-
66
- # Wrap pipeline into a LangChain LLM
67
- class PipelineLLM(LLM):
68
- def _call(self, prompt, stop=None):
69
- output = text_gen_pipeline(prompt)[0]["generated_text"]
70
- return output
71
-
72
- @property
73
- def _llm_type(self):
74
- return "pipeline_llm"
75
-
76
- llm = PipelineLLM()
77
-
78
- # Create agent and executor
79
- agent = create_react_agent(llm=llm, tools=[text_to_image_tool, search_tool])
80
- agent_executor = AgentExecutor(agent=agent, tools=[text_to_image_tool, search_tool], verbose=True)
81
-
82
-
83
  # === History Generator ===
 
 
 
 
 
84
  def generate_object_history(object_name: str):
85
  prompts = generate_prompts_for_object(object_name)
86
  images = []
87
- labels = {
88
- "past": f"{object_name} - Past",
89
- "present": f"{object_name} - Present",
90
- "future": f"{object_name} - Future"
91
- }
92
  for period, prompt in prompts.items():
93
- result = text_to_image_tool._run(prompt)
94
- labeled = add_label_to_image(result, labels[period])
95
- file_path = f"{object_name}_{period}.png"
96
- labeled.save(file_path)
97
- images.append((file_path, labels[period]))
98
- gif_path = f"{object_name}_evolution.gif"
99
- pil_images = [Image.open(img[0]) for img in images]
 
 
 
 
100
  pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], duration=1000, loop=0)
101
- return images, gif_path
102
 
 
103
 
104
  # === Gradio UI ===
105
  def create_gradio_interface():
@@ -117,7 +120,7 @@ def create_gradio_interface():
117
 
118
  return demo
119
 
120
-
121
  # === Launch App ===
122
- demo = create_gradio_interface()
123
- demo.launch(share=True)
 
 
1
  from huggingface_hub import InferenceClient
 
2
  from langchain_community.tools import DuckDuckGoSearchResults
3
  from langchain.agents import create_react_agent, AgentExecutor
4
  from langchain_core.tools import BaseTool
5
  from pydantic import Field
6
  from PIL import Image, ImageDraw, ImageFont
7
+ from functools import lru_cache
8
  import gradio as gr
9
  from io import BytesIO
 
 
10
  from transformers import pipeline
11
+ from langchain_core.language_models.llms import LLM
12
+ import os
13
+
14
+ # === Global Model Setup ===
15
+
16
+ # Preload image generation inference client
17
+ image_client = InferenceClient("m-ric/text-to-image")
18
+
19
+ # Preload text generation model via HuggingFace Transformers
20
+ text_gen_pipeline = pipeline("text-generation", model="Qwen/Qwen2.5-72B-Instruct", max_new_tokens=512)
21
+
22
+ # === LangChain Wrapper for Pipeline ===
23
+ class PipelineLLM(LLM):
24
+ def _call(self, prompt, stop=None):
25
+ return text_gen_pipeline(prompt)[0]["generated_text"]
26
+
27
+ @property
28
+ def _llm_type(self):
29
+ return "pipeline_llm"
30
+
31
+ llm = PipelineLLM()
32
 
33
+ # === Image Tool ===
34
  class TextToImageTool(BaseTool):
35
  name: str = "text_to_image"
36
  description: str = "Generate an image from a text prompt."
37
+ client: InferenceClient = Field(default=image_client, exclude=True)
38
 
39
  def _run(self, prompt: str) -> Image.Image:
40
  print(f"[Tool] Generating image for prompt: {prompt}")
 
44
  def _arun(self, prompt: str):
45
  raise NotImplementedError("This tool does not support async.")
46
 
47
+ # Instantiate tools
48
+ text_to_image_tool = TextToImageTool()
49
+ search_tool = DuckDuckGoSearchResults()
50
+
51
+ # Create LangChain agent
52
+ agent = create_react_agent(llm=llm, tools=[text_to_image_tool, search_tool])
53
+ agent_executor = AgentExecutor(agent=agent, tools=[text_to_image_tool, search_tool], verbose=True)
54
 
55
+ # === Utility: Add Label to Image ===
56
  def add_label_to_image(image, label):
57
  draw = ImageDraw.Draw(image)
58
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
 
59
  try:
60
+ font = ImageFont.truetype(font_path, 30)
61
  except:
62
  font = ImageFont.load_default()
63
+
64
  text_width, text_height = draw.textsize(label, font=font)
65
  position = (image.width - text_width - 20, image.height - text_height - 20)
66
  rect_position = [position[0] - 10, position[1] - 10, position[0] + text_width + 10, position[1] + text_height + 10]
 
68
  draw.text(position, label, fill="white", font=font)
69
  return image
70
 
 
71
  # === Prompt Generator ===
72
+ @lru_cache(maxsize=128)
73
  def generate_prompts_for_object(object_name):
74
  return {
75
  "past": f"Show an old version of a {object_name} from its early days.",
 
77
  "future": f"Show a futuristic version of a {object_name}, predicting future features/designs.",
78
  }
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # === History Generator ===
81
+ @lru_cache(maxsize=64)
82
+ def generate_image_for_prompt(prompt, label):
83
+ img = text_to_image_tool._run(prompt)
84
+ return add_label_to_image(img, label)
85
+
86
  def generate_object_history(object_name: str):
87
  prompts = generate_prompts_for_object(object_name)
88
  images = []
89
+ file_paths = []
90
+
 
 
 
91
  for period, prompt in prompts.items():
92
+ label = f"{object_name} - {period.capitalize()}"
93
+ labeled_image = generate_image_for_prompt(prompt, label)
94
+
95
+ file_path = f"/tmp/{object_name}_{period}.png"
96
+ labeled_image.save(file_path)
97
+ images.append((file_path, label))
98
+ file_paths.append(file_path)
99
+
100
+ # Create GIF
101
+ gif_path = f"/tmp/{object_name}_evolution.gif"
102
+ pil_images = [Image.open(p) for p in file_paths]
103
  pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], duration=1000, loop=0)
 
104
 
105
+ return images, gif_path
106
 
107
  # === Gradio UI ===
108
  def create_gradio_interface():
 
120
 
121
  return demo
122
 
 
123
  # === Launch App ===
124
+ if __name__ == "__main__":
125
+ demo = create_gradio_interface()
126
+ demo.launch(share=True)