import os import re import httpx import asyncio from functools import lru_cache from pathlib import Path from typing import List, Optional, Dict, Any # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- UMLS_API_KEY = os.getenv("UMLS_KEY") UMLS_AUTH_URL = "https://utslogin.nlm.nih.gov/cas/v1/api-key" UMLS_SEARCH_URL = "https://uts-ws.nlm.nih.gov/rest/search/current" # --------------------------------------------------------------------------- # Named types # --------------------------------------------------------------------------- class UMLSResult(Dict[str, Optional[str]]): """ Represents a single UMLS lookup result. Keys: term, cui, name, definition """ pass # --------------------------------------------------------------------------- # NLP model loading with caching # --------------------------------------------------------------------------- @lru_cache(maxsize=None) def _load_spacy_model(model_name: str): import spacy return spacy.load(model_name) @lru_cache(maxsize=None) def _load_scispacy_model(): # Prefer the BioNLP model; fall back to the smaller sci model try: return _load_spacy_model("en_ner_bionlp13cg_md") except Exception: return _load_spacy_model("en_core_sci_sm") @lru_cache(maxsize=None) def _load_general_spacy(): return _load_spacy_model("en_core_web_sm") # --------------------------------------------------------------------------- # Concept extraction utilities # --------------------------------------------------------------------------- def _extract_entities(nlp, text: str, min_length: int) -> List[str]: """ Run a spaCy nlp pipeline over text and return unique entity texts of at least min_length. """ doc = nlp(text) ents = {ent.text.strip() for ent in doc.ents if len(ent.text.strip()) >= min_length} return list(ents) def _regex_fallback(text: str, min_length: int) -> List[str]: """ Simple regex-based token extraction for fallback. """ tokens = re.findall(r"\b[a-zA-Z0-9\-]+\b", text) return list({t for t in tokens if len(t) >= min_length}) def extract_umls_concepts(text: str, min_length: int = 3) -> List[str]: """ Extract biomedical concepts from text in priority order: 1. SciSpaCy (en_ner_bionlp13cg_md or en_core_sci_sm) 2. spaCy general NER (en_core_web_sm) 3. Regex tokens Guaranteed to return a list of unique strings. """ # 1) SciSpaCy pipeline try: scispacy_nlp = _load_scispacy_model() entities = _extract_entities(scispacy_nlp, text, min_length) if entities: return entities except ImportError: # SciSpaCy not installed pass except Exception: # Unexpected failure in scispacy pass # 2) General spaCy pipeline try: general_nlp = _load_general_spacy() entities = _extract_entities(general_nlp, text, min_length) if entities: return entities except Exception: pass # 3) Regex fallback return _regex_fallback(text, min_length) # --------------------------------------------------------------------------- # UMLS API integration # --------------------------------------------------------------------------- async def _get_umls_ticket() -> Optional[str]: """ Obtain a UMLS service ticket for subsequent queries. Returns None if API key is missing or authentication fails. """ if not UMLS_API_KEY: return None try: async with httpx.AsyncClient(timeout=10) as client: response = await client.post( UMLS_AUTH_URL, data={"apikey": UMLS_API_KEY} ) response.raise_for_status() tgt_url = response.text.split('action="')[1].split('"')[0] service_resp = await client.post( tgt_url, data={"service": "http://umlsks.nlm.nih.gov"} ) return service_resp.text except Exception: return None @lru_cache(maxsize=512) async def lookup_umls(term: str) -> UMLSResult: """ Look up a term in the UMLS API. Returns a dict containing the original term, its CUI, preferred name, and definition. On failure or quota issues, returns all values except 'term' as None. """ ticket = await _get_umls_ticket() if not ticket: return {"term": term, "cui": None, "name": None, "definition": None} params = {"string": term, "ticket": ticket, "pageSize": 1} try: async with httpx.AsyncClient(timeout=8) as client: resp = await client.get(UMLS_SEARCH_URL, params=params) resp.raise_for_status() results = resp.json().get("result", {}).get("results", []) first = results[0] if results else {} return { "term": term, "cui": first.get("ui"), "name": first.get("name"), "definition": first.get("definition") or first.get("rootSource"), } except Exception: return {"term": term, "cui": None, "name": None, "definition": None}