load_model_ipfs.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import errno
  2. import os
  3. import sys
  4. import torch
  5. import io
  6. import errno
  7. import hashlib
  8. import os
  9. import shutil
  10. import sys
  11. import tempfile
  12. import torch
  13. import requests
  14. import tarfile
  15. def download_cid_to_file(url, cid, dst, hash_prefix=None):
  16. r"""Download object at the given CID to a local path.
  17. Args:
  18. url (string): URL of the IPFS instance
  19. cid (string): CID of the model to download
  20. dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
  21. hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
  22. Default: None
  23. progress (bool, optional): whether or not to display a progress bar to stderr
  24. Default: True
  25. Example:
  26. >>> torch.hub.download_url_to_file('the-models-ipfs-cid-here', '/tmp/temporary_file')
  27. """
  28. # We deliberately save it in a temp file and move it after
  29. # download is complete. This prevents a local working checkpoint
  30. # being overridden by a broken download.
  31. dst = os.path.expanduser(dst)
  32. dst_dir = os.path.dirname(dst)
  33. f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
  34. response = requests.post(url+"/get?arg="+cid)
  35. contents = response.content
  36. tar = tarfile.open(fileobj=io.BytesIO(contents))
  37. for member in tar.getmembers():
  38. if member.isfile:
  39. extractedFile = tar.extractfile(member)
  40. if extractedFile is not None:
  41. f.write(extractedFile.read())
  42. try:
  43. if hash_prefix is not None:
  44. sha256 = hashlib.sha256()
  45. f.close()
  46. if hash_prefix is not None:
  47. digest = sha256.hexdigest()
  48. if digest[:len(hash_prefix)] != hash_prefix:
  49. raise RuntimeError('invalid hash value (expected "{}", got "{}")'
  50. .format(hash_prefix, digest))
  51. shutil.move(f.name, dst)
  52. finally:
  53. f.close()
  54. if os.path.exists(f.name):
  55. os.remove(f.name)
  56. def load_state_dict_from_ipfs(cid, model_dir=None, url="http://127.0.0.1:5001/api/v0", map_location=None, check_hash=False, file_name=None):
  57. r"""Loads the Torch serialized object at the given IPFS CID.
  58. If downloaded file is a zip file, it will be automatically
  59. decompressed.
  60. If the object is already present in `model_dir`, it's deserialized and
  61. returned.
  62. The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
  63. ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
  64. Args:
  65. cid (string): CID of the model to download
  66. url (string): URL of the IPFS instance
  67. model_dir (string, optional): directory in which to save the object
  68. map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
  69. progress (bool, optional): whether or not to display a progress bar to stderr.
  70. Default: True
  71. check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
  72. ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
  73. digits of the SHA256 hash of the contents of the file. The hash is used to
  74. ensure unique names and to verify the contents of the file.
  75. Default: False
  76. file_name (string, optional): name for the downloaded file. Filename from ``url`` will be used if not set.
  77. Example:
  78. >>> state_dict = torch.hub.load_state_dict_from_ipfs('my-cid-goes-here')
  79. """
  80. if model_dir is None:
  81. hub_dir = torch.hub.get_dir()
  82. model_dir = os.path.join(hub_dir, 'checkpoints')
  83. try:
  84. os.makedirs(model_dir)
  85. except OSError as e:
  86. if e.errno == errno.EEXIST:
  87. # Directory already exists, ignore.
  88. pass
  89. else:
  90. # Unexpected OSError, re-raise.
  91. raise
  92. filename = cid
  93. if file_name is not None:
  94. filename = file_name
  95. cached_file = os.path.join(model_dir, filename)
  96. if not os.path.exists(cached_file):
  97. sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
  98. hash_prefix = None
  99. if check_hash:
  100. r = torch.hub.HASH_REGEX.search(filename)
  101. hash_prefix = r.group(1) if r else None
  102. download_cid_to_file(url, cid, cached_file, hash_prefix)
  103. if torch.hub._is_legacy_zip_format(cached_file):
  104. return torch.hub._legacy_zip_load(cached_file, model_dir, map_location)
  105. return torch.load(cached_file, map_location=map_location)