|
__import__('pysqlite3') |
|
import sys |
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') |
|
|
|
import os |
|
import gradio as gr |
|
import chromadb |
|
from sentence_transformers import SentenceTransformer |
|
import pandas as pd |
|
import numpy as np |
|
|
|
from chromadb.utils import embedding_functions |
|
from huggingface_hub import InferenceClient |
|
|
|
from dotenv import load_dotenv, find_dotenv |
|
_ = load_dotenv(find_dotenv()) |
|
hf_api_key = os.environ['HF_API_KEY'] |
|
|
|
dfs = pd.read_csv('Patents.csv') |
|
ids= [str(x) for x in dfs.index.tolist()] |
|
docs = dfs['text'].tolist() |
|
client = chromadb.Client() |
|
collection = client.get_or_create_collection("patents") |
|
|
|
collection.add(documents=docs,ids=ids) |
|
|
|
def text_embedding(text)-> None: |
|
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
return model.encode(text) |
|
|
|
def gen_context(query): |
|
vector = text_embedding(query).tolist() |
|
|
|
results = collection.query( |
|
query_embeddings=vector,n_results=15,include=["documents"]) |
|
|
|
res = "\n".join(str(item) for item in results['documents'][0]) |
|
return res |
|
|
|
def chat_completion(user_prompt): |
|
length = 1000 |
|
system_prompt = """\ |
|
You are a helpful AI assistant that can answer questions on the patents dataset. Answer based on the context provided.If you cannot find the correct answer, say I don't know. Be concise and just include the response""" |
|
final_prompt = f"""<s>[INST]<<SYS>> |
|
{system_prompt} |
|
<</SYS>> |
|
|
|
{user_prompt}[/INST]""" |
|
|
|
return client.text_generation(prompt=final_prompt,max_new_tokens = length).strip() |
|
|
|
|
|
|
|
client = InferenceClient(model = "mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
|
|
def chat_completion(query): |
|
length = 1000 |
|
|
|
context = gen_context(query) |
|
|
|
user_prompt = f""" |
|
Based on the context: |
|
{context} |
|
Answer the below query: |
|
{query} |
|
""" |
|
system_prompt = """\ |
|
You are a helpful AI assistant that can answer questions on the patents dataset. Answer based on the context provided.If you cannot find the correct answer, say I don't know. Be concise and just include the response""" |
|
final_prompt = f"""<s>[INST]<<SYS>> |
|
{system_prompt} |
|
<</SYS>> |
|
|
|
{user_prompt}[/INST]""" |
|
|
|
|
|
|
|
return client.text_generation(prompt=final_prompt,max_new_tokens = length).strip() |
|
|
|
demo = gr.Interface(fn=chat_completion, |
|
inputs=[gr.Textbox(label="Query", lines=2)], |
|
outputs=[gr.Textbox(label="Result", lines=16)], |
|
title="Chat on Patents Data") |
|
|
|
demo.queue().launch(share=True) |