gpt-base-2048-clmbr / README.md
Miking98's picture
Update README.md
db245ed verified
metadata
license: cc-by-nc-4.0
library_name: gpt-base-2048-clmbr
tags:
  - healthcare
  - medical
extra_gated_prompt: >-
  You agree to all terms outlined in 'The EHRSHOT Credentialed Health Data
  License' (see https://shahlab.stanford.edu/ehrshot_license). Access requires a
  verified CITI training certificate using the same process outlined by
  PhysioNet (see https://physionet.org/about/citi-course/). Please complete the
  'Data or Specimens Only Research' course and please provide proof via the
  verification URL, which takes the form
  https://www.citiprogram.org/verify/?XXXXXX. You agree to not use the model to
  conduct experiments that cause harm to human subjects.
extra_gated_fields:
  Full Name: text
  Email: text
  Affiliation: text
  CITI Certification Verification URL: text
  I agree to all terms outlined in 'The EHRSHOT Credentialed Health Data License': checkbox
  I agree to use this model for non-commercial use ONLY: checkbox

gpt-base-2048-clmbr

This is a gpt model with context length 2048 with 117209088 parameters from the Context Clues paper

It is a foundation model trained from scratch on the structured data within 2.57 million deidentified EHRs from Stanford Medicine.

As input, this model expects a sequence of coded medical events that have been mapped to Standard Concepts within the OMOP-CDM vocabulary. As output, the model can generate either (a) synthetic future timelines or (b) a vector representation of a patient which can then be used for downstream prediction tasks.

Usage

First, install the hf_ehr package:

pip install transformers torch hf_ehr

Second, run this Python script to do inference on a patient representation:

from transformers import AutoModelForCausalLM, AutoTokenizer
from hf_ehr.data.tokenization import CLMBRTokenizer
from hf_ehr.config import Event
from typing import List, Dict
import torch

####################################
# 1. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("StanfordShahLab/gpt-base-2048-clmbr")
tokenizer = AutoTokenizer.from_pretrained("StanfordShahLab/gpt-base-2048-clmbr")

####################################
# 2. Define patient as sequence of `Event` objects. Only `code` is required.
patient: List[Event] = [
    Event(code='SNOMED/3950001', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='Gender/F', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='Ethnicity/Hispanic', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='SNOMED/609040007', value=None, unit=None, start=None, end=None, omop_table=None),
    Event(code='LOINC/2236-8', value=-3.0, unit=None, start=None, end=None, omop_table=None),
    Event(code='SNOMED/12199005', value=26.3, unit=None, start=None, end=None, omop_table=None),        
]

####################################
# 3. Tokenize patient
batch: Dict[str, torch.Tensor] = tokenizer([ patient ], add_special_tokens=True, return_tensors='pt')
# > batch = {
#     'input_ids': tensor([[ 5, 0, 7, 9, 27, 2049, 6557, 22433, 1]]), 
#     'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]), 
#     'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])
# }
textual_tokens: List[str] = tokenizer.convert_events_to_tokens(patient)
# > textual_tokens = ['SNOMED/3950001', 'Gender/F', 'Ethnicity/Hispanic', 'SNOMED/609040007', 'LOINC/2236-8 || None || -1.7976931348623157e+308 - 4.0', 'SNOMED/12199005 || None || 26.0 - 28.899999618530273']

####################################
# 4. Run model
logits = model(**batch).logits
# > logits.shape = torch.Size([1, 9, 39818])

####################################
# 5. Get patient representation for finetuning (usually we choose the last token's logits)
representation = logits[:, -1, :]
# > representation.shape = torch.Size([1, 39818])

Model Details

  • Developed by: Shah lab @ Stanford University
  • Funded by: Stanford Healthcare
  • Shared by: Shah lab @ Stanford University
  • Model type: gpt
  • Languages: Electronic health record codes (as standardized by the OMOP-CDM)
  • License: CC-BY NC 4.0
  • Finetuned from model: N/A -- trained from scratch

Uses

This model is intended to generate representations for patients based on the structured data within their electronic health record. These representations can then be used for downstream tasks such as predicting diagnoses, detecting anomalies, or doing propensity score matching for causal inference.

Direct Use

You will likely want to tune the model for your downstream use case.

Out-of-Scope Use

This model is for research purposes only. It is not for use in any real-world decision making that impacts patients, providers, or hospital operations.

Bias, Risks, and Limitations

This model was trained on a corpus of 2 billion tokens sourced from 2.57 million patients from Stanford Medicine. The model will thus reflect the patterns of how care is delivered at Stanford Medicine, in addition to the racial and socioeconomic makeup of Stanford Medicine's patient base. This model may not generalize well to other hospitals and demographic mixes.

While this is technically a generative model, we have not tested its generative abilities and thus do not anticipate it being used to generate synthetic EHR records. We aim to explore its generative abilities in future work.

Training Details

Full training details are provided in our accompanying paper, Context Clues.

Training Data

The model is trained on 2 billion tokens sourced from 2.57 million patients from the Stanford Medicine Research Data Repository (STARR), which contains structured EHR data from both Stanford Health Care (primarily adult care) and Lucile Packard Children’s Hospital (primarily pediatric care). The dataset contains only structured data (i.e. no clinical text or images) and covers demographics (e.g. age, sex, race), diagnoses, procedures, laboratory results, medication prescriptions, and other coded clinical observations. The data is formatted according to the Observational Medical Outcomes Partnership Common Data Model (OMOP-CDM). All data that we work with is deidentified.

Training Procedure

We train our model using an autoregressive next code prediction objective, i.e. predict the next code in a patient's timeline given their previous codes.

Citation

BibTeX:

@article{wornow2024contextclues,
      title={Context Clues: Evaluating Long Context Models for Clinical Prediction Tasks on EHRs}, 
      author={Michael Wornow and Suhana Bedi and Miguel Angel Fuentes Hernandez and Ethan Steinberg and Jason Alan Fries and Christopher Ré and Sanmi Koyejo and Nigam H. Shah},
      year={2024},
      eprint={2412.16178},
      url={https://arxiv.org/abs/2412.16178}, 
}

Model Card Authors

Michael Wornow, Suhana Bedi, Ethan Steinberg