|
|
@@ -0,0 +1,129 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 13,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import torch\n",
|
|
|
+ "from torch.utils.data import Dataset, DataLoader\n",
|
|
|
+ "import io \n",
|
|
|
+ "import requests\n",
|
|
|
+ "import tarfile\n",
|
|
|
+ "from PIL import Image \n",
|
|
|
+ "from torchvision import transforms\n",
|
|
|
+ "class IPFSDataset(Dataset):\n",
|
|
|
+ " \"\"\"IPFS dataset.\"\"\"\n",
|
|
|
+ "\n",
|
|
|
+ " def __init__(self, cid, transform=None, target_transform=None, url=\"http://127.0.0.1:5001/api/v0\"):\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " Args: \n",
|
|
|
+ " cid (string): IPFS Directory CID with all the files.\n",
|
|
|
+ " url (string): IPFS base URL\n",
|
|
|
+ " transform (callable, optional): Optional transform to be applied\n",
|
|
|
+ " on a sample.\n",
|
|
|
+ " target_transform (callable, optional): A function/transform that takes\n",
|
|
|
+ " in the target and transforms it.\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " response = requests.post(url+\"/get?arg=\"+cid)\n",
|
|
|
+ " contents = response.content\n",
|
|
|
+ " tar = tarfile.open(fileobj=io.BytesIO(contents))\n",
|
|
|
+ " self.isNotImage = lambda n: 'jpg' not in n and 'webp' not in n and 'png' not in n and 'gif' not in n and 'jpeg' not in n and \"bmp\" not in n and \"tif\" not in n and \"ppm\" not in n\n",
|
|
|
+ " self.files = []\n",
|
|
|
+ " self.classes = [name for name in tar.getnames() if self.isNotImage(name) and '/' in name]\n",
|
|
|
+ " self.classes.sort()\n",
|
|
|
+ " self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}\n",
|
|
|
+ " print(self.class_to_idx)\n",
|
|
|
+ " for member in tar.getmembers():\n",
|
|
|
+ " if member.isfile: \n",
|
|
|
+ " extractedFile = tar.extractfile(member)\n",
|
|
|
+ " if extractedFile is not None:\n",
|
|
|
+ " member.path\n",
|
|
|
+ " for classkey in self.class_to_idx:\n",
|
|
|
+ " if classkey in member.path:\n",
|
|
|
+ " img = Image.open(extractedFile)\n",
|
|
|
+ " self.files.append((img, self.class_to_idx[classkey]))\n",
|
|
|
+ " tar.close()\n",
|
|
|
+ " self.targets = [s[1] for s in self.files]\n",
|
|
|
+ " self.cid = cid\n",
|
|
|
+ " self.transform = transform\n",
|
|
|
+ " self.target_transform = target_transform\n",
|
|
|
+ "\n",
|
|
|
+ " def __len__(self):\n",
|
|
|
+ " return len(self.files)\n",
|
|
|
+ "\n",
|
|
|
+ " def __getitem__(self, idx): \n",
|
|
|
+ " image, target = self.files[idx]\n",
|
|
|
+ " if self.transform is not None:\n",
|
|
|
+ " image = self.transform(image)\n",
|
|
|
+ " if self.target_transform is not None:\n",
|
|
|
+ " target = self.target_transform(target)\n",
|
|
|
+ " return image, target\n",
|
|
|
+ "\n",
|
|
|
+ "\n"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 15,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "{'Qmd3TMunETHoZ1ZrBv7hZqS75nLTLkjqCWWDF2A6L7RJuq/TestImages': 0, 'Qmd3TMunETHoZ1ZrBv7hZqS75nLTLkjqCWWDF2A6L7RJuq/TestImages2': 1}\n",
|
|
|
+ "<torch.utils.data.dataset.Subset object at 0x7f0418b81850>\n",
|
|
|
+ "<torch.utils.data.dataset.Subset object at 0x7f0418b81880>\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "\n",
|
|
|
+ "transformed_IPFSDataset = IPFSDataset(\n",
|
|
|
+ " cid='Qmd3TMunETHoZ1ZrBv7hZqS75nLTLkjqCWWDF2A6L7RJuq', \n",
|
|
|
+ " transform=transforms.Compose([\n",
|
|
|
+ " transforms.Resize(256),\n",
|
|
|
+ " transforms.RandomCrop(224),\n",
|
|
|
+ " transforms.ToTensor()\n",
|
|
|
+ "]))\n",
|
|
|
+ "\n",
|
|
|
+ "dataloader = DataLoader(transformed_IPFSDataset, batch_size=4, shuffle=True, num_workers=0)\n",
|
|
|
+ "test_size = int(.2 * len(dataloader))\n",
|
|
|
+ "train_size = len(dataloader) - test_size\n",
|
|
|
+ "train_dataset, test_dataset = torch.utils.data.random_split(dataloader, [train_size, test_size])\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "interpreter": {
|
|
|
+ "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
|
|
|
+ },
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3.8.10 64-bit",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.8.10"
|
|
|
+ },
|
|
|
+ "orig_nbformat": 4
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 2
|
|
|
+}
|