From 69d7feb651c4c3443665fb271207ec6364bc7250 Mon Sep 17 00:00:00 2001 From: Martin Teichmann <martin.teichmann@gmail.com> Date: Sat, 25 Feb 2023 15:53:45 +0000 Subject: [PATCH] fix setting bits from data --- ebpfcat/ebpf.py | 62 +++++++++++++++++++++++--------------------- ebpfcat/ebpf_test.py | 10 +++++++ 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index b74a64e..106f0be 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -968,38 +968,40 @@ class Memory(Expression): def _set(self, value): opcode = Opcode.STX - if isinstance(self.fmt, tuple): - pos, bits = self.fmt - self.fmt = "B" - if bits == 1: - try: - if value: - value = self | (1 << pos) - else: - value = self & ~(1 << pos) - except AssembleError: - with ebpf.wtmp: + with ExitStack() as exitStack: + if isinstance(self.fmt, tuple): + pos, bits = self.fmt + self.fmt = "B" + if bits == 1: + try: + if value: + value = self | (1 << pos) + else: + value = self & ~(1 << pos) + except AssembleError: + exitStack.enter_context(self.ebpf.wtmp) with value as Else: - ebpf.wtmp = self | (1 << pos) + self.ebpf.wtmp = self | (1 << pos) with Else: - ebpf.wtmp = self & ~(1 << pos) - else: - mask = ((1 << bits) - 1) << pos - value = (mask & (value << pos) | ~mask & self) - elif isinstance(value, IAdd) and len(self.fmt) == 1: - value = value.value - opcode = Opcode.XADD - elif not isinstance(value, Expression): - if self.fmt == "x": - value = Constant(self.ebpf, value) - else: - value = Constant(self.ebpf, - *unpack(self.fmt, pack(self.fmt[-1], value))) - if self.fmt == "x" and not value.fixed: - value *= Expression.FIXED_BASE - elif self.fmt != "x" and value.fixed: - value /= Expression.FIXED_BASE - with ExitStack() as exitStack: + self.ebpf.wtmp = self & ~(1 << pos) + value = self.ebpf.wtmp + else: + mask = ((1 << bits) - 1) << pos + value = (mask & (value << pos) | ~mask & self) + elif isinstance(value, IAdd) and len(self.fmt) == 1: + value = value.value + opcode = Opcode.XADD + elif not isinstance(value, Expression): + if self.fmt == "x": + value = Constant(self.ebpf, value) + else: + value = Constant( + self.ebpf, + *unpack(self.fmt, pack(self.fmt[-1], value))) + if self.fmt == "x" and not value.fixed: + value *= Expression.FIXED_BASE + elif self.fmt != "x" and value.fixed: + value /= Expression.FIXED_BASE if isinstance(self.address, Sum): dst = self.address.left.no offset = self.address.right.value diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index aa81111..01b61ac 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -316,6 +316,8 @@ class Tests(TestCase): with e.b: e.a = 0 + e.a = e.b + self.assertEqual(e.opcodes, [ Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=32), @@ -343,6 +345,14 @@ class Tests(TestCase): Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0), Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33), + Instruction(opcode=O.STX+O.B, dst=10, src=0, off=-1, imm=0), + Instruction(opcode=O.LD+O.B, dst=2, src=10, off=-2, imm=0), + Instruction(opcode=O.JSET, dst=2, src=0, off=3, imm=120), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33), + Instruction(opcode=O.JMP, dst=0, src=0, off=2, imm=0), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=32), Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0)]) def test_local_subprog(self): -- GitLab