Spaces:
Runtime error
Runtime error
Commit
·
665ec8e
1
Parent(s):
7d696ff
Update app.py
Browse files
app.py
CHANGED
@@ -32,16 +32,9 @@ st.cache(show_spinner=False)
|
|
32 |
def image_search(query, top_k=10):
|
33 |
with torch.no_grad():
|
34 |
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
35 |
-
values, indices = torch.
|
36 |
return [links[i] for i in indices[:top_k]]
|
37 |
|
38 |
-
st.cache(show_spinner=False)
|
39 |
-
def image_search_(query, top_k=10):
|
40 |
-
with torch.no_grad():
|
41 |
-
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output.numpy()
|
42 |
-
results = np.argsort((image_embeddings@text_embedding.T)[:, 0])[-1:-top_k-1:-1]
|
43 |
-
return [links[i] for i in results]
|
44 |
-
|
45 |
|
46 |
description = '''
|
47 |
# Persian (fa) image search
|
@@ -79,7 +72,7 @@ def main():
|
|
79 |
_, c, _ = st.columns((1, 3, 1))
|
80 |
query = c.text_input('Search Box (type in fa)', value='گل صورتی')
|
81 |
if len(query) > 0:
|
82 |
-
results =
|
83 |
st.markdown(get_html(results), unsafe_allow_html=True)
|
84 |
|
85 |
|
|
|
32 |
def image_search(query, top_k=10):
|
33 |
with torch.no_grad():
|
34 |
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
|
35 |
+
values, indices = torch.dot(image_embeddings, text_embedding.T).sort(descending=True)
|
36 |
return [links[i] for i in indices[:top_k]]
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
description = '''
|
40 |
# Persian (fa) image search
|
|
|
72 |
_, c, _ = st.columns((1, 3, 1))
|
73 |
query = c.text_input('Search Box (type in fa)', value='گل صورتی')
|
74 |
if len(query) > 0:
|
75 |
+
results = image_search(query)
|
76 |
st.markdown(get_html(results), unsafe_allow_html=True)
|
77 |
|
78 |
|