{ "cells": [ { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "import io \n", "import requests\n", "import tarfile\n", "from PIL import Image \n", "from torchvision import transforms\n", "class IPFSDataset(Dataset):\n", " \"\"\"IPFS dataset.\"\"\"\n", "\n", " def __init__(self, cid, transform=None, target_transform=None, url=\"http://127.0.0.1:5001/api/v0\"):\n", " \"\"\"\n", " Args: \n", " cid (string): IPFS Directory CID with all the files.\n", " url (string): IPFS base URL\n", " transform (callable, optional): Optional transform to be applied\n", " on a sample.\n", " target_transform (callable, optional): A function/transform that takes\n", " in the target and transforms it.\n", " \"\"\"\n", " response = requests.post(url+\"/get?arg=\"+cid)\n", " contents = response.content\n", " tar = tarfile.open(fileobj=io.BytesIO(contents))\n", " self.isNotImage = lambda n: 'jpg' not in n and 'webp' not in n and 'png' not in n and 'gif' not in n and 'jpeg' not in n and \"bmp\" not in n and \"tif\" not in n and \"ppm\" not in n\n", " self.files = []\n", " self.classes = [name for name in tar.getnames() if self.isNotImage(name) and '/' in name]\n", " self.classes.sort()\n", " self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}\n", " print(self.class_to_idx)\n", " for member in tar.getmembers():\n", " if member.isfile: \n", " extractedFile = tar.extractfile(member)\n", " if extractedFile is not None:\n", " for classkey in self.class_to_idx:\n", " if classkey in member.path:\n", " img = Image.open(extractedFile)\n", " self.files.append((img, self.class_to_idx[classkey]))\n", " tar.close()\n", " self.targets = [s[1] for s in self.files]\n", " self.cid = cid\n", " self.transform = transform\n", " self.target_transform = target_transform\n", "\n", " def __len__(self):\n", " return len(self.files)\n", "\n", " def __getitem__(self, idx): \n", " image, target = self.files[idx]\n", " if self.transform is not None:\n", " image = self.transform(image)\n", " if self.target_transform is not None:\n", " target = self.target_transform(target)\n", " return image, target\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [] } ], "source": [ "\n", "transformed_IPFSDataset = IPFSDataset(\n", " cid='your-root-folders-cid-goes-here', \n", " transform=transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.RandomCrop(224),\n", " transforms.ToTensor()\n", "]))\n", "\n", "dataloader = DataLoader(transformed_IPFSDataset, batch_size=4, shuffle=True, num_workers=0)\n", "test_size = int(.2 * len(dataloader))\n", "train_size = len(dataloader) - test_size\n", "train_dataset, test_dataset = torch.utils.data.random_split(dataloader, [train_size, test_size])\n" ] } ], "metadata": { "interpreter": { "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" }, "kernelspec": { "display_name": "Python 3.8.10 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }