|
from dotenv import load_dotenv |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_openai import ChatOpenAI |
|
from prompt import * |
|
from utils import call, permutate |
|
|
|
import os |
|
import json |
|
import pandas as pd |
|
import re |
|
|
|
load_dotenv() |
|
|
|
class Validation(): |
|
|
|
def __init__(self, llm): |
|
|
|
if llm.startswith('gpt'): |
|
self.llm = ChatOpenAI(temperature=0, model_name=llm) |
|
elif llm.startswith('gemini'): |
|
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) |
|
else: |
|
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") |
|
|
|
def validate(self, df, api): |
|
|
|
df = df.fillna('') |
|
df['Genes'] = df['Genes'].str.replace(' ', '').str.upper() |
|
df['SNPs'] = df['SNPs'].str.lower() |
|
|
|
|
|
sym = [',', '/', '|', '-'] |
|
for i in df.index: |
|
gene = df.loc[i, 'Genes'] |
|
for s in sym: |
|
if s in gene: |
|
genes = gene.split(s) |
|
df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i] |
|
df = df.sort_index().reset_index(drop=True) |
|
df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:]) |
|
break |
|
|
|
|
|
for i in df.index: |
|
safe = True |
|
snp = df.loc[i, 'SNPs'] |
|
snp = snp.replace('l', '1') |
|
if re.fullmatch('rs(\d)+|', snp): |
|
pass |
|
elif re.fullmatch('ts(\d)+', snp): |
|
snp = 'r' + snp[1:] |
|
elif re.fullmatch('s(\d)+', snp): |
|
snp = 'r' + snp |
|
elif re.fullmatch('(\d)+', snp): |
|
snp = 'rs' + snp |
|
else: |
|
safe = False |
|
df = df.drop(i) |
|
|
|
if safe: |
|
df.loc[i, 'SNPs'] = snp |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
df_clean = df.copy() |
|
|
|
|
|
if api: |
|
dbsnp = {} |
|
for i in df.index: |
|
snp = df.loc[i, 'SNPs'] |
|
gene = df.loc[i, 'Genes'] |
|
|
|
if snp not in dbsnp: |
|
res = call(f'https://www.ebi.ac.uk/gwas/rest/api/singleNucleotidePolymorphisms/{snp}/') |
|
try: |
|
res = res.json() |
|
dbsnp[snp] = [r['gene']['geneName'] for r in res['genomicContexts']] |
|
except: |
|
print("Error at first API", e) |
|
dbsnp[snp] = [] |
|
|
|
try: |
|
res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]] |
|
if 'error' not in res: |
|
dbsnp[snp].extend([r['name'] for r in res['genes']]) |
|
except Exception as e: |
|
print("Error at second API", e) |
|
pass |
|
|
|
dbsnp[snp] = list(set(dbsnp[snp])) |
|
|
|
if gene not in dbsnp[snp]: |
|
for other in permutate(gene): |
|
if other in dbsnp[snp]: |
|
df.loc[i, 'Genes'] = other |
|
print(f'{gene} corrected to {other}') |
|
break |
|
else: |
|
df = df.drop(i) |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
df_no_llm = df.copy() |
|
|
|
|
|
idx = 0 |
|
results = [] |
|
|
|
while True: |
|
json_table = df[['Genes', 'SNPs', 'Diseases']][idx:idx+50].to_json(orient='records') |
|
str_json_table = json.dumps(json.loads(json_table), indent=2) |
|
|
|
result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content |
|
print('val', idx) |
|
print(result) |
|
|
|
result = result[result.find('['):result.rfind(']')+1] |
|
try: |
|
result = eval(result) |
|
except SyntaxError: |
|
result = [] |
|
|
|
results.extend(result) |
|
idx += 50 |
|
if idx not in df.index: |
|
break |
|
|
|
df = pd.DataFrame(results) |
|
df = df.merge(df_no_llm.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross') |
|
|
|
return df, df_no_llm, df_clean |