Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
from huggingface_hub import CommitScheduler | |
from datasets import Dataset | |
import gradio as gr | |
import pandas as pd | |
import plotly.graph_objects as go | |
import os | |
from utility import load_from_hub_csv | |
from DNAseq import DNAseq | |
from grapher import DNAgrapher | |
from parameter_extractor import ParameterExtractor | |
from helper import list_at_index_0, list_at_index_1 | |
from logger import cts_log_file_create, logger, cts_logger | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
repo_id = os.environ.get("repo_id", None) | |
# Create csv file for data logging | |
log_file_path = cts_log_file_create("flagged") | |
# Initialise CommitScheduler | |
scheduler = CommitScheduler( | |
repo_id=repo_id, | |
repo_type="dataset", | |
folder_path=log_file_path.parent, | |
path_in_repo="data", | |
every=2880, | |
private=True, | |
token=HF_TOKEN | |
) | |
# Load Code-Function Mapping | |
load_from_hub_csv(path=repo_id, | |
data_file="app/code_function_mapping.csv", | |
token=HF_TOKEN, | |
csv_output_file="code_function_mapping.csv") | |
def chat_to_sequence(sequence, user_query): | |
# Sequence to be analysed/queried | |
input_sequence = sequence | |
# Set DNAseq class expected variable | |
dna = input_sequence | |
# Model | |
model_name = "all-mpnet-base-v2" | |
# Load model | |
model = SentenceTransformer(model_name) | |
# User input | |
user_query = user_query | |
# Set ParameterExtractor class expected variable | |
query = user_query | |
# Initialise Graphic Response | |
fig = None | |
# Initialise Text Response | |
response = None | |
# Query Code Description Message | |
code_descript_message = '' | |
# kNN semantic similarity threshold / used to determine if query can execute code | |
# kNN semantic similarity values less than the lower threshold should return a code eval response | |
# kNN semantic similarity values more than the lower threshold shouldn't return a code eval response | |
proximal_lower_threshold = 1.1 | |
proximal_upper_threshold = 1.4 | |
threshold_exceeded_message = "Your Query Wasn't Understood. Can You Rephrase The Query" | |
threshold_approximate_message = "Your Query Wasn't Understood Clearly. Try Using The Following Query Formats" | |
# Load the function mapping CSV file into a pandas DataFrame | |
code_function_mapping = pd.read_csv("code_function_mapping.csv") | |
# Load reference query database from JSON file back into a DataFrame | |
ref_query_df = pd.read_json('reference_query_db.json', orient='records') | |
# Create Dataset object using the pandas data frame | |
ref_query_ds = Dataset.from_pandas(ref_query_df) | |
# Load FAISS index | |
ref_query_ds.load_faiss_index('all-mpnet-base-v2_embeddings', 'ref_query_db_index') | |
# Create embeddings for user query | |
query_embedding = model.encode(user_query) | |
# Semantic similarity search user query against sample queries | |
index_result = ref_query_ds.get_nearest_examples("all-mpnet-base-v2_embeddings", query_embedding, k=3) | |
# Retrieve results from dataset object | |
scores, examples = index_result | |
# Create a DataFrame from the examples dictionary | |
result_df = pd.DataFrame(examples) | |
# Add the scores as a new column to the DataFrame | |
result_df['score'] = scores | |
# Sort the DataFrame by the 'Score' column in ascending order | |
# FIASS uses kNN as the similarity algorithm / value of 0 indicates an exact match | |
sorted_df = result_df.sort_values(by='score', ascending=True) | |
# Get the query with the lowest kNN score (first row after sorting) | |
ref_question = sorted_df.iloc[0]['question'] | |
# Get the code for the query with the lowest kNN score (first row after sorting) | |
query_code = sorted_df.iloc[0]['code'] | |
# Get the score for the query with the lowest kNN score (first row after sorting) | |
query_score = sorted_df.iloc[0]['score'] | |
# Description of query code to be executed | |
query_code_description = code_function_mapping[code_function_mapping['code'] == query_code]['description'].values[0] | |
# Extra log entities | |
similarity_metric = "k nearest neighbours" | |
ref_question_2 = sorted_df.iloc[1]['question'] | |
ref_question_3 = sorted_df.iloc[1]['question'] | |
query_score_2 = sorted_df.iloc[1]['score'] | |
query_score_3 = sorted_df.iloc[1]['score'] | |
# logger function log_data parameter input | |
log_data = [ | |
user_query, | |
ref_question, | |
query_score, | |
query_code, | |
ref_question_2, | |
query_score_2, | |
ref_question_3, | |
query_score_3, | |
similarity_metric, | |
model_name, | |
proximal_lower_threshold, | |
proximal_upper_threshold, | |
] | |
# Check the query score against threshold values | |
if query_score >= proximal_upper_threshold: | |
response = threshold_exceeded_message | |
cts_logger(scheduler, log_file_path, log_data, response) | |
print(threshold_exceeded_message) | |
elif proximal_lower_threshold < query_score < proximal_upper_threshold: | |
response = threshold_approximate_message + "\n" + ref_question | |
cts_logger(scheduler, log_file_path, log_data, response) | |
print(threshold_approximate_message, ref_question) | |
else: | |
print("Execute query") | |
# Define the question | |
code = query_code | |
# Filter the DataFrame to find the code that matches the question | |
matching_row = code_function_mapping[code_function_mapping["code"] == code] | |
# Check if there is a match | |
if not matching_row.empty: | |
function = matching_row.iloc[0]["function"] | |
f_response = eval(function) | |
if code[0] == 'c': | |
response = None | |
fig = go.Figure(f_response) | |
else: | |
response = str(f_response) | |
fig = None | |
code_descript_message = query_code_description.title() | |
cts_logger(scheduler, log_file_path, log_data, response) | |
else: | |
response = "Error processing query" | |
query_code = "No Match Error" | |
cts_logger(scheduler, log_file_path, log_data, response) | |
print("No matching code found for the function:", code) | |
return response, fig, code_descript_message | |
return response, fig, code_descript_message | |
ChatToSequence = gr.Interface( | |
fn=chat_to_sequence, | |
inputs=[gr.Textbox(label="Sequence", placeholder="Input DNA Sequence..."), | |
gr.Textbox(label="Query", placeholder="Input Query...")], | |
outputs=[gr.Textbox(label="Response"), | |
gr.Plot(label='Graphic Response'), | |
gr.Textbox(label="Action Executed")], | |
allow_flagging="never", | |
title="Chat-To-Sequence", | |
description="<h2><center><span style='color: purple;'>This Demo App Allows You To Explore Your DNA Sequence Using Natural Language</span></h2></center>" | |
"<h5><center>Disclaimer: The app stores the user queries but doesn't store the DNA sequence." | |
" Please Don't Input Any Information You Don't Wish To Share Into The Query Box.<h5><center>", | |
theme=gr.themes.Soft(), | |
examples=[ | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaa", | |
"What is the length of the sequence"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaa", | |
"How many guanines bases are there in the sequence"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaa", | |
"What is the base at position 10"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaa", | |
"What are the bases from position 2 to 10"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaa", | |
"How many bases are there from position 2 to 10"], | |
["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaaaa", | |
"Show pie chart of total bases"], | |
], | |
).queue() | |
ChatToSequence.launch() | |