|
__import__('pysqlite3') |
|
import sys |
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') |
|
|
|
import os |
|
import gradio as gr |
|
import chromadb |
|
from sentence_transformers import SentenceTransformer |
|
import pandas as pd |
|
import numpy as np |
|
|
|
from chromadb.utils import embedding_functions |
|
from huggingface_hub import InferenceClient |
|
|
|
dfs = pd.read_csv('Patents.csv') |
|
ids= [str(x) for x in dfs.index.tolist()] |
|
docs = dfs['text'].tolist() |
|
client = chromadb.Client() |
|
collection = client.get_or_create_collection("patents") |
|
collection.add(documents=docs,ids=ids) |
|
|
|
def text_embedding(input): |
|
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') |
|
return model.encode(input) |
|
|
|
def gen_context(query): |
|
vector = text_embedding(query).tolist() |
|
results = collection.query(query_embeddings=vector,n_results=15,include=["documents"]) |
|
res = "\n".join(str(item) for item in results['documents'][0]) |
|
return res |
|
|
|
def chat_completion(user_prompt): |
|
length = 1000 |
|
system_prompt = """\You are a helpful AI assistant that can answer questions on the patents dataset. Answer based on the context provided.If you cannot find the correct answer, say I don't know. Be concise and just include the response""" |
|
final_prompt = f"""<s>[INST]<<SYS>>{system_prompt}<</SYS>>{user_prompt}[/INST]""" |
|
|
|
return client.text_generation(prompt=final_prompt,max_new_tokens = length).strip() |
|
|
|
client = InferenceClient(model = "mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
|
|
demo = gr.Interface(fn=chat_completion, |
|
inputs=[gr.Textbox(label="Query", lines=2)], |
|
outputs=[gr.Textbox(label="Result", lines=16)], |
|
title="Chat on Patents Data") |
|
|
|
demo.queue().launch(share=True) |