Skip to content

Commit 1ac5a19

Browse files
laltengemini-code-assist[bot]aignas
authored
fix: Quote all files if original RECORD had all files quoted (#3515)
When patching a single file in Pytorch, repack_whl.py will print over 20k lines of RECORD.patch. The reason is that the original RECORD has all filenames quoted for some reason, but the automatically generated one quotes only when required (such as commas in file names, see #2269). This PR refactors the wheelmaker.py to still use the csv.writer to auto-quote, but adds an additional detection for forced quote usage. This makes the RECORD.patch match intuitive expectations of what could change. There are some relevant existing tests in examples/wheel/wheel_test.py and tests/whl_filegroup/extract_wheel_files_test.py --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Ignas Anikevicius <240938+aignas@users.noreply.github.com>
1 parent e7e659a commit 1ac5a19

File tree

7 files changed

+136
-30
lines changed

7 files changed

+136
-30
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ END_UNRELEASED_TEMPLATE
6161
* (binaries/tests) The `PYTHONBREAKPOINT` environment variable is automatically inherited
6262
* (binaries/tests) The {obj}`stamp` attribute now transitions the Bazel builtin
6363
{obj}`--stamp` flag.
64+
* (pypi) Now the RECORD file patches will follow the quoted or unquoted filenames convention
65+
in order to make `pytorch` and friends easier to patch.
6466

6567
{#v0-0-0-fixed}
6668
### Fixed

python/private/pypi/BUILD.bazel

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
16+
load("//python:py_library.bzl", "py_library")
1617

1718
package(default_visibility = ["//:__subpackages__"])
1819

@@ -377,6 +378,12 @@ bzl_library(
377378
],
378379
)
379380

381+
py_library(
382+
name = "repack_whl",
383+
srcs = ["repack_whl.py"],
384+
deps = ["//tools:wheelmaker"],
385+
)
386+
380387
bzl_library(
381388
name = "requirements_files_by_platform_bzl",
382389
srcs = ["requirements_files_by_platform.bzl"],

python/private/pypi/repack_whl.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@
4444
_DISTINFO = "dist-info"
4545

4646

47+
def _has_all_quoted_filenames(record_contents: str) -> bool:
48+
"""Check if all filenames in the RECORD are quoted.
49+
50+
Some wheels (like torch) have all filenames quoted in their RECORD file.
51+
We detect this to preserve the quoting style when repacking.
52+
"""
53+
lines = record_contents.splitlines()
54+
return all(line.startswith('"') for line in lines)
55+
56+
4757
def _unidiff_output(expected, actual, record):
4858
"""
4959
Helper function. Returns a string containing the unified diff of two
@@ -151,17 +161,21 @@ def main(sys_argv):
151161
logging.debug(f"Found dist-info dir: {distinfo_dir}")
152162
record_path = distinfo_dir / "RECORD"
153163
record_contents = record_path.read_text() if record_path.exists() else ""
164+
quote_files = _has_all_quoted_filenames(record_contents)
154165
distribution_prefix = distinfo_dir.with_suffix("").name
155166

156167
with _WhlFile(
157-
args.output, mode="w", distribution_prefix=distribution_prefix
168+
args.output,
169+
mode="w",
170+
distribution_prefix=distribution_prefix,
171+
quote_all_filenames=quote_files,
158172
) as out:
159173
for p in _files_to_pack(patched_wheel_dir, record_contents):
160174
rel_path = p.relative_to(patched_wheel_dir)
161175
out.add_file(str(rel_path), p)
162176

163177
logging.debug(f"Writing RECORD file")
164-
got_record = out.add_recordfile().decode("utf-8", "surrogateescape")
178+
got_record = out.add_recordfile()
165179

166180
if got_record == record_contents:
167181
logging.info(f"Created a whl file: {args.output}")

tests/pypi/repack_whl/BUILD.bazel

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
load("//python:py_test.bzl", "py_test")
2+
3+
py_test(
4+
name = "repack_whl_test",
5+
size = "small",
6+
srcs = ["repack_whl_test.py"],
7+
deps = ["//python/private/pypi:repack_whl"],
8+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import unittest
2+
3+
from python.private.pypi import repack_whl
4+
5+
6+
class HasAllQuotedFilenamesTest(unittest.TestCase):
7+
"""Tests for _has_all_quoted_filenames detection logic."""
8+
9+
def test_all_quoted(self) -> None:
10+
"""Returns True when all lines start with quotes (torch-style)."""
11+
record = """\
12+
"torch/__init__.py",sha256=abc,123
13+
"torch/utils.py",sha256=def,456
14+
"torch-2.0.0.dist-info/WHEEL",sha256=ghi,789
15+
"""
16+
self.assertTrue(repack_whl._has_all_quoted_filenames(record))
17+
18+
def test_none_quoted(self) -> None:
19+
"""Returns False when no lines are quoted (standard style)."""
20+
record = """\
21+
torch/__init__.py,sha256=abc,123
22+
torch/utils.py,sha256=def,456
23+
torch-2.0.0.dist-info/WHEEL,sha256=ghi,789
24+
"""
25+
self.assertFalse(repack_whl._has_all_quoted_filenames(record))
26+
27+
def test_mixed_quoting(self) -> None:
28+
"""Returns False when only some lines are quoted."""
29+
record = """\
30+
"file,with,commas.py",sha256=abc,123
31+
normal_file.py,sha256=def,456
32+
"""
33+
self.assertFalse(repack_whl._has_all_quoted_filenames(record))
34+
35+
36+
if __name__ == "__main__":
37+
unittest.main()

tests/tools/wheelmaker_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,45 @@
1+
import io
12
import unittest
23

34
import tools.wheelmaker as wheelmaker
45

56

7+
class QuoteAllFilenamesTest(unittest.TestCase):
8+
"""Tests for quote_all_filenames behavior in _WhlFile.
9+
10+
Some wheels (like torch) have all filenames quoted in their RECORD file.
11+
When repacking, we preserve this style to minimize diffs.
12+
"""
13+
14+
def _make_whl_file(self, quote_all: bool) -> wheelmaker._WhlFile:
15+
"""Create a _WhlFile instance for testing."""
16+
buf = io.BytesIO()
17+
return wheelmaker._WhlFile(
18+
buf,
19+
mode="w",
20+
distribution_prefix="test-1.0.0",
21+
quote_all_filenames=quote_all,
22+
)
23+
24+
def test_quote_all_quotes_simple_filenames(self) -> None:
25+
"""When quote_all_filenames=True, all filenames are quoted."""
26+
whl = self._make_whl_file(quote_all=True)
27+
self.assertEqual(whl._quote_filename("foo/bar.py"), '"foo/bar.py"')
28+
29+
def test_quote_all_false_leaves_simple_filenames_unquoted(self) -> None:
30+
"""When quote_all_filenames=False, simple filenames stay unquoted."""
31+
whl = self._make_whl_file(quote_all=False)
32+
self.assertEqual(whl._quote_filename("foo/bar.py"), "foo/bar.py")
33+
34+
def test_quote_all_quotes_filenames_with_commas(self) -> None:
35+
"""Filenames with commas are always quoted, regardless of quote_all_filenames."""
36+
whl = self._make_whl_file(quote_all=True)
37+
self.assertEqual(whl._quote_filename("foo,bar/baz.py"), '"foo,bar/baz.py"')
38+
39+
whl = self._make_whl_file(quote_all=False)
40+
self.assertEqual(whl._quote_filename("foo,bar/baz.py"), '"foo,bar/baz.py"')
41+
42+
643
class ArcNameFromTest(unittest.TestCase):
744
def test_arcname_from(self) -> None:
845
# (name, distribution_prefix, strip_path_prefixes, want) tuples

tools/wheelmaker.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,17 @@ def __init__(
132132
distribution_prefix: str,
133133
strip_path_prefixes=None,
134134
compression=zipfile.ZIP_DEFLATED,
135+
quote_all_filenames: bool = False,
135136
**kwargs,
136137
):
137138
self._distribution_prefix = distribution_prefix
138139

139140
self._strip_path_prefixes = strip_path_prefixes or []
140-
# Entries for the RECORD file as (filename, hash, size) tuples.
141-
self._record = []
141+
# Entries for the RECORD file as (filename, digest, size) tuples.
142+
self._record: list[tuple[str, str, str]] = []
143+
# Whether to quote filenames in the RECORD file (for compatibility with
144+
# some wheels like torch that have quoted filenames in their RECORD).
145+
self.quote_all_filenames = quote_all_filenames
142146

143147
super().__init__(filename, mode=mode, compression=compression, **kwargs)
144148

@@ -192,16 +196,15 @@ def add_string(self, filename, contents):
192196
hash.update(contents)
193197
self._add_to_record(filename, self._serialize_digest(hash), len(contents))
194198

195-
def _serialize_digest(self, hash):
199+
def _serialize_digest(self, hash) -> str:
196200
# https://www.python.org/dev/peps/pep-0376/#record
197201
# "base64.urlsafe_b64encode(digest) with trailing = removed"
198202
digest = base64.urlsafe_b64encode(hash.digest())
199203
digest = b"sha256=" + digest.rstrip(b"=")
200-
return digest
204+
return digest.decode("utf-8", "surrogateescape")
201205

202-
def _add_to_record(self, filename, hash, size):
203-
size = str(size).encode("ascii")
204-
self._record.append((filename, hash, size))
206+
def _add_to_record(self, filename: str, hash: str, size: int) -> None:
207+
self._record.append((filename, hash, str(size)))
205208

206209
def _zipinfo(self, filename):
207210
"""Construct deterministic ZipInfo entry for a file named filename"""
@@ -223,29 +226,27 @@ def _zipinfo(self, filename):
223226
zinfo.compress_type = self.compression
224227
return zinfo
225228

226-
def add_recordfile(self):
229+
def _quote_filename(self, filename: str) -> str:
230+
"""Return a possibly quoted filename for RECORD file."""
231+
filename = filename.lstrip("/")
232+
# Some RECORDs like torch have *all* filenames quoted and we must minimize diff.
233+
# Otherwise, we quote only when necessary (e.g. for filenames with commas).
234+
quoting = csv.QUOTE_ALL if self.quote_all_filenames else csv.QUOTE_MINIMAL
235+
with io.StringIO() as buf:
236+
csv.writer(buf, quoting=quoting).writerow([filename])
237+
return buf.getvalue().strip()
238+
239+
def add_recordfile(self) -> str:
227240
"""Write RECORD file to the distribution."""
228241
record_path = self.distinfo_path("RECORD")
229-
entries = self._record + [(record_path, b"", b"")]
230-
with io.StringIO() as contents_io:
231-
writer = csv.writer(contents_io, lineterminator="\n")
232-
for filename, digest, size in entries:
233-
if isinstance(filename, str):
234-
filename = filename.lstrip("/")
235-
writer.writerow(
236-
(
237-
(
238-
c
239-
if isinstance(c, str)
240-
else c.decode("utf-8", "surrogateescape")
241-
)
242-
for c in (filename, digest, size)
243-
)
244-
)
245-
246-
contents = contents_io.getvalue()
247-
self.add_string(record_path, contents)
248-
return contents.encode("utf-8", "surrogateescape")
242+
entries = self._record + [(record_path, "", "")]
243+
entries = [
244+
(self._quote_filename(fname), digest, size)
245+
for fname, digest, size in entries
246+
]
247+
contents = "\n".join(",".join(entry) for entry in entries) + "\n"
248+
self.add_string(record_path, contents)
249+
return contents
249250

250251

251252
class WheelMaker(object):

0 commit comments

Comments
 (0)