diff --git a/ebpfcat/arraymap.py b/ebpfcat/arraymap.py index 1cf4ae87b63aa11f408d729978b0ff70ce9eb661..9552eb35d5c2995297f2b02d7354e3c233e64e11 100644 --- a/ebpfcat/arraymap.py +++ b/ebpfcat/arraymap.py @@ -1,7 +1,7 @@ from itertools import chain from struct import pack_into, unpack_from, calcsize -from .ebpf import FuncId, Map, Memory, MemoryDesc, Opcode +from .ebpf import FuncId, Map, MemoryDesc from .bpf import create_map, lookup_elem, MapType, update_elem @@ -22,18 +22,24 @@ class ArrayGlobalVarDesc(MemoryDesc): def __get__(self, instance, owner): if instance is None: return self - fmt, addr = self.fmt_addr(instance) if instance.ebpf.loaded: + fmt, addr = self.fmt_addr(instance) data = instance.ebpf.__dict__[self.map.name].data - return unpack_from(fmt, data, addr)[0] + ret = unpack_from(fmt, data, addr) + if len(ret) == 1: + return ret[0] + else: + return ret else: return super().__get__(instance, owner) def __set__(self, instance, value): - fmt, addr = self.fmt_addr(instance) if instance.ebpf.loaded: + fmt, addr = self.fmt_addr(instance) + if not isinstance(value, tuple): + value = value, pack_into(fmt, instance.ebpf.__dict__[self.map.name].data, - addr, value) + addr, *value) else: super().__set__(instance, value) @@ -85,7 +91,6 @@ class ArrayMap(Map): else: return write_size, position - def __set_name__(self, owner, name): self.name = name @@ -97,7 +102,7 @@ class ArrayMap(Map): fd = create_map(MapType.ARRAY, 4, size, 1) setattr(ebpf, self.name, ArrayMapAccess(fd, write_size, size)) with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack: - ebpf.append(Opcode.ST, 10, 0, stack, 0) + ebpf.mI[ebpf.r10 + stack] = 0 ebpf.r1 = ebpf.get_fd(fd) ebpf.r2 = ebpf.r10 + stack ebpf.call(FuncId.map_lookup_elem) diff --git a/ebpfcat/devices.py b/ebpfcat/devices.py index 047610c5167b9228f06b79100dabe7ccd768f583..55e4288355c1f6f8897babb4783f7ff7b6bf0317 100644 --- a/ebpfcat/devices.py +++ b/ebpfcat/devices.py @@ -184,3 +184,15 @@ class Dummy(Device): def program(self): pass + + +class RandomDropper(Device): + rate = DeviceVar("I", write=True) + + def program(self): + from .xdp import XDPExitCode + with self.ebpf.tmp: + self.ebpf.tmp = ktime(self.ebpf) + self.ebpf.tmp = self.ebpf.tmp * 0xcf019d85 + 1 + with self.ebpf.tmp & 0xffff < self.rate: + self.ebpf.exit(XDPExitCode.DROP) diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 99b91ff101480703172732ab60e7de570e4edb61..090dc42e3c0fc92de8be0c55907ec9f4f3b15f58 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -1,6 +1,6 @@ from collections import namedtuple from contextlib import contextmanager, ExitStack -from struct import pack, unpack +from struct import pack, unpack, calcsize from enum import Enum from . import bpf @@ -261,7 +261,6 @@ def comparison(uposop, unegop, sposop=None, snegop=None): class Comparison: def __init__(self, ebpf): self.ebpf = ebpf - self.invert = None self.else_origin = None def __enter__(self): @@ -280,36 +279,16 @@ class Comparison: self.ebpf.owners, self.owners = \ self.ebpf.owners & self.owners, self.ebpf.owners - if self.invert is not None: - olen = len(self.ebpf.opcodes) - assert self.ebpf.opcodes[self.invert].opcode == Opcode.JMP - self.ebpf.opcodes[self.invert:self.invert] = \ - self.ebpf.opcodes[self.else_origin+1:] - del self.ebpf.opcodes[olen-1:] - op, dst, src, off, imm = self.ebpf.opcodes[self.invert - 1] - self.ebpf.opcodes[self.invert - 1] = \ - Instruction(op, dst, src, - len(self.ebpf.opcodes) - self.else_origin + 1, imm) + def retarget_one(self): + op, dst, src, off, imm = self.ebpf.opcodes[self.origin] + self.ebpf.opcodes[self.origin] = Instruction(op, dst, src, off+1, imm) def Else(self): - op, dst, src, off, imm = self.ebpf.opcodes[self.origin] - if op == Opcode.JMP: - self.invert = self.origin - else: - self.ebpf.opcodes[self.origin] = \ - Instruction(op, dst, src, off+1, imm) self.else_origin = len(self.ebpf.opcodes) self.ebpf.opcodes.append(None) + self.target(True) return self - def invert_result(self): - origin = len(self.ebpf.opcodes) - self.ebpf.opcodes.append(None) - self.target() - self.origin = origin - self.right = self.dst = 0 - self.opcode = Opcode.JMP - def __and__(self, value): return AndOrComparison(self.ebpf, self, value, True) @@ -333,19 +312,22 @@ class SimpleComparison(Comparison): def compare(self, negative): with self.left.calculate(None, None, None) as (self.dst, _, lsigned): with ExitStack() as exitStack: - if not isinstance(self.right, int): + if isinstance(self.right, int): + rsigned = (self.right < 0) + else: self.src, _, rsigned = exitStack.enter_context( self.right.calculate(None, None, None)) - if rsigned != rsigned: - raise AssembleError("need same signedness for comparison") self.origin = len(self.ebpf.opcodes) self.ebpf.opcodes.append(None) - self.opcode = self.opcode[negative + 2 * lsigned] + self.opcode = self.opcode[negative + 2 * (lsigned or rsigned)] self.owners = self.ebpf.owners.copy() - def target(self): - assert self.ebpf.opcodes[self.origin] is None - if isinstance(self.right, int): + def target(self, retarget=False): + assert retarget or self.ebpf.opcodes[self.origin] is None + if self.opcode == Opcode.JMP: + inst = Instruction(Opcode.JMP, 0, 0, + len(self.ebpf.opcodes) - self.origin - 1, 0) + elif isinstance(self.right, int): inst = Instruction( self.opcode, self.dst, 0, len(self.ebpf.opcodes) - self.origin - 1, self.right) @@ -354,8 +336,9 @@ class SimpleComparison(Comparison): self.opcode + Opcode.REG, self.dst, self.src, len(self.ebpf.opcodes) - self.origin - 1, 0) self.ebpf.opcodes[self.origin] = inst - self.ebpf.owners, self.owners = \ - self.ebpf.owners & self.owners, self.ebpf.owners + if not retarget: + self.ebpf.owners, self.owners = \ + self.ebpf.owners & self.owners, self.ebpf.owners class AndOrComparison(Comparison): @@ -367,19 +350,18 @@ class AndOrComparison(Comparison): self.targetted = False def compare(self, negative): - self.left.compare(self.is_and != negative) - self.right.compare(self.is_and != negative) + self.negative = negative + self.left.compare(self.is_and) + self.right.compare(negative) + self.origin = len(self.ebpf.opcodes) if self.is_and != negative: - self.invert_result() - self.owners = self.ebpf.owners.copy() - - def target(self): - if self.targetted: - super().target() - else: self.left.target() - self.right.target() - self.targetted = True + self.owners = self.ebpf.owners.copy() + + def target(self, retarget=False): + if self.is_and == self.negative: + self.left.target(retarget) + self.right.target(retarget) class InvertComparison(Comparison): @@ -447,8 +429,9 @@ class Expression: @contextmanager def calculate(self, dst, long, signed, force=False): with self.ebpf.get_free_register(dst) as dst: - with self.get_address(dst, long, signed) as (src, bits): - self.ebpf.append(Opcode.LD + bits, dst, src, 0, 0) + with self.get_address(dst, long, signed) as (src, fmt): + self.ebpf.append(Opcode.LD + Memory.fmt_to_opcode[fmt], + dst, src, 0, 0) yield dst, long, self.signed @contextmanager @@ -476,25 +459,31 @@ class Binary(Expression): dst = None with self.ebpf.get_free_register(dst) as dst: with self.left.calculate(dst, long, signed, True) \ - as (dst, long, signed): - pass + as (dst, l_long, l_signed): + if long is None: + long = l_long + signed = signed or l_signed if self.operator is Opcode.RSH and signed: # >>= operator = Opcode.ARSH else: operator = self.operator if isinstance(self.right, int): - self.ebpf.append(operator + (Opcode.LONG if long is None - else Opcode.LONG * long), + r_signed = self.right < 0 + self.ebpf.append(operator + Opcode.LONG * long, dst, 0, 0, self.right) else: - with self.right.calculate(None, long, None) as (src, long, _): - self.ebpf.append(operator + Opcode.LONG*long + Opcode.REG, - dst, src, 0, 0) + with self.right.calculate(None, long, None) as \ + (src, r_long, r_signed): + self.ebpf.append( + operator + Opcode.REG + + Opcode.LONG * ((r_long or l_long) + if long is None else long), + dst, src, 0, 0) if orig_dst is None or orig_dst == dst: - yield dst, long, signed + yield dst, long, signed or r_signed return self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * long, orig_dst, dst, 0, 0) - yield orig_dst, long, signed + yield orig_dst, long, signed or r_signed def contains(self, no): return self.left.contains(no) or (not isinstance(self.right, int) @@ -566,11 +555,38 @@ class AndExpression(SimpleComparison, Binary): Binary.__init__(self, ebpf, left, right, Opcode.AND) SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET) self.opcode = (Opcode.JSET, None, Opcode.JSET, None) + self.invert = None def compare(self, negative): super().compare(False) if negative: - self.invert_result() + origin = len(self.ebpf.opcodes) + self.ebpf.opcodes.append(None) + self.target() + self.origin = origin + self.opcode = Opcode.JMP + + def __exit__(self, exc, etype, tb): + super().__exit__(exc, etype, tb) + if self.invert is not None: + olen = len(self.ebpf.opcodes) + assert self.ebpf.opcodes[self.invert].opcode == Opcode.JMP + self.ebpf.opcodes[self.invert:self.invert] = \ + self.ebpf.opcodes[self.else_origin+1:] + del self.ebpf.opcodes[olen-1:] + op, dst, src, off, imm = self.ebpf.opcodes[self.invert - 1] + self.ebpf.opcodes[self.invert - 1] = \ + Instruction(op, dst, src, + len(self.ebpf.opcodes) - self.else_origin + 1, imm) + + def Else(self): + if self.ebpf.opcodes[self.origin][0] == Opcode.JMP: + self.invert = self.origin + else: + self.retarget_one() + self.else_origin = len(self.ebpf.opcodes) + self.ebpf.opcodes.append(None) + return self class Register(Expression): offset = 0 @@ -621,36 +637,41 @@ class Memory(Expression): bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW} fmt_to_opcode = {'I': Opcode.W, 'H': Opcode.H, 'B': Opcode.B, 'Q': Opcode.DW, 'i': Opcode.W, 'h': Opcode.H, 'b': Opcode.B, 'q': Opcode.DW} - fmt_to_size = {'I': 4, 'H': 2, 'B': 1, 'Q': 8, - 'i': 4, 'h': 2, 'b': 1, 'q': 8} - def __init__(self, ebpf, bits, address, signed=False): + def __init__(self, ebpf, fmt, address, signed=False, long=False): self.ebpf = ebpf - self.bits = bits + self.fmt = fmt self.address = address self.signed = signed + self.long = long def __iadd__(self, value): - return IAdd(value) + if self.fmt in "qQiI": + return IAdd(value) + else: + return NotImplemented def __isub__(self, value): - return IAdd(-value) + if self.fmt in "qQiI": + return IAdd(-value) + else: + return NotImplemented @contextmanager def calculate(self, dst, long, signed, force=False): if isinstance(self.address, Sum): with self.ebpf.get_free_register(dst) as dst: - self.ebpf.append(Opcode.LD + self.bits, dst, + self.ebpf.append(Opcode.LD + self.fmt_to_opcode[self.fmt], dst, self.address.left.no, self.address.right, 0) - yield dst, long, self.signed + yield dst, self.long, self.signed else: - with super().calculate(dst, long, signed, force) as ret: - yield ret + with super().calculate(dst, long, signed, force) as (dst, _, _): + yield dst, self.long, self.signed @contextmanager def get_address(self, dst, long, signed, force=False): - with self.address.calculate(dst, None, None) as (src, _, _): - yield src, self.bits + with self.address.calculate(dst, True, None) as (src, _, _): + yield src, self.fmt def contains(self, no): return self.address.contains(no) @@ -661,7 +682,7 @@ class MemoryDesc: if instance is None: return self fmt, addr = self.fmt_addr(instance) - return Memory(instance.ebpf, Memory.fmt_to_opcode[fmt], + return Memory(instance.ebpf, fmt, instance.ebpf.r[self.base_register] + addr, fmt.islower()) @@ -695,7 +716,7 @@ class LocalVar(MemoryDesc): self.fmt = fmt def __set_name__(self, owner, name): - size = Memory.fmt_to_size[self.fmt] + size = calcsize(self.fmt) owner.stack -= size owner.stack &= -size self.relative_addr = owner.stack @@ -709,9 +730,11 @@ class LocalVar(MemoryDesc): class MemoryMap: - def __init__(self, ebpf, bits): + def __init__(self, ebpf, fmt, signed=False, long=False): self.ebpf = ebpf - self.bits = bits + self.fmt = fmt + self.long = long + self.signed = signed def __setitem__(self, addr, value): with ExitStack() as exitStack: @@ -720,10 +743,11 @@ class MemoryMap: offset = addr.right else: dst, _, _ = exitStack.enter_context( - addr.calculate(None, None, None)) + addr.calculate(None, True, None)) offset = 0 if isinstance(value, int): - self.ebpf.append(Opcode.ST + self.bits, dst, 0, offset, value) + self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt], + dst, 0, offset, value) return elif isinstance(value, IAdd): value = value.value @@ -731,18 +755,20 @@ class MemoryMap: with self.ebpf.get_free_register(None) as src: self.ebpf.r[src] = value self.ebpf.append( - Opcode.XADD + self.bits, dst, src, offset, 0) + Opcode.XADD + Memory.fmt_to_opcode[self.fmt], + dst, src, offset, 0) return opcode = Opcode.XADD else: opcode = Opcode.STX with value.calculate(None, None, None) as (src, _, _): - self.ebpf.append(opcode + self.bits, dst, src, offset, 0) + self.ebpf.append(opcode + Memory.fmt_to_opcode[self.fmt], + dst, src, offset, 0) def __getitem__(self, addr): if isinstance(addr, Register): addr = addr + 0 - return Memory(self.ebpf, self.bits, addr) + return Memory(self.ebpf, self.fmt, addr, self.signed, self.long) class Map: @@ -872,10 +898,11 @@ class EBPF: self.name = name self.loaded = False - self.mB = MemoryMap(self, Opcode.B) - self.mH = MemoryMap(self, Opcode.H) - self.mI = MemoryMap(self, Opcode.W) - self.mQ = MemoryMap(self, Opcode.DW) + self.mB = MemoryMap(self, "B") + self.mH = MemoryMap(self, "H") + self.mI = MemoryMap(self, "I") + self.mA = MemoryMap(self, "I", False, True) + self.mQ = MemoryMap(self, "Q", False, True) self.r = RegisterArray(self, True, False) self.sr = RegisterArray(self, True, True) diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index 7eabe576a69ebad531a7ad32ea9c9d5103d14ac9..1ad296066f5985521a40bb1e528fe43941b2d421 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -183,18 +183,35 @@ class Tests(TestCase): def test_lock_add(self): class Local(EBPF): a = LocalVar('I') + b = LocalVar('q') + c = LocalVar('h') e = Local(ProgType.XDP, "GPL") e.a += 3 e.mI[e.r1] += e.r1 e.a -= 3 + e.b += 3 + e.mQ[e.r1] += e.r1 + + # do not generate XADD for bytes and words + e.c += 3 + e.mB[e.r1] += e.r1 self.assertEqual(e.opcodes, [ Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=3), Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0), Instruction(opcode=O.XADD+O.W, dst=1, src=1, off=0, imm=0), Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=-3), - Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0)]) + Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=3), + Instruction(opcode=O.XADD+O.DW, dst=10, src=0, off=-16, imm=0), + Instruction(opcode=O.XADD+O.DW, dst=1, src=1, off=0, imm=0), + Instruction(opcode=O.LD+O.REG, dst=0, src=10, off=-18, imm=0), + Instruction(opcode=O.ADD, dst=0, src=0, off=0, imm=3), + Instruction(opcode=O.STX+O.REG, dst=10, src=0, off=-18, imm=0), + Instruction(opcode=O.B+O.LD, dst=0, src=1, off=0, imm=0), + Instruction(opcode=O.ADD+O.REG, dst=0, src=1, off=0, imm=0), + Instruction(opcode=O.STX+O.B, dst=1, src=0, off=0, imm=0)]) def test_jump(self): @@ -294,13 +311,21 @@ class Tests(TestCase): e.r6 = 7 with e.r2: e.r3 = 2 + with e.r4 > 3 as cond: + e.r5 = 7 + with cond.Else(): + e.r7 = 8 self.assertEqual(e.opcodes, [Instruction(opcode=0xb5, dst=2, src=0, off=2, imm=3), Instruction(opcode=0xb7, dst=2, src=0, off=0, imm=5), Instruction(opcode=0x5, dst=0, src=0, off=1, imm=0), Instruction(opcode=O.MOV+O.LONG, dst=6, src=0, off=0, imm=7), Instruction(opcode=O.JEQ, dst=2, src=0, off=1, imm=0), - Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=2)]) + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=2), + Instruction(opcode=O.JLE, dst=4, src=0, off=2, imm=3), + Instruction(opcode=O.MOV+O.LONG, dst=5, src=0, off=0, imm=7), + Instruction(opcode=O.JMP, dst=0, src=0, off=1, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=7, src=0, off=0, imm=8)]) def test_with_inversion(self): e = EBPF() @@ -321,6 +346,50 @@ class Tests(TestCase): Instruction(opcode=183, dst=0, src=0, off=0, imm=2), Instruction(opcode=183, dst=1, src=0, off=0, imm=4)]) + def test_with_and(self): + e = EBPF() + e.owners = set(range(11)) + with (e.r2 > 3) & (e.r3 > 2) as cond: + e.r1 = 5 + with (e.r2 > 2) & (e.r1 < 2) as cond: + e.r2 = 5 + with cond.Else(): + e.r3 = 7 + self.maxDiff = None + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.JLE, dst=2, src=0, off=2, imm=3), + Instruction(opcode=O.JLE, dst=3, src=0, off=1, imm=2), + Instruction(opcode=O.MOV+O.LONG, dst=1, src=0, off=0, imm=5), + Instruction(opcode=O.JLE, dst=2, src=0, off=3, imm=2), + Instruction(opcode=O.JGE, dst=1, src=0, off=2, imm=2), + Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=5), + Instruction(opcode=O.JMP, dst=0, src=0, off=1, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7)]) + + def test_with_or(self): + e = EBPF() + e.owners = set(range(11)) + with (e.r2 > 3) | (e.r3 > 2) as cond: + e.r1 = 5 + with (e.r2 > 2) | (e.r1 > 2) as cond: + e.r2 = 5 + e.r5 = 4 + with cond.Else(): + e.r3 = 7 + e.r4 = 3 + self.maxDiff = None + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.JGT, dst=2, src=0, off=1, imm=3), + Instruction(opcode=O.JLE, dst=3, src=0, off=1, imm=2), + Instruction(opcode=O.MOV+O.LONG, dst=1, src=0, off=0, imm=5), + Instruction(opcode=O.JGT, dst=2, src=0, off=1, imm=2), + Instruction(opcode=O.JLE, dst=1, src=0, off=3, imm=2), + Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=5), + Instruction(opcode=O.MOV+O.LONG, dst=5, src=0, off=0, imm=4), + Instruction(opcode=O.JMP, dst=0, src=0, off=2, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7), + Instruction(opcode=O.MOV+O.LONG, dst=4, src=0, off=0, imm=3)]) + def test_comp_binary(self): e = EBPF() e.owners = {1, 2, 3, 5} @@ -335,7 +404,7 @@ class Tests(TestCase): self.assertEqual(e.opcodes, [ Instruction(opcode=191, dst=0, src=1, off=0, imm=0), - Instruction(opcode=15, dst=0, src=3, off=0, imm=0), + Instruction(opcode=O.ADD+O.REG+O.LONG, dst=0, src=3, off=0, imm=0), Instruction(opcode=181, dst=0, src=0, off=2, imm=3), Instruction(opcode=183, dst=0, src=0, off=0, imm=5), Instruction(opcode=5, dst=0, src=0, off=1, imm=0), @@ -392,6 +461,34 @@ class Tests(TestCase): Instruction(opcode=191, dst=0, src=1, off=0, imm=0), Instruction(opcode=95, dst=0, src=2, off=0, imm=0)]) + def test_mixed_binary(self): + e = EBPF() + e.owners = {0, 1, 2, 3} + e.w1 = e.r2 + e.w3 + e.r1 = e.w2 + e.w3 + e.w1 = e.w2 + e.w3 + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.MOV+O.LONG+O.REG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.REG+O.ADD, dst=1, src=3, off=0, imm=0), + Instruction(opcode=O.MOV+O.REG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.REG+O.ADD, dst=1, src=3, off=0, imm=0), + Instruction(opcode=O.MOV+O.REG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.REG+O.ADD, dst=1, src=3, off=0, imm=0)]) + + def test_mixed_compare(self): + e = EBPF() + e.owners = {0, 1, 2, 3} + with e.r1 > e.sr2: + pass + with (e.r1 + e.sr2) > 3: + pass + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.JSLE+O.REG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG+O.REG, dst=4, src=1, off=0, imm=0), + Instruction(opcode=O.ADD+O.LONG+O.REG, dst=4, src=2, off=0, imm=0), + Instruction(opcode=O.JSLE, dst=4, src=0, off=0, imm=3)]) + + def test_reverse_binary(self): e = EBPF() e.owners = {0, 1, 2, 3} @@ -489,16 +586,16 @@ class Tests(TestCase): Instruction(opcode=39, dst=0, src=0, off=0, imm=2), Instruction(opcode=31, dst=3, src=0, off=0, imm=0), Instruction(opcode=191, dst=0, src=3, off=0, imm=0), - Instruction(opcode=39, dst=0, src=0, off=0, imm=2), + Instruction(opcode=O.MUL+O.LONG, dst=0, src=0, off=0, imm=2), Instruction(opcode=107, dst=10, src=0, off=-10, imm=0), Instruction(opcode=191, dst=0, src=10, off=0, imm=0), Instruction(opcode=15, dst=0, src=3, off=0, imm=0), Instruction(opcode=191, dst=2, src=3, off=0, imm=0), - Instruction(opcode=39, dst=2, src=0, off=0, imm=2), + Instruction(opcode=O.MUL+O.LONG, dst=2, src=0, off=0, imm=2), Instruction(opcode=107, dst=0, src=2, off=0, imm=0), Instruction(opcode=191, dst=5, src=10, off=0, imm=0), - Instruction(opcode=15, dst=5, src=3, off=0, imm=0), + Instruction(opcode=O.ADD+O.REG+O.LONG, dst=5, src=3, off=0, imm=0), Instruction(opcode=105, dst=5, src=5, off=0, imm=0), Instruction(opcode=191, dst=0, src=1, off=0, imm=0), diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py index d10ca31ca581a085d5a12c27dbbb8ad1d9c1f7af..bebc790f147f5b41985c64946758c79690d448f0 100644 --- a/ebpfcat/ebpfcat.py +++ b/ebpfcat/ebpfcat.py @@ -4,9 +4,8 @@ from struct import pack, unpack, calcsize, pack_into, unpack_from from time import time from .arraymap import ArrayMap, ArrayGlobalVarDesc from .ethercat import ECCmd, EtherCat, Packet, Terminal -from .ebpf import FuncId, MemoryDesc, SubProgram +from .ebpf import FuncId, MemoryDesc, SubProgram, ktime from .xdp import XDP, XDPExitCode -from .hashmap import HashMap from .bpf import ( ProgType, MapType, create_map, update_elem, prog_test_run, lookup_elem) @@ -207,7 +206,7 @@ class EBPFTerminal(Terminal): (self.vendorId, self.productCode) not in self.compatibility): raise RuntimeError("Incompatible Terminal") - def allocate(self, packet): + def allocate(self, packet, readonly): if self.pdo_in_sz: bases = [packet.size + packet.DATAGRAM_HEADER] packet.append(ECCmd.FPRD, b"\0" * self.pdo_in_sz, 0, @@ -216,44 +215,52 @@ class EBPFTerminal(Terminal): bases = [None] if self.pdo_out_sz: bases.append(packet.size + packet.DATAGRAM_HEADER) - packet.append(ECCmd.FPWR, b"\0" * self.pdo_out_sz, 0, - self.position, self.pdo_out_off) + if readonly: + packet.on_the_fly.append((packet.size, ECCmd.FPWR)) + packet.append(ECCmd.NOP, b"\0" * self.pdo_out_sz, 0, + self.position, self.pdo_out_off) + else: + packet.append(ECCmd.FPWR, b"\0" * self.pdo_out_sz, 0, + self.position, self.pdo_out_off) return bases def update(self, data): pass -class EBPFCat(XDP): - vars = HashMap() - - count = vars.globalVar() - ptype = vars.globalVar() - - def program(self): - #with self.If(self.packet16[12] != 0xA488): - # self.exit(2) - self.count += 1 - #self.ptype = self.packet32[18] - self.exit(2) - - class EtherXDP(XDP): license = "GPL" - variables = HashMap() - count = variables.globalVar() - allcount = variables.globalVar() + variables = ArrayMap() + counters = variables.globalVar("64I") + + rate = 0 def program(self): - e = self - with e.packetSize > 24 as p, p.pH[12] == 0xA488, p.pB[16] == 0: - e.count += 1 - e.r2 = e.get_fd(self.programs) - e.r3 = p.pI[18] - e.call(FuncId.tail_call) - e.allcount += 1 - e.exit(XDPExitCode.PASS) + with self.tmp: + self.ebpf.tmp = ktime(self.ebpf) + self.ebpf.tmp = self.ebpf.tmp * 0xcf019d85 + 1 + with self.ebpf.tmp & 0xffff < self.rate: + self.ebpf.exit(XDPExitCode.DROP) + with self.packetSize > 24 as p, p.pH[12] == 0xA488, p.pB[16] == 0: + self.r3 = p.pI[18] + with self.counters.get_address(None, False, False) as (dst, _), \ + self.r3 < FastEtherCat.MAX_PROGS: + self.mH[self.r[dst] + 4 * self.r3] += 1 + p.pB[17] += 2 + with p.pB[17] & 1 as is_regular: + self.mB[self.r[dst] + 4 * self.r3 + 3] += 1 + self.mB[self.r[dst] + 4 * self.r3 + 2] = 0 + with is_regular.Else(): + self.mB[self.r[dst] + 4 * self.r3 + 2] += 1 + self.mB[self.r[dst] + 4 * self.r3 + 3] = 0 + with self.mB[self.r[dst] + 4 * self.r3 + 2] > 3 as exceed: + p.pB[17] += 1 # turn into regular package + with exceed.Else(): + self.exit(XDPExitCode.TX) + self.r2 = self.get_fd(self.programs) + self.call(FuncId.tail_call) + self.exit(XDPExitCode.PASS) class SimpleEtherCat(EtherCat): @@ -275,20 +282,36 @@ class FastEtherCat(SimpleEtherCat): self.programs = create_map(MapType.PROG_ARRAY, 4, 4, self.MAX_PROGS) self.sync_groups = {} - def register_sync_group(self, sg): + def register_sync_group(self, sg, packet): index = len(self.sync_groups) while index in self.sync_groups: index = (index + 1) % self.MAX_PROGS fd, _ = sg.load(log_level=1) update_elem(self.programs, pack("<I", index), pack("<I", fd), 0) self.sync_groups[index] = sg + sg.assembled = packet.assemble(index) return index + async def watchdog(self): + lastcounts = [0] * 64 + while True: + t0 = time() + self.ebpf.variables.read() + counts = self.ebpf.counters + for i, sg in self.sync_groups.items(): + if ((counts[i] ^ lastcounts[i]) & 0xffff == 0 + or (counts[i] >> 24) > 3): + self.send_packet(sg.assembled) + lastcounts[i] = counts[i] + await sleep(0.001) + async def connect(self): await super().connect() self.ebpf = EtherXDP() self.ebpf.programs = self.programs self.fd = await self.ebpf.attach(self.addr[0]) + ensure_future(self.watchdog()) + class SyncGroupBase: def __init__(self, ec, devices, **kwargs): @@ -304,10 +327,6 @@ class SyncGroupBase: self.terminals = {t: None for t in sorted(terminals, key=lambda t: t.position)} - def allocate(self): - self.packet = Packet() - self.terminals = {t: t.allocate(self.packet) for t in self.terminals} - class SyncGroup(SyncGroupBase): """A group of devices communicating at the same time""" @@ -334,6 +353,11 @@ class SyncGroup(SyncGroupBase): self.asm_packet = self.packet.assemble(self.packet_index) return ensure_future(self.run()) + def allocate(self): + self.packet = Packet() + self.terminals = {t: t.allocate(self.packet, False) + for t in self.terminals} + class FastSyncGroup(SyncGroupBase, XDP): license = "GPL" @@ -347,14 +371,21 @@ class FastSyncGroup(SyncGroupBase, XDP): def program(self): with self.packetSize >= self.packet.size + Packet.ETHERNET_HEADER as p: + for pos, cmd in self.packet.on_the_fly: + p.pB[pos + Packet.ETHERNET_HEADER] = cmd.value for dev in self.devices: dev.program() self.exit(XDPExitCode.TX) def start(self): self.allocate() - index = self.ec.register_sync_group(self) - self.ec.send_packet(self.packet.assemble(index)) + self.ec.register_sync_group(self, self.packet) self.monitor = ensure_future(gather(*[t.to_operational() for t in self.terminals])) return self.monitor + + def allocate(self): + self.packet = Packet() + self.packet.on_the_fly = [] + self.terminals = {t: t.allocate(self.packet, True) + for t in self.terminals} diff --git a/ebpfcat/ethercat.py b/ebpfcat/ethercat.py index 194da3b4c4a0ef5199d57bc1dddacdc16252e9fd..c9d2956c397c6a701d083eda10d67800eb4da8cb 100644 --- a/ebpfcat/ethercat.py +++ b/ebpfcat/ethercat.py @@ -483,100 +483,100 @@ class Terminal: async def mbx_send(self, type, *args, data=None, address=0, priority=0, channel=0): """send data to the mailbox""" - async with self.mbx_lock: - status, = await self.read(0x805, "B") # always using mailbox 0, OK? - if status & 8: - raise RuntimeError("mailbox full, read first") - await gather(self.write(self.mbx_out_off, "HHBB", - datasize(args, data), - address, channel | priority << 6, - type.value | self.mbx_cnt << 4, - *args, data=data), - self.write(self.mbx_out_off + self.mbx_out_sz - 1, - data=1) - ) - self.mbx_cnt = self.mbx_cnt % 7 + 1 # yes, we start at 1 not 0 + status, = await self.read(0x805, "B") # always using mailbox 0, OK? + if status & 8: + raise RuntimeError("mailbox full, read first") + await self.write(self.mbx_out_off, "HHBB", + datasize(args, data), + address, channel | priority << 6, + type.value | self.mbx_cnt << 4, + *args, data=data) + await self.write(self.mbx_out_off + self.mbx_out_sz - 1, + data=1) + self.mbx_cnt = self.mbx_cnt % 7 + 1 # yes, we start at 1 not 0 async def mbx_recv(self): """receive data from the mailbox""" status = 0 - async with self.mbx_lock: - while status & 8 == 0: - # always using mailbox 1, OK? - status, = await self.read(0x80D, "B") - dlen, address, prio, type, data = await self.read( - self.mbx_in_off, "HHBB", data=self.mbx_in_sz - 6) + while status & 8 == 0: + # always using mailbox 1, OK? + status, = await self.read(0x80D, "B") + dlen, address, prio, type, data = await self.read( + self.mbx_in_off, "HHBB", data=self.mbx_in_sz - 6) return MBXType(type & 0xf), data[:dlen] async def coe_request(self, coecmd, odcmd, *args, **kwargs): - await self.mbx_send(MBXType.COE, "HBxH", coecmd.value << 12, - odcmd.value, 0, *args, **kwargs) - fragments = True - ret = [] - offset = 8 # skip header in first packet - - while fragments: - type, data = await self.mbx_recv() - if type is not MBXType.COE: - raise RuntimeError(f"expected CoE package, got {type}") - coecmd, rodcmd, fragments = unpack("<HBxH", data[:6]) - if rodcmd & 0x7f != odcmd.value + 1: - raise RuntimeError(f"expected {odcmd.value}, got {odcmd}") - ret.append(data[offset:]) - offset = 6 - return b"".join(ret) + async with self.mbx_lock: + await self.mbx_send(MBXType.COE, "HBxH", coecmd.value << 12, + odcmd.value, 0, *args, **kwargs) + fragments = True + ret = [] + offset = 8 # skip header in first packet + + while fragments: + type, data = await self.mbx_recv() + if type is not MBXType.COE: + raise RuntimeError(f"expected CoE package, got {type}") + coecmd, rodcmd, fragments = unpack("<HBxH", data[:6]) + if rodcmd & 0x7f != odcmd.value + 1: + raise RuntimeError(f"expected {odcmd.value}, got {odcmd}") + ret.append(data[offset:]) + offset = 6 + return b"".join(ret) async def sdo_read(self, index, subindex=None): - await self.mbx_send( - MBXType.COE, "HBHB4x", CoECmd.SDOREQ.value << 12, - ODCmd.UP_REQ_CA.value if subindex is None - else ODCmd.UP_REQ.value, - index, 1 if subindex is None else subindex) - type, data = await self.mbx_recv() - if type is not MBXType.COE: - raise RuntimeError(f"expected CoE, got {type}") - coecmd, sdocmd, idx, subidx, size = unpack("<HBHBI", data[:10]) - if coecmd >> 12 != CoECmd.SDORES.value: - raise RuntimeError(f"expected CoE SDORES (3), got {coecmd>>12:x}") - if idx != index: - raise RuntimeError(f"requested index {index}, got {idx}") - if sdocmd & 2: - return data[6 : 10 - ((sdocmd>>2) & 3)] - ret = [data[10:]] - retsize = len(ret[0]) - - toggle = 0 - while retsize < size: + async with self.mbx_lock: await self.mbx_send( MBXType.COE, "HBHB4x", CoECmd.SDOREQ.value << 12, - ODCmd.SEG_UP_REQ.value + toggle, index, - 1 if subindex is None else subindex) + ODCmd.UP_REQ_CA.value if subindex is None + else ODCmd.UP_REQ.value, + index, 1 if subindex is None else subindex) type, data = await self.mbx_recv() if type is not MBXType.COE: raise RuntimeError(f"expected CoE, got {type}") - coecmd, sdocmd = unpack("<HB", data[:3]) + coecmd, sdocmd, idx, subidx, size = unpack("<HBHBI", data[:10]) if coecmd >> 12 != CoECmd.SDORES.value: - raise RuntimeError(f"expected CoE cmd SDORES, got {coecmd}") - if sdocmd & 0xe0 != 0: + raise RuntimeError(f"expected CoE SDORES (3), got {coecmd>>12:x}") + if idx != index: raise RuntimeError(f"requested index {index}, got {idx}") - if sdocmd & 1 and len(data) == 7: - data = data[:3 + (sdocmd >> 1) & 7] - ret += data[3:] - retsize += len(data) - 3 - if sdocmd & 1: - break - toggle ^= 0x10 - if retsize != size: - raise RuntimeError(f"expected {size} bytes, got {retsize}") - return b"".join(ret) + if sdocmd & 2: + return data[6 : 10 - ((sdocmd>>2) & 3)] + ret = [data[10:]] + retsize = len(ret[0]) + + toggle = 0 + while retsize < size: + await self.mbx_send( + MBXType.COE, "HBHB4x", CoECmd.SDOREQ.value << 12, + ODCmd.SEG_UP_REQ.value + toggle, index, + 1 if subindex is None else subindex) + type, data = await self.mbx_recv() + if type is not MBXType.COE: + raise RuntimeError(f"expected CoE, got {type}") + coecmd, sdocmd = unpack("<HB", data[:3]) + if coecmd >> 12 != CoECmd.SDORES.value: + raise RuntimeError(f"expected CoE cmd SDORES, got {coecmd}") + if sdocmd & 0xe0 != 0: + raise RuntimeError(f"requested index {index}, got {idx}") + if sdocmd & 1 and len(data) == 7: + data = data[:3 + (sdocmd >> 1) & 7] + ret += data[3:] + retsize += len(data) - 3 + if sdocmd & 1: + break + toggle ^= 0x10 + if retsize != size: + raise RuntimeError(f"expected {size} bytes, got {retsize}") + return b"".join(ret) async def sdo_write(self, data, index, subindex=None): if len(data) <= 4 and subindex is not None: - await self.mbx_send( - MBXType.COE, "HBHB4s", CoECmd.SDOREQ.value << 12, - ODCmd.DOWN_EXP.value | (((4 - len(data)) << 2) & 0xc), - index, subindex, data) - type, data = await self.mbx_recv() + async with self.mbx_lock: + await self.mbx_send( + MBXType.COE, "HBHB4s", CoECmd.SDOREQ.value << 12, + ODCmd.DOWN_EXP.value | (((4 - len(data)) << 2) & 0xc), + index, subindex, data) + type, data = await self.mbx_recv() if type is not MBXType.COE: raise RuntimeError(f"expected CoE, got {type}") coecmd, sdocmd, idx, subidx = unpack("<HBHB", data[:6]) @@ -585,45 +585,46 @@ class Terminal: if coecmd >> 12 != CoECmd.SDORES.value: raise RuntimeError(f"expected CoE SDORES, got {coecmd>>12:x}") else: - stop = min(len(data), self.mbx_out_sz - 16) - await self.mbx_send( - MBXType.COE, "HBHB4x", CoECmd.SDOREQ.value << 12, - ODCmd.DOWN_INIT_CA.value if subindex is None - else ODCmd.DOWN_INIT.value, - index, 1 if subindex is None else subindex, - data=data[:stop]) - type, data = await self.mbx_recv() - if type is not MBXType.COE: - raise RuntimeError(f"expected CoE, got {type}") - coecmd, sdocmd, idx, subidx = unpack("<HBHB", data[:6]) - if coecmd >> 12 != CoECmd.SDORES.value: - raise RuntimeError(f"expected CoE SDORES, got {coecmd>>12:x}") - if idx != index or subindex != subidx: - raise RuntimeError(f"requested index {index}, got {idx}") - toggle = 0 - while stop < len(data): - start = stop - stop = min(len(data), start + self.mbx_out_sz - 9) - if stop == len(data): - if stop - start < 7: - cmd = 1 + (7-stop+start << 1) - d = data[start:stop] + b"\0" * (7 - stop + start) - else: - cmd = 1 - d = data[start:stop] - await self.mbx_send( - MBXType.COE, "HBHB4x", CoECmd.SDOREQ.value << 12, - cmd + toggle, index, - 1 if subindex is None else subindex, data=d) - type, data = await self.mbx_recv() - if type is not MBXType.COE: - raise RuntimeError(f"expected CoE, got {type}") - coecmd, sdocmd, idx, subidx = unpack("<HBHB", data[:6]) - if coecmd >> 12 != CoECmd.SDORES.value: - raise RuntimeError(f"expected CoE SDORES") - if idx != index or subindex != subidx: - raise RuntimeError(f"requested index {index}") - toggle ^= 0x10 + async with self.mbx_lock: + stop = min(len(data), self.mbx_out_sz - 16) + await self.mbx_send( + MBXType.COE, "HBHB4x", CoECmd.SDOREQ.value << 12, + ODCmd.DOWN_INIT_CA.value if subindex is None + else ODCmd.DOWN_INIT.value, + index, 1 if subindex is None else subindex, + data=data[:stop]) + type, data = await self.mbx_recv() + if type is not MBXType.COE: + raise RuntimeError(f"expected CoE, got {type}") + coecmd, sdocmd, idx, subidx = unpack("<HBHB", data[:6]) + if coecmd >> 12 != CoECmd.SDORES.value: + raise RuntimeError(f"expected CoE SDORES, got {coecmd>>12:x}") + if idx != index or subindex != subidx: + raise RuntimeError(f"requested index {index}, got {idx}") + toggle = 0 + while stop < len(data): + start = stop + stop = min(len(data), start + self.mbx_out_sz - 9) + if stop == len(data): + if stop - start < 7: + cmd = 1 + (7-stop+start << 1) + d = data[start:stop] + b"\0" * (7 - stop + start) + else: + cmd = 1 + d = data[start:stop] + await self.mbx_send( + MBXType.COE, "HBHB4x", CoECmd.SDOREQ.value << 12, + cmd + toggle, index, + 1 if subindex is None else subindex, data=d) + type, data = await self.mbx_recv() + if type is not MBXType.COE: + raise RuntimeError(f"expected CoE, got {type}") + coecmd, sdocmd, idx, subidx = unpack("<HBHB", data[:6]) + if coecmd >> 12 != CoECmd.SDORES.value: + raise RuntimeError(f"expected CoE SDORES") + if idx != index or subindex != subidx: + raise RuntimeError(f"requested index {index}") + toggle ^= 0x10 async def read_ODlist(self): idxes = await self.coe_request(CoECmd.SDOINFO, ODCmd.LIST_REQ, "H", 1) diff --git a/ebpfcat/ethercat_test.py b/ebpfcat/ethercat_test.py index c5e885bb72f74e823defbb05ec4379793997e971..43b2e802715e2ae128c09daf09a2c5cc408eaecc 100644 --- a/ebpfcat/ethercat_test.py +++ b/ebpfcat/ethercat_test.py @@ -4,7 +4,8 @@ from unittest import TestCase, main from .devices import AnalogInput, AnalogOutput from .terminals import EL4104, EL3164, EK1814 from .ethercat import ECCmd -from .ebpfcat import FastSyncGroup, SyncGroup, TerminalVar, Device +from .ebpfcat import ( + FastSyncGroup, SyncGroup, TerminalVar, Device, EBPFTerminal, PacketDesc) from .ebpf import Instruction, Opcode as O @@ -27,7 +28,7 @@ class MockEtherCat: await sleep(0) return self.results.pop(0) - def register_sync_group(self, sg): + def register_sync_group(self, sg, packet): self.rsg = sg return 0x33 @@ -187,7 +188,9 @@ class Tests(TestCase): Instruction(opcode=O.W+O.LD, dst=0, src=1, off=4, imm=0), Instruction(opcode=O.W+O.LD, dst=2, src=1, off=0, imm=0), Instruction(opcode=O.ADD+O.LONG, dst=2, src=0, off=0, imm=83), - Instruction(opcode=O.JLE+O.REG, dst=0, src=2, off=21, imm=0), + Instruction(opcode=O.JLE+O.REG, dst=0, src=2, off=23, imm=0), + Instruction(opcode=O.ST+O.B, dst=9, src=0, off=41, imm=5), + Instruction(opcode=O.ST+O.B, dst=9, src=0, off=54, imm=5), Instruction(opcode=O.B+O.LD, dst=0, src=9, off=51, imm=0), Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-2), diff --git a/ebpfcat/hashmap.py b/ebpfcat/hashmap.py index 3221b6c8b9705f609f9a5c4b5ff17eee03262a24..963afdc149eb265a09f46fc0c81142fdc8a0220f 100644 --- a/ebpfcat/hashmap.py +++ b/ebpfcat/hashmap.py @@ -29,7 +29,7 @@ class HashGlobalVar(Expression): 0, 0, 0) else: dst = 0 - yield dst, Memory.fmt_to_opcode[self.fmt] + yield dst, self.fmt class HashGlobalVarDesc: diff --git a/ebpfcat/terminals.py b/ebpfcat/terminals.py index 48d83ae7a7d756d7de266e517387de6411734afc..c46f4d4e1d6e6ad7f7fa63166e8b293fe2c5323e 100644 --- a/ebpfcat/terminals.py +++ b/ebpfcat/terminals.py @@ -10,6 +10,32 @@ class Skip(EBPFTerminal): pass +class EL1808(EBPFTerminal): + compatibility = {(2, 118501458)} + + ch1 = PacketDesc((0, 0), 0) + ch2 = PacketDesc((0, 0), 1) + ch3 = PacketDesc((0, 0), 2) + ch4 = PacketDesc((0, 0), 3) + ch5 = PacketDesc((0, 0), 4) + ch6 = PacketDesc((0, 0), 5) + ch7 = PacketDesc((0, 0), 6) + ch8 = PacketDesc((0, 0), 7) + + +class EL2808(EBPFTerminal): + compatibility = {(2, 184037458)} + + ch1 = PacketDesc((1, 0), 0) + ch2 = PacketDesc((1, 0), 1) + ch3 = PacketDesc((1, 0), 2) + ch4 = PacketDesc((1, 0), 3) + ch5 = PacketDesc((1, 0), 4) + ch6 = PacketDesc((1, 0), 5) + ch7 = PacketDesc((1, 0), 6) + ch8 = PacketDesc((1, 0), 7) + + class EL4104(EBPFTerminal): ch1_value = PacketDesc((1, 0), 'H') ch2_value = PacketDesc((1, 2), 'H') @@ -40,14 +66,15 @@ class EK1814(EBPFTerminal): class EL5042(EBPFTerminal): + compatibility = {(2, 330444882)} class Channel(Struct): - position = PacketDesc((0, 2), "Q") + position = PacketDesc((0, 2), "q") warning = PacketDesc((0, 0), 0) error = PacketDesc((0, 0), 1) status = PacketDesc((0, 0), "H") - channel1 = Channel(0) - channel2 = Channel(10) + channel1 = Channel(0, None, 0) + channel2 = Channel(10, None, 0x10) class EL6022(EBPFTerminal): @@ -57,20 +84,19 @@ class EL6022(EBPFTerminal): init_accept = PacketDesc((0, 0), 2) status = PacketDesc((0, 0), "H") in_string = PacketDesc((0, 1), "23p") - wkc1 = PacketDesc((0, 24), "H") transmit_request = PacketDesc((1, 0), 0) receive_accept = PacketDesc((1, 0), 1) init_request = PacketDesc((1, 0), 2) control = PacketDesc((1, 0), "H") out_string = PacketDesc((1, 1), "23p") - wkc2 = PacketDesc((0, 24), "H") channel1 = Channel(0, 0) channel2 = Channel(24, 24) class EL7041(EBPFTerminal): + compatibility = {(2, 461451346)} velocity = PacketDesc((1, 6), "h") enable = PacketDesc((1, 4), 0) status = PacketDesc((0, 6), "H") diff --git a/ebpfcat/xdp.py b/ebpfcat/xdp.py index 73c745e215a74a3af43ba16d305d8bb71a8a97ea..522b376330bd237ef9ce4457b8944dfe6d798b1a 100644 --- a/ebpfcat/xdp.py +++ b/ebpfcat/xdp.py @@ -5,7 +5,7 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex import socket from struct import pack, unpack -from .ebpf import EBPF, Expression, Memory, Opcode, Comparison +from .ebpf import EBPF from .bpf import ProgType @@ -75,11 +75,11 @@ class PacketArray: self.no = no self.memory = memory - def __getitem__(self, value): - return self.memory[self.ebpf.r[self.no] + value] + def __getitem__(self, pos): + return self.memory[self.ebpf.r[self.no] + pos] - def __setitem__(self, value): - self.memory[self.ebpf.r[self.no]] = value + def __setitem__(self, pos, value): + self.memory[self.ebpf.r[self.no] + pos] = value class Packet: @@ -104,15 +104,15 @@ class PacketSize: @contextmanager def __lt__(self, value): e = self.ebpf - e.r9 = e.mI[e.r1] - with e.mI[e.r1 + 4] < e.mI[e.r1] + value as comp: + e.r9 = e.mA[e.r1] + with e.mA[e.r1 + 4] < e.mA[e.r1] + value as comp: yield Packet(e, comp, 9) @contextmanager def __gt__(self, value): e = self.ebpf - e.r9 = e.mI[e.r1] - with e.mI[e.r1 + 4] > e.mI[e.r1] + value as comp: + e.r9 = e.mA[e.r1] + with e.mA[e.r1 + 4] > e.mA[e.r1] + value as comp: yield Packet(e, comp, 9) def __le__(self, value):