Spaces:
Sleeping
Sleeping
File size: 1,778 Bytes
a1551fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, connections
from sentence_transformers import SentenceTransformer
import configparser
def retrieve_molecule_index(molecule):
model = SentenceTransformer(model_name_or_path="bert-base-uncased")
search_vector = model.encode(molecule).reshape(1,-1)
cfp = configparser.RawConfigParser()
cfp.read('config.ini')
milvus_uri = cfp.get('example', 'uri')
token = cfp.get('example', 'token')
connections.connect("default",
uri=milvus_uri,
token=token)
print(f"Connecting to DB: {milvus_uri}")
collection_name = "molecule_embeddings"
dim = 768 # Adjust based on the dimensionality of your embeddings
# Define collection schema
molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True)
molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name")
molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings],
auto_id=False,
description="my first collection!")
print(f"Creating example collection: {collection_name}")
collection = Collection(name=collection_name, schema=schema)
search_params = {"metric_type": "IP"}
topk = 1
results = collection.search(search_vector, anns_field='molecule_embedding', param=search_params, limit=topk)
print(results)
# Disconnect from Milvus server
connections.disconnect("default")
print("Disconnected from Milvus server.")
return results[0].ids |