IAMTFRMZA commited on
Commit
8eea10a
·
verified ·
1 Parent(s): 9b91377

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -60
app.py CHANGED
@@ -9,11 +9,16 @@ import rembg
9
  import torch
10
  from PIL import Image
11
  from functools import partial
 
 
 
 
12
 
13
  from tsr.system import TSR
14
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
 
16
- #HF_TOKEN = os.getenv("HF_TOKEN")
 
17
 
18
  HEADER = """
19
  **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
@@ -23,32 +28,29 @@ HEADER = """
23
  2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
24
  """
25
 
26
-
27
- if torch.cuda.is_available():
28
- device = "cuda:0"
29
- else:
30
- device = "cpu"
31
-
32
- d = os.environ.get("DEVICE", None)
33
- if d != None:
34
- device = d
35
-
36
- model = TSR.from_pretrained(
37
- "stabilityai/TripoSR",
38
- config_name="config.yaml",
39
- weight_name="model.ckpt",
40
- # token=HF_TOKEN
41
- )
42
- model.renderer.set_chunk_size(131072)
43
- model.to(device)
44
-
45
- rembg_session = rembg.new_session()
46
-
47
-
48
- def check_input_image(input_image):
49
- if input_image is None:
50
- raise gr.Error("No image uploaded!")
51
-
52
 
53
  def preprocess(input_image, do_remove_background, foreground_ratio):
54
  def fill_background(image):
@@ -68,7 +70,6 @@ def preprocess(input_image, do_remove_background, foreground_ratio):
68
  image = fill_background(image)
69
  return image
70
 
71
-
72
  def generate(image):
73
  scene_codes = model(image, device=device)
74
  mesh = model.extract_mesh(scene_codes)[0]
@@ -79,23 +80,52 @@ def generate(image):
79
  mesh.export(mesh_path2.name)
80
  return mesh_path.name, mesh_path2.name
81
 
82
- def run_example(image_pil):
83
- preprocessed = preprocess(image_pil, False, 0.9)
84
- mesh_name, mesn_name2 = generate(preprocessed)
85
- return preprocessed, mesh_name, mesh_name2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Blocks() as demo:
88
  gr.Markdown(HEADER)
89
  with gr.Row(variant="panel"):
90
  with gr.Column():
91
  with gr.Row():
92
- input_image = gr.Image(
93
- label="Input Image",
94
- image_mode="RGBA",
95
- sources="upload",
96
- type="pil",
97
- elem_id="content_image",
98
- )
99
  processed_image = gr.Image(label="Processed Image", interactive=False)
100
  with gr.Row():
101
  with gr.Group():
@@ -122,27 +152,8 @@ with gr.Blocks() as demo:
122
  label="Output Model",
123
  interactive=False,
124
  )
125
- with gr.Row(variant="panel"):
126
- gr.Examples(
127
- examples=[
128
- os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
129
- ],
130
- inputs=[input_image],
131
- outputs=[processed_image, output_model, output_model2],
132
- #cache_examples=True,
133
- fn=partial(run_example),
134
- label="Examples",
135
- examples_per_page=20
136
- )
137
- submit.click(fn=check_input_image, inputs=[input_image]).success(
138
- fn=preprocess,
139
- inputs=[input_image, do_remove_background, foreground_ratio],
140
- outputs=[processed_image],
141
- ).success(
142
- fn=generate,
143
- inputs=[processed_image],
144
- outputs=[output_model, output_model2],
145
- )
146
 
147
  demo.queue(max_size=10)
148
  demo.launch()
 
 
9
  import torch
10
  from PIL import Image
11
  from functools import partial
12
+ from serpapi import GoogleSearch
13
+ import requests
14
+ from io import BytesIO
15
+ import matplotlib.pyplot as plt
16
 
17
  from tsr.system import TSR
18
  from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
19
 
20
+ # Set your SerpApi key here
21
+ SERPAPI_KEY = "YOUR_SERPAPI_KEY"
22
 
23
  HEADER = """
24
  **TripoSR** is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
 
28
  2. Please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
29
  """
30
 
31
+ def get_motorcycle_image(make, model):
32
+ params = {
33
+ "api_key": SERPAPI_KEY,
34
+ "engine": "google",
35
+ "q": f"{make} {model} motorcycle product photo",
36
+ "tbm": "isch"
37
+ }
38
+
39
+ search = GoogleSearch(params)
40
+ results = search.get_dict()
41
+ if "images_results" in results:
42
+ first_image = results["images_results"][0]
43
+ image_url = first_image.get("original")
44
+ if image_url:
45
+ image_response = requests.get(image_url)
46
+ image = Image.open(BytesIO(image_response.content))
47
+ return image
48
+ else:
49
+ print("Image URL not found in results.")
50
+ return None
51
+ else:
52
+ print("No image results found.")
53
+ return None
 
 
 
54
 
55
  def preprocess(input_image, do_remove_background, foreground_ratio):
56
  def fill_background(image):
 
70
  image = fill_background(image)
71
  return image
72
 
 
73
  def generate(image):
74
  scene_codes = model(image, device=device)
75
  mesh = model.extract_mesh(scene_codes)[0]
 
80
  mesh.export(mesh_path2.name)
81
  return mesh_path.name, mesh_path2.name
82
 
83
+ def run_example(make, model):
84
+ image = get_motorcycle_image(make, model)
85
+ if image:
86
+ # Save the image
87
+ input_image_path = '/content/motorcycle.jpg'
88
+ image.save(input_image_path)
89
+
90
+ # Load the image
91
+ img = Image.open(input_image_path)
92
+ output_image_path = '/content/motorcyclebg.png'
93
+ img_no_bg = rembg_remove(img)
94
+ img_no_bg.save(output_image_path)
95
+
96
+ # Preprocess and generate 3D model
97
+ preprocessed = preprocess(img_no_bg, False, 0.9)
98
+ mesh_name, mesh_name2 = generate(preprocessed)
99
+ return preprocessed, mesh_name, mesh_name2
100
+ else:
101
+ raise gr.Error("Image could not be fetched.")
102
+
103
+ if torch.cuda.is_available():
104
+ device = "cuda:0"
105
+ else:
106
+ device = "cpu"
107
+
108
+ d = os.environ.get("DEVICE", None)
109
+ if d != None:
110
+ device = d
111
+
112
+ model = TSR.from_pretrained(
113
+ "stabilityai/TripoSR",
114
+ config_name="config.yaml",
115
+ weight_name="model.ckpt",
116
+ )
117
+ model.renderer.set_chunk_size(131072)
118
+ model.to(device)
119
+
120
+ rembg_session = rembg.new_session()
121
 
122
  with gr.Blocks() as demo:
123
  gr.Markdown(HEADER)
124
  with gr.Row(variant="panel"):
125
  with gr.Column():
126
  with gr.Row():
127
+ make_input = gr.Textbox(label="Motorcycle Make", placeholder="Enter motorcycle make")
128
+ model_input = gr.Textbox(label="Motorcycle Model", placeholder="Enter motorcycle model")
 
 
 
 
 
129
  processed_image = gr.Image(label="Processed Image", interactive=False)
130
  with gr.Row():
131
  with gr.Group():
 
152
  label="Output Model",
153
  interactive=False,
154
  )
155
+ submit.click(fn=run_example, inputs=[make_input, model_input], outputs=[processed_image, output_model, output_model2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  demo.queue(max_size=10)
158
  demo.launch()
159
+