ipfs_dataset.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from torch.utils.data import Dataset
  2. import io
  3. import requests
  4. import tarfile
  5. from PIL import Image
  6. class IPFSDataset(Dataset):
  7. """IPFS dataset."""
  8. def __init__(self, cid, transform=None, target_transform=None, url="http://127.0.0.1:5001/api/v0"):
  9. """
  10. Args:
  11. cid (string): IPFS Directory CID with all the files.
  12. url (string): IPFS base URL
  13. transform (callable, optional): Optional transform to be applied
  14. on a sample.
  15. target_transform (callable, optional): A function/transform that takes
  16. in the target and transforms it.
  17. """
  18. response = requests.post(url+"/get?arg="+cid)
  19. contents = response.content
  20. tar = tarfile.open(fileobj=io.BytesIO(contents))
  21. self.isNotImage = lambda n: 'jpg' not in n and 'webp' not in n and 'png' not in n and 'gif' not in n and 'jpeg' not in n and "bmp" not in n and "tif" not in n and "ppm" not in n
  22. self.files = []
  23. self.classes = [name for name in tar.getnames(
  24. ) if self.isNotImage(name) and '/' in name]
  25. self.classes.sort()
  26. self.class_to_idx = {self.classes[i]
  27. : i for i in range(len(self.classes))}
  28. print(self.class_to_idx)
  29. for member in tar.getmembers():
  30. if member.isfile:
  31. extractedFile = tar.extractfile(member)
  32. if extractedFile is not None:
  33. member.path
  34. for classkey in self.class_to_idx:
  35. if classkey in member.path:
  36. img = Image.open(extractedFile)
  37. self.files.append(
  38. (img, self.class_to_idx[classkey]))
  39. tar.close()
  40. self.targets = [s[1] for s in self.files]
  41. self.cid = cid
  42. self.transform = transform
  43. self.target_transform = target_transform
  44. def __len__(self):
  45. return len(self.files)
  46. def __getitem__(self, idx):
  47. image, target = self.files[idx]
  48. if self.transform is not None:
  49. image = self.transform(image)
  50. if self.target_transform is not None:
  51. target = self.target_transform(target)
  52. return image, target