vibs08 commited on
Commit
1b3bb2f
·
verified ·
1 Parent(s): a3395fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -88
app.py CHANGED
@@ -20,6 +20,8 @@ from PIL import Image
20
  from functools import partial
21
  import io
22
 
 
 
23
 
24
  subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
25
 
@@ -34,6 +36,9 @@ if torch.cuda.is_available():
34
  else:
35
  device = "cpu"
36
 
 
 
 
37
  model = TSR.from_pretrained(
38
  "stabilityai/TripoSR",
39
  config_name="config.yaml",
@@ -47,17 +52,30 @@ ACCESS = os.getenv("ACCESS")
47
  SECRET = os.getenv("SECRET")
48
  bedrock = boto3.client(service_name='bedrock', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
49
  bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
50
- # def generate_image_from_text(pos_prompt):
51
- # # bedrock_runtime = boto3.client(region_name = 'us-east-1', service_name='bedrock-runtime')
52
- # parameters = {'text_prompts': [{'text': pos_prompt , 'weight':1},
53
- # {'text': """Blurry, out of frame, out of focus, Detailed, dull, duplicate, bad quality, low resolution, cropped""", 'weight': -1}],
54
- # 'cfg_scale': 7, 'seed': 0, 'samples': 1}
55
- # request_body = json.dumps(parameters)
56
- # response = bedrock_runtime.invoke_model(body=request_body,modelId = 'stability.stable-diffusion-xl-v1')
57
- # response_body = json.loads(response.get('body').read())
58
- # base64_image_data = base64.b64decode(response_body['artifacts'][0]['base64'])
59
 
60
- # return Image.open(io.BytesIO(base64_image_data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def gen_pos_prompt(text):
@@ -164,88 +182,47 @@ def generate(image, mc_resolution, formats=["obj", "glb"]):
164
  return mesh_path_obj.name, mesh_path_glb.name
165
 
166
  def run_example(text_prompt,seed ,do_remove_background, foreground_ratio, mc_resolution):
167
- # Step 1: Generate the image from text prompt
168
  image_pil = generate_image_from_text(text_prompt, seed)
169
-
170
- # Step 2: Preprocess the image
171
  preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio)
172
 
173
- # Step 3: Generate the 3D model
174
  mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution, ["obj", "glb"])
175
 
176
  return preprocessed, mesh_name_obj, mesh_name_glb
177
 
178
- with gr.Blocks() as demo:
179
- gr.Markdown(HEADER)
180
- with gr.Row(variant="panel"):
181
- with gr.Column():
182
- with gr.Row():
183
- text_prompt = gr.Textbox(
184
- label="Text Prompt",
185
- placeholder="Enter a text prompt for image generation"
186
- )
187
- input_image = gr.Image(
188
- label="Generated Image",
189
- image_mode="RGBA",
190
- sources="upload",
191
- type="pil",
192
- elem_id="content_image",
193
- visible=False # Hidden since we generate the image from text
194
- )
195
- seed = gr.Number(value=0)
196
- processed_image = gr.Image(label="Processed Image", interactive=False, visible=False)
197
- with gr.Row():
198
- with gr.Group():
199
- do_remove_background = gr.Checkbox(
200
- label="Remove Background", value=True
201
- )
202
- foreground_ratio = gr.Slider(
203
- label="Foreground Ratio",
204
- minimum=0.5,
205
- maximum=1.0,
206
- value=0.85,
207
- step=0.05,
208
- )
209
- mc_resolution = gr.Slider(
210
- label="Marching Cubes Resolution",
211
- minimum=32,
212
- maximum=320,
213
- value=256,
214
- step=32
215
- )
216
- with gr.Row():
217
- submit = gr.Button("Generate", elem_id="generate", variant="primary")
218
- with gr.Column():
219
- with gr.Tab("OBJ"):
220
- output_model_obj = gr.Model3D(
221
- label="Output Model (OBJ Format)",
222
- interactive=False,
223
- )
224
- gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
225
- with gr.Tab("GLB"):
226
- output_model_glb = gr.Model3D(
227
- label="Output Model (GLB Format)",
228
- interactive=False,
229
- )
230
- gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
231
- # with gr.Row(variant="panel"):
232
- # gr.Examples(
233
- # examples=[
234
- # os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
235
- # ],
236
- # inputs=[text_prompt],
237
- # outputs=[processed_image, output_model_obj, output_model_glb],
238
- # cache_examples=True,
239
- # fn=partial(run_example, do_remove_background=True, foreground_ratio=0.85, mc_resolution=256),
240
- # label="Examples",
241
- # examples_per_page=20
242
- # )
243
- submit.click(fn=check_input_image, inputs=[text_prompt]).success(
244
- fn=run_example,
245
- inputs=[text_prompt, seed, do_remove_background, foreground_ratio, mc_resolution],
246
- outputs=[processed_image, output_model_obj, output_model_glb],
247
- # outputs=[output_model_obj, output_model_glb],
248
- )
249
-
250
- demo.queue(max_size=10)
251
- demo.launch(auth=(os.getenv('USERNAME'), os.getenv('PASSWORD')))
 
20
  from functools import partial
21
  import io
22
 
23
+ app = FastAPI()
24
+
25
 
26
  subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
27
 
 
36
  else:
37
  device = "cpu"
38
 
39
+ torch.cuda.synchronize()
40
+
41
+
42
  model = TSR.from_pretrained(
43
  "stabilityai/TripoSR",
44
  config_name="config.yaml",
 
52
  SECRET = os.getenv("SECRET")
53
  bedrock = boto3.client(service_name='bedrock', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
54
  bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
55
+
56
+ def upload_file_to_s3(file_path, bucket_name, object_name=None):
57
+ s3_client = boto3.client('s3',aws_access_key_id = ACCESS, aws_secret_access_key = SECRET, region_name='us-east-1')
 
 
 
 
 
 
58
 
59
+ if object_name is None:
60
+ object_name = file_path
61
+
62
+ try:
63
+ s3_client.upload_file(file_path, bucket_name, object_name)
64
+ except FileNotFoundError:
65
+ print(f"The file {file_path} was not found.")
66
+ return False
67
+ except NoCredentialsError:
68
+ print("Credentials not available.")
69
+ return False
70
+ except PartialCredentialsError:
71
+ print("Incomplete credentials provided.")
72
+ return False
73
+ except Exception as e:
74
+ print(f"An error occurred: {e}")
75
+ return False
76
+
77
+ print(f"File {file_path} uploaded successfully to {bucket_name}/{object_name}.")
78
+ return True
79
 
80
 
81
  def gen_pos_prompt(text):
 
182
  return mesh_path_obj.name, mesh_path_glb.name
183
 
184
  def run_example(text_prompt,seed ,do_remove_background, foreground_ratio, mc_resolution):
 
185
  image_pil = generate_image_from_text(text_prompt, seed)
186
+
 
187
  preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio)
188
 
 
189
  mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution, ["obj", "glb"])
190
 
191
  return preprocessed, mesh_name_obj, mesh_name_glb
192
 
193
+
194
+ @app.post("/process_text/")
195
+ async def process_image(
196
+ text_prompt: str = Form(...),
197
+ seed: int = Form(...),
198
+ do_remove_background: bool = Form(...),
199
+ foreground_ratio: float = Form(...),
200
+ mc_resolution: int = Form(...),
201
+ auth: str = Form(...)
202
+ ):
203
+
204
+ if auth == os.getenv("AUTHORIZE"):
205
+
206
+ preprocessed, mesh_name_obj, mesh_name_glb = run_example(text_prompt,seed ,do_remove_background, foreground_ratio, mc_resolution)
207
+ # preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio)
208
+ # mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution)
209
+ timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
210
+ object_name = f'object_{timestamp}_1.obj'
211
+ object_name_2 = f'object_{timestamp}_2.glb'
212
+
213
+ if upload_file_to_s3(mesh_name_obj, 'framebucket3d',object_name) and upload_file_to_s3(mesh_name_glb, 'framebucket3d',object_name_2):
214
+
215
+ return {
216
+ "obj_path": f"https://framebucket3d.s3.amazonaws.com/{object_name}",
217
+ "glb_path": f"https://framebucket3d.s3.amazonaws.com/{object_name_2}"
218
+
219
+ }
220
+
221
+ else:
222
+ return {"Internal Server Error": False}
223
+ else:
224
+ return {"Authentication":"Failed"}
225
+
226
+ if __name__ == "__main__":
227
+ import uvicorn
228
+ uvicorn.run(app, host="0.0.0.0", port=7860)