From 783a3b721f673167a734557bea30bbbe96609dae Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 28 Dec 2020 15:15:58 +0000
Subject: [PATCH] use Opcode enum for better readability

---
 ebpf.py | 212 ++++++++++++++++++++++++++++++++++++++++----------------
 1 file changed, 153 insertions(+), 59 deletions(-)

diff --git a/ebpf.py b/ebpf.py
index cca2f1f..e7ddfd6 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -1,20 +1,107 @@
 from collections import namedtuple
 from struct import pack
+from enum import Enum
 
 from .bpf import prog_load
 
 Instruction = namedtuple("Instruction",
                          ["opcode", "dst", "src", "off", "imm"])
 
+class Opcode(Enum):
+    ADD = 4
+    SUB = 0x14
+    MUL = 0x24
+    DIV = 0x34
+    OR = 0x44
+    AND = 0x54
+    LSH = 0x64
+    RSH = 0x74
+    NEG = 0x84
+    MOD = 0x94
+    XOR = 0xa4
+    MOV = 0xb4
+    ARSH = 0xc4
+
+    JMP = 5
+    JEQ = 0x15
+    JGT = 0x25
+    JGE = 0x35
+    JSET = 0x45
+    JNE = 0x55
+    JSGT = 0x65
+    JSGE = 0x75
+    JLT = 0xa5
+    JLE = 0xb5
+    JSLT = 0xc5
+    JSLE = 0xd5
+
+    CALL = 0x85
+    EXIT = 0x95
+
+    REG = 8
+    LONG = 3
+
+    H = 8
+    W = 0
+    B = 0x10
+    DW = 0x18
+
+    LD = 0x61
+    ST = 0x62
+    STX = 0x63
+
+    def __mul__(self, value):
+        if value:
+            return OpcodeFlags({self})
+        else:
+            return OpcodeFlags(set())
+
+    def __add__(self, value):
+        return OpcodeFlags({self}) + value
+
+    def __repr__(self):
+        return self.name
+
+    def __eq__(self, value):
+        return self is value or self.value == value
+
+    def __hash__(self):
+        return super().__hash__()
+
+class OpcodeFlags:
+    def __init__(self, opcodes):
+        self.opcodes = opcodes
+
+    @property
+    def value(self):
+        return sum(op.value for op in self.opcodes)
+
+    def __add__(self, value):
+        if isinstance(value, Opcode):
+            return OpcodeFlags(self.opcodes | {value})
+        else:
+            return OpcodeFlags(self.opcodes | value.opcodes)
+
+    def __repr__(self):
+        return "|".join(op.name for op in self.opcodes)
+
+    def __eq__(self, value):
+        if isinstance(value, int):
+            return self.value == value
+        else:
+            self.opcodes == value.opcodes
+
+
 class AssembleError(Exception):
     pass
 
 def augassign(opcode):
     def ret(self, value):
         if isinstance(value, int):
-            return Instruction(opcode + 3 * self.long, self.no, 0, 0, value)
+            return Instruction(opcode + Opcode.LONG * self.long, self.no,
+                               0, 0, value)
         elif isinstance(value, Register) and self.long == value.long:
-            return Instruction(opcode + 8 + 3 * self.long,
+            return Instruction(opcode + Opcode.REG + Opcode.LONG * self.long,
                                self.no, value.no, 0, 0)
         else:
             return NotImplemented
@@ -45,13 +132,15 @@ class Comparison:
             self.target()
             return
         assert self.ebpf.opcodes[self.else_origin] is None
-        self.ebpf.opcodes[self.else_origin] = Instruction(5, 0, 0, len(self.ebpf.opcodes) - self.else_origin - 1, 0)
+        self.ebpf.opcodes[self.else_origin] = Instruction(
+                Opcode.JMP, 0, 0,
+                len(self.ebpf.opcodes) - self.else_origin - 1, 0)
         self.ebpf.owners, self.owners = \
                 self.ebpf.owners & self.owners, self.ebpf.owners
 
         if self.invert is not None:
             olen = len(self.ebpf.opcodes)
-            assert self.ebpf.opcodes[self.invert].opcode == 5
+            assert self.ebpf.opcodes[self.invert].opcode == Opcode.JMP
             self.ebpf.opcodes[self.invert:self.invert] = \
                     self.ebpf.opcodes[self.else_origin+1:]
             del self.ebpf.opcodes[olen-1:]
@@ -62,7 +151,7 @@ class Comparison:
 
     def Else(self):
         op, dst, src, off, imm = self.ebpf.opcodes[self.origin]
-        if op == 5:
+        if op == Opcode.JMP:
             self.invert = self.origin
         else:
             self.ebpf.opcodes[self.origin] = \
@@ -77,7 +166,7 @@ class Comparison:
         self.target()
         self.origin = origin
         self.right = self.dst = 0
-        self.opcode = 5
+        self.opcode = Opcode.JMP
 
     def __and__(self, value):
         return AndOrComparison(self.ebpf, self, value, True)
@@ -122,7 +211,7 @@ class SimpleComparison(Comparison):
                 len(self.ebpf.opcodes) - self.origin - 1, self.right)
         else:
             inst = Instruction(
-                self.opcode + 8, self.dst, self.src,
+                self.opcode + Opcode.REG, self.dst, self.src,
                 len(self.ebpf.opcodes) - self.origin - 1, 0)
         self.ebpf.opcodes[self.origin] = inst
         self.ebpf.owners, self.owners = \
@@ -169,22 +258,22 @@ def binary(opcode, symetric=False):
 
 
 class Expression:
-    __radd__ = __add__ = binary(4, True)
-    __sub__ = binary(0x14)
-    __rmul__ = __mul__ = binary(0x24, True)
-    __truediv__ = binary(0x34)
-    __ror__ = __or__ = binary(0x44, True)
-    __lshift__ = binary(0x64)
-    __rshift__ = binary(0x74)
-    __mod__ = binary(0x94)
-    __rxor__ = __xor__ = binary(0xa4, True)
-
-    __eq__ = comparison(0x15, 0x55)
-    __gt__ = comparison(0x25, 0xb5, 0x65, 0xd5)
-    __ge__ = comparison(0x35, 0xa5, 0x75, 0xc5)
-    __lt__ = comparison(0xa5, 0x35, 0xc5, 0x75)
-    __le__ = comparison(0xb5, 0x25, 0xd5, 0x65)
-    __ne__ = comparison(0x55, 0x15)
+    __radd__ = __add__ = binary(Opcode.ADD, True)
+    __sub__ = binary(Opcode.SUB)
+    __rmul__ = __mul__ = binary(Opcode.MUL, True)
+    __truediv__ = binary(Opcode.DIV)
+    __ror__ = __or__ = binary(Opcode.OR, True)
+    __lshift__ = binary(Opcode.LSH)
+    __rshift__ = binary(Opcode.RSH)
+    __mod__ = binary(Opcode.MOD)
+    __rxor__ = __xor__ = binary(Opcode.XOR, True)
+
+    __eq__ = comparison(Opcode.JEQ, Opcode.JNE)
+    __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE)
+    __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT)
+    __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE)
+    __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT)
+    __ne__ = comparison(Opcode.JNE, Opcode.JEQ)
 
     def __and__(self, value):
         return AndExpression(self.ebpf, self, value)
@@ -207,16 +296,18 @@ class Binary(Expression):
         else:
             free = False
         dst, long, signed, _ = self.left.calculate(dst, long, signed, True)
-        if self.operator == 0x74 and signed:  # >>=
-            operator = 0xc4
+        if self.operator is Opcode.RSH and signed:  # >>=
+            operator = Opcode.ARSH
         else:
             operator = self.operator
         if isinstance(self.right, int):
-            self.ebpf.append(operator + (3 if long is None else 3 * long),
+            self.ebpf.append(operator + (Opcode.LONG if long is None
+                                         else Opcode.LONG * long),
                              dst, 0, 0, self.right)
         else:
             src, long, _, rfree = self.right.calculate(None, long, None)
-            self.ebpf.append(operator + 3 * long + 8, dst, src, 0, 0)
+            self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG,
+                             dst, src, 0, 0)
             if rfree:
                 self.ebpf.owners.discard(src)
         return dst, long, signed, free
@@ -224,7 +315,7 @@ class Binary(Expression):
 
 class Sum(Binary):
     def __init__(self, ebpf, left, right):
-        super().__init__(ebpf, left, right, 4)
+        super().__init__(ebpf, left, right, Opcode.ADD)
 
     def __add__(self, value):
         if isinstance(value, int):
@@ -243,9 +334,9 @@ class Sum(Binary):
 
 class AndExpression(Binary, SimpleComparison):
     def __init__(self, ebpf, left, right):
-        Binary.__init__(self, ebpf, left, right, 0x54)
-        SimpleComparison.__init__(self, ebpf, left, right, 0x45)
-        self.opcode = (0x45, None, 0x45, None)
+        Binary.__init__(self, ebpf, left, right, Opcode.AND)
+        SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET)
+        self.opcode = (Opcode.JSET, None, Opcode.JSET, None)
 
     def compare(self, negative):
         super().compare(False)
@@ -261,22 +352,23 @@ class Register(Expression):
         self.long = long
         self.signed = signed
 
-    __iadd__ = augassign(4)
-    __isub__ = augassign(0x14)
-    __imul__ = augassign(0x24)
-    __itruediv__ = augassign(0x34)
-    __ior__ = augassign(0x44)
-    __iand__ = augassign(0x54)
-    __ilshift__ = augassign(0x64)
-    __imod__ = augassign(0x94)
-    __ixor__ = augassign(0xa4)
+    __iadd__ = augassign(Opcode.ADD)
+    __isub__ = augassign(Opcode.SUB)
+    __imul__ = augassign(Opcode.MUL)
+    __itruediv__ = augassign(Opcode.DIV)
+    __ior__ = augassign(Opcode.OR)
+    __iand__ = augassign(Opcode.AND)
+    __ilshift__ = augassign(Opcode.LSH)
+    __imod__ = augassign(Opcode.MOD)
+    __ixor__ = augassign(Opcode.XOR)
 
     def __irshift__(self, value):
+        opcode = Opcode.ARSH if self.signed else Opcode.RSH
         if isinstance(value, int):
-            return Instruction(0x74 + 3 * self.long + 0x50 * self.signed,
-                               self.no, 0, 0, value)
+            return Instruction(opcode + Opcode.LONG * self.long, self.no,
+                               0, 0, value)
         elif isinstance(value, Register) and self.long == value.long:
-            return Instruction(0x7c + 3 * self.long + 0x50 * self.signed,
+            return Instruction(opcode + Opcode.REG + Opcode.LONG * self.long,
                                self.no, value.no, 0, 0)
         else:
             return NotImplemented
@@ -303,7 +395,8 @@ class Register(Expression):
         if self.no not in self.ebpf.owners:
             raise AssembleError("register has no value")
         if dst != self.no and force:
-            self.ebpf.append(0xbc + 3 * self.long, dst, self.no, 0, 0)
+            self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * self.long,
+                             dst, self.no, 0, 0)
             return dst, self.long, self.signed, False
         else:
             return self.no, self.long, self.signed, False
@@ -316,7 +409,7 @@ class Memory(Expression):
         self.address = address
 
     def calculate(self, dst, long, signed, force=False):
-        if not long and self.bits == 0x18:
+        if not long and self.bits == Opcode.DW:
             raise AssembleError("cannot compile")
         if dst is None:
             dst = self.ebpf.get_free_register()
@@ -324,11 +417,11 @@ class Memory(Expression):
         else:
             free = False
         if isinstance(self.address, Sum):
-            self.ebpf.append(0x61 + self.bits, dst, self.address.left.no,
+            self.ebpf.append(Opcode.LD + self.bits, dst, self.address.left.no,
                              self.address.right, 0)
         else:
             src, _, _, rfree = self.address.calculate(None, None, None)
-            self.ebpf.append(0x61 + self.bits, dst, src, 0, 0)
+            self.ebpf.append(Opcode.LD + self.bits, dst, src, 0, 0)
             if rfree:
                 self.ebpf.owners.discard(src)
         return dst, long, signed, free
@@ -348,10 +441,10 @@ class MemoryDesc:
             dst, _, _, afree = addr.calculate(None, None, None)
             offset = 0
         if isinstance(value, int):
-            self.ebpf.append(0x62 + self.bits, dst, 0, offset, value)
+            self.ebpf.append(Opcode.ST + self.bits, dst, 0, offset, value)
         else:
             src, _, _, free = value.calculate(None, None, None)
-            self.ebpf.append(0x63 + self.bits, dst, src, offset, 0)
+            self.ebpf.append(Opcode.STX + self.bits, dst, src, offset, 0)
             if free:
                 self.ebpf.owners.discard(src)
         if afree:
@@ -374,7 +467,7 @@ class PseudoFd(Expression):
             free = True
         else:
             free = False
-        self.ebpf.append(0x18, dst, 1, 0, self.fd)
+        self.ebpf.append(Opcode.DW, dst, 1, 0, self.fd)
         self.ebpf.append(0, 0, 0, 0, 0)
         return dst, long, signed, free
 
@@ -395,9 +488,10 @@ class RegisterDesc:
         instance.owners.add(self.no)
         if isinstance(value, int):
             if -0x80000000 <= value < 0x80000000:
-                instance.append(0xb4 + 3 * self.long, self.no, 0, 0, value)
+                instance.append(Opcode.MOV + Opcode.LONG * self.long,
+                                self.no, 0, 0, value)
             else:
-                instance.append(0x18, self.no, 0, 0, value & 0xffffffff)
+                instance.append(Opcode.DW, self.no, 0, 0, value & 0xffffffff)
                 instance.append(0, 0, 0, 0, value >> 32)
         elif isinstance(value, Expression):
             value.calculate(self.no, self.long, self.signed, True)
@@ -414,10 +508,10 @@ class EBPF:
         self.license = license
         self.kern_version = kern_version
 
-        self.m8 = MemoryDesc(self, 0x10)
-        self.m16 = MemoryDesc(self, 0x8)
-        self.m32 = MemoryDesc(self, 0)
-        self.m64 = MemoryDesc(self, 0x18)
+        self.m8 = MemoryDesc(self, Opcode.B)
+        self.m16 = MemoryDesc(self, Opcode.H)
+        self.m32 = MemoryDesc(self, Opcode.W)
+        self.m64 = MemoryDesc(self, Opcode.DW)
 
         self.owners = {1, 10}
 
@@ -439,7 +533,7 @@ class EBPF:
         return comp
 
     def jump(self):
-        comp = SimpleComparison(self, None, 0, 5)
+        comp = SimpleComparison(self, None, 0, Opcode.JMP)
         comp.origin = len(self.opcodes)
         comp.dst = 0
         comp.owners = self.owners.copy()
@@ -455,12 +549,12 @@ class EBPF:
         return PseudoFd(self, fd)
 
     def call(self, no):
-        self.append(0x85, 0, 0, 0, no)
+        self.append(Opcode.CALL, 0, 0, 0, no)
         self.owners.add(0)
         self.owners -= set(range(1, 6))
 
     def exit(self):
-        self.append(0x95, 0, 0, 0, 0)
+        self.append(Opcode.EXIT, 0, 0, 0, 0)
 
     def get_free_register(self):
         for i in range(10):
-- 
GitLab