fadliaulawi's picture
Change ground truth preparation
593dcaa
raw
history blame
5.01 kB
from collections import defaultdict
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from prompt import *
from utils import call, permutate
from io import StringIO
import os
import pandas as pd
import re
import requests
load_dotenv()
raw_url = "https://www.ebi.ac.uk/gwas/api/search/downloads/alternative"
gwas_path = "resources/gwas_catalog.tsv"
if os.path.exists(gwas_path):
gwas = pd.read_csv(gwas_path, delimiter='\t')[['DISEASE/TRAIT', 'CHR_ID', 'MAPPED_GENE', 'SNPS', 'P-VALUE', 'OR or BETA']]
else:
data = requests.get(raw_url).content.decode('utf-8')
gwas = pd.read_csv(StringIO(data), delimiter='\t')[['DISEASE/TRAIT', 'CHR_ID', 'MAPPED_GENE', 'SNPS', 'P-VALUE', 'OR or BETA']]
gwas_gene_rsid = gwas[['MAPPED_GENE', 'SNPS']]
gwas_gene_rsid.dropna(inplace=True, ignore_index=True)
gwas_gene_rsid['MAPPED_GENE'] = gwas_gene_rsid['MAPPED_GENE'].apply(lambda x: x.replace(' ', '').upper())
ground_truth = defaultdict(list)
for i in gwas_gene_rsid.index:
gene = gwas_gene_rsid.loc[i, 'MAPPED_GENE']
snp = gwas_gene_rsid.loc[i, 'SNPS']
ground_truth[gene].append(snp)
ground_truth[snp].append(gene)
while '-' in gene:
genes = gene.split('-')
for gene in genes:
ground_truth[gene].append(snp)
ground_truth[snp].append(gene)
gene = genes[-1]
class Validation():
def __init__(self, llm):
if llm.startswith('gpt'):
self.llm = ChatOpenAI(temperature=0, model_name=llm)
elif llm.startswith('gemini'):
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
else:
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
def validate(self, df, api):
df = df.fillna('')
df['Genes'] = df['Genes'].str.replace(' ', '').str.upper()
df['rsID'] = df['rsID'].str.lower()
# Check if there is two gene names
sym = [',', '/', '|', '-']
for i in df.index:
gene = df.loc[i, 'Genes']
for s in sym:
if s in gene:
genes = gene.split(s)
df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i]
df = df.sort_index().reset_index(drop=True)
df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:])
break
# Check if there is SNPs without 'rs'
for i in df.index:
safe = True
snp = df.loc[i, 'rsID']
snp = snp.replace('l', '1')
if re.fullmatch('rs(\d)+|', snp):
pass
elif re.fullmatch('ts(\d)+', snp):
snp = 'r' + snp[1:]
elif re.fullmatch('s(\d)+', snp):
snp = 'r' + snp
elif re.fullmatch('(\d)+', snp):
snp = 'rs' + snp
else:
safe = False
df = df.drop(i)
if safe:
df.loc[i, 'rsID'] = snp
df.reset_index(drop=True, inplace=True)
df_clean = df.copy()
# Validate genes and SNPs with APIs
if api:
dbsnp = {}
for i in df.index:
snp = df.loc[i, 'SNPs']
gene = df.loc[i, 'Genes']
if snp not in dbsnp:
try:
res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]]
if 'error' not in res:
dbsnp[snp].extend([r['name'] for r in res['genes']])
except Exception as e:
print("Error at API", e)
pass
dbsnp[snp] = list(set(dbsnp[snp]))
if gene not in dbsnp[snp]:
for other in permutate(gene):
if other in dbsnp[snp]:
df.loc[i, 'Genes'] = other
print(f'{gene} corrected to {other}')
break
else:
df = df.drop(i)
# Check with GWAS ground truth
for i in df.index:
gene = df.loc[i, 'Genes']
snp = df.loc[i, 'rsID']
perms = permutate(gene)
for perm in perms:
if perm in ground_truth and snp in ground_truth[perm]:
df.loc[i, 'Genes'] = perm
if gene != perm:
print(f'{gene} corrected to {perm}')
else:
print(f'{gene} and {snp} safe')
break
else:
print(f'{gene} and {snp} not found')
df = df.drop(i)
df.reset_index(drop=True, inplace=True)
return df, df_clean