ipfs_dataset.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from torch.utils.data import Dataset
  2. import io
  3. import requests
  4. import tarfile
  5. from PIL import Image
  6. class IPFSDataset(Dataset):
  7. """IPFS dataset."""
  8. def __init__(self, cid, transform=None, target_transform=None, url="http://127.0.0.1:5001/api/v0"):
  9. """
  10. Args:
  11. cid (string): IPFS Directory CID with all the sub directories (categories) & files.
  12. url (string): IPFS base URL
  13. transform (callable, optional): Optional transform to be applied
  14. on a sample.
  15. target_transform (callable, optional): A function/transform that takes
  16. in the target and transforms it.
  17. """
  18. response = requests.post(url+"/get?arg="+cid)
  19. contents = response.content
  20. tar = tarfile.open(fileobj=io.BytesIO(contents))
  21. self.isNotImage = lambda n: 'jpg' not in n and 'webp' not in n and 'png' not in n and 'gif' not in n and 'jpeg' not in n and "bmp" not in n and "tif" not in n and "ppm" not in n
  22. self.files = []
  23. self.classes = [name for name in tar.getnames(
  24. ) if self.isNotImage(name) and '/' in name]
  25. self.classes.sort()
  26. self.class_to_idx = {self.classes[i]
  27. : i for i in range(len(self.classes))}
  28. for member in tar.getmembers():
  29. if member.isfile:
  30. extractedFile = tar.extractfile(member)
  31. if extractedFile is not None:
  32. for classkey in self.class_to_idx:
  33. if classkey in member.path:
  34. img = Image.open(extractedFile)
  35. self.files.append(
  36. (img, self.class_to_idx[classkey]))
  37. tar.close()
  38. self.targets = [s[1] for s in self.files]
  39. self.cid = cid
  40. self.transform = transform
  41. self.target_transform = target_transform
  42. def __len__(self):
  43. return len(self.files)
  44. def __getitem__(self, idx):
  45. image, target = self.files[idx]
  46. if self.transform is not None:
  47. image = self.transform(image)
  48. if self.target_transform is not None:
  49. target = self.target_transform(target)
  50. return image, target