Skip to content

Commit 1c5aa10

Browse files
committed
Tighten up lint rules
1 parent 459de93 commit 1c5aa10

File tree

3 files changed

+51
-22
lines changed

3 files changed

+51
-22
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.7.0
3+
rev: v0.7.3
44
hooks:
55
- id: ruff
66
args:

ollama_dl.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
import logging
55
import pathlib
66
import time
7+
from typing import AsyncIterable
78
from urllib.parse import urljoin
89

910
import httpx
1011
from rich.logging import RichHandler
11-
from rich.progress import Progress
12+
from rich.progress import Progress, TaskID
13+
14+
BYTES_IN_KILOBYTE = 1024
15+
BYTES_IN_MEGABYTE = BYTES_IN_KILOBYTE**2
16+
DOWNLOAD_READ_SIZE = BYTES_IN_MEGABYTE
1217

1318
log = logging.getLogger("ollama-dl")
1419

@@ -22,16 +27,17 @@
2227

2328

2429
def get_short_hash(layer: dict) -> str:
25-
assert layer["digest"].startswith("sha256:")
30+
if not layer["digest"].startswith("sha256:"):
31+
raise ValueError(f"Unexpected digest: {layer['digest']}")
2632
return layer["digest"].partition(":")[2][:12]
2733

2834

2935
def format_size(size: int) -> str:
30-
if size < 1024:
36+
if size < BYTES_IN_KILOBYTE:
3137
return f"{size} B"
32-
if size < 1048576:
33-
return f"{size // 1024} KB"
34-
return f"{size // 1048576} MB"
38+
if size < BYTES_IN_MEGABYTE:
39+
return f"{size // BYTES_IN_KILOBYTE} KB"
40+
return f"{size // BYTES_IN_MEGABYTE} MB"
3541

3642

3743
@dataclasses.dataclass(frozen=True)
@@ -49,9 +55,9 @@ async def _inner_download(
4955
temp_path: pathlib.Path,
5056
size: int,
5157
progress: Progress,
52-
task_id,
58+
task_id: TaskID,
5359
) -> None:
54-
if size < 1048576:
60+
if size < BYTES_IN_MEGABYTE:
5561
resp = await client.get(url, follow_redirects=True)
5662
resp.raise_for_status()
5763
temp_path.write_bytes(resp.content)
@@ -71,10 +77,11 @@ async def _inner_download(
7177
headers=headers,
7278
follow_redirects=True,
7379
) as resp:
74-
assert resp.status_code == (206 if start_offset else 200)
80+
if resp.status_code != (206 if start_offset else 200):
81+
raise ValueError(f"Unexpected status code: {resp.status_code}")
7582
resp.raise_for_status()
7683
with temp_path.open("ab") as f:
77-
async for chunk in resp.aiter_bytes(1048576):
84+
async for chunk in resp.aiter_bytes(DOWNLOAD_READ_SIZE):
7885
f.write(chunk)
7986
progress.update(task_id, completed=f.tell())
8087

@@ -85,7 +92,7 @@ async def download_blob(
8592
*,
8693
progress: Progress,
8794
num_retries: int = 10,
88-
):
95+
) -> None:
8996
job.dest_path.parent.mkdir(parents=True, exist_ok=True)
9097
task_desc = f"{job.dest_path} ({format_size(job.size)})"
9198
task = progress.add_task(task_desc, total=job.size)
@@ -94,7 +101,8 @@ async def download_blob(
94101
for attempt in range(1, num_retries + 1):
95102
if attempt != 1:
96103
progress.update(
97-
task, description=f"{task_desc} (retry {attempt}/{num_retries})"
104+
task,
105+
description=f"{task_desc} (retry {attempt}/{num_retries})",
98106
)
99107
try:
100108
await _inner_download(
@@ -117,7 +125,11 @@ async def download_blob(
117125
raise
118126
else:
119127
break
120-
assert temp_path.stat().st_size == job.size
128+
result_size = temp_path.stat().st_size
129+
if result_size != job.size:
130+
raise RuntimeError(
131+
f"Did not download expected size: {result_size} != {job.size}",
132+
)
121133
temp_path.rename(job.dest_path)
122134
progress.update(task, completed=job.size)
123135
finally:
@@ -132,15 +144,16 @@ async def get_download_jobs_for_image(
132144
dest_dir: str,
133145
name: str,
134146
version: str,
135-
):
147+
) -> AsyncIterable[DownloadJob]:
136148
manifest_url = urljoin(registry, f"v2/{name}/manifests/{version}")
137149
resp = await client.get(manifest_url)
138150
resp.raise_for_status()
139151
manifest_data = resp.json()
140-
assert (
141-
manifest_data["mediaType"]
142-
== "application/vnd.docker.distribution.manifest.v2+json"
143-
)
152+
manifest_media_type = manifest_data["mediaType"]
153+
if manifest_media_type != "application/vnd.docker.distribution.manifest.v2+json":
154+
raise ValueError(
155+
f"Unexpected media type for manifest: {manifest_media_type}",
156+
)
144157
for layer in sorted(manifest_data["layers"], key=lambda x: x["size"]):
145158
file_template = media_type_to_file_template.get(layer["mediaType"])
146159
if not file_template:
@@ -159,7 +172,7 @@ async def get_download_jobs_for_image(
159172
)
160173

161174

162-
async def download(*, registry: str, name: str, version: str, dest_dir: str):
175+
async def download(*, registry: str, name: str, version: str, dest_dir: str) -> None:
163176
with Progress() as progress:
164177
async with httpx.AsyncClient() as client:
165178
tasks = []
@@ -178,7 +191,7 @@ async def download(*, registry: str, name: str, version: str, dest_dir: str):
178191
await asyncio.gather(*tasks)
179192

180193

181-
def main():
194+
def main() -> None:
182195
ap = argparse.ArgumentParser()
183196
ap.add_argument("name")
184197
ap.add_argument("--registry", default="https://registry.ollama.ai/")
@@ -208,7 +221,7 @@ def main():
208221
name=name,
209222
dest_dir=dest_dir,
210223
version=version,
211-
)
224+
),
212225
)
213226

214227

pyproject.toml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,19 @@ ollama-dl = "ollama_dl:main"
1515
dev-dependencies = [
1616
"ruff>=0.6.7",
1717
]
18+
19+
[tool.ruff.lint]
20+
select = [
21+
"ANN",
22+
"COM812",
23+
"E",
24+
"F",
25+
"I",
26+
"W",
27+
]
28+
ignore = [
29+
"D",
30+
"EM102",
31+
"PLR0913",
32+
"TRY003",
33+
]

0 commit comments

Comments
 (0)