Jake Kalstad před 4 roky
revize
600b18cea0
3 změnil soubory, kde provedl 191 přidání a 0 odebrání
  1. 129 0
      ipfs.ipynb
  2. 57 0
      ipfs_dataset.py
  3. 5 0
      readme.md

+ 129 - 0
ipfs.ipynb

@@ -0,0 +1,129 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "from torch.utils.data import Dataset, DataLoader\n",
+    "import io \n",
+    "import requests\n",
+    "import tarfile\n",
+    "from PIL import Image  \n",
+    "from torchvision import transforms\n",
+    "class IPFSDataset(Dataset):\n",
+    "    \"\"\"IPFS dataset.\"\"\"\n",
+    "\n",
+    "    def __init__(self, cid, transform=None, target_transform=None, url=\"http://127.0.0.1:5001/api/v0\"):\n",
+    "        \"\"\"\n",
+    "        Args: \n",
+    "            cid (string): IPFS Directory CID with all the files.\n",
+    "            url (string): IPFS base URL\n",
+    "            transform (callable, optional): Optional transform to be applied\n",
+    "                on a sample.\n",
+    "            target_transform (callable, optional): A function/transform that takes\n",
+    "                in the target and transforms it.\n",
+    "        \"\"\"\n",
+    "        response = requests.post(url+\"/get?arg=\"+cid)\n",
+    "        contents = response.content\n",
+    "        tar = tarfile.open(fileobj=io.BytesIO(contents))\n",
+    "        self.isNotImage = lambda n: 'jpg' not in n and 'webp' not in n and 'png' not in n and 'gif' not in n and 'jpeg' not in n and \"bmp\" not in n and \"tif\" not in n and \"ppm\" not in n\n",
+    "        self.files = []\n",
+    "        self.classes = [name for name in tar.getnames() if self.isNotImage(name) and '/' in name]\n",
+    "        self.classes.sort()\n",
+    "        self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}\n",
+    "        print(self.class_to_idx)\n",
+    "        for member in tar.getmembers():\n",
+    "            if member.isfile: \n",
+    "                extractedFile = tar.extractfile(member)\n",
+    "                if extractedFile is not None:\n",
+    "                    member.path\n",
+    "                    for classkey in self.class_to_idx:\n",
+    "                        if classkey in member.path:\n",
+    "                            img = Image.open(extractedFile)\n",
+    "                            self.files.append((img, self.class_to_idx[classkey]))\n",
+    "        tar.close()\n",
+    "        self.targets = [s[1] for s in self.files]\n",
+    "        self.cid = cid\n",
+    "        self.transform = transform\n",
+    "        self.target_transform = target_transform\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return len(self.files)\n",
+    "\n",
+    "    def __getitem__(self, idx): \n",
+    "        image, target = self.files[idx]\n",
+    "        if self.transform is not None:\n",
+    "            image = self.transform(image)\n",
+    "        if self.target_transform is not None:\n",
+    "            target = self.target_transform(target)\n",
+    "        return image, target\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'Qmd3TMunETHoZ1ZrBv7hZqS75nLTLkjqCWWDF2A6L7RJuq/TestImages': 0, 'Qmd3TMunETHoZ1ZrBv7hZqS75nLTLkjqCWWDF2A6L7RJuq/TestImages2': 1}\n",
+      "<torch.utils.data.dataset.Subset object at 0x7f0418b81850>\n",
+      "<torch.utils.data.dataset.Subset object at 0x7f0418b81880>\n"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "transformed_IPFSDataset = IPFSDataset(\n",
+    "    cid='Qmd3TMunETHoZ1ZrBv7hZqS75nLTLkjqCWWDF2A6L7RJuq', \n",
+    "    transform=transforms.Compose([\n",
+    "        transforms.Resize(256),\n",
+    "        transforms.RandomCrop(224),\n",
+    "        transforms.ToTensor()\n",
+    "]))\n",
+    "\n",
+    "dataloader = DataLoader(transformed_IPFSDataset, batch_size=4, shuffle=True, num_workers=0)\n",
+    "test_size = int(.2 * len(dataloader))\n",
+    "train_size = len(dataloader) - test_size\n",
+    "train_dataset, test_dataset = torch.utils.data.random_split(dataloader, [train_size, test_size])\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.8.10 64-bit",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.10"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 57 - 0
ipfs_dataset.py

@@ -0,0 +1,57 @@
+from torch.utils.data import Dataset
+import io
+import requests
+import tarfile
+from PIL import Image
+
+
+class IPFSDataset(Dataset):
+    """IPFS dataset."""
+
+    def __init__(self, cid, transform=None, target_transform=None, url="http://127.0.0.1:5001/api/v0"):
+        """
+        Args: 
+            cid (string): IPFS Directory CID with all the files.
+            url (string): IPFS base URL
+            transform (callable, optional): Optional transform to be applied
+                on a sample.
+            target_transform (callable, optional): A function/transform that takes
+                in the target and transforms it.
+        """
+        response = requests.post(url+"/get?arg="+cid)
+        contents = response.content
+        tar = tarfile.open(fileobj=io.BytesIO(contents))
+        self.isNotImage = lambda n: 'jpg' not in n and 'webp' not in n and 'png' not in n and 'gif' not in n and 'jpeg' not in n and "bmp" not in n and "tif" not in n and "ppm" not in n
+        self.files = []
+        self.classes = [name for name in tar.getnames(
+        ) if self.isNotImage(name) and '/' in name]
+        self.classes.sort()
+        self.class_to_idx = {self.classes[i]
+            : i for i in range(len(self.classes))}
+        print(self.class_to_idx)
+        for member in tar.getmembers():
+            if member.isfile:
+                extractedFile = tar.extractfile(member)
+                if extractedFile is not None:
+                    member.path
+                    for classkey in self.class_to_idx:
+                        if classkey in member.path:
+                            img = Image.open(extractedFile)
+                            self.files.append(
+                                (img, self.class_to_idx[classkey]))
+        tar.close()
+        self.targets = [s[1] for s in self.files]
+        self.cid = cid
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, idx):
+        image, target = self.files[idx]
+        if self.transform is not None:
+            image = self.transform(image)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+        return image, target

+ 5 - 0
readme.md

@@ -0,0 +1,5 @@
+# PyTorch IPFS Dataset
+
+`IPFSDataset(Dataset)`
+
+See the jupyter notepad to see how it works and how it interacts with a standard pytorch DataLoader