File size: 7,379 Bytes
a7228f9
 
 
7c244fe
 
e6cc6ba
ad774cf
a7228f9
7c244fe
e6cc6ba
 
 
a7228f9
 
 
e6cc6ba
ad774cf
 
a7228f9
 
 
 
 
 
 
 
 
e9633fa
a7228f9
ad774cf
a7228f9
7c244fe
ad774cf
7c244fe
d636407
7c244fe
 
 
a7228f9
7c244fe
 
 
 
 
 
 
 
 
 
 
 
 
 
e6cc6ba
 
 
 
 
7c244fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d636407
7c244fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d636407
7c244fe
 
 
 
 
 
 
d636407
7c244fe
 
 
 
 
 
 
 
 
 
 
 
 
 
dfbb079
7c244fe
 
 
a7228f9
7c244fe
 
 
a7228f9
 
7c244fe
 
 
 
 
 
 
 
 
 
 
 
e6cc6ba
 
 
 
 
 
 
7c244fe
a7228f9
7c244fe
 
 
a7228f9
7c244fe
 
e6cc6ba
 
7c244fe
 
 
 
 
 
1cc7286
 
 
a7228f9
7c244fe
 
 
 
 
 
 
 
 
 
 
 
 
 
ead1723
e6cc6ba
7c244fe
 
 
d636407
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from sentence_transformers import SentenceTransformer
from huggingface_hub import CommitScheduler
from datasets import Dataset
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
import os

from DNAseq import DNAseq
from grapher import DNAgrapher
from parameter_extractor import ParameterExtractor

from helper import list_at_index_0, list_at_index_1
from logger import cts_log_file_create, logger, cts_logger


HF_TOKEN = os.environ.get("HF_TOKEN", None)

# Create csv file for data logging
log_file_path = cts_log_file_create("flagged")

# Initialise CommitScheduler
scheduler = CommitScheduler(
    repo_id="kevkev05/CTS-performance-metrics",
    repo_type="dataset",
    folder_path=log_file_path.parent,
    path_in_repo="data",
    every=1440,
    private=True,
    token=HF_TOKEN
)


def chat_to_sequence(sequence, user_query):
    
    # Sequence to be analysed/queried
    input_sequence = sequence

    # Set DNAseq class expected variable
    dna = input_sequence

    # Model
    model_name = "all-mpnet-base-v2"

    # Load model
    model = SentenceTransformer(model_name)

    # User input
    user_query = user_query

    # Set ParameterExtractor class expected variable
    query = user_query

    # Initialise Graphic Response
    fig = None

    # Initialise Text Response
    response = None

    # Query Code Description Message
    code_descript_message = ''

    # kNN semantic similarity threshold / used to determine if query can execute code
    # kNN semantic similarity values less than the lower threshold should return a code eval response
    # kNN semantic similarity values more than the lower threshold shouldn't return a code eval response
    proximal_lower_threshold = 1.1
    proximal_upper_threshold = 1.4

    threshold_exceeded_message = "Your Query Wasn't Understood. Can You Rephrase The Query"
    threshold_approximate_message = "Your Query Wasn't Understood Clearly. Try Using The Following Query Formats"

    # Load the function mapping CSV file into a pandas DataFrame
    code_function_mapping = pd.read_csv("code_function_mapping.csv")

    # Load reference query database from JSON file back into a DataFrame
    ref_query_df = pd.read_json('reference_query_db.json', orient='records')

    # Create Dataset object using the pandas data frame
    ref_query_ds = Dataset.from_pandas(ref_query_df)

    # Load FAISS index
    ref_query_ds.load_faiss_index('all-mpnet-base-v2_embeddings', 'ref_query_db_index')

    # Create embeddings for user query
    query_embedding = model.encode(user_query)

    # Semantic similarity search user query against sample queries
    index_result = ref_query_ds.get_nearest_examples("all-mpnet-base-v2_embeddings", query_embedding, k=3)
    
    # Retrieve results from dataset object
    scores, examples = index_result

    # Create a DataFrame from the examples dictionary
    result_df = pd.DataFrame(examples)

    # Add the scores as a new column to the DataFrame
    result_df['score'] = scores

    # Sort the DataFrame by the 'Score' column in ascending order
    # FIASS uses kNN as the similarity algorithm / value of 0 indicates an exact match
    sorted_df = result_df.sort_values(by='score', ascending=True)

    # Get the query with the lowest kNN score (first row after sorting)
    ref_question = sorted_df.iloc[0]['question']

    # Get the code for the query with the lowest kNN score (first row after sorting)
    query_code = sorted_df.iloc[0]['code']

    # Get the score for the query with the lowest kNN score (first row after sorting)
    query_score = sorted_df.iloc[0]['score']

    # Description of query code to be executed
    query_code_description = code_function_mapping[code_function_mapping['code'] == query_code]['description'].values[0]

    # Extra log entities
    similarity_metric = "k nearest neighbours"

    ref_question_2 = sorted_df.iloc[1]['question']
    ref_question_3 = sorted_df.iloc[1]['question']
    query_score_2 = sorted_df.iloc[1]['score']
    query_score_3 = sorted_df.iloc[1]['score']

    # logger function log_data parameter input
    log_data = [
        user_query,
        ref_question,
        query_score,
        query_code,
        ref_question_2,
        query_score_2,
        ref_question_3,
        query_score_3,
        similarity_metric,
        model_name,
        proximal_lower_threshold,
        proximal_upper_threshold,
    ]

    # Check the query score against threshold values
    if query_score >= proximal_upper_threshold:
        response = threshold_exceeded_message
        cts_logger(scheduler, log_file_path, log_data, response)
        print(threshold_exceeded_message)

    elif proximal_lower_threshold < query_score < proximal_upper_threshold:
        response = threshold_approximate_message + "\n" + ref_question
        cts_logger(scheduler, log_file_path, log_data, response)
        print(threshold_approximate_message, ref_question)
    else:
        print("Execute query")
        # Define the question
        code = query_code

        # Filter the DataFrame to find the code that matches the question
        matching_row = code_function_mapping[code_function_mapping["code"] == code]

        # Check if there is a match
        if not matching_row.empty:
            function = matching_row.iloc[0]["function"]
            f_response = eval(function)
            if code[0] == 'c':
                response = None
                fig = go.Figure(f_response)
            else:
                response = str(f_response)
                fig = None
            code_descript_message = query_code_description.title()
            cts_logger(scheduler, log_file_path, log_data, response)
        else:
            response = "Error processing query"
            query_code = "No Match Error"
            cts_logger(scheduler, log_file_path, log_data, response)
            print("No matching code found for the function:", code)

        return response, fig, code_descript_message
    return response, fig, code_descript_message


ChatToSequence = gr.Interface(
    fn=chat_to_sequence,
    inputs=[gr.Textbox(label="Sequence", placeholder="Input DNA Sequence..."),
            gr.Textbox(label="Query", placeholder="Input Query...")],
    outputs=[gr.Textbox(label="Response"),
             gr.Plot(label='Graphic Response'),
             gr.Textbox(label="Action Executed")],
    allow_flagging="never",
    title="Chat-To-Sequence",
    description="This Demo App Allows You To Explore Your DNA Sequence Using Natural Language",
    theme=gr.themes.Soft(),
    examples=[
        ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
         "What is the length of the sequence"],
        ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
         "How many guanines bases are there in the sequence"],
        ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
         "What is the base at position 10"],
        ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
         "What are the bases from position 2 to 10"],
        ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaa",
         "How many bases are there from position 2 to 10"],
        ["ggcattgaggagaccattgacaccgtcattagcaatgcactacaactgtcacaacctaaaaa",
         "Show pie chart of total bases"],
    ],
).queue()

ChatToSequence.launch()