Skip to content

Commit 7c75907

Browse files
jhnwu3John Wu
andauthored
Add/load tsv (#542)
* update to add tsv support for BaseDataset * init test cases --------- Co-authored-by: John Wu <[email protected]>
1 parent c563621 commit 7c75907

File tree

2 files changed

+280
-11
lines changed

2 files changed

+280
-11
lines changed

pyhealth/datasets/base_dataset.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,32 +53,42 @@ def path_exists(path: str) -> bool:
5353
return Path(path).exists()
5454

5555

56-
def scan_csv_gz_or_csv(path: str) -> pl.LazyFrame:
56+
def scan_csv_gz_or_csv_tsv(path: str) -> pl.LazyFrame:
5757
"""
58-
Scan a CSV.gz or CSV file and returns a LazyFrame.
58+
Scan a CSV.gz, CSV, TSV.gz, or TSV file and returns a LazyFrame.
5959
It will fall back to the other extension if not found.
6060
6161
Args:
62-
path (str): URL or local path to a .csv or .csv.gz file
62+
path (str): URL or local path to a .csv, .csv.gz, .tsv, or .tsv.gz file
6363
6464
Returns:
65-
pl.LazyFrame: The LazyFrame for the CSV.gz or CSV file.
65+
pl.LazyFrame: The LazyFrame for the CSV.gz, CSV, TSV.gz, or TSV file.
6666
"""
67+
def scan_file(file_path: str) -> pl.LazyFrame:
68+
separator = '\t' if '.tsv' in file_path else ','
69+
return pl.scan_csv(file_path, separator=separator, infer_schema=False)
70+
6771
if path_exists(path):
68-
return pl.scan_csv(path, infer_schema=False)
72+
return scan_file(path)
73+
6974
# Try the alternative extension
7075
if path.endswith(".csv.gz"):
71-
alt_path = path[:-3] # Remove .gz
76+
alt_path = path[:-3] # Remove .gz -> try .csv
7277
elif path.endswith(".csv"):
73-
alt_path = f"{path}.gz" # Add .gz
78+
alt_path = f"{path}.gz" # Add .gz -> try .csv.gz
79+
elif path.endswith(".tsv.gz"):
80+
alt_path = path[:-3] # Remove .gz -> try .tsv
81+
elif path.endswith(".tsv"):
82+
alt_path = f"{path}.gz" # Add .gz -> try .tsv.gz
7483
else:
7584
raise FileNotFoundError(f"Path does not have expected extension: {path}")
85+
7686
if path_exists(alt_path):
7787
logger.info(f"Original path does not exist. Using alternative: {alt_path}")
78-
return pl.scan_csv(alt_path, infer_schema=False)
88+
return scan_file(alt_path)
89+
7990
raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}")
8091

81-
8292
class BaseDataset(ABC):
8393
"""Abstract base class for all PyHealth datasets.
8494
@@ -198,7 +208,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
198208
csv_path = clean_path(csv_path)
199209

200210
logger.info(f"Scanning table: {table_name} from {csv_path}")
201-
df = scan_csv_gz_or_csv(csv_path)
211+
df = scan_csv_gz_or_csv_tsv(csv_path)
202212

203213
# Convert column names to lowercase before calling preprocess_func
204214
col_names = df.collect_schema().names()
@@ -219,7 +229,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
219229
other_csv_path = f"{self.root}/{join_cfg.file_path}"
220230
other_csv_path = clean_path(other_csv_path)
221231
logger.info(f"Joining with table: {other_csv_path}")
222-
join_df = scan_csv_gz_or_csv(other_csv_path)
232+
join_df = scan_csv_gz_or_csv_tsv(other_csv_path)
223233
join_df = join_df.with_columns(
224234
[
225235
pl.col(col).alias(col.lower())

tests/core/test_tsv_load.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
from pathlib import Path
5+
6+
import polars as pl
7+
import yaml
8+
9+
from pyhealth.datasets.base_dataset import BaseDataset
10+
11+
12+
class TestTSVLoad(unittest.TestCase):
13+
"""Test TSV loading functionality with BaseDataset."""
14+
15+
def setUp(self):
16+
"""Set up temporary directory and create pseudo dataset."""
17+
self.temp_dir = tempfile.mkdtemp()
18+
self._create_pseudo_dataset()
19+
self._create_config_file()
20+
21+
def tearDown(self):
22+
"""Clean up temporary directory."""
23+
import shutil
24+
25+
if os.path.exists(self.temp_dir):
26+
shutil.rmtree(self.temp_dir)
27+
28+
def _create_pseudo_dataset(self):
29+
"""Create pseudo TSV dataset files with random data."""
30+
# Create patients table
31+
patients_data = {
32+
"patient_id": ["P001", "P002", "P003", "P004", "P005"],
33+
"gender": ["M", "F", "M", "F", "M"],
34+
"age": [45, 32, 67, 28, 53],
35+
"admission_date": [
36+
"2023-01-15",
37+
"2023-02-20",
38+
"2023-03-10",
39+
"2023-01-25",
40+
"2023-04-05",
41+
],
42+
}
43+
patients_df = pl.DataFrame(patients_data)
44+
patients_path = Path(self.temp_dir) / "patients.tsv"
45+
patients_df.write_csv(patients_path, separator="\t")
46+
47+
# Create diagnoses table
48+
diagnoses_data = {
49+
"patient_id": ["P001", "P001", "P002", "P003", "P004", "P005"],
50+
"diagnosis_code": ["A01.1", "B15.9", "C78.0", "D50.0", "E11.9", "F32.9"],
51+
"diagnosis_desc": [
52+
"Typhoid fever",
53+
"Hepatitis A",
54+
"Lung cancer",
55+
"Iron deficiency",
56+
"Type 2 diabetes",
57+
"Depression",
58+
],
59+
"timestamp": [
60+
"2023-01-15 10:00",
61+
"2023-01-16 14:30",
62+
"2023-02-20 09:15",
63+
"2023-03-10 11:45",
64+
"2023-01-25 16:20",
65+
"2023-04-05 08:30",
66+
],
67+
}
68+
diagnoses_df = pl.DataFrame(diagnoses_data)
69+
diagnoses_path = Path(self.temp_dir) / "diagnoses.tsv"
70+
diagnoses_df.write_csv(diagnoses_path, separator="\t")
71+
72+
# Create procedures table
73+
procedures_data = {
74+
"patient_id": ["P001", "P002", "P003", "P004", "P005"],
75+
"procedure_code": ["99213", "99214", "99215", "99213", "99214"],
76+
"procedure_desc": [
77+
"Office visit",
78+
"Extended visit",
79+
"Complex visit",
80+
"Office visit",
81+
"Extended visit",
82+
],
83+
"timestamp": [
84+
"2023-01-15 11:00",
85+
"2023-02-20 10:30",
86+
"2023-03-10 12:00",
87+
"2023-01-25 17:00",
88+
"2023-04-05 09:00",
89+
],
90+
}
91+
procedures_df = pl.DataFrame(procedures_data)
92+
procedures_path = Path(self.temp_dir) / "procedures.tsv"
93+
procedures_df.write_csv(procedures_path, separator="\t")
94+
95+
self.patients_file = str(patients_path)
96+
self.diagnoses_file = str(diagnoses_path)
97+
self.procedures_file = str(procedures_path)
98+
99+
def _create_config_file(self):
100+
"""Create YAML configuration file for the pseudo dataset."""
101+
config_data = {
102+
"version": "1.0",
103+
"tables": {
104+
"patients": {
105+
"file_path": "patients.tsv",
106+
"patient_id": "patient_id",
107+
"timestamp": None,
108+
"attributes": ["gender", "age", "admission_date"],
109+
},
110+
"diagnoses": {
111+
"file_path": "diagnoses.tsv",
112+
"patient_id": "patient_id",
113+
"timestamp": "timestamp",
114+
"timestamp_format": "%Y-%m-%d %H:%M",
115+
"attributes": ["diagnosis_code", "diagnosis_desc"],
116+
},
117+
"procedures": {
118+
"file_path": "procedures.tsv",
119+
"patient_id": "patient_id",
120+
"timestamp": "timestamp",
121+
"timestamp_format": "%Y-%m-%d %H:%M",
122+
"attributes": ["procedure_code", "procedure_desc"],
123+
},
124+
},
125+
}
126+
127+
self.config_path = Path(self.temp_dir) / "test_config.yaml"
128+
with open(self.config_path, "w") as f:
129+
yaml.dump(config_data, f, default_flow_style=False)
130+
131+
def test_tsv_load(self):
132+
"""Test loading TSV dataset with BaseDataset and using stats() function."""
133+
# Test loading the dataset with different table combinations
134+
tables_to_test = [
135+
["patients"],
136+
["diagnoses"],
137+
["procedures"],
138+
["patients", "diagnoses"],
139+
["diagnoses", "procedures"],
140+
["patients", "diagnoses", "procedures"],
141+
]
142+
143+
for tables in tables_to_test:
144+
with self.subTest(tables=tables):
145+
# Create BaseDataset instance
146+
dataset = BaseDataset(
147+
root=self.temp_dir,
148+
tables=tables,
149+
dataset_name="TestTSVDataset",
150+
config_path=str(self.config_path),
151+
dev=False,
152+
)
153+
154+
# Verify the dataset was loaded
155+
self.assertIsNotNone(dataset.global_event_df)
156+
self.assertIsNotNone(dataset.config)
157+
158+
# Test that we can collect the dataframe
159+
collected_df = dataset.collected_global_event_df
160+
self.assertIsInstance(collected_df, pl.DataFrame)
161+
self.assertGreater(
162+
collected_df.height, 0, "Dataset should have at least one row"
163+
)
164+
165+
# Verify patient_id column exists
166+
self.assertIn("patient_id", collected_df.columns)
167+
168+
# Test stats() function
169+
try:
170+
dataset.stats()
171+
except Exception as e:
172+
self.fail(f"dataset.stats() failed with tables {tables}: {e}")
173+
174+
def test_tsv_load_dev_mode(self):
175+
"""Test loading TSV dataset in dev mode."""
176+
# Create dataset in dev mode
177+
dataset = BaseDataset(
178+
root=self.temp_dir,
179+
tables=["patients", "diagnoses", "procedures"],
180+
dataset_name="TestTSVDatasetDev",
181+
config_path=str(self.config_path),
182+
dev=True,
183+
)
184+
185+
# Verify dev mode is enabled
186+
self.assertTrue(dataset.dev)
187+
188+
# Test stats() function in dev mode
189+
try:
190+
dataset.stats()
191+
except Exception as e:
192+
self.fail(f"dataset.stats() failed in dev mode: {e}")
193+
194+
def test_tsv_file_detection(self):
195+
"""Test that TSV files are correctly detected and loaded."""
196+
dataset = BaseDataset(
197+
root=self.temp_dir,
198+
tables=["patients"],
199+
dataset_name="TestTSVDetection",
200+
config_path=str(self.config_path),
201+
dev=False,
202+
)
203+
204+
collected_df = dataset.collected_global_event_df
205+
206+
# Verify we have the expected number of patients
207+
self.assertEqual(collected_df["patient_id"].n_unique(), 5)
208+
209+
# Verify we have the expected columns from the patients table
210+
# Note: attribute columns are prefixed with table name (e.g., "patients/gender")
211+
expected_base_columns = ["patient_id", "event_type", "timestamp"]
212+
expected_patient_columns = [
213+
"patients/gender",
214+
"patients/age",
215+
"patients/admission_date",
216+
]
217+
218+
for col in expected_base_columns:
219+
self.assertIn(col, collected_df.columns)
220+
221+
for col in expected_patient_columns:
222+
self.assertIn(col, collected_df.columns)
223+
224+
def test_multiple_tsv_tables(self):
225+
"""Test loading and joining multiple TSV tables."""
226+
dataset = BaseDataset(
227+
root=self.temp_dir,
228+
tables=["diagnoses", "procedures"],
229+
dataset_name="TestMultipleTSV",
230+
config_path=str(self.config_path),
231+
dev=False,
232+
)
233+
234+
collected_df = dataset.collected_global_event_df
235+
236+
# Should have data from both tables
237+
self.assertGreater(collected_df.height, 5) # More than just patients table
238+
239+
# Should have timestamp column since both diagnoses and procedures have timestamps
240+
self.assertIn("timestamp", collected_df.columns)
241+
242+
# Should have both diagnosis and procedure data
243+
# Note: columns from different tables are prefixed with table names
244+
all_columns = set(collected_df.columns)
245+
246+
# Check for diagnosis-specific columns (prefixed with table name)
247+
diagnosis_columns = {"diagnoses/diagnosis_code", "diagnoses/diagnosis_desc"}
248+
procedure_columns = {"procedures/procedure_code", "procedures/procedure_desc"}
249+
250+
# At least some of these should be present in the concatenated result
251+
self.assertTrue(
252+
len(diagnosis_columns.intersection(all_columns)) > 0
253+
or len(procedure_columns.intersection(all_columns)) > 0,
254+
f"Expected some diagnosis or procedure columns in {all_columns}",
255+
)
256+
257+
258+
if __name__ == "__main__":
259+
unittest.main()

0 commit comments

Comments
 (0)