AMfeta99 commited on
Commit
a568cb6
·
verified ·
1 Parent(s): e38cd3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -7,52 +7,54 @@ from PIL import Image, ImageDraw, ImageFont
7
  from functools import lru_cache
8
  import gradio as gr
9
  from io import BytesIO
10
- from transformers import pipeline
11
- from langchain_core.language_models.llms import LLM
12
  import os
13
 
14
- # === Global Model Setup ===
 
 
15
 
16
- # Preload image generation inference client
17
  image_client = InferenceClient("m-ric/text-to-image")
 
18
 
19
- # Preload text generation model via HuggingFace Transformers
20
- text_gen_pipeline = pipeline("text-generation", model="Qwen/Qwen2.5-72B-Instruct", max_new_tokens=512)
 
 
 
21
 
22
- # === LangChain Wrapper for Pipeline ===
23
- class PipelineLLM(LLM):
24
- def _call(self, prompt, stop=None):
25
- return text_gen_pipeline(prompt)[0]["generated_text"]
 
26
 
27
- @property
28
- def _llm_type(self):
29
- return "pipeline_llm"
30
-
31
- llm = PipelineLLM()
32
 
33
- # === Image Tool ===
34
  class TextToImageTool(BaseTool):
35
  name: str = "text_to_image"
36
  description: str = "Generate an image from a text prompt."
37
  client: InferenceClient = Field(default=image_client, exclude=True)
38
 
39
  def _run(self, prompt: str) -> Image.Image:
40
- print(f"[Tool] Generating image for prompt: {prompt}")
41
  image_bytes = self.client.text_to_image(prompt)
42
  return Image.open(BytesIO(image_bytes))
43
 
44
  def _arun(self, prompt: str):
45
- raise NotImplementedError("This tool does not support async.")
46
 
47
- # Instantiate tools
48
  text_to_image_tool = TextToImageTool()
 
49
  search_tool = DuckDuckGoSearchResults()
50
 
51
- # Create LangChain agent
52
- agent = create_react_agent(llm=llm, tools=[text_to_image_tool, search_tool])
53
  agent_executor = AgentExecutor(agent=agent, tools=[text_to_image_tool, search_tool], verbose=True)
54
 
55
- # === Utility: Add Label to Image ===
56
  def add_label_to_image(image, label):
57
  draw = ImageDraw.Draw(image)
58
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
@@ -68,7 +70,7 @@ def add_label_to_image(image, label):
68
  draw.text(position, label, fill="white", font=font)
69
  return image
70
 
71
- # === Prompt Generator ===
72
  @lru_cache(maxsize=128)
73
  def generate_prompts_for_object(object_name):
74
  return {
@@ -77,12 +79,13 @@ def generate_prompts_for_object(object_name):
77
  "future": f"Show a futuristic version of a {object_name}, predicting future features/designs.",
78
  }
79
 
80
- # === History Generator ===
81
  @lru_cache(maxsize=64)
82
  def generate_image_for_prompt(prompt, label):
83
  img = text_to_image_tool._run(prompt)
84
  return add_label_to_image(img, label)
85
 
 
86
  def generate_object_history(object_name: str):
87
  prompts = generate_prompts_for_object(object_name)
88
  images = []
@@ -120,7 +123,7 @@ def create_gradio_interface():
120
 
121
  return demo
122
 
123
- # === Launch App ===
124
  if __name__ == "__main__":
125
  demo = create_gradio_interface()
126
  demo.launch(share=True)
 
7
  from functools import lru_cache
8
  import gradio as gr
9
  from io import BytesIO
 
 
10
  import os
11
 
12
+ # === Setup Inference Clients ===
13
+ # Use your Hugging Face token if necessary:
14
+ # client = InferenceClient(repo_id="model", token="YOUR_HF_TOKEN")
15
 
 
16
  image_client = InferenceClient("m-ric/text-to-image")
17
+ text_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")
18
 
19
+ # === LangChain wrapper using InferenceClient for text generation ===
20
+ class InferenceClientLLM(BaseTool):
21
+ name: str = "inference_text_generator"
22
+ description: str = "Generate text using HF Inference API."
23
+ client: InferenceClient = Field(default=text_client, exclude=True)
24
 
25
+ def _run(self, prompt: str) -> str:
26
+ print(f"[LLM] Generating text for prompt: {prompt}")
27
+ response = self.client.text_generation(prompt)
28
+ # response is usually a dict with 'generated_text'
29
+ return response.get("generated_text", "")
30
 
31
+ def _arun(self, prompt: str):
32
+ raise NotImplementedError("Async not supported.")
 
 
 
33
 
34
+ # === Image generation tool ===
35
  class TextToImageTool(BaseTool):
36
  name: str = "text_to_image"
37
  description: str = "Generate an image from a text prompt."
38
  client: InferenceClient = Field(default=image_client, exclude=True)
39
 
40
  def _run(self, prompt: str) -> Image.Image:
41
+ print(f"[Image Tool] Generating image for prompt: {prompt}")
42
  image_bytes = self.client.text_to_image(prompt)
43
  return Image.open(BytesIO(image_bytes))
44
 
45
  def _arun(self, prompt: str):
46
+ raise NotImplementedError("Async not supported.")
47
 
48
+ # === Initialize tools ===
49
  text_to_image_tool = TextToImageTool()
50
+ text_gen_tool = InferenceClientLLM()
51
  search_tool = DuckDuckGoSearchResults()
52
 
53
+ # === Create agent ===
54
+ agent = create_react_agent(llm=text_gen_tool, tools=[text_to_image_tool, search_tool])
55
  agent_executor = AgentExecutor(agent=agent, tools=[text_to_image_tool, search_tool], verbose=True)
56
 
57
+ # === Image labeling ===
58
  def add_label_to_image(image, label):
59
  draw = ImageDraw.Draw(image)
60
  font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
 
70
  draw.text(position, label, fill="white", font=font)
71
  return image
72
 
73
+ # === Prompt generation with caching ===
74
  @lru_cache(maxsize=128)
75
  def generate_prompts_for_object(object_name):
76
  return {
 
79
  "future": f"Show a futuristic version of a {object_name}, predicting future features/designs.",
80
  }
81
 
82
+ # === Cache generated images ===
83
  @lru_cache(maxsize=64)
84
  def generate_image_for_prompt(prompt, label):
85
  img = text_to_image_tool._run(prompt)
86
  return add_label_to_image(img, label)
87
 
88
+ # === Main generation function ===
89
  def generate_object_history(object_name: str):
90
  prompts = generate_prompts_for_object(object_name)
91
  images = []
 
123
 
124
  return demo
125
 
126
+ # === Launch app ===
127
  if __name__ == "__main__":
128
  demo = create_gradio_interface()
129
  demo.launch(share=True)