diff --git a/ebpf.py b/ebpf.py index cca2f1f826a5f345c04d572054df8d07987ae7c7..e7ddfd695d5526231be515e1ccef69deea38db2f 100644 --- a/ebpf.py +++ b/ebpf.py @@ -1,20 +1,107 @@ from collections import namedtuple from struct import pack +from enum import Enum from .bpf import prog_load Instruction = namedtuple("Instruction", ["opcode", "dst", "src", "off", "imm"]) +class Opcode(Enum): + ADD = 4 + SUB = 0x14 + MUL = 0x24 + DIV = 0x34 + OR = 0x44 + AND = 0x54 + LSH = 0x64 + RSH = 0x74 + NEG = 0x84 + MOD = 0x94 + XOR = 0xa4 + MOV = 0xb4 + ARSH = 0xc4 + + JMP = 5 + JEQ = 0x15 + JGT = 0x25 + JGE = 0x35 + JSET = 0x45 + JNE = 0x55 + JSGT = 0x65 + JSGE = 0x75 + JLT = 0xa5 + JLE = 0xb5 + JSLT = 0xc5 + JSLE = 0xd5 + + CALL = 0x85 + EXIT = 0x95 + + REG = 8 + LONG = 3 + + H = 8 + W = 0 + B = 0x10 + DW = 0x18 + + LD = 0x61 + ST = 0x62 + STX = 0x63 + + def __mul__(self, value): + if value: + return OpcodeFlags({self}) + else: + return OpcodeFlags(set()) + + def __add__(self, value): + return OpcodeFlags({self}) + value + + def __repr__(self): + return self.name + + def __eq__(self, value): + return self is value or self.value == value + + def __hash__(self): + return super().__hash__() + +class OpcodeFlags: + def __init__(self, opcodes): + self.opcodes = opcodes + + @property + def value(self): + return sum(op.value for op in self.opcodes) + + def __add__(self, value): + if isinstance(value, Opcode): + return OpcodeFlags(self.opcodes | {value}) + else: + return OpcodeFlags(self.opcodes | value.opcodes) + + def __repr__(self): + return "|".join(op.name for op in self.opcodes) + + def __eq__(self, value): + if isinstance(value, int): + return self.value == value + else: + self.opcodes == value.opcodes + + class AssembleError(Exception): pass def augassign(opcode): def ret(self, value): if isinstance(value, int): - return Instruction(opcode + 3 * self.long, self.no, 0, 0, value) + return Instruction(opcode + Opcode.LONG * self.long, self.no, + 0, 0, value) elif isinstance(value, Register) and self.long == value.long: - return Instruction(opcode + 8 + 3 * self.long, + return Instruction(opcode + Opcode.REG + Opcode.LONG * self.long, self.no, value.no, 0, 0) else: return NotImplemented @@ -45,13 +132,15 @@ class Comparison: self.target() return assert self.ebpf.opcodes[self.else_origin] is None - self.ebpf.opcodes[self.else_origin] = Instruction(5, 0, 0, len(self.ebpf.opcodes) - self.else_origin - 1, 0) + self.ebpf.opcodes[self.else_origin] = Instruction( + Opcode.JMP, 0, 0, + len(self.ebpf.opcodes) - self.else_origin - 1, 0) self.ebpf.owners, self.owners = \ self.ebpf.owners & self.owners, self.ebpf.owners if self.invert is not None: olen = len(self.ebpf.opcodes) - assert self.ebpf.opcodes[self.invert].opcode == 5 + assert self.ebpf.opcodes[self.invert].opcode == Opcode.JMP self.ebpf.opcodes[self.invert:self.invert] = \ self.ebpf.opcodes[self.else_origin+1:] del self.ebpf.opcodes[olen-1:] @@ -62,7 +151,7 @@ class Comparison: def Else(self): op, dst, src, off, imm = self.ebpf.opcodes[self.origin] - if op == 5: + if op == Opcode.JMP: self.invert = self.origin else: self.ebpf.opcodes[self.origin] = \ @@ -77,7 +166,7 @@ class Comparison: self.target() self.origin = origin self.right = self.dst = 0 - self.opcode = 5 + self.opcode = Opcode.JMP def __and__(self, value): return AndOrComparison(self.ebpf, self, value, True) @@ -122,7 +211,7 @@ class SimpleComparison(Comparison): len(self.ebpf.opcodes) - self.origin - 1, self.right) else: inst = Instruction( - self.opcode + 8, self.dst, self.src, + self.opcode + Opcode.REG, self.dst, self.src, len(self.ebpf.opcodes) - self.origin - 1, 0) self.ebpf.opcodes[self.origin] = inst self.ebpf.owners, self.owners = \ @@ -169,22 +258,22 @@ def binary(opcode, symetric=False): class Expression: - __radd__ = __add__ = binary(4, True) - __sub__ = binary(0x14) - __rmul__ = __mul__ = binary(0x24, True) - __truediv__ = binary(0x34) - __ror__ = __or__ = binary(0x44, True) - __lshift__ = binary(0x64) - __rshift__ = binary(0x74) - __mod__ = binary(0x94) - __rxor__ = __xor__ = binary(0xa4, True) - - __eq__ = comparison(0x15, 0x55) - __gt__ = comparison(0x25, 0xb5, 0x65, 0xd5) - __ge__ = comparison(0x35, 0xa5, 0x75, 0xc5) - __lt__ = comparison(0xa5, 0x35, 0xc5, 0x75) - __le__ = comparison(0xb5, 0x25, 0xd5, 0x65) - __ne__ = comparison(0x55, 0x15) + __radd__ = __add__ = binary(Opcode.ADD, True) + __sub__ = binary(Opcode.SUB) + __rmul__ = __mul__ = binary(Opcode.MUL, True) + __truediv__ = binary(Opcode.DIV) + __ror__ = __or__ = binary(Opcode.OR, True) + __lshift__ = binary(Opcode.LSH) + __rshift__ = binary(Opcode.RSH) + __mod__ = binary(Opcode.MOD) + __rxor__ = __xor__ = binary(Opcode.XOR, True) + + __eq__ = comparison(Opcode.JEQ, Opcode.JNE) + __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE) + __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT) + __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE) + __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT) + __ne__ = comparison(Opcode.JNE, Opcode.JEQ) def __and__(self, value): return AndExpression(self.ebpf, self, value) @@ -207,16 +296,18 @@ class Binary(Expression): else: free = False dst, long, signed, _ = self.left.calculate(dst, long, signed, True) - if self.operator == 0x74 and signed: # >>= - operator = 0xc4 + if self.operator is Opcode.RSH and signed: # >>= + operator = Opcode.ARSH else: operator = self.operator if isinstance(self.right, int): - self.ebpf.append(operator + (3 if long is None else 3 * long), + self.ebpf.append(operator + (Opcode.LONG if long is None + else Opcode.LONG * long), dst, 0, 0, self.right) else: src, long, _, rfree = self.right.calculate(None, long, None) - self.ebpf.append(operator + 3 * long + 8, dst, src, 0, 0) + self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG, + dst, src, 0, 0) if rfree: self.ebpf.owners.discard(src) return dst, long, signed, free @@ -224,7 +315,7 @@ class Binary(Expression): class Sum(Binary): def __init__(self, ebpf, left, right): - super().__init__(ebpf, left, right, 4) + super().__init__(ebpf, left, right, Opcode.ADD) def __add__(self, value): if isinstance(value, int): @@ -243,9 +334,9 @@ class Sum(Binary): class AndExpression(Binary, SimpleComparison): def __init__(self, ebpf, left, right): - Binary.__init__(self, ebpf, left, right, 0x54) - SimpleComparison.__init__(self, ebpf, left, right, 0x45) - self.opcode = (0x45, None, 0x45, None) + Binary.__init__(self, ebpf, left, right, Opcode.AND) + SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET) + self.opcode = (Opcode.JSET, None, Opcode.JSET, None) def compare(self, negative): super().compare(False) @@ -261,22 +352,23 @@ class Register(Expression): self.long = long self.signed = signed - __iadd__ = augassign(4) - __isub__ = augassign(0x14) - __imul__ = augassign(0x24) - __itruediv__ = augassign(0x34) - __ior__ = augassign(0x44) - __iand__ = augassign(0x54) - __ilshift__ = augassign(0x64) - __imod__ = augassign(0x94) - __ixor__ = augassign(0xa4) + __iadd__ = augassign(Opcode.ADD) + __isub__ = augassign(Opcode.SUB) + __imul__ = augassign(Opcode.MUL) + __itruediv__ = augassign(Opcode.DIV) + __ior__ = augassign(Opcode.OR) + __iand__ = augassign(Opcode.AND) + __ilshift__ = augassign(Opcode.LSH) + __imod__ = augassign(Opcode.MOD) + __ixor__ = augassign(Opcode.XOR) def __irshift__(self, value): + opcode = Opcode.ARSH if self.signed else Opcode.RSH if isinstance(value, int): - return Instruction(0x74 + 3 * self.long + 0x50 * self.signed, - self.no, 0, 0, value) + return Instruction(opcode + Opcode.LONG * self.long, self.no, + 0, 0, value) elif isinstance(value, Register) and self.long == value.long: - return Instruction(0x7c + 3 * self.long + 0x50 * self.signed, + return Instruction(opcode + Opcode.REG + Opcode.LONG * self.long, self.no, value.no, 0, 0) else: return NotImplemented @@ -303,7 +395,8 @@ class Register(Expression): if self.no not in self.ebpf.owners: raise AssembleError("register has no value") if dst != self.no and force: - self.ebpf.append(0xbc + 3 * self.long, dst, self.no, 0, 0) + self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * self.long, + dst, self.no, 0, 0) return dst, self.long, self.signed, False else: return self.no, self.long, self.signed, False @@ -316,7 +409,7 @@ class Memory(Expression): self.address = address def calculate(self, dst, long, signed, force=False): - if not long and self.bits == 0x18: + if not long and self.bits == Opcode.DW: raise AssembleError("cannot compile") if dst is None: dst = self.ebpf.get_free_register() @@ -324,11 +417,11 @@ class Memory(Expression): else: free = False if isinstance(self.address, Sum): - self.ebpf.append(0x61 + self.bits, dst, self.address.left.no, + self.ebpf.append(Opcode.LD + self.bits, dst, self.address.left.no, self.address.right, 0) else: src, _, _, rfree = self.address.calculate(None, None, None) - self.ebpf.append(0x61 + self.bits, dst, src, 0, 0) + self.ebpf.append(Opcode.LD + self.bits, dst, src, 0, 0) if rfree: self.ebpf.owners.discard(src) return dst, long, signed, free @@ -348,10 +441,10 @@ class MemoryDesc: dst, _, _, afree = addr.calculate(None, None, None) offset = 0 if isinstance(value, int): - self.ebpf.append(0x62 + self.bits, dst, 0, offset, value) + self.ebpf.append(Opcode.ST + self.bits, dst, 0, offset, value) else: src, _, _, free = value.calculate(None, None, None) - self.ebpf.append(0x63 + self.bits, dst, src, offset, 0) + self.ebpf.append(Opcode.STX + self.bits, dst, src, offset, 0) if free: self.ebpf.owners.discard(src) if afree: @@ -374,7 +467,7 @@ class PseudoFd(Expression): free = True else: free = False - self.ebpf.append(0x18, dst, 1, 0, self.fd) + self.ebpf.append(Opcode.DW, dst, 1, 0, self.fd) self.ebpf.append(0, 0, 0, 0, 0) return dst, long, signed, free @@ -395,9 +488,10 @@ class RegisterDesc: instance.owners.add(self.no) if isinstance(value, int): if -0x80000000 <= value < 0x80000000: - instance.append(0xb4 + 3 * self.long, self.no, 0, 0, value) + instance.append(Opcode.MOV + Opcode.LONG * self.long, + self.no, 0, 0, value) else: - instance.append(0x18, self.no, 0, 0, value & 0xffffffff) + instance.append(Opcode.DW, self.no, 0, 0, value & 0xffffffff) instance.append(0, 0, 0, 0, value >> 32) elif isinstance(value, Expression): value.calculate(self.no, self.long, self.signed, True) @@ -414,10 +508,10 @@ class EBPF: self.license = license self.kern_version = kern_version - self.m8 = MemoryDesc(self, 0x10) - self.m16 = MemoryDesc(self, 0x8) - self.m32 = MemoryDesc(self, 0) - self.m64 = MemoryDesc(self, 0x18) + self.m8 = MemoryDesc(self, Opcode.B) + self.m16 = MemoryDesc(self, Opcode.H) + self.m32 = MemoryDesc(self, Opcode.W) + self.m64 = MemoryDesc(self, Opcode.DW) self.owners = {1, 10} @@ -439,7 +533,7 @@ class EBPF: return comp def jump(self): - comp = SimpleComparison(self, None, 0, 5) + comp = SimpleComparison(self, None, 0, Opcode.JMP) comp.origin = len(self.opcodes) comp.dst = 0 comp.owners = self.owners.copy() @@ -455,12 +549,12 @@ class EBPF: return PseudoFd(self, fd) def call(self, no): - self.append(0x85, 0, 0, 0, no) + self.append(Opcode.CALL, 0, 0, 0, no) self.owners.add(0) self.owners -= set(range(1, 6)) def exit(self): - self.append(0x95, 0, 0, 0, 0) + self.append(Opcode.EXIT, 0, 0, 0, 0) def get_free_register(self): for i in range(10):