|
import argparse |
|
import logging |
|
import os |
|
import wandb |
|
import gradio as gr |
|
|
|
import zipfile |
|
import pickle |
|
from pathlib import Path |
|
from typing import List, Any, Dict |
|
from PIL import Image |
|
from pathlib import Path |
|
|
|
from transformers import AutoTokenizer |
|
from sentence_transformers import SentenceTransformer, util |
|
from multilingual_clip import pt_multilingual_clip |
|
import torch |
|
|
|
from pathlib import Path |
|
from typing import Callable, Dict, List, Tuple |
|
from PIL.Image import Image |
|
|
|
print(__file__) |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "" |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
DEFAULT_APPLICATION_NAME = "FashGen" |
|
|
|
APP_DIR = Path(__file__).resolve().parent |
|
README = APP_DIR / "README.md" |
|
|
|
DEFAULT_PORT = 11700 |
|
|
|
EMBEDDINGS_DIR = "artifacts/img-embeddings" |
|
EMBEDDINGS_FILE = os.path.join(EMBEDDINGS_DIR, "embeddings.pkl") |
|
RAW_PHOTOS_DIR = "artifacts/raw-photos" |
|
|
|
|
|
wandb.login(key="4b5a23a662b20fdd61f2aeb5032cf56fdce278a4") |
|
api = wandb.Api() |
|
artifact_embeddings = api.artifact("ryparmar/fashion-aggregator/unimoda-images:v1") |
|
artifact_embeddings.download(EMBEDDINGS_DIR) |
|
artifact_raw_photos = api.artifact("ryparmar/fashion-aggregator/unimoda-raw-images:v1") |
|
artifact_raw_photos.download("artifacts") |
|
|
|
with zipfile.ZipFile("artifacts/unimoda.zip", 'r') as zip_ref: |
|
zip_ref.extractall(RAW_PHOTOS_DIR) |
|
|
|
|
|
class TextEncoder: |
|
"""Encodes the given text""" |
|
|
|
def __init__(self, model_path="M-CLIP/XLM-Roberta-Large-Vit-B-32"): |
|
self.model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_path) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
@torch.no_grad() |
|
def encode(self, query: str) -> torch.Tensor: |
|
"""Predict/infer text embedding for a given query.""" |
|
query_emb = self.model.forward([query], self.tokenizer) |
|
return query_emb |
|
|
|
|
|
class ImageEnoder: |
|
"""Encodes the given image""" |
|
|
|
def __init__(self, model_path="clip-ViT-B-32"): |
|
self.model = SentenceTransformer(model_path) |
|
|
|
@torch.no_grad() |
|
def encode(self, image: Image) -> torch.Tensor: |
|
"""Predict/infer text embedding for a given query.""" |
|
image_emb = self.model.encode([image], convert_to_tensor=True, show_progress_bar=False) |
|
return image_emb |
|
|
|
|
|
class Retriever: |
|
"""Retrieves relevant images for a given text embedding.""" |
|
|
|
def __init__(self, image_embeddings_path=None): |
|
self.text_encoder = TextEncoder() |
|
self.image_encoder = ImageEnoder() |
|
|
|
with open(image_embeddings_path, "rb") as file: |
|
self.image_names, self.image_embeddings = pickle.load(file) |
|
self.image_names = [ |
|
img_name.replace("fashion-aggregator/fashion_aggregator/data/photos/", "") |
|
for img_name in self.image_names |
|
] |
|
print("Images:", len(self.image_names)) |
|
|
|
@torch.no_grad() |
|
def predict(self, text_query: str, k: int = 10) -> List[Any]: |
|
"""Return top-k relevant items for a given embedding""" |
|
query_emb = self.text_encoder.encode(text_query) |
|
relevant_images = util.semantic_search(query_emb, self.image_embeddings, top_k=k)[0] |
|
return relevant_images |
|
|
|
@torch.no_grad() |
|
def search_images(self, text_query: str, k: int = 6) -> Dict[str, List[Any]]: |
|
"""Return top-k relevant images for a given embedding""" |
|
images = self.predict(text_query, k) |
|
paths_and_scores = {"path": [], "score": []} |
|
for img in images: |
|
paths_and_scores["path"].append(os.path.join(RAW_PHOTOS_DIR, self.image_names[img["corpus_id"]])) |
|
paths_and_scores["score"].append(img["score"]) |
|
return paths_and_scores |
|
|
|
|
|
def main(args): |
|
predictor = PredictorBackend(url=args.model_url) |
|
frontend = make_frontend(predictor.run, flagging=args.flagging, gantry=args.gantry, app_name=args.application) |
|
frontend.launch( |
|
|
|
|
|
|
|
) |
|
|
|
|
|
def make_frontend( |
|
fn: Callable[[Image], str], flagging: bool = False, gantry: bool = False, app_name: str = "fashion-aggregator" |
|
): |
|
"""Creates a gradio.Interface frontend for text to image search function.""" |
|
|
|
allow_flagging = "never" |
|
|
|
|
|
frontend = gr.Interface( |
|
fn=fn, |
|
outputs=gr.Gallery(label="Relevant Items"), |
|
|
|
inputs=gr.components.Textbox(label="Item Description"), |
|
title="FashGen", |
|
description=__doc__, |
|
cache_examples=False, |
|
allow_flagging=allow_flagging, |
|
flagging_options=["incorrect", "offensive", "other"], |
|
) |
|
return frontend |
|
|
|
|
|
class PredictorBackend: |
|
"""Interface to a backend that serves predictions. |
|
To communicate with a backend accessible via a URL, provide the url kwarg. |
|
Otherwise, runs a predictor locally. |
|
""" |
|
|
|
def __init__(self, url=None): |
|
if url is not None: |
|
self.url = url |
|
self._predict = self._predict_from_endpoint |
|
else: |
|
model = Retriever(image_embeddings_path=EMBEDDINGS_FILE) |
|
self._predict = model.predict |
|
self._search_images = model.search_images |
|
|
|
def run(self, text: str): |
|
pred, metrics = self._predict_with_metrics(text) |
|
self._log_inference(pred, metrics) |
|
return pred |
|
|
|
def _predict_with_metrics(self, text: str) -> Tuple[List[str], Dict[str, float]]: |
|
paths_and_scores = self._search_images(text) |
|
metrics = {"mean_score": sum(paths_and_scores["score"]) / len(paths_and_scores["score"])} |
|
return paths_and_scores["path"], metrics |
|
|
|
def _log_inference(self, pred, metrics): |
|
for key, value in metrics.items(): |
|
logging.info(f"METRIC {key} {value}") |
|
logging.info(f"PRED >begin\n{pred}\nPRED >end") |
|
|
|
|
|
def _make_parser(): |
|
parser = argparse.ArgumentParser(description=__doc__) |
|
parser.add_argument( |
|
"--model_url", |
|
default=None, |
|
type=str, |
|
help="Identifies a URL to which to send image data. Data is base64-encoded, converted to a utf-8 string, and then set via a POST request as JSON with the key 'image'. Default is None, which instead sends the data to a model running locally.", |
|
) |
|
parser.add_argument( |
|
"--port", |
|
default=DEFAULT_PORT, |
|
type=int, |
|
help=f"Port on which to expose this server. Default is {DEFAULT_PORT}.", |
|
) |
|
parser.add_argument( |
|
"--flagging", |
|
action="store_true", |
|
help="Pass this flag to allow users to 'flag' model behavior and provide feedback.", |
|
) |
|
parser.add_argument( |
|
"--gantry", |
|
action="store_true", |
|
help="Pass --flagging and this flag to log user feedback to Gantry. Requires GANTRY_API_KEY to be defined as an environment variable.", |
|
) |
|
parser.add_argument( |
|
"--application", |
|
default=DEFAULT_APPLICATION_NAME, |
|
type=str, |
|
help=f"Name of the Gantry application to which feedback should be logged, if --gantry and --flagging are passed. Default is {DEFAULT_APPLICATION_NAME}.", |
|
) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = _make_parser() |
|
args = parser.parse_args() |
|
main(args) |