shorndrup commited on
Commit
66ed76d
·
1 Parent(s): fa75e04

Update to API script

Browse files
Files changed (2) hide show
  1. app.py +22 -27
  2. test.py +34 -0
app.py CHANGED
@@ -1,37 +1,32 @@
 
1
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
2
  from PIL import Image
3
- import torch
4
- import gradio as gr
5
 
6
- # Load model and processor
7
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
8
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
9
 
 
10
  def predict(image):
11
- # Prepare image
12
- image = Image.open(image).convert("RGB")
13
-
14
- # Define inputs (zero-shot queries)
15
- text_queries = ["A Pokémon", "Pikachu", "Bulbasaur"]
16
-
17
- # Run the model
18
  inputs = processor(text=text_queries, images=image, return_tensors="pt")
19
- with torch.no_grad():
20
- outputs = model(**inputs)
21
-
22
- # Get predictions
23
- target_sizes = torch.tensor([image.size[::-1]])
24
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.1)
25
-
26
- # Extract boxes
27
- boxes = []
28
- for score, label, box in zip(results[0]["scores"], results[0]["labels"], results[0]["boxes"]):
29
- box = [round(i, 2) for i in box.tolist()]
30
- label_text = processor.tokenizer.decode([label])
31
- boxes.append({"score": round(score.item(), 3), "label": label_text, "box": box})
32
 
33
- return boxes
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Create Gradio interface
36
- interface = gr.Interface(fn=predict, inputs="image", outputs="json")
37
- interface.launch()
 
1
+ import gradio as gr
2
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
3
  from PIL import Image
 
 
4
 
5
+ # Load the OWL-ViT model and processor
6
  processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
7
  model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
8
 
9
+ # Define the prediction function
10
  def predict(image):
11
+ image = Image.open(image)
12
+ text_queries = ["A photo of a pokemon", "a photo of a human face", "a photo of a couch"] # Example queries
13
+
14
+ # Prepare inputs for the model
 
 
 
15
  inputs = processor(text=text_queries, images=image, return_tensors="pt")
16
+ # Perform inference
17
+ outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Format the response (dummy response as example)
20
+ response = {"message": "Detection successful!"}
21
+ return response
22
+
23
+ # Create a Gradio interface and enable the queue (API mode)
24
+ interface = gr.Interface(
25
+ fn=predict,
26
+ inputs="image",
27
+ outputs="json",
28
+ allow_flagging="never"
29
+ )
30
 
31
+ # Launch the interface with API mode enabled
32
+ interface.launch(enable_queue=True)
 
test.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ # Replace with your actual Space URL
5
+ API_URL = "https:/shorndrup-owlvit_api.hf.space/predict"
6
+ IMAGE_PATH = r"C:/Users/Administrator/Downloads/pokedex/images/solrock.png"
7
+
8
+ def call_gradio_api(image_path):
9
+ # Open the image file in binary mode
10
+ with open(image_path, "rb") as image_file:
11
+ files = {"data": (image_path, image_file, "image/png")}
12
+ response = requests.post(API_URL, files=files)
13
+
14
+ # Check for errors
15
+ if response.status_code != 200:
16
+ print(f"Error: {response.status_code}")
17
+ print(response.text)
18
+ return None
19
+
20
+ # Parse the JSON response
21
+ result = response.json()
22
+ try:
23
+ # Extract the predictions from the response
24
+ predictions = result.get("data", [])[0]
25
+ if predictions:
26
+ print("Predictions:", json.dumps(predictions, indent=2))
27
+ else:
28
+ print("No predictions found.")
29
+ except (IndexError, KeyError) as e:
30
+ print(f"Error parsing response: {e}")
31
+ print(response.text)
32
+
33
+ if __name__ == "__main__":
34
+ call_gradio_api(IMAGE_PATH)