diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 245da0caf51a3eb40b426e91a02efde3f3410db5..3ae5975dea79435a8fc8221a4f489e53f4ce5df2 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -319,6 +319,9 @@ class Comparison: def __invert__(self): return InvertComparison(self.ebpf, self) + def __bool__(self): + raise RuntimeError("Use with statement for comparisons") + class SimpleComparison(Comparison): def __init__(self, ebpf, left, right, opcode): @@ -430,6 +433,17 @@ class Expression: def __neg__(self): return Negate(self.ebpf, self) + def __bool__(self): + raise RuntimeError("Expression only has a value at execution time") + + def __enter__(self): + ret = self != 0 + self.as_comparison = ret + return ret.__enter__() + + def __exit__(self, exc_type, exc, tb): + return self.as_comparison.__exit__(exc_type, exc, tb) + @contextmanager def calculate(self, dst, long, signed, force=False): with self.ebpf.get_free_register(dst) as dst: @@ -547,7 +561,7 @@ class Sum(Binary): return super().__sub__(value) -class AndExpression(Binary, SimpleComparison): +class AndExpression(SimpleComparison, Binary): def __init__(self, ebpf, left, right): Binary.__init__(self, ebpf, left, right, Opcode.AND) SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET) diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index d1d269d40d5c5bdeb74015e91d890f2b42bdac4e..073e13499a7ed5780c9984d32e7bd699428d2dc7 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -42,7 +42,7 @@ class Tests(TestCase): e.owners = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} e.r5 = 7 e.r6 = e.r3 - self.assertEqual(e.opcodes, + self.assertEqual(e.opcodes, [Instruction(0xb7, 5, 0, 0, 7), Instruction(0xbf, 6, 3, 0, 0)]) @@ -53,7 +53,7 @@ class Tests(TestCase): e.w4 = e.w1 e.w2 += 3 e.w5 += e.w6 - self.assertEqual(e.opcodes, + self.assertEqual(e.opcodes, [Instruction(O.MOV+O.LONG, 3, 0, 0, 7), Instruction(0xbc, 4, 1, 0, 0), Instruction(opcode=4, dst=2, src=0, off=0, imm=3), @@ -85,7 +85,7 @@ class Tests(TestCase): e.sr4 >>= 3 e.sr4 >>= e.r7 - self.assertEqual(e.opcodes, + self.assertEqual(e.opcodes, [Instruction(opcode=7, dst=5, src=0, off=0, imm=7), Instruction(opcode=15, dst=3, src=6, off=0, imm=0), Instruction(opcode=7, dst=4, src=0, off=0, imm=-3), @@ -305,11 +305,15 @@ class Tests(TestCase): e.r2 = 5 with cond.Else(): e.r6 = 7 + with e.r2: + e.r3 = 2 self.assertEqual(e.opcodes, [Instruction(opcode=0xb5, dst=2, src=0, off=2, imm=3), Instruction(opcode=0xb7, dst=2, src=0, off=0, imm=5), Instruction(opcode=0x5, dst=0, src=0, off=1, imm=0), - Instruction(opcode=0xb7, dst=6, src=0, off=0, imm=7)]) + Instruction(opcode=O.MOV+O.LONG, dst=6, src=0, off=0, imm=7), + Instruction(opcode=O.JEQ, dst=2, src=0, off=1, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=2)]) def test_with_inversion(self): e = EBPF()