Skip to content
Snippets Groups Projects
Commit 08c86558 authored by Martin Teichmann's avatar Martin Teichmann
Browse files

move bit access to ebpf proper

this is not actually ethercat specific
parent e17ec174
No related branches found
No related tags found
No related merge requests found
...@@ -266,10 +266,7 @@ class AssembleError(Exception): ...@@ -266,10 +266,7 @@ class AssembleError(Exception):
pass pass
def comparison(uposop, unegop, sposop=None, snegop=None): def comparison(uposop, unegop, sposop, snegop):
if sposop is None:
sposop = uposop
snegop = unegop
def ret(self, value): def ret(self, value):
return SimpleComparison(self.ebpf, self, value, return SimpleComparison(self.ebpf, self, value,
(uposop, unegop, sposop, snegop)) (uposop, unegop, sposop, snegop))
...@@ -385,7 +382,6 @@ class AndOrComparison(Comparison): ...@@ -385,7 +382,6 @@ class AndOrComparison(Comparison):
self.left = left self.left = left
self.right = right self.right = right
self.is_and = is_and self.is_and = is_and
self.targetted = False
def compare(self, negative): def compare(self, negative):
self.negative = negative self.negative = negative
...@@ -404,11 +400,15 @@ class AndOrComparison(Comparison): ...@@ -404,11 +400,15 @@ class AndOrComparison(Comparison):
class InvertComparison(Comparison): class InvertComparison(Comparison):
def __init__(self, ebpf, value): def __init__(self, ebpf, value):
self.ebpf = ebpf super().__init__(ebpf)
self.value = value self.value = value
def compare(self, negative): def compare(self, negative):
self.value.compare(not negative) self.value.compare(not negative)
self.owners = self.value.owners
def target(self, retarget=False):
self.value.target(retarget)
def binary(opcode): def binary(opcode):
...@@ -439,16 +439,22 @@ class Expression: ...@@ -439,16 +439,22 @@ class Expression:
__rmod__ = rbinary(Opcode.MOD) __rmod__ = rbinary(Opcode.MOD)
__rxor__ = __xor__ = binary(Opcode.XOR) __rxor__ = __xor__ = binary(Opcode.XOR)
__eq__ = comparison(Opcode.JEQ, Opcode.JNE)
__gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE) __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE)
__ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT) __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT)
__lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE) __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE)
__le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT) __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT)
__ne__ = comparison(Opcode.JNE, Opcode.JEQ)
def __and__(self, value): def __and__(self, value):
return AndExpression(self.ebpf, self, value) return AndExpression(self.ebpf, self, value)
def __ne__(self, value):
return SimpleComparison(
self.ebpf, self, value,
(Opcode.JNE, Opcode.JEQ, Opcode.JNE, Opcode.JEQ))
def __eq__(self, value):
return ~(self != value)
__rand__ = __and__ __rand__ = __and__
def __neg__(self): def __neg__(self):
...@@ -632,8 +638,20 @@ class Sum(Binary): ...@@ -632,8 +638,20 @@ class Sum(Binary):
return super().__sub__(value) return super().__sub__(value)
class AndExpression(SimpleComparison, Binary): class AndExpression(Binary):
"""The & operator may also be used as a comparison""" # there is a special comparison with & instruction
def __init__(self, ebpf, left, right):
super().__init__(ebpf, left, right, Opcode.AND)
def __ne__(self, value):
if isinstance(value, int) and value == 0:
return AndComparison(self.ebpf, self.left, self.right)
return super().__ne__(value)
class AndComparison(SimpleComparison):
# there is a special comparison with & instruction
# it is the only one which has not inversion
def __init__(self, ebpf, left, right): def __init__(self, ebpf, left, right):
Binary.__init__(self, ebpf, left, right, Opcode.AND) Binary.__init__(self, ebpf, left, right, Opcode.AND)
SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET) SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET)
...@@ -743,13 +761,21 @@ class Memory(Expression): ...@@ -743,13 +761,21 @@ class Memory(Expression):
@contextmanager @contextmanager
def calculate(self, dst, long, signed, force=False): def calculate(self, dst, long, signed, force=False):
if isinstance(self.address, Sum): with ExitStack() as exitStack:
with self.ebpf.get_free_register(dst) as dst: if isinstance(self.address, Sum):
self.ebpf.append(Opcode.LD + self.fmt_to_opcode[self.fmt], dst, dst = exitStack.enter_context(self.ebpf.get_free_register(dst))
self.address.left.no, self.address.right, 0) self.ebpf.append(
yield dst, self.fmt in "QqA", self.fmt.islower() Opcode.LD + self.fmt_to_opcode.get(self.fmt, Opcode.B),
else: dst, self.address.left.no, self.address.right, 0)
with super().calculate(dst, long, signed, force) as (dst, _, _): else:
dst, _, _ = exitStack.enter_context(
super().calculate(dst, long, signed, force))
if isinstance(self.fmt, tuple):
self.ebpf.r[dst] &= ((1 << self.fmt[1]) - 1) << self.fmt[0]
if self.fmt[0] > 0:
self.ebpf.r[dst] >>= self.fmt[0]
yield dst, "B", False
else:
yield dst, self.fmt in "QqA", self.fmt.islower() yield dst, self.fmt in "QqA", self.fmt.islower()
@contextmanager @contextmanager
...@@ -764,6 +790,18 @@ class Memory(Expression): ...@@ -764,6 +790,18 @@ class Memory(Expression):
def signed(self): def signed(self):
return isinstance(self.fmt, str) and self.fmt.islower() return isinstance(self.fmt, str) and self.fmt.islower()
def __invert__(self):
if not isinstance(self.fmt, tuple) or self.fmt[1] != 1:
return NotImplemented
return self == 0
def __ne__(self, value):
if isinstance(self.fmt, tuple) and isinstance(value, int) \
and value == 0:
mask = ((1 << self.fmt[1]) - 1) << self.fmt[0]
return Memory(self.ebpf, "B", self.address) & mask != 0
return super().__ne__(value)
class MemoryDesc: class MemoryDesc:
"""A base class used by descriptors for memory """A base class used by descriptors for memory
...@@ -782,8 +820,26 @@ class MemoryDesc: ...@@ -782,8 +820,26 @@ class MemoryDesc:
def __set__(self, instance, value): def __set__(self, instance, value):
ebpf = instance.ebpf ebpf = instance.ebpf
fmt, addr = self.fmt_addr(instance) fmt, addr = self.fmt_addr(instance)
bits = Memory.fmt_to_opcode[fmt] bits = Memory.fmt_to_opcode.get(fmt, Opcode.B)
if isinstance(value, int): if isinstance(fmt, tuple):
before = Memory(ebpf, "B", ebpf.r[self.base_register] + addr)
if fmt[1] == 1:
try:
if value:
value = before | (1 << fmt[0])
else:
value = before & ~(1 << fmt[0])
except AssembleError:
with ebpf.wtmp:
with value as cond:
ebpf.wtmp = before | (1 << fmt[0])
with cond.Else():
ebpf.wtmp = before & ~(1 << fmt[0])
else:
mask = ((1 << fmt[1]) - 1) << fmt[0]
value = (mask & (value << self.fmt[0]) | ~mask & before)
opcode = Opcode.STX
elif isinstance(value, int):
ebpf.append(Opcode.ST + bits, self.base_register, 0, ebpf.append(Opcode.ST + bits, self.base_register, 0,
addr, value) addr, value)
return return
...@@ -798,7 +854,9 @@ class MemoryDesc: ...@@ -798,7 +854,9 @@ class MemoryDesc:
opcode = Opcode.XADD opcode = Opcode.XADD
else: else:
opcode = Opcode.STX opcode = Opcode.STX
with value.calculate(None, fmt in 'qQ', fmt.islower()) as (src, _, _): with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ',
isinstance(fmt, str) and fmt.islower()
) as (src, _, _):
ebpf.append(opcode + bits, self.base_register, src, addr, 0) ebpf.append(opcode + bits, self.base_register, src, addr, 0)
...@@ -810,7 +868,10 @@ class LocalVar(MemoryDesc): ...@@ -810,7 +868,10 @@ class LocalVar(MemoryDesc):
self.fmt = fmt self.fmt = fmt
def __set_name__(self, owner, name): def __set_name__(self, owner, name):
size = calcsize(self.fmt) if isinstance(self.fmt, str):
size = calcsize(self.fmt)
else: # this is to support bit addressing, mostly for testing
size = 1
owner.stack -= size owner.stack -= size
owner.stack &= -size owner.stack &= -size
self.relative_addr = owner.stack self.relative_addr = owner.stack
...@@ -1064,6 +1125,8 @@ class EBPF: ...@@ -1064,6 +1125,8 @@ class EBPF:
def jumpIf(self, comp): def jumpIf(self, comp):
"""jump if `comp` is true to a later defined `target`""" """jump if `comp` is true to a later defined `target`"""
if isinstance(comp, Expression):
comp = comp != 0
comp.compare(False) comp.compare(False)
return comp return comp
......
...@@ -174,6 +174,53 @@ class Tests(TestCase): ...@@ -174,6 +174,53 @@ class Tests(TestCase):
Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0), Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0),
Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)]) Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)])
def test_local_bits(self):
class Local(EBPF):
a = LocalVar((5, 1))
b = LocalVar((3, 4))
e = Local(ProgType.XDP, "GPL")
with e.a:
e.a = 1
e.b = e.a
with ~e.a:
e.b = 3
with e.b:
e.a = 0
self.assertEqual(e.opcodes, [
Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=32),
Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0),
Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=32),
Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0),
Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
Instruction(opcode=O.AND+O.LONG, dst=0, src=0, off=0, imm=32),
Instruction(opcode=O.RSH+O.LONG, dst=0, src=0, off=0, imm=5),
Instruction(opcode=O.LSH, dst=0, src=0, off=0, imm=3),
Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=120),
Instruction(opcode=O.LD+O.B, dst=2, src=10, off=-2, imm=0),
Instruction(opcode=O.AND, dst=2, src=0, off=0, imm=-121),
Instruction(opcode=O.REG+O.OR, dst=0, src=2, off=0, imm=0),
Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-2, imm=0),
Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
Instruction(opcode=O.JSET, dst=0, src=0, off=4, imm=32),
Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-2, imm=0),
Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-121),
Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=24),
Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-2, imm=0),
Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-2, imm=0),
Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=120),
Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0),
Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33),
Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0)])
def test_local_subprog(self): def test_local_subprog(self):
class Local(EBPF): class Local(EBPF):
a = LocalVar('I') a = LocalVar('I')
......
...@@ -93,25 +93,6 @@ class PacketVar(MemoryDesc): ...@@ -93,25 +93,6 @@ class PacketVar(MemoryDesc):
def set(self, device, value): def set(self, device, value):
if device.sync_group.current_data is None: if device.sync_group.current_data is None:
if isinstance(self.size, int):
try:
bool(value)
except RuntimeError:
e = device.sync_group
with e.wtmp:
e.wtmp = super().__get__(device, None)
with value as cond:
e.wtmp |= 1 << self.size
with cond.Else():
e.wtmp &= ~(1 << self.size)
super().__set__(device, e.wtmp)
return
else:
old = super().__get__(device, None)
if value:
value = old | (1 << self.size)
else:
value = old & ~(1 << self.size)
super().__set__(device, value) super().__set__(device, value)
else: else:
data = device.sync_group.current_data data = device.sync_group.current_data
...@@ -126,10 +107,7 @@ class PacketVar(MemoryDesc): ...@@ -126,10 +107,7 @@ class PacketVar(MemoryDesc):
def get(self, device): def get(self, device):
if device.sync_group.current_data is None: if device.sync_group.current_data is None:
if isinstance(self.size, int): return super().__get__(device, None)
return super().__get__(device, None) & (1 << self.size)
else:
return super().__get__(device, None)
else: else:
data = device.sync_group.current_data data = device.sync_group.current_data
start = self._start(device) start = self._start(device)
...@@ -143,7 +121,7 @@ class PacketVar(MemoryDesc): ...@@ -143,7 +121,7 @@ class PacketVar(MemoryDesc):
+ self.position + self.position
def fmt_addr(self, device): def fmt_addr(self, device):
return ("B" if isinstance(self.size, int) else self.size, return ((self.size, 1) if isinstance(self.size, int) else self.size,
self._start(device) + Packet.ETHERNET_HEADER) self._start(device) + Packet.ETHERNET_HEADER)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment