BenkHel commited on
Commit
51b1272
·
verified ·
1 Parent(s): 6d02f81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -41
app.py CHANGED
@@ -307,44 +307,4 @@ with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo:
307
  demo.queue(
308
  status_update_rate=10,
309
  api_open=False
310
- ).launch()tokenizer, model, image_processor, context_len = load_pretrained_model(
311
- model_path, model_base, model_name, load_8bit, load_4bit, device, use_flash_attn=False
312
- )
313
-
314
-
315
- PROMPT = "What material is this item and how is it disposed of?"
316
- PROMPT_WITH_IMAGE = f"{DEFAULT_IMAGE_TOKEN} {PROMPT}"
317
-
318
- @spaces.GPU
319
- def classify_image(image):
320
- if image is None:
321
- return "Please upload an image."
322
- if not isinstance(image, Image.Image):
323
- image = Image.fromarray(image)
324
- images = process_images([image], image_processor, model.config)
325
- images = [img.to(device, dtype=torch.float16) for img in images]
326
- image_args = {"images": images}
327
- input_ids = tokenizer_image_token(PROMPT_WITH_IMAGE, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
328
- with torch.no_grad():
329
- outputs = model.generate(
330
- inputs=input_ids,
331
- max_new_tokens=128,
332
- pad_token_id=tokenizer.eos_token_id,
333
- **image_args
334
- )
335
- output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
336
- answer = output_text[len(PROMPT):].strip() if output_text.startswith(PROMPT) else output_text
337
- return answer
338
-
339
- iface = gr.Interface(
340
- fn=classify_image,
341
- inputs=gr.Image(type="pil", label="Upload an image of a waste item"),
342
- outputs=gr.Textbox(label="Classification & Disposal Recommendation"),
343
- title="CuMo Waste Classifier",
344
- description="Upload a photo of a household waste item. The model will classify the material and recommend how to dispose of it."
345
- )
346
-
347
- if __name__ == "__main__":
348
- iface.launch()
349
-
350
-
 
307
  demo.queue(
308
  status_update_rate=10,
309
  api_open=False
310
+ ).launch()