diff --git a/srsly/_json_api.py b/srsly/_json_api.py index 24d25fd..46d8140 100644 --- a/srsly/_json_api.py +++ b/srsly/_json_api.py @@ -1,4 +1,4 @@ -from typing import Union, Iterable, Sequence, Any, Optional, Iterator +from typing import Any, Iterable, Dict, List, Optional, Iterator, Union, Type, cast import sys import json as _builtin_json import gzip @@ -39,6 +39,32 @@ def json_loads(data: Union[str, bytes]) -> JSONOutput: return ujson.loads(data) +def json_loads_dict(data: Union[str, bytes]) -> Dict[str, JSONOutput]: + """Deserialize unicode or bytes to a Python dict. + + data (str / bytes): The data to deserialize. + RAISES: ValueError if the loaded data is not a dict + RETURNS: The deserialized Python dict. + """ + obj = json_loads(data) + if not isinstance(obj, dict): + raise ValueError("JSON data could not be parsed to a dict.") + return obj + + +def json_loads_list(data: Union[str, bytes]) -> List[Dict[str, JSONOutput]]: + """Deserialize unicode or bytes to a Python list of dicts. + + data (str / bytes): The data to deserialize. + RAISES: ValueError if the loaded data is not a list + RETURNS: The deserialized Python list. + """ + loaded = json_loads(data) + if not isinstance(loaded, list): + raise ValueError("JSON data could not be parsed to a list of dicts.") + return loaded + + def read_json(path: FilePath) -> JSONOutput: """Load JSON from file or standard input. @@ -53,6 +79,52 @@ def read_json(path: FilePath) -> JSONOutput: return ujson.load(f) +def read_json_dict(path: FilePath) -> Dict[str, JSONOutput]: + """Load JSON from file or standard input. + + path (FilePath): The file path. "-" for reading from stdin. + RETURNS (JSONOutput): The loaded JSON content. + """ + data = read_json(path) + if not isinstance(data, dict): + raise ValueError("JSON data could not be parsed to a dict.") + return data + + +def read_json_list(path: FilePath) -> List[JSONOutput]: + """Load JSON from file or standard input. Parse as a list + + path (FilePath): The file path. "-" for reading from stdin. + RETURNS (JSONOutput): The loaded JSON content. + """ + + data = read_json(path) + if not isinstance(data, list): + raise ValueError("JSON data could not be parsed to a list.") + return data + + + +def read_json_list_of_dicts(path: FilePath, skip_invalid: bool = False) -> List[Dict[str, JSONOutput]]: + """Load JSON from file or standard input. Parse as list of dicts + + path (FilePath): The file path. "-" for reading from stdin. + RETURNS (JSONOutput): The loaded JSON content. + """ + + data = read_json(path) + if not isinstance(data, list): + raise ValueError("JSON data could not be parsed to a list.") + output = [] + for i, obj in enumerate(data): + if not isinstance(obj, dict): + if skip_invalid: + continue + raise ValueError(f"JSON object at index: {i + 1} of list could not be parsed to a valid dict.") + output.append(obj) + return output + + def read_gzip_json(path: FilePath) -> JSONOutput: """Load JSON from a gzipped file. @@ -149,6 +221,22 @@ def read_jsonl(path: FilePath, skip: bool = False) -> Iterable[JSONOutput]: yield line +def read_jsonl_dicts(path: FilePath, skip: bool = False) -> Iterable[Dict[str, JSONOutput]]: + """Read a .jsonl file or standard input and yield contents line by line. + Blank lines will always be skipped. Validates the contents of each line is a dict. + + path (FilePath): The file path. "-" for reading from stdin. + skip (bool): Skip broken lines and don't raise ValueError. + YIELDS (JSONOutput): The loaded JSON contents of each line. + """ + for i, line in enumerate(read_jsonl(path, skip=skip)): + if not isinstance(line, dict): + if skip: + continue + raise ValueError(f"Invalid JSON Object on line: {i + 1}. Line is not a valid dict.") + yield line + + def write_jsonl( path: FilePath, lines: Iterable[JSONInput], diff --git a/srsly/tests/test_json_api.py b/srsly/tests/test_json_api.py index 89ce400..1764c83 100644 --- a/srsly/tests/test_json_api.py +++ b/srsly/tests/test_json_api.py @@ -4,13 +4,20 @@ import gzip import numpy +from typing import Any, Dict, List, Union + from .._json_api import ( + JSONOutput, read_json, + read_json_dict, + read_json_list, + read_jsonl_dicts, write_json, read_jsonl, write_jsonl, read_gzip_jsonl, write_gzip_jsonl, + ) from .._json_api import write_gzip_json, json_dumps, is_json_serializable from .._json_api import json_loads @@ -262,3 +269,35 @@ def test_read_jsonl_gzip(): assert len(data[1]) == 1 assert data[0]["hello"] == "world" assert data[1]["test"] == 123 + + +READ_JSONL_DICTS_TEST_CASES = { + "invalid_str": ('"test"', ValueError()), + "invalid_num": ('-32', ValueError()), + "invalid_json_list": ('[{"hello": "world"}\n{"test": 123}]', ValueError()), + "valid_dicts": ('{"hello": "world"}\n{"test": 123}', [{"hello": "world"}, {"test": 123}]), +} + +@pytest.mark.parametrize( + "file_contents, expected", + READ_JSONL_DICTS_TEST_CASES.values(), + ids=READ_JSONL_DICTS_TEST_CASES.keys() +) +def test_read_jsonl_dicts(file_contents: str, expected: Union[List[Dict[str, JSONOutput]], ValueError]): + + with make_tempdir({"tmp.json": file_contents}) as temp_dir: + file_path = temp_dir / "tmp.json" + assert file_path.exists() + data = read_jsonl_dicts(file_path) + # Make sure this returns a generator, not just a list + assert not hasattr(data, "__len__") + try: + # actually consume the generator to trigger errors + data = list(data) + except ValueError: + assert isinstance(expected, ValueError) + else: + assert isinstance(expected, list) + assert len(data) == len(expected) + for data_item, expected_item in zip(data, expected): + assert data_item == expected_item