Kevin Louis
Update app.py
ae93270
raw
history blame
7.66 kB
from sentence_transformers import SentenceTransformer
from huggingface_hub import CommitScheduler
from datasets import Dataset
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
import os
from utility import load_from_hub_csv
from DNAseq import DNAseq
from grapher import DNAgrapher
from parameter_extractor import ParameterExtractor
from helper import list_at_index_0, list_at_index_1
from logger import cts_log_file_create, logger, cts_logger
HF_TOKEN = os.environ.get("HF_TOKEN", None)
repo_id = os.environ.get("repo_id", None)
# Create csv file for data logging
log_file_path = cts_log_file_create("flagged")
# Initialise CommitScheduler
scheduler = CommitScheduler(
repo_id=repo_id,
repo_type="dataset",
folder_path=log_file_path.parent,
path_in_repo="data",
every=1440,
private=True,
token=HF_TOKEN
)
# Load Code-Function Mapping
load_from_hub_csv(path=repo_id,
data_file="app/code_function_mapping.csv",
token=HF_TOKEN,
csv_output_file="code_function_mapping.csv")
def chat_to_sequence(sequence, user_query):
# Sequence to be analysed/queried
input_sequence = sequence
# Set DNAseq class expected variable
dna = input_sequence
# Model
model_name = "all-mpnet-base-v2"
# Load model
model = SentenceTransformer(model_name)
# User input
user_query = user_query
# Set ParameterExtractor class expected variable
query = user_query
# Initialise Graphic Response
fig = None
# Initialise Text Response
response = None
# Query Code Description Message
code_descript_message = ''
# kNN semantic similarity threshold / used to determine if query can execute code
# kNN semantic similarity values less than the lower threshold should return a code eval response
# kNN semantic similarity values more than the lower threshold shouldn't return a code eval response
proximal_lower_threshold = 1.1
proximal_upper_threshold = 1.4
threshold_exceeded_message = "Your Query Wasn't Understood. Can You Rephrase The Query"
threshold_approximate_message = "Your Query Wasn't Understood Clearly. Try Using The Following Query Formats"
# Load the function mapping CSV file into a pandas DataFrame
code_function_mapping = pd.read_csv("code_function_mapping.csv")
# Load reference query database from JSON file back into a DataFrame
ref_query_df = pd.read_json('reference_query_db.json', orient='records')
# Create Dataset object using the pandas data frame
ref_query_ds = Dataset.from_pandas(ref_query_df)
# Load FAISS index
ref_query_ds.load_faiss_index('all-mpnet-base-v2_embeddings', 'ref_query_db_index')
# Create embeddings for user query
query_embedding = model.encode(user_query)
# Semantic similarity search user query against sample queries
index_result = ref_query_ds.get_nearest_examples("all-mpnet-base-v2_embeddings", query_embedding, k=3)
# Retrieve results from dataset object
scores, examples = index_result
# Create a DataFrame from the examples dictionary
result_df = pd.DataFrame(examples)
# Add the scores as a new column to the DataFrame
result_df['score'] = scores
# Sort the DataFrame by the 'Score' column in ascending order
# FIASS uses kNN as the similarity algorithm / value of 0 indicates an exact match
sorted_df = result_df.sort_values(by='score', ascending=True)
# Get the query with the lowest kNN score (first row after sorting)
ref_question = sorted_df.iloc[0]['question']
# Get the code for the query with the lowest kNN score (first row after sorting)
query_code = sorted_df.iloc[0]['code']
# Get the score for the query with the lowest kNN score (first row after sorting)
query_score = sorted_df.iloc[0]['score']
# Description of query code to be executed
query_code_description = code_function_mapping[code_function_mapping['code'] == query_code]['description'].values[0]
# Extra log entities
similarity_metric = "k nearest neighbours"
ref_question_2 = sorted_df.iloc[1]['question']
ref_question_3 = sorted_df.iloc[1]['question']
query_score_2 = sorted_df.iloc[1]['score']
query_score_3 = sorted_df.iloc[1]['score']
# logger function log_data parameter input
log_data = [
user_query,
ref_question,
query_score,
query_code,
ref_question_2,
query_score_2,
ref_question_3,
query_score_3,
similarity_metric,
model_name,
proximal_lower_threshold,
proximal_upper_threshold,
]
# Check the query score against threshold values
if query_score >= proximal_upper_threshold:
response = threshold_exceeded_message
cts_logger(scheduler, log_file_path, log_data, response)
print(threshold_exceeded_message)
elif proximal_lower_threshold < query_score < proximal_upper_threshold:
response = threshold_approximate_message + "\n" + ref_question
cts_logger(scheduler, log_file_path, log_data, response)
print(threshold_approximate_message, ref_question)
else:
print("Execute query")
# Define the question
code = query_code
# Filter the DataFrame to find the code that matches the question
matching_row = code_function_mapping[code_function_mapping["code"] == code]
# Check if there is a match
if not matching_row.empty:
function = matching_row.iloc[0]["function"]
f_response = eval(function)
if code[0] == 'c':
response = None
fig = go.Figure(f_response)
else:
response = str(f_response)
fig = None
code_descript_message = query_code_description.title()
cts_logger(scheduler, log_file_path, log_data, response)
else:
response = "Error processing query"
query_code = "No Match Error"
cts_logger(scheduler, log_file_path, log_data, response)
print("No matching code found for the function:", code)
return response, fig, code_descript_message
return response, fig, code_descript_message
ChatToSequence = gr.Interface(
fn=chat_to_sequence,
inputs=[gr.Textbox(label="Sequence", placeholder="Input DNA Sequence..."),
gr.Textbox(label="Query", placeholder="Input Query...")],
outputs=[gr.Textbox(label="Response"),
gr.Plot(label='Graphic Response'),
gr.Textbox(label="Action Executed")],
allow_flagging="never",
title="Chat-To-Sequence",
description="This Demo App Allows You To Explore Your DNA Sequence Using Natural Language",
theme=gr.themes.Soft(),
examples=[
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
"What is the length of the sequence"],
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
"How many guanines bases are there in the sequence"],
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
"What is the base at position 10"],
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
"What are the bases from position 2 to 10"],
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
"How many bases are there from position 2 to 10"],
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaaa",
"Show pie chart of total bases"],
],
).queue()
ChatToSequence.launch()