From 404a01f350940b1e7978f87eb1f706cbb8b558bf Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Wed, 6 Jan 2021 08:28:24 +0000
Subject: [PATCH] make MapType an Enum

---
 arraymap.py |  8 ++++----
 bpf.py      | 28 ++++++++++++++++++++++++++--
 ebpfcat.py  |  2 +-
 hashmap.py  | 10 +++++-----
 4 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/arraymap.py b/arraymap.py
index a7cc881..c030503 100644
--- a/arraymap.py
+++ b/arraymap.py
@@ -1,7 +1,7 @@
 from struct import pack, unpack
 
 from .ebpf import Map, Memory, Opcode
-from . import bpf
+from .bpf import create_map, lookup_elem, MapType, update_elem
 
 
 class ArrayGlobalVarDesc:
@@ -42,10 +42,10 @@ class ArrayMapAccess:
         self.size = size
 
     def read(self):
-        self.data = bpf.lookup_elem(self.fd, b"\0\0\0\0", self.size)
+        self.data = lookup_elem(self.fd, b"\0\0\0\0", self.size)
 
     def write(self):
-        bpf.update_elem(self.fd, b"\0\0\0\0", self.data, 0)
+        update_elem(self.fd, b"\0\0\0\0", self.data, 0)
 
 
 class ArrayMap(Map):
@@ -64,7 +64,7 @@ class ArrayMap(Map):
         self.name = name
 
     def init(self, ebpf):
-        fd = bpf.create_map(2, 4, self.position, 1)
+        fd = create_map(MapType.ARRAY, 4, self.position, 1)
         setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
         with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack:
             ebpf.append(Opcode.ST, 10, 0, stack, 0)
diff --git a/bpf.py b/bpf.py
index 7ff7474..a545eb9 100644
--- a/bpf.py
+++ b/bpf.py
@@ -7,6 +7,29 @@ from os import strerror
 class BPFError(OSError):
     pass
 
+
+class MapType(Enum):
+    UNSPEC = 0
+    HASH = 1
+    ARRAY = 2
+    PROG_ARRAY = 3
+    PERF_EVENT_ARRAY = 4
+    PERCPU_HASH = 5
+    PERCPU_ARRAY = 6
+    STACK_TRACE = 7
+    CGROUP_ARRAY = 8
+    LRU_HASH = 9
+    LRU_PERCPU_HASH = 10
+    LPM_TRIE = 11
+    ARRAY_OF_MAPS = 12
+    HASH_OF_MAPS = 13
+    DEVMAP = 14
+    SOCKMAP = 15
+    CPUMAP = 16
+    XSKMAP = 17
+    SOCKHASH = 18
+
+
 class ProgType(Enum):
     UNSPEC = 0
     SOCKET_FILTER = 1
@@ -44,7 +67,8 @@ def bpf(cmd, fmt, *args):
     return ret, unpack(fmt, attr.raw)
 
 def create_map(map_type, key_size, value_size, max_entries):
-    return bpf(0, "IIII", map_type, key_size, value_size, max_entries)[0]
+    assert isinstance(map_type, MapType)
+    return bpf(0, "IIII", map_type.value, key_size, value_size, max_entries)[0]
 
 def lookup_elem(fd, key, size):
     value = create_string_buffer(size)
@@ -100,7 +124,7 @@ def prog_test_run(fd, data_in, data_out, ctx_in, ctx_out,
     return ret, retval, duration, data_out.value, ctx_out.value
 
 if __name__ == "__main__":
-    fd = create_map(1, 4, 4, 10)
+    fd = create_map(MapType.HASH, 4, 4, 10)
     update_elem(fd, b"asdf", b"ckde", 0)
     ret = lookup_elem(fd, b"asdf", 4)
     ret[2:4] = b"kk"
diff --git a/ebpfcat.py b/ebpfcat.py
index 516ddf4..fdc00c4 100644
--- a/ebpfcat.py
+++ b/ebpfcat.py
@@ -6,7 +6,7 @@ from .ebpf import EBPF
 from .bpf import ProgType, create_map, update_elem, prog_test_run, lookup_elem
 
 def script():
-    fd = create_map(1, 4, 4, 7)
+    fd = create_map(MapType.HASH, 4, 4, 7)
     update_elem(fd, b"AAAA", b"BBBB", 0)
 
     e = EBPF(ProgType.XDP, "GPL")
diff --git a/hashmap.py b/hashmap.py
index 2631290..aca0a20 100644
--- a/hashmap.py
+++ b/hashmap.py
@@ -2,7 +2,7 @@ from contextlib import contextmanager
 from struct import pack, unpack, unpack
 
 from .ebpf import AssembleError, Expression, Opcode, Map
-from . import bpf
+from .bpf import create_map, lookup_elem, MapType, update_elem
 
 
 class HashGlobalVar(Expression):
@@ -43,7 +43,7 @@ class HashGlobalVarDesc:
             return self
         if instance.loaded:
             fd = instance.__dict__[self.name].fd
-            ret = bpf.lookup_elem(fd, pack("B", self.count), 4)
+            ret = lookup_elem(fd, pack("B", self.count), 4)
             return unpack("i" if self.signed else "I", ret)[0]
         ret = instance.__dict__.get(self.name, None)
         if ret is None:
@@ -57,8 +57,8 @@ class HashGlobalVarDesc:
     def __set__(self, ebpf, value):
         if ebpf.loaded:
             fd = ebpf.__dict__[self.name].fd
-            bpf.update_elem(fd, pack("B", self.count),
-                            pack("i" if self.signed else "I", value), 0)
+            update_elem(fd, pack("B", self.count),
+                        pack("i" if self.signed else "I", value), 0)
             return
         with ebpf.save_registers([3]):
             with value.get_address(3, False, self.signed, True):
@@ -84,7 +84,7 @@ class HashMap(Map):
         return ret
 
     def init(self, ebpf):
-        fd = bpf.create_map(1, 1, 4, self.count)
+        fd = create_map(MapType.HASH, 1, 4, self.count)
         for v in self.vars:
             getattr(ebpf, v.name).fd = fd
 
-- 
GitLab