Skip to content

Commit 2957aa9

Browse files
committed
Add lab events and icd 10 codes
1 parent 08c0240 commit 2957aa9

File tree

2 files changed

+107
-6
lines changed

2 files changed

+107
-6
lines changed

examples/foundation_ehr/multimodal_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
# Apply multimodal task
3434
task = EHRFoundationalModelMIMIC4()
35-
samples = dataset.set_task(task, cache_dir=f"{CACHE_DIR}/task", num_workers=8)
35+
samples = dataset.set_task(task)
3636

3737
# Get and print sample
3838
sample = samples[0]

pyhealth/tasks/ehr_foundational_model_mimic4.py

Lines changed: 106 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,38 @@
11
from datetime import datetime
2-
from typing import Any, Dict, List, Optional, Union, Tuple
2+
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
3+
import polars as pl
34

45
from pyhealth.tasks.base_task import BaseTask
56

67
class EHRFoundationalModelMIMIC4(BaseTask):
7-
8+
89
task_name: str = "EHRFoundationalModelMIMIC4"
910
TOKEN_REPRESENTING_MISSING_TEXT = "<missing>"
1011
TOKEN_REPRESENTING_MISSING_FLOAT = float("nan")
11-
12+
PADDING: int = 0
13+
14+
LAB_CATEGORIES: ClassVar[Dict[str, List[str]]] = {
15+
"Sodium": ["50824", "52455", "50983", "52623"],
16+
"Potassium": ["50822", "52452", "50971", "52610"],
17+
"Chloride": ["50806", "52434", "50902", "52535"],
18+
"Bicarbonate": ["50803", "50804"],
19+
"Glucose": ["50809", "52027", "50931", "52569"],
20+
"Calcium": ["50808", "51624"],
21+
"Magnesium": ["50960"],
22+
"Anion Gap": ["50868", "52500"],
23+
"Osmolality": ["52031", "50964", "51701"],
24+
"Phosphate": ["50970"],
25+
}
26+
27+
LAB_CATEGORY_NAMES: ClassVar[List[str]] = [
28+
"Sodium", "Potassium", "Chloride", "Bicarbonate", "Glucose",
29+
"Calcium", "Magnesium", "Anion Gap", "Osmolality", "Phosphate",
30+
]
31+
32+
LABITEMS: ClassVar[List[str]] = [
33+
item for itemids in LAB_CATEGORIES.values() for item in itemids
34+
]
35+
1236
def __init__(self):
1337
"""Initialize the EHR Foundational Model task."""
1438
self.input_schema: Dict[str, Union[str, Tuple[str, Dict]]] = {
@@ -25,7 +49,9 @@ def __init__(self):
2549
"tokenizer_name": "bert-base-uncased",
2650
"type_tag": "note",
2751
},
28-
)
52+
),
53+
"icd_codes": ("stagenet", {"padding": self.PADDING}),
54+
"labs": ("stagenet_tensor", {}),
2955
}
3056
self.output_schema: Dict[str, str] = {"mortality": "binary"}
3157

@@ -79,11 +105,26 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
79105
all_discharge_times_from_admission: List[float] = []
80106
all_radiology_texts: List[str] = []
81107
all_radiology_times_from_admission: List[float] = []
108+
all_icd_codes: List[List[str]] = []
109+
all_icd_times: List[float] = []
110+
all_lab_values: List[List[Any]] = []
111+
all_lab_times: List[float] = []
112+
previous_admission_time = None
82113

83114
# Process each admission independently (per hadm_id)
84115
for admission in admissions_to_process:
85116
admission_time = admission.timestamp
86117

118+
try:
119+
admission_dischtime = datetime.strptime(
120+
admission.dischtime, "%Y-%m-%d %H:%M:%S"
121+
)
122+
except (ValueError, AttributeError):
123+
continue
124+
125+
if admission_dischtime < admission_time:
126+
continue
127+
87128
# Get notes for this hadm_id only
88129
discharge_notes = patient.get_events(
89130
event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)]
@@ -122,6 +163,64 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
122163
all_radiology_texts.append(self.TOKEN_REPRESENTING_MISSING_TEXT) # Token representing missing text
123164
all_radiology_times_from_admission.append(self.TOKEN_REPRESENTING_MISSING_FLOAT) # Token representing missing time(?)
124165

166+
# ICD codes (diagnoses + procedures) with time relative to previous admission
167+
diagnoses_icd = patient.get_events(
168+
event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)]
169+
)
170+
procedures_icd = patient.get_events(
171+
event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)]
172+
)
173+
visit_icd_codes = (
174+
[e.icd_code for e in diagnoses_icd if hasattr(e, "icd_code") and e.icd_code] +
175+
[e.icd_code for e in procedures_icd if hasattr(e, "icd_code") and e.icd_code]
176+
)
177+
if visit_icd_codes:
178+
if previous_admission_time is None:
179+
time_from_previous = 0.0
180+
else:
181+
time_from_previous = (admission_time - previous_admission_time).total_seconds() / 3600.0
182+
all_icd_codes.append(visit_icd_codes)
183+
all_icd_times.append(time_from_previous)
184+
185+
previous_admission_time = admission_time
186+
187+
# Lab events with time relative to this admission's start
188+
labevents_df = patient.get_events(
189+
event_type="labevents",
190+
start=admission_time,
191+
end=admission_dischtime,
192+
return_df=True,
193+
)
194+
labevents_df = labevents_df.filter(
195+
pl.col("labevents/itemid").is_in(self.LABITEMS)
196+
)
197+
if labevents_df.height > 0:
198+
labevents_df = labevents_df.with_columns(
199+
pl.col("labevents/storetime").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S")
200+
)
201+
labevents_df = labevents_df.filter(
202+
pl.col("labevents/storetime") <= admission_dischtime
203+
)
204+
if labevents_df.height > 0:
205+
labevents_df = labevents_df.select(
206+
pl.col("timestamp"),
207+
pl.col("labevents/itemid"),
208+
pl.col("labevents/valuenum").cast(pl.Float64),
209+
)
210+
for lab_ts in sorted(labevents_df["timestamp"].unique().to_list()):
211+
ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts)
212+
lab_vector: List[Any] = []
213+
for category_name in self.LAB_CATEGORY_NAMES:
214+
category_value = None
215+
for itemid in self.LAB_CATEGORIES[category_name]:
216+
matching = ts_labs.filter(pl.col("labevents/itemid") == itemid)
217+
if matching.height > 0:
218+
category_value = matching["labevents/valuenum"][0]
219+
break
220+
lab_vector.append(category_value)
221+
all_lab_values.append(lab_vector)
222+
all_lab_times.append((lab_ts - admission_time).total_seconds() / 3600.0)
223+
125224
discharge_note_times_from_admission = (all_discharge_texts, all_discharge_times_from_admission)
126225
radiology_note_times_from_admission = (all_radiology_texts, all_radiology_times_from_admission)
127226

@@ -130,6 +229,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
130229
"patient_id": patient.patient_id,
131230
"discharge_note_times": discharge_note_times_from_admission,
132231
"radiology_note_times": radiology_note_times_from_admission,
232+
"icd_codes": (all_icd_times, all_icd_codes),
233+
"labs": (all_lab_times, all_lab_values),
133234
"mortality": mortality_label,
134235
}
135-
]
236+
]

0 commit comments

Comments
 (0)