Source code for neurotic.datasets.download

# -*- coding: utf-8 -*-
"""
The :mod:`neurotic.datasets.download` module implements a general purpose
download function that handles connecting to remote servers, performing
authentication, and downloading files with progress reporting. The function
handles various errors and will automatically reprompt the user for login
credentials if a bad user name or password is given.

The module installs an :class:`urllib.request.HTTPBasicAuthHandler` and a
:class:`neurotic.datasets.ftpauth.FTPBasicAuthHandler` at import time.

.. autofunction:: download
"""

import os
import shutil
import urllib
from getpass import getpass
import numpy as np
from tqdm.auto import tqdm

from urllib.request import HTTPBasicAuthHandler, HTTPPasswordMgrWithDefaultRealm
from .. import global_config
from ..datasets.ftpauth import FTPBasicAuthHandler
from ..datasets.gdrive import GoogleDriveDownloader

import logging
logger = logging.getLogger(__name__)


# install HTTP and FTP authentication handlers, the latter of which also adds
# reliable file size retrieval before downloading
_max_bad_login_attempts = 3
_http_auth_handler = HTTPBasicAuthHandler(HTTPPasswordMgrWithDefaultRealm())
_ftp_auth_handler = FTPBasicAuthHandler()
_opener = urllib.request.build_opener(_http_auth_handler, _ftp_auth_handler)
urllib.request.install_opener(_opener)

gdrive_downloader = GoogleDriveDownloader(
    client_secret_file=global_config['gdrive']['client_secret_file'],
    tokens_file=global_config['gdrive']['tokens_file'],
    save_tokens=global_config['gdrive']['save_tokens'],
)


[docs]def download(url, local_file, overwrite_existing=False, show_progress=True, bytes_per_chunk=1024*8): """ Download a file. """ if urllib.parse.urlparse(url).scheme == 'gdrive': return gdrive_downloader.download(url, local_file, show_progress=show_progress, bytes_per_chunk=1024*1024*5) if not overwrite_existing and os.path.exists(local_file): logger.info(f'Skipping {os.path.basename(local_file)} (already exists)') return error = None error_code = None logger.info(f'Downloading {os.path.basename(local_file)}') try: _download_with_progress_bar(url, local_file, show_progress=show_progress, bytes_per_chunk=bytes_per_chunk) except urllib.error.HTTPError as e: error = e error_code = e.code except urllib.error.URLError as e: error = e if isinstance(e.reason, str): # special cases for ftp errors if e.reason.startswith('ftp error: error_perm('): reason = e.reason[23:-2] error_code = int(reason[:3]) elif e.reason.startswith('ftp error: TimeoutError('): reason = e.reason[24:-2] error_code = int(reason[:5]) else: error_code = -1 else: error_code = e.reason.errno finally: if error_code is not None: if error_code == 404: # not found logger.error(f'Skipping {os.path.basename(local_file)} (not found on server)') raise error elif error_code == 550: # no such file or folder, or permission denied logger.error(f'Skipping {os.path.basename(local_file)} (not found on server, or user is unauthorized)') raise error elif error_code == 10060: # timeout hostname = urllib.parse.urlparse(url).hostname logger.error(f'Skipping {os.path.basename(local_file)} (timed out when connecting to {hostname})') raise error elif error_code == 11001: # could not reach server or resolve hostname hostname = urllib.parse.urlparse(url).hostname logger.error(f'Skipping {os.path.basename(local_file)} (cannot connect to {hostname})') raise error else: logger.error(f'Encountered a problem: {error}') raise error
def _download_with_progress_bar(url, local_file, show_progress=True, bytes_per_chunk=1024*8): """ Authenticate if necessary, then download while showing a progress bar. """ auth_needed = _auth_needed(url) if auth_needed: authenticated = _authenticate(url) if not auth_needed or (auth_needed and authenticated): # determine where to temporarily save the file during download temp_file = local_file + '.part' logger.debug(f'Temporarily downloading to {temp_file}') # create the containing directory if necessary if not os.path.exists(os.path.dirname(local_file)): os.makedirs(os.path.dirname(local_file)) try: with urllib.request.urlopen(urllib.parse.quote(url, safe='/:')) as dist: with open(temp_file, 'wb') as f: if show_progress: if 'Content-Length' in dist.headers: # knowing the file size allows progress to be displayed file_size_in_bytes = int(dist.headers['Content-Length']) num_chunks = int(np.ceil(file_size_in_bytes/bytes_per_chunk)) total = num_chunks*bytes_per_chunk else: # progress can't be displayed, but other stats can be total = None pbar = tqdm(total=total, unit='B', unit_scale=True) while True: chunk = dist.read(bytes_per_chunk) if chunk: f.write(chunk) if show_progress: pbar.update(bytes_per_chunk) else: break if show_progress: pbar.close() except: # the download is likely incomplete, so delete the temporary file if os.path.exists(temp_file): os.remove(temp_file) # raise the exception so that it can be handled elsewhere raise else: # download completed, so move the temp file to the final location shutil.move(temp_file, local_file) def _auth_needed(url): """ Determine whether authentication is needed by attempting to make a connection. """ # escape spaces and other unsafe characters url = urllib.parse.quote(url, safe='/:') error = None error_code = None try: # try to connect with urllib.request.urlopen(url) as dist: # if connection was successful, additional auth is not needed return False except urllib.error.HTTPError as e: error = e error_code = e.code except urllib.error.URLError as e: error = e if isinstance(e.reason, str): # special cases for ftp errors if e.reason.startswith('ftp error: error_perm('): reason = e.reason[23:-2] error_code = int(reason[:3]) else: raise else: error_code = e.reason.errno finally: if error_code is not None: if error_code in [401, 530, 553]: # unauthorized return True else: raise error def _authenticate(url): """ Perform HTTP or FTP authentication. """ # escape spaces and other unsafe characters url = urllib.parse.quote(url, safe='/:') bad_login_attempts = 0 while True: error = None error_code = None try: # try to connect with urllib.request.urlopen(url) as dist: # if connection was successful, authentication is done return True except urllib.error.HTTPError as e: error = e error_code = e.code except urllib.error.URLError as e: error = e if isinstance(e.reason, str): # special cases for ftp errors if e.reason.startswith('ftp error: error_perm('): reason = e.reason[23:-2] error_code = int(reason[:3]) else: raise else: error_code = e.reason.errno finally: if error_code is not None: if error_code == 401: # unauthorized -- will try to authenticate with http handler handler = _http_auth_handler elif error_code in [530, 553]: # unauthorized -- will try to authenticate with ftp handler handler = _ftp_auth_handler else: raise error if bad_login_attempts >= _max_bad_login_attempts: logger.error('Unauthorized: Aborting login') return False else: if bad_login_attempts == 0: logger.info('Authentication required') else: logger.warning(f'Failed login ({bad_login_attempts} of ' f'{_max_bad_login_attempts}): Bad login credentials, or ' f'else user {user} does not have permission to access ' f'{url}') bad_login_attempts += 1 netloc = urllib.parse.urlparse(url).netloc host, port = urllib.parse.splitport(netloc) user = input(f'User name on {host}: ') if not user: logger.error('No user given, aborting login') return False passwd = getpass('Password: ') handler.add_password(None, netloc, user, passwd)