ShawnAI commited on
Commit
1c2a427
·
1 Parent(s): bc2a463

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +188 -0
main.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, OpenAIEmbeddings
4
+ from pymilvus import Collection, connections
5
+ import json
6
+ import os
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+
9
+
10
+ MILVUS_COLLECTION = os.environ.get("MILVUS_COLLECTION", "LangChainCollection")
11
+ MILVUS_INDEX = os.environ.get("MILVUS_INDEX", '_default_idx_103')
12
+
13
+ MILVUS_HOST = os.environ.get("MILVUS_HOST", "")
14
+ MILVUS_PORT = "19530"
15
+
16
+ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "hkunlp/instructor-large")
17
+ EMBEDDING_LOADER = os.environ.get("EMBEDDING_LOADER", "HuggingFaceInstructEmbeddings")
18
+ EMBEDDING_LIST = ["HuggingFaceInstructEmbeddings", "HuggingFaceEmbeddings"]
19
+
20
+ # return top-k text chunks from vector store
21
+ TOP_K_DEFAULT = 15
22
+ TOP_K_MAX = 30
23
+ SCORE_DEFAULT = 0.33
24
+
25
+ BUTTON_MIN_WIDTH = 100
26
+
27
+ global g_emb
28
+ g_emb = None
29
+ global g_col
30
+ g_col = None
31
+
32
+ def init_emb(emb_name, emb_loader, db_col_textbox):
33
+
34
+ global g_emb
35
+ global g_col
36
+
37
+ g_emb = eval(emb_loader)(model_name=emb_name)
38
+
39
+ connections.connect(
40
+ host=MILVUS_HOST,
41
+ port=MILVUS_PORT
42
+ )
43
+
44
+ g_col = Collection(db_col_textbox)
45
+
46
+ g_col.load()
47
+
48
+ return (str(g_emb), str(g_col))
49
+
50
+
51
+ def get_emb():
52
+ return g_emb
53
+
54
+ def get_col():
55
+ return g_col
56
+
57
+
58
+ def remove_duplicates(documents, score_min):
59
+ seen_content = set()
60
+ unique_documents = []
61
+ for (doc, score) in documents:
62
+ if (doc.page_content not in seen_content) and (score >= score_min):
63
+ seen_content.add(doc.page_content)
64
+ unique_documents.append(doc)
65
+ return unique_documents
66
+
67
+
68
+ def get_data(query, top_k, score, db_col, db_index):
69
+ if not query:
70
+ return "Please init db in configuration"
71
+
72
+ embed_query = g_emb.embed_query(query)
73
+
74
+ search_params = {"metric_type": "L2",
75
+ "params": {"nprobe": 2},
76
+ "offset": 5}
77
+
78
+
79
+ results = g_col.search(
80
+ data=[embed_query],
81
+ anns_field="vector",
82
+ param=search_params,
83
+ limit=10,
84
+ expr=None,
85
+ output_fields=['source', 'text'],
86
+ consistency_level="Strong"
87
+ )
88
+
89
+ jsons = json.dumps([{'source': hit.entity.get('source'),
90
+ 'text': hit.entity.get('text')}
91
+ for hit in results[0]],
92
+ indent=0)
93
+
94
+ return jsons
95
+
96
+ with gr.Blocks(
97
+ title = "3GPP Database",
98
+ theme = "Base",
99
+ css = """.bigbox {
100
+ min-height:250px;
101
+ }
102
+ """) as demo:
103
+ with gr.Tab("Matching"):
104
+ with gr.Accordion("Vector similarity"):
105
+ with gr.Row():
106
+ with gr.Column():
107
+ top_k = gr.Slider(1,
108
+ TOP_K_MAX,
109
+ value=TOP_K_DEFAULT,
110
+ step=1,
111
+ label="Vector similarity top_k",
112
+ interactive=True)
113
+ with gr.Column():
114
+ score = gr.Slider(0.01,
115
+ 0.99,
116
+ value=SCORE_DEFAULT,
117
+ step=0.01,
118
+ label="Vector similarity score",
119
+ interactive=True)
120
+
121
+ with gr.Row():
122
+ with gr.Column(scale=10):
123
+ input_box = gr.Textbox(label = "Input", placeholder="What are you looking for?")
124
+ with gr.Column(scale=1, min_width=BUTTON_MIN_WIDTH):
125
+ btn_run = gr.Button("Run", variant="primary")
126
+
127
+ output_box = gr.JSON(label = "Output")
128
+
129
+
130
+ with gr.Tab("Configuration"):
131
+ with gr.Row():
132
+ btn_init = gr.Button("Init")
133
+
134
+ load_emb = gr.Textbox(get_emb, label = 'Embedding Client', show_label=True)
135
+ load_col = gr.Textbox(get_col, label = 'Milvus Collection', show_label=True)
136
+
137
+ with gr.Accordion("Embedding"):
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ emb_textbox = gr.Textbox(
142
+ label = "Embedding Model",
143
+ # show_label = False,
144
+ value = EMBEDDING_MODEL,
145
+ placeholder = "Paste Your Embedding Model Repo on HuggingFace",
146
+ lines=1,
147
+ interactive=True,
148
+ type='email')
149
+
150
+ with gr.Column():
151
+ emb_dropdown = gr.Dropdown(
152
+ EMBEDDING_LIST,
153
+ value=EMBEDDING_LOADER,
154
+ multiselect=False,
155
+ interactive=True,
156
+ label="Embedding Loader")
157
+
158
+ with gr.Accordion("Milvus Database"):
159
+ with gr.Row():
160
+ db_col_textbox = gr.Textbox(
161
+ label = "Milvus Collection",
162
+ # show_label = False,
163
+ value = MILVUS_COLLECTION,
164
+ placeholder = "Paste Your Milvus Collection (xx-xx-xx) and Hit ENTER",
165
+ lines=1,
166
+ interactive=True,
167
+ type='email')
168
+ db_index_textbox = gr.Textbox(
169
+ label = "Milvus Index",
170
+ # show_label = False,
171
+ value = MILVUS_INDEX,
172
+ placeholder = "Paste Your Milvus Index (xxxx) and Hit ENTER",
173
+ lines=1,
174
+ interactive=True,
175
+ type='email')
176
+
177
+ btn_init.click(fn=init_emb,
178
+ inputs=[emb_textbox, emb_dropdown, db_col_textbox],
179
+ outputs=[load_emb, load_col])
180
+ btn_run.click(fn=get_data,
181
+ inputs=[input_box, top_k, score, db_col_textbox, db_index_textbox],
182
+ outputs=[output_box])
183
+
184
+ if __name__ == "__main__":
185
+ demo.queue()
186
+ demo.launch(inbrowser = True,
187
+ )
188
+