Skip to content
Snippets Groups Projects
Commit 183a37fb authored by Martin Teichmann's avatar Martin Teichmann
Browse files

amidst making & and | work for If

super complicated, pausing.
parent b2b9ee88
No related branches found
No related tags found
No related merge requests found
...@@ -32,35 +32,33 @@ def comparison(uposop, unegop, sposop=None, snegop=None): ...@@ -32,35 +32,33 @@ def comparison(uposop, unegop, sposop=None, snegop=None):
class Comparison: class Comparison:
def target(self): def __init__(self, ebpf):
assert self.ebpf.opcodes[self.origin] is None self.ebpf = ebpf
if isinstance(self.right, int): self.invert = None
inst = Instruction( self.else_origin = None
self.opcode, self.dst, 0,
len(self.ebpf.opcodes) - self.origin - 1, self.right)
else:
inst = Instruction(
self.opcode + 8, self.dst, self.src,
len(self.ebpf.opcodes) - self.origin - 1, 0)
self.ebpf.opcodes[self.origin] = inst
self.ebpf.owners, self.owners = \
self.ebpf.owners & self.owners, self.ebpf.owners
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
self.target() if self.else_origin is None:
self.target()
return
assert self.ebpf.opcodes[self.else_origin] is None
self.ebpf.opcodes[self.else_origin] = Instruction(5, 0, 0, len(self.ebpf.opcodes) - self.else_origin - 1, 0)
self.ebpf.owners, self.owners = \
self.ebpf.owners & self.owners, self.ebpf.owners
if self.invert is not None: if self.invert is not None:
olen = len(self.ebpf.opcodes) olen = len(self.ebpf.opcodes)
assert self.ebpf.opcodes[self.invert].opcode == 5 assert self.ebpf.opcodes[self.invert].opcode == 5
self.ebpf.opcodes[self.invert:self.invert] = \ self.ebpf.opcodes[self.invert:self.invert] = \
self.ebpf.opcodes[self.origin+1:] self.ebpf.opcodes[self.else_origin+1:]
del self.ebpf.opcodes[olen-1:] del self.ebpf.opcodes[olen-1:]
op, dst, src, off, imm = self.ebpf.opcodes[self.invert - 1] op, dst, src, off, imm = self.ebpf.opcodes[self.invert - 1]
self.ebpf.opcodes[self.invert - 1] = \ self.ebpf.opcodes[self.invert - 1] = \
Instruction(op, dst, src, Instruction(op, dst, src,
len(self.ebpf.opcodes) - self.origin + 1, imm) len(self.ebpf.opcodes) - self.else_origin + 1, imm)
def Else(self): def Else(self):
op, dst, src, off, imm = self.ebpf.opcodes[self.origin] op, dst, src, off, imm = self.ebpf.opcodes[self.origin]
...@@ -69,20 +67,34 @@ class Comparison: ...@@ -69,20 +67,34 @@ class Comparison:
else: else:
self.ebpf.opcodes[self.origin] = \ self.ebpf.opcodes[self.origin] = \
Instruction(op, dst, src, off+1, imm) Instruction(op, dst, src, off+1, imm)
self.origin = len(self.ebpf.opcodes) self.else_origin = len(self.ebpf.opcodes)
self.ebpf.opcodes.append(None)
return self
def invert_result(self):
origin = len(self.ebpf.opcodes)
self.ebpf.opcodes.append(None) self.ebpf.opcodes.append(None)
self.target()
self.origin = origin
self.right = self.dst = 0 self.right = self.dst = 0
self.opcode = 5 self.opcode = 5
return self
def __and__(self, value):
return AndOrComparison(self.ebpf, self, value, True)
def __or__(self, value):
return AndOrComparison(self.ebpf, self, value, False)
def __invert__(self):
return InvertComparison(self.ebpf, self)
class SimpleComparison(Comparison): class SimpleComparison(Comparison):
def __init__(self, ebpf, left, right, opcode): def __init__(self, ebpf, left, right, opcode):
self.ebpf = ebpf super().__init__(ebpf)
self.left = left self.left = left
self.right = right self.right = right
self.opcode = opcode self.opcode = opcode
self.invert = None
def compare(self, negative): def compare(self, negative):
self.dst, _, lsigned, lfree = self.left.calculate(None, None, None) self.dst, _, lsigned, lfree = self.left.calculate(None, None, None)
...@@ -102,14 +114,52 @@ class SimpleComparison(Comparison): ...@@ -102,14 +114,52 @@ class SimpleComparison(Comparison):
self.ebpf.owners.discard(self.src) self.ebpf.owners.discard(self.src)
self.owners = self.ebpf.owners.copy() self.owners = self.ebpf.owners.copy()
def target(self):
assert self.ebpf.opcodes[self.origin] is None
if isinstance(self.right, int):
inst = Instruction(
self.opcode, self.dst, 0,
len(self.ebpf.opcodes) - self.origin - 1, self.right)
else:
inst = Instruction(
self.opcode + 8, self.dst, self.src,
len(self.ebpf.opcodes) - self.origin - 1, 0)
self.ebpf.opcodes[self.origin] = inst
self.ebpf.owners, self.owners = \
self.ebpf.owners & self.owners, self.ebpf.owners
class AndOrComparison(Comparison): class AndOrComparison(Comparison):
def __init__(self, ebpf, left, right, is_and):
super().__init__(ebpf)
self.left = left
self.right = right
self.is_and = is_and
self.targetted = False
def compare(self, negative): def compare(self, negative):
self.left.compare(negative) self.left.compare(self.is_and != negative)
self.right.compare(negative) self.right.compare(self.is_and != negative)
if self.is_and != negative:
self.invert_result()
self.owners = self.ebpf.owners.copy()
def target(self): def target(self):
self.left.target() if self.targetted:
self.right.target() super().target()
else:
self.left.target()
self.right.target()
self.targetted = True
class InvertComparison(Comparison):
def __init__(self, ebpf, value):
self.ebpf = ebpf
self.value = value
def compare(self, negative):
self.value.compare(not negative)
def binary(opcode, symetric=False): def binary(opcode, symetric=False):
...@@ -192,9 +242,6 @@ class Sum(Binary): ...@@ -192,9 +242,6 @@ class Sum(Binary):
class AndExpression(Binary, SimpleComparison): class AndExpression(Binary, SimpleComparison):
__and__ = __rand__ = comparison(0x45, None)
__rand__ = __and__ = binary(0x54, True)
def __init__(self, ebpf, left, right): def __init__(self, ebpf, left, right):
Binary.__init__(self, ebpf, left, right, 0x54) Binary.__init__(self, ebpf, left, right, 0x54)
SimpleComparison.__init__(self, ebpf, left, right, 0x45) SimpleComparison.__init__(self, ebpf, left, right, 0x45)
...@@ -203,12 +250,7 @@ class AndExpression(Binary, SimpleComparison): ...@@ -203,12 +250,7 @@ class AndExpression(Binary, SimpleComparison):
def compare(self, negative): def compare(self, negative):
super().compare(False) super().compare(False)
if negative: if negative:
origin = len(self.ebpf.opcodes) self.invert_result()
self.ebpf.opcodes.append(None)
self.target()
self.origin = origin
self.right = self.dst = 0
self.opcode = 5
class Register(Expression): class Register(Expression):
offset = 0 offset = 0
......
...@@ -394,11 +394,10 @@ class Tests(TestCase): ...@@ -394,11 +394,10 @@ class Tests(TestCase):
class KernelTests(TestCase): class KernelTests(TestCase):
def test_minimal(self): def test_minimal(self):
e = EBPF(ProgType.XDP, "GPL") e = EBPF(ProgType.XDP, "GPL")
with e.If(e.r1 & 1111111) as cond: with e.If((e.r1 == 0x111111) & (e.r10 == 0x22222)) as cond:
e.r0 = 2 e.r0 = 333333
e.r1 = 4
with cond.Else(): with cond.Else():
e.r0 = 3 e.r0 = 444444
e.exit() e.exit()
print(e.load(log_level=1)[1]) print(e.load(log_level=1)[1])
self.fail() self.fail()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment