From 0e57945bdc5dce5bcd13fcc633bf069588661db0 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Sun, 26 Feb 2023 11:28:48 +0000
Subject: [PATCH] add switch_endian method, and use it

---
 ebpfcat/ebpf.py | 93 ++++++++++++++++++++++++++++++++-----------------
 1 file changed, 62 insertions(+), 31 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index e796cf6..879759f 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -546,10 +546,15 @@ class Expression:
     __rand__ = __and__
 
     def __neg__(self):
-        return Negate(self.ebpf, self)
+        return Negate(self)
 
     def __abs__(self):
-        return Absolute(self.ebpf, self)
+        return Absolute(self)
+
+    def switch_endian(self, fmt):
+        if isinstance(fmt, str) and len(fmt) > 1:
+            return SwitchEndian(self, fmt)
+        return self
 
     def __bool__(self):
         raise AssembleError("Expression only has a value at execution time")
@@ -590,7 +595,6 @@ class Expression:
             with self.get_address(dst, long) as (src, fmt):
                 self.ebpf.append(Opcode.LD + fmt_to_opcode(fmt),
                                  dst, src, 0, 0)
-                self.ebpf.append_endian(fmt, dst)
                 yield dst, long
 
     @contextmanager
@@ -656,38 +660,54 @@ class Binary(Expression):
                                           and self.right.contains(no))
 
 
-class Negate(Expression):
-    def __init__(self, ebpf, arg):
-        self.ebpf = ebpf
+class Unary(Expression):
+    def __init__(self, arg):
         self.arg = arg
-        self.signed = True
+        self.ebpf = arg.ebpf
+        self.signed = arg.signed
         self.fixed = arg.fixed
 
     @contextmanager
     def calculate(self, dst, long, force=False):
         with self.arg.calculate(dst, long, force) as (dst, long):
-            self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0)
+            self.calculate_unary(dst, long)
             yield dst, long
 
     def contains(self, no):
         return self.arg.contains(no)
 
 
-class Absolute(Expression):
-    def __init__(self, ebpf, arg):
-        self.ebpf = ebpf
-        self.arg = arg
-        self.fixed = arg.fixed
+class Negate(Unary):
+    def __init__(self, arg):
+        super().__init__(arg)
+        self.signed = True
 
-    @contextmanager
-    def calculate(self, dst, long, force=False):
-        with self.arg.calculate(dst, long, force) as (dst, long):
-            with self.ebpf.sr[dst] < 0:
-                self.ebpf.sr[dst] = -self.ebpf.sr[dst]
-            yield dst, long
+    def calculate_unary(self, dst, long):
+        self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0)
 
-    def contains(self, no):
-        return self.arg.contains(no)
+
+class Absolute(Unary):
+    def __init__(self, arg):
+        super().__init__(arg)
+        self.signed = False
+
+    def calculate_unary(self, dst, long):
+        with self.ebpf.sr[dst] < 0:
+            self.ebpf.sr[dst] = -self.ebpf.sr[dst]
+
+
+class SwitchEndian(Unary):
+    def __init__(self, arg, fmt):
+        super().__init__(arg)
+        self.fmt = fmt
+
+    def calculate_unary(self, dst, long):
+        endian, size = self.fmt
+        if endian == "<":
+            opcode = Opcode.LE
+        elif endian in ">!":
+            opcode = Opcode.BE
+        self.ebpf.append(opcode, dst, 0, 0, calcsize(size) * 8)
 
 
 class Sum(Binary):
@@ -796,6 +816,11 @@ class Constant(Expression):
                 self.ebpf.append(Opcode.W, 0, 0, 0, value >> 32)
             yield dst, not (-0x80000000 <= value < 0x100000000)
 
+    def switch_endian(self, fmt):
+        if not isinstance(fmt, str) or len(fmt) == 1:
+            return self
+        return Constant(self.ebpf, *unpack(fmt, pack(fmt[-1], self.value)))
+
 
 class Register(Expression):
     """represent one EBPF register"""
@@ -881,13 +906,17 @@ class Memory(Expression):
 
     @contextmanager
     def calculate(self, dst, long, force=False):
+        if self.has_endian():
+            with self.without_endian().switch_endian(self.fmt) \
+                 .calculate(dst, long, force) as (dst, long):
+                yield dst, long
+                return
         with ExitStack() as exitStack:
             if isinstance(self.address, Sum):
                 dst = exitStack.enter_context(self.ebpf.get_free_register(dst))
                 opcode = fmt_to_opcode(self.fmt)
                 self.ebpf.append(Opcode.LD + opcode, dst, self.address.left.no,
                                  self.address.right.value, 0)
-                self.ebpf.append_endian(self.fmt, dst)
             else:
                 dst, _ = exitStack.enter_context(
                     super().calculate(dst, long, force))
@@ -897,7 +926,7 @@ class Memory(Expression):
                     self.ebpf.r[dst] >>= self.fmt[0]
                 yield dst, "B"
             else:
-                yield dst, self.fmt in "QqA"
+                yield dst, self.fmt[-1] in "QqAx"
 
     @contextmanager
     def get_address(self, dst, long, force=False):
@@ -915,6 +944,14 @@ class Memory(Expression):
     def fixed(self):
         return isinstance(self.fmt, str) and self.fmt == "x"
 
+    def has_endian(self):
+        return isinstance(self.fmt, str) and len(self.fmt) > 1
+
+    def without_endian(self):
+        if self.has_endian():
+            return Memory(self.ebpf, self.fmt[-1], self.address)
+        return self
+
     def __invert__(self):
         if not isinstance(self.fmt, tuple) or self.fmt[1] != 1:
             return NotImplemented
@@ -953,12 +990,7 @@ class Memory(Expression):
                 value = value.value
                 opcode = Opcode.XADD
             elif not isinstance(value, Expression):
-                if self.fmt == "x":
-                    value = Constant(self.ebpf, value)
-                else:
-                    value = Constant(
-                        self.ebpf,
-                        *unpack(self.fmt, pack(self.fmt[-1], value)))
+                value = Constant(self.ebpf, value)
             if self.fmt == "x" and not value.fixed:
                 value *= Expression.FIXED_BASE
             elif self.fmt != "x" and value.fixed:
@@ -970,6 +1002,7 @@ class Memory(Expression):
                 dst, _ = exitStack.enter_context(
                     self.address.calculate(None, True))
                 offset = 0
+            value = value.switch_endian(self.fmt)
             if value.small_constant and opcode == Opcode.STX:
                 self.ebpf.append(Opcode.ST + fmt_to_opcode(self.fmt), dst, 0,
                                  offset, int(value.value))
@@ -977,8 +1010,6 @@ class Memory(Expression):
             src, _ = exitStack.enter_context(
                 value.calculate(None, isinstance(self.fmt, str)
                                       and self.fmt[-1] in 'qQx'))
-            if not isinstance(value, Constant):
-                self.ebpf.append_endian(self.fmt, src)
             self.ebpf.append(opcode + fmt_to_opcode(self.fmt),
                              dst, src, offset, 0)
 
-- 
GitLab