from collections import defaultdict from dotenv import load_dotenv from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from prompt import * from utils import call, permutate from io import StringIO import os import pandas as pd import re import requests load_dotenv() raw_url = "https://www.ebi.ac.uk/gwas/api/search/downloads/alternative" gwas_path = "resources/gwas_catalog.tsv" if os.path.exists(gwas_path): gwas = pd.read_csv(gwas_path, delimiter='\t')[['DISEASE/TRAIT', 'CHR_ID', 'MAPPED_GENE', 'SNPS', 'P-VALUE', 'OR or BETA']] else: data = requests.get(raw_url).content.decode('utf-8') gwas = pd.read_csv(StringIO(data), delimiter='\t')[['DISEASE/TRAIT', 'CHR_ID', 'MAPPED_GENE', 'SNPS', 'P-VALUE', 'OR or BETA']] gwas_gene_rsid = gwas[['MAPPED_GENE', 'SNPS']] gwas_gene_rsid.dropna(inplace=True, ignore_index=True) gwas_gene_rsid['MAPPED_GENE'] = gwas_gene_rsid['MAPPED_GENE'].apply(lambda x: x.replace(' ', '').upper()) ground_truth = defaultdict(list) for i in gwas_gene_rsid.index: gene = gwas_gene_rsid.loc[i, 'MAPPED_GENE'] snp = gwas_gene_rsid.loc[i, 'SNPS'] ground_truth[gene].append(snp) ground_truth[snp].append(gene) while '-' in gene: genes = gene.split('-') for gene in genes: ground_truth[gene].append(snp) ground_truth[snp].append(gene) gene = genes[-1] class Validation(): def __init__(self, llm): if llm.startswith('gpt'): self.llm = ChatOpenAI(temperature=0, model_name=llm) elif llm.startswith('gemini'): self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) else: self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") def validate(self, df, api): df = df.fillna('') df['Genes'] = df['Genes'].str.replace(' ', '').str.upper() df['rsID'] = df['rsID'].str.lower() # Check if there is two gene names sym = [',', '/', '|', '-'] for i in df.index: gene = df.loc[i, 'Genes'] for s in sym: if s in gene: genes = gene.split(s) df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i] df = df.sort_index().reset_index(drop=True) df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:]) break # Check if there is SNPs without 'rs' for i in df.index: safe = True snp = df.loc[i, 'rsID'] snp = snp.replace('l', '1') if re.fullmatch('rs(\d)+|', snp): pass elif re.fullmatch('ts(\d)+', snp): snp = 'r' + snp[1:] elif re.fullmatch('s(\d)+', snp): snp = 'r' + snp elif re.fullmatch('(\d)+', snp): snp = 'rs' + snp else: safe = False df = df.drop(i) if safe: df.loc[i, 'rsID'] = snp df.reset_index(drop=True, inplace=True) df_clean = df.copy() # Validate genes and SNPs with APIs if api: dbsnp = {} for i in df.index: snp = df.loc[i, 'SNPs'] gene = df.loc[i, 'Genes'] if snp not in dbsnp: try: res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]] if 'error' not in res: dbsnp[snp].extend([r['name'] for r in res['genes']]) except Exception as e: print("Error at API", e) pass dbsnp[snp] = list(set(dbsnp[snp])) if gene not in dbsnp[snp]: for other in permutate(gene): if other in dbsnp[snp]: df.loc[i, 'Genes'] = other print(f'{gene} corrected to {other}') break else: df = df.drop(i) # Check with GWAS ground truth for i in df.index: gene = df.loc[i, 'Genes'] snp = df.loc[i, 'rsID'] perms = permutate(gene) for perm in perms: if perm in ground_truth and snp in ground_truth[perm]: df.loc[i, 'Genes'] = perm if gene != perm: print(f'{gene} corrected to {perm}') else: print(f'{gene} and {snp} safe') break else: print(f'{gene} and {snp} not found') df = df.drop(i) df.reset_index(drop=True, inplace=True) return df, df_clean