Source code for unimol_tools.weights.weighthub

import os

from ..utils import logger

try:
    from huggingface_hub import snapshot_download
except:
    huggingface_hub_installed = False

    def snapshot_download(*args, **kwargs):
        raise ImportError(
            'huggingface_hub is not installed. If weights are not avaliable, please install it by running: pip install huggingface_hub. Otherwise, please download the weights manually from https://huggingface.co/dptech/Uni-Mol-Models'
        )


DEFAULT_WEIGHT_DIR = os.path.dirname(os.path.abspath(__file__))

[docs] def get_weight_dir(): """Return the directory where weights should be stored.""" return os.environ.get("UNIMOL_WEIGHT_DIR", DEFAULT_WEIGHT_DIR)
HF_MIRROR = "https://hf-mirror.com"
[docs] def _snapshot_download_with_fallback(**kwargs): """Try downloading with the current HF_ENDPOINT and fall back to the mirror. The mirror is only tried when the user has not explicitly set HF_ENDPOINT and the first attempt fails. """ user_set = "HF_ENDPOINT" in os.environ try: return snapshot_download(**kwargs) except Exception as e: if user_set: raise logger.warning( f"Download failed from Hugging Face: {e}. Retrying with {HF_MIRROR}" ) os.environ["HF_ENDPOINT"] = HF_MIRROR return snapshot_download(**kwargs)
[docs] def log_weights_dir(): """ Logs the directory where the weights are stored. """ weight_dir = get_weight_dir() if 'UNIMOL_WEIGHT_DIR' in os.environ: logger.warning( f'Using custom weight directory from UNIMOL_WEIGHT_DIR: {weight_dir}' ) else: logger.info(f'Weights will be downloaded to default directory: {weight_dir}')
[docs] def weight_download(pretrain, save_path, local_dir_use_symlinks=True): """ Downloads the specified pretrained model weights. :param pretrain: (str), The name of the pretrained model to download. :param save_path: (str), The directory where the weights should be saved. :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True. """ log_weights_dir() if os.path.exists(os.path.join(save_path, pretrain)): logger.info(f'{pretrain} exists in {save_path}') return logger.info(f'Downloading {pretrain}') _snapshot_download_with_fallback( repo_id="dptech/Uni-Mol-Models", local_dir=save_path, allow_patterns=pretrain, # local_dir_use_symlinks=local_dir_use_symlinks, # max_workers=8 )
[docs] def weight_download_v2(pretrain, save_path, local_dir_use_symlinks=True): """ Downloads the specified pretrained model weights. :param pretrain: (str), The name of the pretrained model to download. :param save_path: (str), The directory where the weights should be saved. :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to True. """ log_weights_dir() if os.path.exists(os.path.join(save_path, pretrain)): logger.info(f'{pretrain} exists in {save_path}') return logger.info(f'Downloading {pretrain}') _snapshot_download_with_fallback( repo_id="dptech/Uni-Mol2", local_dir=save_path, allow_patterns=pretrain, # local_dir_use_symlinks=local_dir_use_symlinks, # max_workers=8 )
# Download all the weights when this script is run
[docs] def download_all_weights(local_dir_use_symlinks=False): """ Downloads all available pretrained model weights to the WEIGHT_DIR. :param local_dir_use_symlinks: (bool, optional), Whether to use symlinks for the local directory. Defaults to False. """ log_weights_dir() weight_dir = get_weight_dir() logger.info(f'Downloading all weights to {weight_dir}') _snapshot_download_with_fallback( repo_id="dptech/Uni-Mol-Models", local_dir=weight_dir, allow_patterns='*', # local_dir_use_symlinks=local_dir_use_symlinks, # max_workers=8 )
if '__main__' == __name__: download_all_weights()