Skip to content

Commit 08c0240

Browse files
authored
Merge pull request #1 from will-pang/FoundationalEHR/wp-create-multimodal-task-notes
Foundational ehr/wp create multimodal task notes
2 parents 0584779 + 53761d6 commit 08c0240

File tree

3 files changed

+37
-164
lines changed

3 files changed

+37
-164
lines changed

examples/foundation_ehr/multimodal_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
note_tables=["discharge", "radiology"],
2828
cache_dir=CACHE_DIR,
2929
num_workers=8,
30-
# dev=True
30+
dev=True
3131
)
3232

3333
# Apply multimodal task

pyhealth/tasks/ehr_foundational_model_mimic4.py

Lines changed: 36 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
class EHRFoundationalModelMIMIC4(BaseTask):
77

88
task_name: str = "EHRFoundationalModelMIMIC4"
9+
TOKEN_REPRESENTING_MISSING_TEXT = "<missing>"
10+
TOKEN_REPRESENTING_MISSING_FLOAT = float("nan")
911

1012
def __init__(self):
1113
"""Initialize the EHR Foundational Model task."""
@@ -23,47 +25,14 @@ def __init__(self):
2325
"tokenizer_name": "bert-base-uncased",
2426
"type_tag": "note",
2527
},
26-
),
27-
"icd_codes": (
28-
"stagenet",
29-
{"padding": 0
30-
}
31-
),
28+
)
3229
}
3330
self.output_schema: Dict[str, str] = {"mortality": "binary"}
3431

3532
def _clean_text(self, text: Optional[str]) -> Optional[str]:
3633
"""Return text if non-empty, otherwise None."""
3734
return text if text else None
3835

39-
def _compute_time_diffs(self, notes_with_timestamps, first_admission_time):
40-
"""Compute hourly time offsets for notes relative to first admission.
41-
42-
Sorts notes chronologically by timestamp, then computes each note's
43-
offset (in hours) from the first admission time.
44-
45-
Args:
46-
notes_with_timestamps: List of (text, timestamp) tuples where
47-
text is the clinical note string and timestamp is a datetime.
48-
first_admission_time: datetime of the patient's first admission,
49-
used as the anchor (t=0) for all time offsets.
50-
51-
Returns:
52-
Tuple of (texts, time_diffs) where:
53-
- texts: List[str] of note contents, sorted chronologically
54-
- time_diffs: List[float] of hours since first admission
55-
Returns (["<missing>"], [0.0]) if no notes are available.
56-
"""
57-
result = []
58-
59-
if not notes_with_timestamps:
60-
return (["<missing>"], [0.0]) # TODO: Need to also figure out how to tokenize missing timestamps
61-
notes_with_timestamps.sort(key=lambda x: x[1])
62-
result = [(text, (ts - first_admission_time).total_seconds() / 3600) for text, ts in notes_with_timestamps]
63-
texts, time_diffs = zip(*result)
64-
65-
return (list(texts), list(time_diffs))
66-
6736
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
6837
# Get demographic info to filter by age
6938
demographics = patient.get_events(event_type="patients")
@@ -105,102 +74,62 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
10574
if len(admissions_to_process) == 0:
10675
return []
10776

108-
# Get first admission time as reference for notes time offset
109-
first_admission_time = admissions_to_process[0].timestamp
77+
# Aggregated notes and time offsets across all admissions (per hadm_id)
78+
all_discharge_texts: List[str] = []
79+
all_discharge_times_from_admission: List[float] = []
80+
all_radiology_texts: List[str] = []
81+
all_radiology_times_from_admission: List[float] = []
11082

111-
# Aggregated data across all admissions
112-
all_discharge_notes_timestamped = [] # List of (note_text, timestamp) tuples
113-
all_radiology_notes_timestamped = [] # List of (note_text, timestamp) tuples
114-
all_icd_codes = [] # ICD code lists per visit
115-
all_icd_times = [] # Hours from first admission per visit
116-
117-
# Process each admission and aggregate data
83+
# Process each admission independently (per hadm_id)
11884
for admission in admissions_to_process:
119-
# Parse admission discharge time for lab events filtering
120-
try:
121-
admission_dischtime = datetime.strptime(
122-
admission.dischtime, "%Y-%m-%d %H:%M:%S"
123-
)
124-
except (ValueError, AttributeError):
125-
# If we can't parse discharge time, skip this admission
126-
continue
127-
128-
# Skip if discharge is before admission (data quality issue)
129-
if admission_dischtime < admission.timestamp:
130-
continue
131-
132-
# Get notes using hadm_id filtering
85+
admission_time = admission.timestamp
86+
87+
# Get notes for this hadm_id only
13388
discharge_notes = patient.get_events(
13489
event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)]
13590
)
13691
radiology_notes = patient.get_events(
13792
event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)]
13893
)
13994

140-
# Extract and aggregate notes as individual items in lists
141-
# Note: attribute is "text" (from mimic4_note.yaml), not "discharge"/"radiology"
142-
for note in discharge_notes:
95+
for note in discharge_notes: #TODO: Maybe make this into a helper function?
14396
try:
14497
note_text = self._clean_text(note.text)
14598
if note_text:
146-
all_discharge_notes_timestamped.append((note_text, note.timestamp))
147-
except AttributeError:
99+
time_from_admission = (
100+
note.timestamp - admission_time
101+
).total_seconds() / 3600.0
102+
all_discharge_texts.append(note_text)
103+
all_discharge_times_from_admission.append(time_from_admission)
104+
except AttributeError: # note object is missing .text or .timestamp attribute (e.g. malformed note)
148105
pass
106+
if not discharge_notes: # If we get an empty list
107+
all_discharge_texts.append(self.TOKEN_REPRESENTING_MISSING_TEXT) # Token representing missing text
108+
all_discharge_times_from_admission.append(self.TOKEN_REPRESENTING_MISSING_FLOAT) # Token representing missing time(?)
149109

150-
for note in radiology_notes:
110+
for note in radiology_notes: #TODO: Maybe make this into a helper function?
151111
try:
152112
note_text = self._clean_text(note.text)
153113
if note_text:
154-
all_radiology_notes_timestamped.append((note_text, note.timestamp))
155-
except AttributeError:
114+
time_from_admission = (
115+
note.timestamp - admission_time
116+
).total_seconds() / 3600.0
117+
all_radiology_texts.append(note_text)
118+
all_radiology_times_from_admission.append(time_from_admission)
119+
except AttributeError: # note object is missing .text or .timestamp attribute (e.g. malformed note)
156120
pass
121+
if not radiology_notes: # If we receive empty list
122+
all_radiology_texts.append(self.TOKEN_REPRESENTING_MISSING_TEXT) # Token representing missing text
123+
all_radiology_times_from_admission.append(self.TOKEN_REPRESENTING_MISSING_FLOAT) # Token representing missing time(?)
157124

158-
# Get diagnosis codes for this admission using hadm_id
159-
diagnoses_icd = patient.get_events(
160-
event_type="diagnoses_icd",
161-
filters=[("hadm_id", "==", admission.hadm_id)],
162-
)
163-
visit_diagnoses = [
164-
event.icd_code
165-
for event in diagnoses_icd
166-
if hasattr(event, "icd_code") and event.icd_code
167-
]
168-
169-
# Get procedure codes for this admission using hadm_id
170-
procedures_icd = patient.get_events(
171-
event_type="procedures_icd",
172-
filters=[("hadm_id", "==", admission.hadm_id)],
173-
)
174-
visit_procedures = [
175-
event.icd_code
176-
for event in procedures_icd
177-
if hasattr(event, "icd_code") and event.icd_code
178-
]
179-
180-
# Combine diagnoses and procedures into single ICD code list
181-
visit_icd_codes = visit_diagnoses + visit_procedures
182-
183-
# Calculate time from admission start (hours)
184-
if visit_icd_codes:
185-
time_from_first = (
186-
admission.timestamp - first_admission_time
187-
).total_seconds() / 3600.0
188-
all_icd_codes.append(visit_icd_codes)
189-
all_icd_times.append(time_from_first)
190-
191-
# Convert (note_text, timestamp) tuples to (note_text, time_diff_hours) tuples
192-
discharge_note_times = self._compute_time_diffs(all_discharge_notes_timestamped, first_admission_time)
193-
radiology_note_times = self._compute_time_diffs(all_radiology_notes_timestamped, first_admission_time)
194-
195-
# icd_codes: (List[List[str]], List[float]) — codes per visit, hours from first admission
196-
icd_codes = (all_icd_codes, all_icd_times)
125+
discharge_note_times_from_admission = (all_discharge_texts, all_discharge_times_from_admission)
126+
radiology_note_times_from_admission = (all_radiology_texts, all_radiology_times_from_admission)
197127

198128
return [
199129
{
200130
"patient_id": patient.patient_id,
201-
"discharge_note_times": discharge_note_times,
202-
"radiology_note_times": radiology_note_times,
203-
"icd_codes": icd_codes,
131+
"discharge_note_times": discharge_note_times_from_admission,
132+
"radiology_note_times": radiology_note_times_from_admission,
204133
"mortality": mortality_label,
205134
}
206135
]

tests/core/test_ehr_foundational_model_mimic4.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

0 commit comments

Comments
 (0)