44import logging
55import pathlib
66import time
7+ from typing import AsyncIterable
78from urllib .parse import urljoin
89
910import httpx
1011from rich .logging import RichHandler
11- from rich .progress import Progress
12+ from rich .progress import Progress , TaskID
13+
14+ BYTES_IN_KILOBYTE = 1024
15+ BYTES_IN_MEGABYTE = BYTES_IN_KILOBYTE ** 2
16+ DOWNLOAD_READ_SIZE = BYTES_IN_MEGABYTE
1217
1318log = logging .getLogger ("ollama-dl" )
1419
2227
2328
2429def get_short_hash (layer : dict ) -> str :
25- assert layer ["digest" ].startswith ("sha256:" )
30+ if not layer ["digest" ].startswith ("sha256:" ):
31+ raise ValueError (f"Unexpected digest: { layer ['digest' ]} " )
2632 return layer ["digest" ].partition (":" )[2 ][:12 ]
2733
2834
2935def format_size (size : int ) -> str :
30- if size < 1024 :
36+ if size < BYTES_IN_KILOBYTE :
3137 return f"{ size } B"
32- if size < 1048576 :
33- return f"{ size // 1024 } KB"
34- return f"{ size // 1048576 } MB"
38+ if size < BYTES_IN_MEGABYTE :
39+ return f"{ size // BYTES_IN_KILOBYTE } KB"
40+ return f"{ size // BYTES_IN_MEGABYTE } MB"
3541
3642
3743@dataclasses .dataclass (frozen = True )
@@ -49,9 +55,9 @@ async def _inner_download(
4955 temp_path : pathlib .Path ,
5056 size : int ,
5157 progress : Progress ,
52- task_id ,
58+ task_id : TaskID ,
5359) -> None :
54- if size < 1048576 :
60+ if size < BYTES_IN_MEGABYTE :
5561 resp = await client .get (url , follow_redirects = True )
5662 resp .raise_for_status ()
5763 temp_path .write_bytes (resp .content )
@@ -71,10 +77,11 @@ async def _inner_download(
7177 headers = headers ,
7278 follow_redirects = True ,
7379 ) as resp :
74- assert resp .status_code == (206 if start_offset else 200 )
80+ if resp .status_code != (206 if start_offset else 200 ):
81+ raise ValueError (f"Unexpected status code: { resp .status_code } " )
7582 resp .raise_for_status ()
7683 with temp_path .open ("ab" ) as f :
77- async for chunk in resp .aiter_bytes (1048576 ):
84+ async for chunk in resp .aiter_bytes (DOWNLOAD_READ_SIZE ):
7885 f .write (chunk )
7986 progress .update (task_id , completed = f .tell ())
8087
@@ -85,7 +92,7 @@ async def download_blob(
8592 * ,
8693 progress : Progress ,
8794 num_retries : int = 10 ,
88- ):
95+ ) -> None :
8996 job .dest_path .parent .mkdir (parents = True , exist_ok = True )
9097 task_desc = f"{ job .dest_path } ({ format_size (job .size )} )"
9198 task = progress .add_task (task_desc , total = job .size )
@@ -94,7 +101,8 @@ async def download_blob(
94101 for attempt in range (1 , num_retries + 1 ):
95102 if attempt != 1 :
96103 progress .update (
97- task , description = f"{ task_desc } (retry { attempt } /{ num_retries } )"
104+ task ,
105+ description = f"{ task_desc } (retry { attempt } /{ num_retries } )" ,
98106 )
99107 try :
100108 await _inner_download (
@@ -117,7 +125,11 @@ async def download_blob(
117125 raise
118126 else :
119127 break
120- assert temp_path .stat ().st_size == job .size
128+ result_size = temp_path .stat ().st_size
129+ if result_size != job .size :
130+ raise RuntimeError (
131+ f"Did not download expected size: { result_size } != { job .size } " ,
132+ )
121133 temp_path .rename (job .dest_path )
122134 progress .update (task , completed = job .size )
123135 finally :
@@ -132,15 +144,16 @@ async def get_download_jobs_for_image(
132144 dest_dir : str ,
133145 name : str ,
134146 version : str ,
135- ):
147+ ) -> AsyncIterable [ DownloadJob ] :
136148 manifest_url = urljoin (registry , f"v2/{ name } /manifests/{ version } " )
137149 resp = await client .get (manifest_url )
138150 resp .raise_for_status ()
139151 manifest_data = resp .json ()
140- assert (
141- manifest_data ["mediaType" ]
142- == "application/vnd.docker.distribution.manifest.v2+json"
143- )
152+ manifest_media_type = manifest_data ["mediaType" ]
153+ if manifest_media_type != "application/vnd.docker.distribution.manifest.v2+json" :
154+ raise ValueError (
155+ f"Unexpected media type for manifest: { manifest_media_type } " ,
156+ )
144157 for layer in sorted (manifest_data ["layers" ], key = lambda x : x ["size" ]):
145158 file_template = media_type_to_file_template .get (layer ["mediaType" ])
146159 if not file_template :
@@ -159,7 +172,7 @@ async def get_download_jobs_for_image(
159172 )
160173
161174
162- async def download (* , registry : str , name : str , version : str , dest_dir : str ):
175+ async def download (* , registry : str , name : str , version : str , dest_dir : str ) -> None :
163176 with Progress () as progress :
164177 async with httpx .AsyncClient () as client :
165178 tasks = []
@@ -178,7 +191,7 @@ async def download(*, registry: str, name: str, version: str, dest_dir: str):
178191 await asyncio .gather (* tasks )
179192
180193
181- def main ():
194+ def main () -> None :
182195 ap = argparse .ArgumentParser ()
183196 ap .add_argument ("name" )
184197 ap .add_argument ("--registry" , default = "https://registry.ollama.ai/" )
@@ -208,7 +221,7 @@ def main():
208221 name = name ,
209222 dest_dir = dest_dir ,
210223 version = version ,
211- )
224+ ),
212225 )
213226
214227
0 commit comments