Spaces:
Runtime error
Runtime error
import pandas as pd | |
from sentence_transformers import SentenceTransformer, util | |
from flask import Flask, render_template, request, jsonify | |
from nltk.corpus import stopwords | |
import os | |
stop = stopwords.words('english') | |
def text_preprocessing(text): | |
# make all words with lower letters | |
text = text.lower() | |
# getting rid of any punctution | |
# text = text.replace('http\S+|www.\S+|@|%|:|,|', '', case=False) | |
# spliting each sentence to words to apply previous funtions on them | |
word_tokens = text.split(' ') | |
keywords = [item for item in word_tokens if item not in stop] | |
# assemble words of each sentence again and assign them in new column | |
return ' '.join(keywords) | |
def concat_content(title, value): | |
return f"{title}: {value}" | |
def df_to_text(df): | |
text = [] | |
titles = ["Product ID", "Product Name", "Brand", "Gender", "Price (INR)", "Description", "Primary Color"] | |
cols = ["ProductID", "ProductName", "ProductBrand", "Gender", "Price (INR)", "Description", "PrimaryColor"] | |
for data in df: | |
for title, col in zip(titles, cols): | |
text.append(concat_content(title, col)) | |
text.append('') | |
return '\n'.join(text) | |
df = pd.read_csv("data/dataset.csv").reset_index(drop=True) | |
embedding_df = pd.read_csv("data/embedding.csv", header=None) | |
docs = embedding_df.values | |
HF_TOKEN=os.environ("HF_TOKEN") | |
model = SentenceTransformer("bert-base-nli-mean-tokens", cache_folder = "/code/", use_auth_token=HF_TOKEN) | |
app = Flask(__name__) | |
def index(): | |
return render_template("chat.html") | |
def chat(): | |
data = request.get_json() | |
msg = data.get("msg") | |
try: | |
output_df = get_chat_response(msg) | |
output_text = df_to_text(output_df) | |
return jsonify({"response": True, "message": output_text}) | |
except Exception as e: | |
print(e) | |
error_message = f'Error: {str(e)}' | |
return jsonify({"message": error_message, "response": False}) | |
def get_chat_response(text): | |
query_vector = model.encode(text_preprocessing(text)).astype(float) | |
results = util.pytorch_cos_sim(query_vector, docs) | |
top_n = 3 | |
sort_idx = results.argsort(descending=True, axis=1)[0][:top_n] | |
return df.iloc[sort_idx] | |
if __name__ == "__main__": | |
app.run(debug=True) | |