Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from datasets import Dataset | |
from sentence_transformers import SentenceTransformer | |
from parameter_extractor import ParameterExtractor | |
from DNAseq import DNAseq | |
from helper import list_at_index_0, list_at_index_1, logger | |
def chat_to_sequence(sequence, user_query): | |
if sequence is None: | |
gr.Warning("Sequence Is Empty. Please Input A Sequence") | |
if user_query is None: | |
gr.Warning("Query Is Empty. Please Input A Query") | |
# Log information to a CSV file | |
log_filename = "CTS_user_log.csv" | |
# Sequence to be analysed/queried | |
input_sequence = sequence | |
# Set ParameterExtractor class expected variable | |
dna = input_sequence | |
# Model | |
model_name = "all-mpnet-base-v2" | |
# Load model | |
model = SentenceTransformer(model_name) | |
# User input | |
user_query = user_query | |
# Set ParameterExtractor class expected variable | |
query = user_query | |
# Bot Response | |
response = "" | |
# Query Code Description Message | |
code_descript_message = '' | |
# kNN semantic similarity threshold / used to determine if query can execute code | |
# kNN semantic similarity values less than the lower threshold should return a code eval response | |
# kNN semantic similarity values more than the lower threshold shouldn't return a code eval response | |
proximal_lower_threshold = 1.1 | |
proximal_upper_threshold = 1.4 | |
threshold_exceeded_message = "Your Query Wasn't Understood. Can You Rephrase The Query" | |
threshold_approximate_message = "Your Query Wasn't Understood Clearly. Try Using The Following Query Formats" | |
# Load the function mapping CSV file into a pandas DataFrame | |
code_function_mapping = pd.read_csv("code_function_mapping.csv") | |
# Load reference query database from JSON file back into a DataFrame | |
ref_query_df = pd.read_json('reference_query_db.json', orient='records') | |
# Create Dataset object using the pandas data frame | |
ref_query_ds = Dataset.from_pandas(ref_query_df) | |
# Load FAISS index | |
ref_query_ds.load_faiss_index('all-mpnet-base-v2_embeddings', 'ref_query_db_index') | |
# Create embeddings for user query | |
query_embedding = model.encode(user_query) | |
# Semantic similarity search user query against sample queries | |
index_result = ref_query_ds.get_nearest_examples("all-mpnet-base-v2_embeddings", query_embedding, k=3) | |
print(index_result) | |
# Retrieve results from dataset object | |
scores, examples = index_result | |
# Create a DataFrame from the examples dictionary | |
result_df = pd.DataFrame(examples) | |
# Add the scores as a new column to the DataFrame | |
result_df['score'] = scores | |
# Sort the DataFrame by the 'Score' column in ascending order | |
# FIASS uses kNN as the similarity algorithm / value of 0 indicates an exact match | |
sorted_df = result_df.sort_values(by='score', ascending=True) | |
# Get the query with the lowest kNN score (first row after sorting) | |
ref_question = sorted_df.iloc[0]['question'] | |
# Get the code for the query with the lowest kNN score (first row after sorting) | |
query_code = sorted_df.iloc[0]['code'] | |
# Get the score for the query with the lowest kNN score (first row after sorting) | |
query_score = sorted_df.iloc[0]['score'] | |
# Description of query code to be executed | |
query_code_description = code_function_mapping[code_function_mapping['code'] == query_code]['description'].values[0] | |
# Print the query with the highest score | |
print(ref_question, query_code, query_score) | |
similarity_metric = "k nearest neighbours" | |
ref_question_2 = sorted_df.iloc[1]['question'] | |
ref_question_3 = sorted_df.iloc[1]['question'] | |
query_score_2 = sorted_df.iloc[1]['score'] | |
query_score_3 = sorted_df.iloc[1]['score'] | |
log_data = [ | |
user_query, | |
ref_question, | |
query_score, | |
query_code, | |
ref_question_2, | |
query_score_2, | |
ref_question_3, | |
query_score_3, | |
similarity_metric, | |
model_name, | |
proximal_lower_threshold, | |
proximal_upper_threshold, | |
] | |
# Check the query score against threshold values | |
if query_score >= proximal_upper_threshold: | |
response = threshold_exceeded_message | |
logger(log_filename, log_data, response) | |
print(threshold_exceeded_message) | |
elif proximal_lower_threshold < query_score < proximal_upper_threshold: | |
response = threshold_approximate_message + "/n" + ref_question | |
logger(log_filename, log_data, response) | |
print(threshold_approximate_message, ref_question) | |
else: | |
print("Execute query") | |
# Define the question | |
code = query_code | |
# Filter the DataFrame to find the code that matches the question | |
matching_row = code_function_mapping[code_function_mapping["code"] == code] | |
# Check if there is a match | |
if not matching_row.empty: | |
function = matching_row.iloc[0]["function"] | |
response = str(eval(function)) | |
code_descript_message = query_code_description.title() | |
logger(log_filename, log_data, response) | |
else: | |
response = "Error processing query" | |
query_code = "No Match Error" | |
logger(log_filename, log_data, response) | |
print("No matching code found for the function:", code) | |
return response, code_descript_message | |
return response, code_descript_message | |
ChatToSequence = gr.Interface( | |
fn=chat_to_sequence, | |
inputs=[gr.Textbox(label="Sequence", placeholder="Input DNA Sequence..."), | |
gr.Textbox(label="Query", placeholder="Input Query...")], | |
outputs=[gr.Textbox(label="Response"), gr.Textbox(label="Action Executed")], | |
title="Chat-To-Sequence", | |
description="This Demo App Allows You To Explore Your DNA Sequence Using Natural Language", | |
theme=gr.themes.Soft(), | |
examples=[ | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", | |
"What is the length of the sequence"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", | |
"How many guanines bases are there in the sequence"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", | |
"What is the base at position 10"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", | |
"What are the bases from position 2 to 10"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", | |
"How many bases are there from position 2 to 10"], | |
], | |
).queue() | |
ChatToSequence.launch(share=True) | |