From 08c865580b4eaa3cafea9fb2e010ab0f42d22f97 Mon Sep 17 00:00:00 2001 From: Martin Teichmann <martin.teichmann@gmail.com> Date: Tue, 7 Feb 2023 23:22:21 +0000 Subject: [PATCH] move bit access to ebpf proper this is not actually ethercat specific --- ebpfcat/ebpf.py | 105 ++++++++++++++++++++++++++++++++++--------- ebpfcat/ebpf_test.py | 47 +++++++++++++++++++ ebpfcat/ebpfcat.py | 26 +---------- 3 files changed, 133 insertions(+), 45 deletions(-) diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 7144f58..15dea9d 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -266,10 +266,7 @@ class AssembleError(Exception): pass -def comparison(uposop, unegop, sposop=None, snegop=None): - if sposop is None: - sposop = uposop - snegop = unegop +def comparison(uposop, unegop, sposop, snegop): def ret(self, value): return SimpleComparison(self.ebpf, self, value, (uposop, unegop, sposop, snegop)) @@ -385,7 +382,6 @@ class AndOrComparison(Comparison): self.left = left self.right = right self.is_and = is_and - self.targetted = False def compare(self, negative): self.negative = negative @@ -404,11 +400,15 @@ class AndOrComparison(Comparison): class InvertComparison(Comparison): def __init__(self, ebpf, value): - self.ebpf = ebpf + super().__init__(ebpf) self.value = value def compare(self, negative): self.value.compare(not negative) + self.owners = self.value.owners + + def target(self, retarget=False): + self.value.target(retarget) def binary(opcode): @@ -439,16 +439,22 @@ class Expression: __rmod__ = rbinary(Opcode.MOD) __rxor__ = __xor__ = binary(Opcode.XOR) - __eq__ = comparison(Opcode.JEQ, Opcode.JNE) __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE) __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT) __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE) __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT) - __ne__ = comparison(Opcode.JNE, Opcode.JEQ) def __and__(self, value): return AndExpression(self.ebpf, self, value) + def __ne__(self, value): + return SimpleComparison( + self.ebpf, self, value, + (Opcode.JNE, Opcode.JEQ, Opcode.JNE, Opcode.JEQ)) + + def __eq__(self, value): + return ~(self != value) + __rand__ = __and__ def __neg__(self): @@ -632,8 +638,20 @@ class Sum(Binary): return super().__sub__(value) -class AndExpression(SimpleComparison, Binary): - """The & operator may also be used as a comparison""" +class AndExpression(Binary): + # there is a special comparison with & instruction + def __init__(self, ebpf, left, right): + super().__init__(ebpf, left, right, Opcode.AND) + + def __ne__(self, value): + if isinstance(value, int) and value == 0: + return AndComparison(self.ebpf, self.left, self.right) + return super().__ne__(value) + + +class AndComparison(SimpleComparison): + # there is a special comparison with & instruction + # it is the only one which has not inversion def __init__(self, ebpf, left, right): Binary.__init__(self, ebpf, left, right, Opcode.AND) SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET) @@ -743,13 +761,21 @@ class Memory(Expression): @contextmanager def calculate(self, dst, long, signed, force=False): - if isinstance(self.address, Sum): - with self.ebpf.get_free_register(dst) as dst: - self.ebpf.append(Opcode.LD + self.fmt_to_opcode[self.fmt], dst, - self.address.left.no, self.address.right, 0) - yield dst, self.fmt in "QqA", self.fmt.islower() - else: - with super().calculate(dst, long, signed, force) as (dst, _, _): + with ExitStack() as exitStack: + if isinstance(self.address, Sum): + dst = exitStack.enter_context(self.ebpf.get_free_register(dst)) + self.ebpf.append( + Opcode.LD + self.fmt_to_opcode.get(self.fmt, Opcode.B), + dst, self.address.left.no, self.address.right, 0) + else: + dst, _, _ = exitStack.enter_context( + super().calculate(dst, long, signed, force)) + if isinstance(self.fmt, tuple): + self.ebpf.r[dst] &= ((1 << self.fmt[1]) - 1) << self.fmt[0] + if self.fmt[0] > 0: + self.ebpf.r[dst] >>= self.fmt[0] + yield dst, "B", False + else: yield dst, self.fmt in "QqA", self.fmt.islower() @contextmanager @@ -764,6 +790,18 @@ class Memory(Expression): def signed(self): return isinstance(self.fmt, str) and self.fmt.islower() + def __invert__(self): + if not isinstance(self.fmt, tuple) or self.fmt[1] != 1: + return NotImplemented + return self == 0 + + def __ne__(self, value): + if isinstance(self.fmt, tuple) and isinstance(value, int) \ + and value == 0: + mask = ((1 << self.fmt[1]) - 1) << self.fmt[0] + return Memory(self.ebpf, "B", self.address) & mask != 0 + return super().__ne__(value) + class MemoryDesc: """A base class used by descriptors for memory @@ -782,8 +820,26 @@ class MemoryDesc: def __set__(self, instance, value): ebpf = instance.ebpf fmt, addr = self.fmt_addr(instance) - bits = Memory.fmt_to_opcode[fmt] - if isinstance(value, int): + bits = Memory.fmt_to_opcode.get(fmt, Opcode.B) + if isinstance(fmt, tuple): + before = Memory(ebpf, "B", ebpf.r[self.base_register] + addr) + if fmt[1] == 1: + try: + if value: + value = before | (1 << fmt[0]) + else: + value = before & ~(1 << fmt[0]) + except AssembleError: + with ebpf.wtmp: + with value as cond: + ebpf.wtmp = before | (1 << fmt[0]) + with cond.Else(): + ebpf.wtmp = before & ~(1 << fmt[0]) + else: + mask = ((1 << fmt[1]) - 1) << fmt[0] + value = (mask & (value << self.fmt[0]) | ~mask & before) + opcode = Opcode.STX + elif isinstance(value, int): ebpf.append(Opcode.ST + bits, self.base_register, 0, addr, value) return @@ -798,7 +854,9 @@ class MemoryDesc: opcode = Opcode.XADD else: opcode = Opcode.STX - with value.calculate(None, fmt in 'qQ', fmt.islower()) as (src, _, _): + with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ', + isinstance(fmt, str) and fmt.islower() + ) as (src, _, _): ebpf.append(opcode + bits, self.base_register, src, addr, 0) @@ -810,7 +868,10 @@ class LocalVar(MemoryDesc): self.fmt = fmt def __set_name__(self, owner, name): - size = calcsize(self.fmt) + if isinstance(self.fmt, str): + size = calcsize(self.fmt) + else: # this is to support bit addressing, mostly for testing + size = 1 owner.stack -= size owner.stack &= -size self.relative_addr = owner.stack @@ -1064,6 +1125,8 @@ class EBPF: def jumpIf(self, comp): """jump if `comp` is true to a later defined `target`""" + if isinstance(comp, Expression): + comp = comp != 0 comp.compare(False) return comp diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index c6ca17d..4646af4 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -174,6 +174,53 @@ class Tests(TestCase): Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0), Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)]) + def test_local_bits(self): + class Local(EBPF): + a = LocalVar((5, 1)) + b = LocalVar((3, 4)) + + e = Local(ProgType.XDP, "GPL") + + with e.a: + e.a = 1 + + e.b = e.a + + with ~e.a: + e.b = 3 + + with e.b: + e.a = 0 + + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=32), + Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=32), + Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.AND+O.LONG, dst=0, src=0, off=0, imm=32), + Instruction(opcode=O.RSH+O.LONG, dst=0, src=0, off=0, imm=5), + Instruction(opcode=O.LSH, dst=0, src=0, off=0, imm=3), + Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=120), + Instruction(opcode=O.LD+O.B, dst=2, src=10, off=-2, imm=0), + Instruction(opcode=O.AND, dst=2, src=0, off=0, imm=-121), + Instruction(opcode=O.REG+O.OR, dst=0, src=2, off=0, imm=0), + Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-2, imm=0), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.JSET, dst=0, src=0, off=4, imm=32), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-2, imm=0), + Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-121), + Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=24), + Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-2, imm=0), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-2, imm=0), + Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=120), + Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33), + Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0)]) + def test_local_subprog(self): class Local(EBPF): a = LocalVar('I') diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py index ae42165..f0f0b6c 100644 --- a/ebpfcat/ebpfcat.py +++ b/ebpfcat/ebpfcat.py @@ -93,25 +93,6 @@ class PacketVar(MemoryDesc): def set(self, device, value): if device.sync_group.current_data is None: - if isinstance(self.size, int): - try: - bool(value) - except RuntimeError: - e = device.sync_group - with e.wtmp: - e.wtmp = super().__get__(device, None) - with value as cond: - e.wtmp |= 1 << self.size - with cond.Else(): - e.wtmp &= ~(1 << self.size) - super().__set__(device, e.wtmp) - return - else: - old = super().__get__(device, None) - if value: - value = old | (1 << self.size) - else: - value = old & ~(1 << self.size) super().__set__(device, value) else: data = device.sync_group.current_data @@ -126,10 +107,7 @@ class PacketVar(MemoryDesc): def get(self, device): if device.sync_group.current_data is None: - if isinstance(self.size, int): - return super().__get__(device, None) & (1 << self.size) - else: - return super().__get__(device, None) + return super().__get__(device, None) else: data = device.sync_group.current_data start = self._start(device) @@ -143,7 +121,7 @@ class PacketVar(MemoryDesc): + self.position def fmt_addr(self, device): - return ("B" if isinstance(self.size, int) else self.size, + return ((self.size, 1) if isinstance(self.size, int) else self.size, self._start(device) + Packet.ETHERNET_HEADER) -- GitLab