Maverick98 commited on
Commit
02011a6
·
verified ·
1 Parent(s): d478ff8
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -74,23 +74,22 @@ model.load_state_dict(new_state_dict)
74
  # Load the tokenizer from Jina
75
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
76
 
77
- def load_image(image):
78
- """
79
- Preprocess the uploaded image.
80
- """
 
 
 
81
  image = transform(image)
82
  image = image.unsqueeze(0) # Add batch dimension
83
  return image
84
 
85
- def predict(image, title, threshold=0.7):
86
- """
87
- Predict the top 3 categories for the given image and title.
88
- Includes "Others" if the confidence of the top prediction is below the threshold.
89
- """
90
  # Preprocess the image
91
- image = load_image(image)
92
 
93
- # Tokenize the title
94
  title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
95
  input_ids = title_encoding['input_ids']
96
  attention_mask = title_encoding['attention_mask']
@@ -118,9 +117,9 @@ def predict(image, title, threshold=0.7):
118
  return results
119
 
120
  # Define the Gradio interface
121
- title_input = gr.inputs.Textbox(label="Product Title", placeholder="Enter the product title here...")
122
- image_input = gr.inputs.Image(type="pil", label="Upload Image")
123
- output = gr.outputs.JSON(label="Top 3 Predictions with Probabilities")
124
 
125
  gr.Interface(
126
  fn=predict,
 
74
  # Load the tokenizer from Jina
75
  tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en")
76
 
77
+ def load_image(image_path_or_url):
78
+ if isinstance(image_path_or_url, str) and image_path_or_url.startswith("http"):
79
+ with urllib.request.urlopen(image_path_or_url) as url:
80
+ image = Image.open(url).convert('RGB')
81
+ else:
82
+ image = Image.open(image_path_or_url).convert('RGB')
83
+
84
  image = transform(image)
85
  image = image.unsqueeze(0) # Add batch dimension
86
  return image
87
 
88
+ def predict(image_path_or_file, title, threshold=0.7):
 
 
 
 
89
  # Preprocess the image
90
+ image = load_image(image_path_or_file)
91
 
92
+ # Tokenize title
93
  title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt')
94
  input_ids = title_encoding['input_ids']
95
  attention_mask = title_encoding['attention_mask']
 
117
  return results
118
 
119
  # Define the Gradio interface
120
+ title_input = gr.Textbox(label="Product Title", placeholder="Enter the product title here...")
121
+ image_input = gr.Image(type="filepath", label="Upload Image or Provide URL")
122
+ output = gr.JSON(label="Top 3 Predictions with Probabilities")
123
 
124
  gr.Interface(
125
  fn=predict,