torch.hub

news/2024/11/16 22:35:20/

参考  torch.hub - 云+社区 - 腾讯云

Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.

Publishing models

Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simple hubconf.py file;

hubconf.py can have multiple entrypoints. Each entrypoint is defined as a python function (example: a pre-trained model you want to publish).

def entrypoint_name(*args, **kwargs):# args & kwargs are optional, for models which take positional/keyword arguments....

How to implement an entrypoint?

Here is a code snippet specifies an entrypoint for resnet18 model if we expand the implementation in pytorch/vision/hubconf.py. In most case importing the right function in hubconf.py is sufficient. Here we just want to use the expanded version as an example to show how it works. You can see the full script in pytorch/vision repo

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):""" # This docstring shows up in hub.help()Resnet18 modelpretrained (bool): kwargs, load pretrained weights into the model"""# Call the model, load pretrained weightsmodel = _resnet18(pretrained=pretrained, **kwargs)return model

  • dependencies variable is a list of package names required to load the model. Note this might be slightly different from dependencies required for training a model.
  • args and kwargs are passed along to the real callable function.
  • Docstring of the function works as a help message. It explains what does the model do and what are the allowed positional/keyword arguments. It’s highly recommended to add a few examples here.
  • Entrypoint function can either return a model(nn.module), or auxiliary tools to make the user workflow smoother, e.g. tokenizers.
  • Callables prefixed with underscore are considered as helper functions which won’t show up in torch.hub.list().
  • Pretrained weights can either be stored locally in the github repo, or loadable by torch.hub.load_state_dict_from_url(). If less than 2GB, it’s recommended to attach it to a project release and use the url from the release. In the example above torchvision.models.resnet.resnet18 handles pretrained, alternatively you can put the following logic in the entrypoint definition.
if pretrained:# For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pthdirname = os.path.dirname(__file__)checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)state_dict = torch.load(checkpoint)model.load_state_dict(state_dict)# For checkpoint saved elsewherecheckpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

Important Notice

  • The published models should be at least in a branch/tag. It can’t be a random commit.

Loading models from Hub

Pytorch Hub provides convenient APIs to explore all available models in hub through torch.hub.list(), show docstring and examples through torch.hub.help() and load the pre-trained models using torch.hub.load()

torch.hub.list(github, force_reload=False)[source]

List all entrypoints available in github hubconf.

Parameters:

  • github (string) – a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
  • force_reload (bool, optional) – whether to discard the existing cache and force a fresh download. Default is False.

Returns:

  • a list of available entrypoint names

Return type:

  • entrypoints

Example:

>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)

torch.hub.help(github, model, force_reload=False)[source]

Show the docstring of entrypoint model.

Parameters:

  • github (string) – a string with format <repo_owner/repo_name[:tag_name]> with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
  • model (string) – a string of entrypoint name defined in repo’s hubconf.py
  • force_reload (bool, optional) – whether to discard the existing cache and force a fresh download. Default is False.

Example:

>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))

torch.hub.load(github, model, *args, **kwargs)[source]

Load a model from a github repo, with pretrained weights.

Parameters:

  • github (string) – a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is master if not specified. Example: ‘pytorch/vision[:hub]’
  • model (string) – a string of entrypoint name defined in repo’s hubconf.py
  • *args (optional) – the corresponding args for callable model.
  • force_reload (bool, optional) – whether to force a fresh download of github repo unconditionally. Default is False.
  • verbose (bool, optional) – If False, mute messages about hitting local caches. Note that the message about first download is cannot be muted. Default is True.
  • **kwargs (optional) – the corresponding kwargs for callable model.

Returns:

  • a single model with corresponding pretrained weights.

Example:

>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)

torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]

Download object at the given URL to a local path.

Parameters:

  • url (string) – URL of the object to download
  • dst (string) – Full path where object will be saved, e.g. /tmp/temporary_file
  • hash_prefix (string, optional) – If not None, the SHA256 downloaded file should start with hash_prefix. Default: None
  • progress (bool, optional) – whether or not to display a progress bar to stderr Default: True

Example:

>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')

torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)[source]

Loads the Torch serialized object at the given URL.

If downloaded file is a zip file, it will be automatically decompressed.

If the object is already present in model_dir, it’s deserialized and returned. The default value of model_dir is $TORCH_HOME/checkpoints where environment variable $TORCH_HOME defaults to $XDG_CACHE_HOME/torch. $XDG_CACHE_HOME follows the X Design Group specification of the Linux filesytem layout, with a default value ~/.cache if not set.

Parameters:

  • url (string) – URL of the object to download
  • model_dir (string, optional) – directory in which to save the object
  • map_location (optional) – a function or a dict specifying how to remap storage locations (see torch.load)
  • progress (bool, optional) – whether or not to display a progress bar to stderr. Default: True
  • check_hash (bool, optional) – If True, the filename part of the URL should follow the naming convention filename-<sha256>.ext where <sha256> is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False

Example:

>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

Running a loaded model:

Note that *args, **kwargs in torch.load() are used to instantiate a model. After you loaded a model, how can you find out what you can do with the model? A suggested workflow is

  • dir(model) to see all available methods of the model.
  • help(model.foo) to check what arguments model.foo takes to run.

To help users explore without referring to documentation back and forth, we strongly recommend repo owners make function help messages clear and succinct. It’s also helpful to include a minimal working example.

Where are my downloaded models saved?

The locations are used in the order of

  • Calling hub.set_dir(<PATH_TO_HUB_DIR>)
  • $TORCH_HOME/hub, if environment variable TORCH_HOME is set.
  • $XDG_CACHE_HOME/torch/hub, if environment variable XDG_CACHE_HOME is set.
  • ~/.cache/torch/hub

torch.hub.set_dir(d)[source]

Optionally set hub_dir to a local dir to save downloaded models & weights.

If set_dir is not called, default path is $TORCH_HOME/hub where environment variable $TORCH_HOME defaults to $XDG_CACHE_HOME/torch. $XDG_CACHE_HOME follows the X Design Group specification of the Linux filesytem layout, with a default value ~/.cache if the environment variable is not set.

Parameters:

  • d (string) – path to a local folder to save downloaded models & weights.

Caching logic

By default, we don’t clean up files after loading it. Hub uses the cache by default if it already exists in hub_dir.

Users can force a reload by calling hub.load(..., force_reload=True). This will delete the existing github folder and downloaded weights, reinitialize a fresh download. This is useful when updates are published to the same branch, users can keep up with the latest release.

Known limitations:

Torch hub works by importing the package as if it was installed. There’re some side effects introduced by importing in Python. For example, you can see new items in Python caches sys.modules and sys.path_importer_cache which is normal Python behavior.

A known limitation that worth mentioning here is user CANNOT load two different branches of the same repo in the same python process. It’s just like installing two packages with the same name in Python, which is not good. Cache might join the party and give you surprises if you actually try that. Of course it’s totally fine to load them in separate processes.


http://www.ppmy.cn/news/364522.html

相关文章

pyinstaller使用多帧ICO图标

pyinstaller使用多帧ICO图标 pyinstaller使用教程&#xff1a;https://pyinstaller.readthedocs.io/en/stable/ pyinstaller -i test.ico test.py如果ico图标不是多帧ico的话&#xff0c;桌面图标会出现锯齿&#xff0c;并且放大缩小桌面图标的话&#xff0c;软件图标不会自动…

Tacotron2

Natural TTS Synthesis by Conditioning Wavenet On Mel Spectrogram predictions 本篇论文是Tacotron2官方发布的论文&#xff0c;主要讲述了该组在TTS方向的新的进展。Tacotron2在目前TTS领域有着十分重要的地位。 该系统由两个部分组成&#xff0c;一个部分是循环序列到序列…

【yocto】

yocto yocto目录下载yocto的poky项目介绍使用Bitbake编译编译步骤开始编译安装python 2020-8-1 yocto目录 imx-yocto|__downloads|__build| |__cache| |__conf| | |__bblayers.conf| | |__local.conf | |__tmp| |__work| |__deploy| |__…

steam搬砖全套操作流程之如何卖货(第③课)

上一篇文章阿阳分享了Steam项目如何选品&#xff0c;今天就给大家说说装备如何发货等重要事项。 本节课主要为大家讲解&#xff1a;定价规则&#xff0c;加价原则&#xff0c;认识装备磨损度和印花&#xff0c;自动发货软件和自动上架软件的讲解。 &#xff08;Steam装备选品…

Toy Train

D1. Toy Train (Simplified) https://codeforces.com/contest/1130/problem/D1 D2. Toy Train https://codeforces.com/contest/1130/problem/D2 题解&#xff1a; /* *Author: STZG *Language: C */ #include <bits/stdc.h> #include<iostream> #include&l…

Torch

torch的安装 torch的版本与python的版本是挂钩的&#xff0c;python低版本安装不了高版本的torch&#xff08;亲测&#xff09;。 python 3.6版本出问题 在下面的网站中找到 https://pytorch.org/ python版本将对应torch的版本 https://www.cnblogs.com/tingtin/p/13601104.h…

Power Toys!!!

Power Toys!!! PowerToys是微软最初发布于Windows 95平台的系统增强工具&#xff0c;直至2002年比尔盖茨调整研发重心&#xff0c;提升系统安全性&#xff0c;PowerToys 这一项目在当时因为Bug 多&#xff0c;功能不稳定&#xff0c;测试不严格&#xff0c;在安全审查中未能幸免…

toybox

toybox 作者&#xff1a;Rob Landley 特性&#xff1a;简单、小巧、快速且功能齐全 http://www.musl-libc.org/ http://landley.net/toybox/downloads/toybox-0.8.0.tar.gz toybox上前Android 的命令行。 Toybox 项目开始于2006 年, Toybox 每季度发布一次 Toybox 获得BSD许可&…