Skip to content
Snippets Groups Projects
Commit 783a3b72 authored by Martin Teichmann's avatar Martin Teichmann
Browse files

use Opcode enum for better readability

parent 183a37fb
No related branches found
No related tags found
No related merge requests found
from collections import namedtuple from collections import namedtuple
from struct import pack from struct import pack
from enum import Enum
from .bpf import prog_load from .bpf import prog_load
Instruction = namedtuple("Instruction", Instruction = namedtuple("Instruction",
["opcode", "dst", "src", "off", "imm"]) ["opcode", "dst", "src", "off", "imm"])
class Opcode(Enum):
ADD = 4
SUB = 0x14
MUL = 0x24
DIV = 0x34
OR = 0x44
AND = 0x54
LSH = 0x64
RSH = 0x74
NEG = 0x84
MOD = 0x94
XOR = 0xa4
MOV = 0xb4
ARSH = 0xc4
JMP = 5
JEQ = 0x15
JGT = 0x25
JGE = 0x35
JSET = 0x45
JNE = 0x55
JSGT = 0x65
JSGE = 0x75
JLT = 0xa5
JLE = 0xb5
JSLT = 0xc5
JSLE = 0xd5
CALL = 0x85
EXIT = 0x95
REG = 8
LONG = 3
H = 8
W = 0
B = 0x10
DW = 0x18
LD = 0x61
ST = 0x62
STX = 0x63
def __mul__(self, value):
if value:
return OpcodeFlags({self})
else:
return OpcodeFlags(set())
def __add__(self, value):
return OpcodeFlags({self}) + value
def __repr__(self):
return self.name
def __eq__(self, value):
return self is value or self.value == value
def __hash__(self):
return super().__hash__()
class OpcodeFlags:
def __init__(self, opcodes):
self.opcodes = opcodes
@property
def value(self):
return sum(op.value for op in self.opcodes)
def __add__(self, value):
if isinstance(value, Opcode):
return OpcodeFlags(self.opcodes | {value})
else:
return OpcodeFlags(self.opcodes | value.opcodes)
def __repr__(self):
return "|".join(op.name for op in self.opcodes)
def __eq__(self, value):
if isinstance(value, int):
return self.value == value
else:
self.opcodes == value.opcodes
class AssembleError(Exception): class AssembleError(Exception):
pass pass
def augassign(opcode): def augassign(opcode):
def ret(self, value): def ret(self, value):
if isinstance(value, int): if isinstance(value, int):
return Instruction(opcode + 3 * self.long, self.no, 0, 0, value) return Instruction(opcode + Opcode.LONG * self.long, self.no,
0, 0, value)
elif isinstance(value, Register) and self.long == value.long: elif isinstance(value, Register) and self.long == value.long:
return Instruction(opcode + 8 + 3 * self.long, return Instruction(opcode + Opcode.REG + Opcode.LONG * self.long,
self.no, value.no, 0, 0) self.no, value.no, 0, 0)
else: else:
return NotImplemented return NotImplemented
...@@ -45,13 +132,15 @@ class Comparison: ...@@ -45,13 +132,15 @@ class Comparison:
self.target() self.target()
return return
assert self.ebpf.opcodes[self.else_origin] is None assert self.ebpf.opcodes[self.else_origin] is None
self.ebpf.opcodes[self.else_origin] = Instruction(5, 0, 0, len(self.ebpf.opcodes) - self.else_origin - 1, 0) self.ebpf.opcodes[self.else_origin] = Instruction(
Opcode.JMP, 0, 0,
len(self.ebpf.opcodes) - self.else_origin - 1, 0)
self.ebpf.owners, self.owners = \ self.ebpf.owners, self.owners = \
self.ebpf.owners & self.owners, self.ebpf.owners self.ebpf.owners & self.owners, self.ebpf.owners
if self.invert is not None: if self.invert is not None:
olen = len(self.ebpf.opcodes) olen = len(self.ebpf.opcodes)
assert self.ebpf.opcodes[self.invert].opcode == 5 assert self.ebpf.opcodes[self.invert].opcode == Opcode.JMP
self.ebpf.opcodes[self.invert:self.invert] = \ self.ebpf.opcodes[self.invert:self.invert] = \
self.ebpf.opcodes[self.else_origin+1:] self.ebpf.opcodes[self.else_origin+1:]
del self.ebpf.opcodes[olen-1:] del self.ebpf.opcodes[olen-1:]
...@@ -62,7 +151,7 @@ class Comparison: ...@@ -62,7 +151,7 @@ class Comparison:
def Else(self): def Else(self):
op, dst, src, off, imm = self.ebpf.opcodes[self.origin] op, dst, src, off, imm = self.ebpf.opcodes[self.origin]
if op == 5: if op == Opcode.JMP:
self.invert = self.origin self.invert = self.origin
else: else:
self.ebpf.opcodes[self.origin] = \ self.ebpf.opcodes[self.origin] = \
...@@ -77,7 +166,7 @@ class Comparison: ...@@ -77,7 +166,7 @@ class Comparison:
self.target() self.target()
self.origin = origin self.origin = origin
self.right = self.dst = 0 self.right = self.dst = 0
self.opcode = 5 self.opcode = Opcode.JMP
def __and__(self, value): def __and__(self, value):
return AndOrComparison(self.ebpf, self, value, True) return AndOrComparison(self.ebpf, self, value, True)
...@@ -122,7 +211,7 @@ class SimpleComparison(Comparison): ...@@ -122,7 +211,7 @@ class SimpleComparison(Comparison):
len(self.ebpf.opcodes) - self.origin - 1, self.right) len(self.ebpf.opcodes) - self.origin - 1, self.right)
else: else:
inst = Instruction( inst = Instruction(
self.opcode + 8, self.dst, self.src, self.opcode + Opcode.REG, self.dst, self.src,
len(self.ebpf.opcodes) - self.origin - 1, 0) len(self.ebpf.opcodes) - self.origin - 1, 0)
self.ebpf.opcodes[self.origin] = inst self.ebpf.opcodes[self.origin] = inst
self.ebpf.owners, self.owners = \ self.ebpf.owners, self.owners = \
...@@ -169,22 +258,22 @@ def binary(opcode, symetric=False): ...@@ -169,22 +258,22 @@ def binary(opcode, symetric=False):
class Expression: class Expression:
__radd__ = __add__ = binary(4, True) __radd__ = __add__ = binary(Opcode.ADD, True)
__sub__ = binary(0x14) __sub__ = binary(Opcode.SUB)
__rmul__ = __mul__ = binary(0x24, True) __rmul__ = __mul__ = binary(Opcode.MUL, True)
__truediv__ = binary(0x34) __truediv__ = binary(Opcode.DIV)
__ror__ = __or__ = binary(0x44, True) __ror__ = __or__ = binary(Opcode.OR, True)
__lshift__ = binary(0x64) __lshift__ = binary(Opcode.LSH)
__rshift__ = binary(0x74) __rshift__ = binary(Opcode.RSH)
__mod__ = binary(0x94) __mod__ = binary(Opcode.MOD)
__rxor__ = __xor__ = binary(0xa4, True) __rxor__ = __xor__ = binary(Opcode.XOR, True)
__eq__ = comparison(0x15, 0x55) __eq__ = comparison(Opcode.JEQ, Opcode.JNE)
__gt__ = comparison(0x25, 0xb5, 0x65, 0xd5) __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE)
__ge__ = comparison(0x35, 0xa5, 0x75, 0xc5) __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT)
__lt__ = comparison(0xa5, 0x35, 0xc5, 0x75) __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE)
__le__ = comparison(0xb5, 0x25, 0xd5, 0x65) __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT)
__ne__ = comparison(0x55, 0x15) __ne__ = comparison(Opcode.JNE, Opcode.JEQ)
def __and__(self, value): def __and__(self, value):
return AndExpression(self.ebpf, self, value) return AndExpression(self.ebpf, self, value)
...@@ -207,16 +296,18 @@ class Binary(Expression): ...@@ -207,16 +296,18 @@ class Binary(Expression):
else: else:
free = False free = False
dst, long, signed, _ = self.left.calculate(dst, long, signed, True) dst, long, signed, _ = self.left.calculate(dst, long, signed, True)
if self.operator == 0x74 and signed: # >>= if self.operator is Opcode.RSH and signed: # >>=
operator = 0xc4 operator = Opcode.ARSH
else: else:
operator = self.operator operator = self.operator
if isinstance(self.right, int): if isinstance(self.right, int):
self.ebpf.append(operator + (3 if long is None else 3 * long), self.ebpf.append(operator + (Opcode.LONG if long is None
else Opcode.LONG * long),
dst, 0, 0, self.right) dst, 0, 0, self.right)
else: else:
src, long, _, rfree = self.right.calculate(None, long, None) src, long, _, rfree = self.right.calculate(None, long, None)
self.ebpf.append(operator + 3 * long + 8, dst, src, 0, 0) self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG,
dst, src, 0, 0)
if rfree: if rfree:
self.ebpf.owners.discard(src) self.ebpf.owners.discard(src)
return dst, long, signed, free return dst, long, signed, free
...@@ -224,7 +315,7 @@ class Binary(Expression): ...@@ -224,7 +315,7 @@ class Binary(Expression):
class Sum(Binary): class Sum(Binary):
def __init__(self, ebpf, left, right): def __init__(self, ebpf, left, right):
super().__init__(ebpf, left, right, 4) super().__init__(ebpf, left, right, Opcode.ADD)
def __add__(self, value): def __add__(self, value):
if isinstance(value, int): if isinstance(value, int):
...@@ -243,9 +334,9 @@ class Sum(Binary): ...@@ -243,9 +334,9 @@ class Sum(Binary):
class AndExpression(Binary, SimpleComparison): class AndExpression(Binary, SimpleComparison):
def __init__(self, ebpf, left, right): def __init__(self, ebpf, left, right):
Binary.__init__(self, ebpf, left, right, 0x54) Binary.__init__(self, ebpf, left, right, Opcode.AND)
SimpleComparison.__init__(self, ebpf, left, right, 0x45) SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET)
self.opcode = (0x45, None, 0x45, None) self.opcode = (Opcode.JSET, None, Opcode.JSET, None)
def compare(self, negative): def compare(self, negative):
super().compare(False) super().compare(False)
...@@ -261,22 +352,23 @@ class Register(Expression): ...@@ -261,22 +352,23 @@ class Register(Expression):
self.long = long self.long = long
self.signed = signed self.signed = signed
__iadd__ = augassign(4) __iadd__ = augassign(Opcode.ADD)
__isub__ = augassign(0x14) __isub__ = augassign(Opcode.SUB)
__imul__ = augassign(0x24) __imul__ = augassign(Opcode.MUL)
__itruediv__ = augassign(0x34) __itruediv__ = augassign(Opcode.DIV)
__ior__ = augassign(0x44) __ior__ = augassign(Opcode.OR)
__iand__ = augassign(0x54) __iand__ = augassign(Opcode.AND)
__ilshift__ = augassign(0x64) __ilshift__ = augassign(Opcode.LSH)
__imod__ = augassign(0x94) __imod__ = augassign(Opcode.MOD)
__ixor__ = augassign(0xa4) __ixor__ = augassign(Opcode.XOR)
def __irshift__(self, value): def __irshift__(self, value):
opcode = Opcode.ARSH if self.signed else Opcode.RSH
if isinstance(value, int): if isinstance(value, int):
return Instruction(0x74 + 3 * self.long + 0x50 * self.signed, return Instruction(opcode + Opcode.LONG * self.long, self.no,
self.no, 0, 0, value) 0, 0, value)
elif isinstance(value, Register) and self.long == value.long: elif isinstance(value, Register) and self.long == value.long:
return Instruction(0x7c + 3 * self.long + 0x50 * self.signed, return Instruction(opcode + Opcode.REG + Opcode.LONG * self.long,
self.no, value.no, 0, 0) self.no, value.no, 0, 0)
else: else:
return NotImplemented return NotImplemented
...@@ -303,7 +395,8 @@ class Register(Expression): ...@@ -303,7 +395,8 @@ class Register(Expression):
if self.no not in self.ebpf.owners: if self.no not in self.ebpf.owners:
raise AssembleError("register has no value") raise AssembleError("register has no value")
if dst != self.no and force: if dst != self.no and force:
self.ebpf.append(0xbc + 3 * self.long, dst, self.no, 0, 0) self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * self.long,
dst, self.no, 0, 0)
return dst, self.long, self.signed, False return dst, self.long, self.signed, False
else: else:
return self.no, self.long, self.signed, False return self.no, self.long, self.signed, False
...@@ -316,7 +409,7 @@ class Memory(Expression): ...@@ -316,7 +409,7 @@ class Memory(Expression):
self.address = address self.address = address
def calculate(self, dst, long, signed, force=False): def calculate(self, dst, long, signed, force=False):
if not long and self.bits == 0x18: if not long and self.bits == Opcode.DW:
raise AssembleError("cannot compile") raise AssembleError("cannot compile")
if dst is None: if dst is None:
dst = self.ebpf.get_free_register() dst = self.ebpf.get_free_register()
...@@ -324,11 +417,11 @@ class Memory(Expression): ...@@ -324,11 +417,11 @@ class Memory(Expression):
else: else:
free = False free = False
if isinstance(self.address, Sum): if isinstance(self.address, Sum):
self.ebpf.append(0x61 + self.bits, dst, self.address.left.no, self.ebpf.append(Opcode.LD + self.bits, dst, self.address.left.no,
self.address.right, 0) self.address.right, 0)
else: else:
src, _, _, rfree = self.address.calculate(None, None, None) src, _, _, rfree = self.address.calculate(None, None, None)
self.ebpf.append(0x61 + self.bits, dst, src, 0, 0) self.ebpf.append(Opcode.LD + self.bits, dst, src, 0, 0)
if rfree: if rfree:
self.ebpf.owners.discard(src) self.ebpf.owners.discard(src)
return dst, long, signed, free return dst, long, signed, free
...@@ -348,10 +441,10 @@ class MemoryDesc: ...@@ -348,10 +441,10 @@ class MemoryDesc:
dst, _, _, afree = addr.calculate(None, None, None) dst, _, _, afree = addr.calculate(None, None, None)
offset = 0 offset = 0
if isinstance(value, int): if isinstance(value, int):
self.ebpf.append(0x62 + self.bits, dst, 0, offset, value) self.ebpf.append(Opcode.ST + self.bits, dst, 0, offset, value)
else: else:
src, _, _, free = value.calculate(None, None, None) src, _, _, free = value.calculate(None, None, None)
self.ebpf.append(0x63 + self.bits, dst, src, offset, 0) self.ebpf.append(Opcode.STX + self.bits, dst, src, offset, 0)
if free: if free:
self.ebpf.owners.discard(src) self.ebpf.owners.discard(src)
if afree: if afree:
...@@ -374,7 +467,7 @@ class PseudoFd(Expression): ...@@ -374,7 +467,7 @@ class PseudoFd(Expression):
free = True free = True
else: else:
free = False free = False
self.ebpf.append(0x18, dst, 1, 0, self.fd) self.ebpf.append(Opcode.DW, dst, 1, 0, self.fd)
self.ebpf.append(0, 0, 0, 0, 0) self.ebpf.append(0, 0, 0, 0, 0)
return dst, long, signed, free return dst, long, signed, free
...@@ -395,9 +488,10 @@ class RegisterDesc: ...@@ -395,9 +488,10 @@ class RegisterDesc:
instance.owners.add(self.no) instance.owners.add(self.no)
if isinstance(value, int): if isinstance(value, int):
if -0x80000000 <= value < 0x80000000: if -0x80000000 <= value < 0x80000000:
instance.append(0xb4 + 3 * self.long, self.no, 0, 0, value) instance.append(Opcode.MOV + Opcode.LONG * self.long,
self.no, 0, 0, value)
else: else:
instance.append(0x18, self.no, 0, 0, value & 0xffffffff) instance.append(Opcode.DW, self.no, 0, 0, value & 0xffffffff)
instance.append(0, 0, 0, 0, value >> 32) instance.append(0, 0, 0, 0, value >> 32)
elif isinstance(value, Expression): elif isinstance(value, Expression):
value.calculate(self.no, self.long, self.signed, True) value.calculate(self.no, self.long, self.signed, True)
...@@ -414,10 +508,10 @@ class EBPF: ...@@ -414,10 +508,10 @@ class EBPF:
self.license = license self.license = license
self.kern_version = kern_version self.kern_version = kern_version
self.m8 = MemoryDesc(self, 0x10) self.m8 = MemoryDesc(self, Opcode.B)
self.m16 = MemoryDesc(self, 0x8) self.m16 = MemoryDesc(self, Opcode.H)
self.m32 = MemoryDesc(self, 0) self.m32 = MemoryDesc(self, Opcode.W)
self.m64 = MemoryDesc(self, 0x18) self.m64 = MemoryDesc(self, Opcode.DW)
self.owners = {1, 10} self.owners = {1, 10}
...@@ -439,7 +533,7 @@ class EBPF: ...@@ -439,7 +533,7 @@ class EBPF:
return comp return comp
def jump(self): def jump(self):
comp = SimpleComparison(self, None, 0, 5) comp = SimpleComparison(self, None, 0, Opcode.JMP)
comp.origin = len(self.opcodes) comp.origin = len(self.opcodes)
comp.dst = 0 comp.dst = 0
comp.owners = self.owners.copy() comp.owners = self.owners.copy()
...@@ -455,12 +549,12 @@ class EBPF: ...@@ -455,12 +549,12 @@ class EBPF:
return PseudoFd(self, fd) return PseudoFd(self, fd)
def call(self, no): def call(self, no):
self.append(0x85, 0, 0, 0, no) self.append(Opcode.CALL, 0, 0, 0, no)
self.owners.add(0) self.owners.add(0)
self.owners -= set(range(1, 6)) self.owners -= set(range(1, 6))
def exit(self): def exit(self):
self.append(0x95, 0, 0, 0, 0) self.append(Opcode.EXIT, 0, 0, 0, 0)
def get_free_register(self): def get_free_register(self):
for i in range(10): for i in range(10):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment