camphong24032002
Test
87a7ec3
raw
history blame
2.33 kB
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from flask import Flask, render_template, request, jsonify
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
stop = stopwords.words('english')
def text_preprocessing(text):
# make all words with lower letters
text = text.lower()
# getting rid of any punctution
# text = text.replace('http\S+|www.\S+|@|%|:|,|', '', case=False)
# spliting each sentence to words to apply previous funtions on them
word_tokens = text.split(' ')
keywords = [item for item in word_tokens if item not in stop]
# assemble words of each sentence again and assign them in new column
return ' '.join(keywords)
def concat_content(title, value):
return f"{title}: {value}"
def df_to_text(df):
text = []
titles = ["Product ID", "Product Name", "Brand", "Gender", "Price (INR)", "Description", "Primary Color"]
cols = ["ProductID", "ProductName", "ProductBrand", "Gender", "Price (INR)", "Description", "PrimaryColor"]
for data in df:
for title, col in zip(titles, cols):
text.append(concat_content(title, col))
text.append('')
return '\n'.join(text)
df = pd.read_csv("data/dataset.csv").reset_index(drop=True)
embedding_df = pd.read_csv("data/embedding.csv", header=None)
docs = embedding_df.values
text = input("Your search: ")
# text = "a white shirt for men"
model = SentenceTransformer("bert-base-nli-mean-tokens")
app = Flask(__name__)
@app.route("/")
def index():
return render_template("chat.html")
@app.route("/chat", methods=["POST"])
def chat():
data = request.get_json()
msg = data.get("msg")
try:
output_df = get_chat_response(msg)
output_text = df_to_text(output_df)
return jsonify({"response": True, "message": output_text})
except Exception as e:
print(e)
error_message = f'Error: {str(e)}'
return jsonify({"message": error_message, "response": False})
def get_chat_response(text):
query_vector = model.encode(text_preprocessing(text)).astype(float)
results = util.pytorch_cos_sim(query_vector, docs)
top_n = 3
sort_idx = results.argsort(descending=True, axis=1)[0][:top_n]
return df.iloc[sort_idx]
if __name__ == "__main__":
app.run(debug=True)