Skip to content

Byte-Range downloader code

  • This code was written by BhimRaj Yadav.
  • It's a great example of how to use semaphore, asyncio and aiohttp to download files much faster.
import asyncio
import aiohttp
import time
from aiohttp import ClientSession
from typing import Dict, Optional
from tqdm import tqdm
import random


async def download_chunk(
    session: ClientSession,
    url: str,
    start: int,
    stop: int,
    headers: Dict[str, str],
    buffer: bytearray,
    progress_bar: tqdm,
    retries: int = 3,
):
    """Download a specific chunk of the file and write it to the correct position in the buffer."""
    # Make a local copy of headers so that each call has its own header dict
    local_headers = headers.copy()
    local_headers.update({"Range": f"bytes={start}-{stop}"})

    attempt = 0
    while attempt < retries:
        try:
            # print(f"Downloading chunk {start}-{stop}")
            async with session.get(url, headers=local_headers) as response:
                if response.status != 206:  # 206 Partial Content is expected
                    raise Exception(
                        f"Failed to download chunk {start}-{stop}: HTTP {response.status}"
                    )
                content = await response.read()
                # Write the downloaded content into the buffer at the correct offset
                buffer[start : start + len(content)] = content

            # Update progress bar by the number of bytes expected for this chunk.
            # (Note: the final chunk might be a bit smaller, but that's fine.)
            progress_bar.update(stop - start + 1)
            return  # Successful download; exit the loop
        except Exception as e:
            print(f"Error downloading chunk {start}-{stop}: {e}")
            attempt += 1
            if attempt < retries:
                wait_time = random.uniform(1, 3)
                print(f"Retrying chunk {start}-{stop} in {wait_time:.2f} seconds...")
                await asyncio.sleep(wait_time)
            else:
                print(
                    f"Failed to download chunk {start}-{stop} after {retries} retries."
                )
                raise e


async def download_file(
    url: str,
    filename: str,
    chunk_size: int,
    max_connections: int,
    headers: Optional[Dict[str, str]] = None,
):
    """Download a file in parallel chunks using asyncio and aiohttp, storing data in a preallocated bytearray."""
    headers = headers or {}

    # Get total file size (handling redirects if necessary)
    async with aiohttp.ClientSession() as session:
        async with session.head(url, headers=headers) as response:
            if response.status == 302:
                location = response.headers.get("Location")
                if location:
                    # print(f"Redirecting to {location}")
                    url = location
                    async with session.head(url, headers=headers) as new_response:
                        if new_response.status != 200:
                            raise Exception(
                                f"Failed to get file info: HTTP {new_response.status}"
                            )
                        content_length = int(
                            new_response.headers.get("Content-Length", 0)
                        )
                else:
                    raise Exception(
                        f"Failed to get file info: HTTP {response.status} - No Location header found."
                    )
            elif response.status == 200:
                content_length = int(response.headers.get("Content-Length", 0))
            else:
                raise Exception(f"Failed to get file info: HTTP {response.status}")

    print(f"Total file size: {content_length} bytes")

    # Preallocate a bytearray for the file
    buffer = bytearray(content_length)

    with tqdm(
        total=content_length, unit="B", unit_scale=True, desc=filename
    ) as progress_bar:
        tasks = []
        semaphore = asyncio.Semaphore(max_connections)

        async with aiohttp.ClientSession() as session:
            for start in range(0, content_length, chunk_size):
                stop = min(start + chunk_size - 1, content_length - 1)
                # print(f"Chunk {start}-{stop} will be downloaded.")

                # Capture start and stop in the local scope of the task.
                async def limited_download(start=start, stop=stop):
                    async with semaphore:
                        await download_chunk(
                            session, url, start, stop, headers, buffer, progress_bar
                        )

                tasks.append(asyncio.create_task(limited_download()))

            await asyncio.gather(*tasks)

        # After downloading all chunks, write the complete buffer to file.
        with open(filename, "wb") as f:
            f.write(buffer)


# Example usage:
if __name__ == "__main__":
    url = "https://huggingface.co/microsoft/OmniParser-v2.0/resolve/main/icon_caption/model.safetensors"
    filename = "model-byte-range.safetensors"
    chunk_size = 1024 * 1024 * 1  # 1 MB chunks
    max_connections = 16  # Limit parallel connections
    print(
        f"Downloading with {max_connections} connections and chunk size of {chunk_size} bytes"
    )

    start_time = time.time()
    asyncio.run(download_file(url, filename, chunk_size, max_connections))
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Download completed in {elapsed_time:.2f} seconds.")