from sentence_transformers import SentenceTransformer from huggingface_hub import CommitScheduler from datasets import Dataset import gradio as gr import pandas as pd import plotly.graph_objects as go import os from DNAseq import DNAseq from grapher import DNAgrapher from parameter_extractor import ParameterExtractor from helper import list_at_index_0, list_at_index_1 from logger import cts_log_file_create, logger, cts_logger HF_TOKEN = os.environ.get("HF_TOKEN", None) repo_id = os.environ.get("repo_id", None) # Create csv file for data logging log_file_path = cts_log_file_create("flagged") # Initialise CommitScheduler scheduler = CommitScheduler( repo_id=repo_id, repo_type="dataset", folder_path=log_file_path.parent, path_in_repo="data", every=1440, private=True, token=HF_TOKEN ) def chat_to_sequence(sequence, user_query): # Sequence to be analysed/queried input_sequence = sequence # Set DNAseq class expected variable dna = input_sequence # Model model_name = "all-mpnet-base-v2" # Load model model = SentenceTransformer(model_name) # User input user_query = user_query # Set ParameterExtractor class expected variable query = user_query # Initialise Graphic Response fig = None # Initialise Text Response response = None # Query Code Description Message code_descript_message = '' # kNN semantic similarity threshold / used to determine if query can execute code # kNN semantic similarity values less than the lower threshold should return a code eval response # kNN semantic similarity values more than the lower threshold shouldn't return a code eval response proximal_lower_threshold = 1.1 proximal_upper_threshold = 1.4 threshold_exceeded_message = "Your Query Wasn't Understood. Can You Rephrase The Query" threshold_approximate_message = "Your Query Wasn't Understood Clearly. Try Using The Following Query Formats" # Load the function mapping CSV file into a pandas DataFrame code_function_mapping = pd.read_csv("code_function_mapping.csv") # Load reference query database from JSON file back into a DataFrame ref_query_df = pd.read_json('reference_query_db.json', orient='records') # Create Dataset object using the pandas data frame ref_query_ds = Dataset.from_pandas(ref_query_df) # Load FAISS index ref_query_ds.load_faiss_index('all-mpnet-base-v2_embeddings', 'ref_query_db_index') # Create embeddings for user query query_embedding = model.encode(user_query) # Semantic similarity search user query against sample queries index_result = ref_query_ds.get_nearest_examples("all-mpnet-base-v2_embeddings", query_embedding, k=3) # Retrieve results from dataset object scores, examples = index_result # Create a DataFrame from the examples dictionary result_df = pd.DataFrame(examples) # Add the scores as a new column to the DataFrame result_df['score'] = scores # Sort the DataFrame by the 'Score' column in ascending order # FIASS uses kNN as the similarity algorithm / value of 0 indicates an exact match sorted_df = result_df.sort_values(by='score', ascending=True) # Get the query with the lowest kNN score (first row after sorting) ref_question = sorted_df.iloc[0]['question'] # Get the code for the query with the lowest kNN score (first row after sorting) query_code = sorted_df.iloc[0]['code'] # Get the score for the query with the lowest kNN score (first row after sorting) query_score = sorted_df.iloc[0]['score'] # Description of query code to be executed query_code_description = code_function_mapping[code_function_mapping['code'] == query_code]['description'].values[0] # Extra log entities similarity_metric = "k nearest neighbours" ref_question_2 = sorted_df.iloc[1]['question'] ref_question_3 = sorted_df.iloc[1]['question'] query_score_2 = sorted_df.iloc[1]['score'] query_score_3 = sorted_df.iloc[1]['score'] # logger function log_data parameter input log_data = [ user_query, ref_question, query_score, query_code, ref_question_2, query_score_2, ref_question_3, query_score_3, similarity_metric, model_name, proximal_lower_threshold, proximal_upper_threshold, ] # Check the query score against threshold values if query_score >= proximal_upper_threshold: response = threshold_exceeded_message cts_logger(scheduler, log_file_path, log_data, response) print(threshold_exceeded_message) elif proximal_lower_threshold < query_score < proximal_upper_threshold: response = threshold_approximate_message + "\n" + ref_question cts_logger(scheduler, log_file_path, log_data, response) print(threshold_approximate_message, ref_question) else: print("Execute query") # Define the question code = query_code # Filter the DataFrame to find the code that matches the question matching_row = code_function_mapping[code_function_mapping["code"] == code] # Check if there is a match if not matching_row.empty: function = matching_row.iloc[0]["function"] f_response = eval(function) if code[0] == 'c': response = None fig = go.Figure(f_response) else: response = str(f_response) fig = None code_descript_message = query_code_description.title() cts_logger(scheduler, log_file_path, log_data, response) else: response = "Error processing query" query_code = "No Match Error" cts_logger(scheduler, log_file_path, log_data, response) print("No matching code found for the function:", code) return response, fig, code_descript_message return response, fig, code_descript_message ChatToSequence = gr.Interface( fn=chat_to_sequence, inputs=[gr.Textbox(label="Sequence", placeholder="Input DNA Sequence..."), gr.Textbox(label="Query", placeholder="Input Query...")], outputs=[gr.Textbox(label="Response"), gr.Plot(label='Graphic Response'), gr.Textbox(label="Action Executed")], allow_flagging="never", title="Chat-To-Sequence", description="This Demo App Allows You To Explore Your DNA Sequence Using Natural Language", theme=gr.themes.Soft(), examples=[ ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", "What is the length of the sequence"], ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", "How many guanines bases are there in the sequence"], ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", "What is the base at position 10"], ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", "What are the bases from position 2 to 10"], ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa", "How many bases are there from position 2 to 10"], ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaaa", "Show pie chart of total bases"], ], ).queue() ChatToSequence.launch()