Source code for pymavryk.michelson.instructions.arithmetic

from typing import Callable
from typing import List
from typing import Tuple
from typing import Type
from typing import Union
from typing import cast

from py_ecc import optimized_bls12_381 as bls12_381

from pymavryk.context.abstract import AbstractContext
from pymavryk.michelson.instructions.base import MichelsonInstruction
from pymavryk.michelson.instructions.base import dispatch_types
from pymavryk.michelson.instructions.base import format_stdout
from pymavryk.michelson.stack import MichelsonStack
from pymavryk.michelson.types import BLS12_381_FrType
from pymavryk.michelson.types import BLS12_381_G1Type
from pymavryk.michelson.types import BLS12_381_G2Type
from pymavryk.michelson.types import BytesType
from pymavryk.michelson.types import IntType
from pymavryk.michelson.types import MumavType
from pymavryk.michelson.types import NatType
from pymavryk.michelson.types import OptionType
from pymavryk.michelson.types import PairType
from pymavryk.michelson.types import TimestampType
from pymavryk.michelson.types.base import MichelsonType


[docs]class AbsInstruction(MichelsonInstruction, prim='ABS'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a = cast(IntType, stack.pop1()) a.assert_type_equal(IntType) res = NatType.from_value(abs(int(a))) stack.push(res) stdout.append(format_stdout(cls.prim, [a], [res])) # type: ignore return cls(stack_items_added=1)
[docs]class AddInstruction(MichelsonInstruction, prim='ADD'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a, b = cast( Tuple[ Union[IntType, NatType, MumavType, TimestampType, BLS12_381_G1Type, BLS12_381_G2Type, BLS12_381_FrType], ..., ], stack.pop2(), ) (res_type,) = dispatch_types( type(a), type(b), mapping={ (NatType, NatType): (NatType,), (NatType, IntType): (IntType,), (IntType, NatType): (IntType,), (IntType, IntType): (IntType,), (TimestampType, IntType): (TimestampType,), (IntType, TimestampType): (TimestampType,), (MumavType, MumavType): (MumavType,), (BLS12_381_FrType, BLS12_381_FrType): (BLS12_381_FrType,), (BLS12_381_G1Type, BLS12_381_G1Type): (BLS12_381_G1Type,), (BLS12_381_G2Type, BLS12_381_G2Type): (BLS12_381_G2Type,), }, ) res_type = cast( Union[ Type[IntType], Type[NatType], Type[TimestampType], Type[MumavType], Type[BLS12_381_G1Type], Type[BLS12_381_G2Type], Type[BLS12_381_FrType], ], res_type, ) if issubclass(res_type, IntType): res = res_type.from_value(int(a) + int(b)) # type: ignore else: res = res_type.from_point(bls12_381.add(a.to_point(), b.to_point())) # type: ignore stack.push(res) stdout.append(format_stdout(cls.prim, [a, b], [res])) # type: ignore return cls(stack_items_added=1)
[docs]class EdivInstruction(MichelsonInstruction, prim='EDIV'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a, b = cast( Tuple[Union[IntType, NatType, MumavType, TimestampType], Union[IntType, NatType, MumavType, TimestampType]], stack.pop2(), ) q_type, r_type = dispatch_types( type(a), type(b), mapping={ # type: ignore (NatType, NatType): (NatType, NatType), (NatType, IntType): (IntType, NatType), (IntType, NatType): (IntType, NatType), (IntType, IntType): (IntType, NatType), (MumavType, NatType): (MumavType, MumavType), (MumavType, MumavType): (NatType, MumavType), }, ) # type: Tuple[Union[Type[IntType], Type[NatType], Type[TimestampType], Type[MumavType]], Union[Type[IntType], Type[NatType], Type[TimestampType], Type[MumavType]]] if int(b) == 0: res = OptionType.none(PairType.create_type(args=[q_type, r_type])) else: q, r = divmod(int(a), int(b)) if r < 0: r += abs(int(b)) q += 1 items: List[MichelsonType] = [q_type.from_value(q), r_type.from_value(r)] res = OptionType.from_some(PairType.from_comb(items)) stack.push(res) stdout.append(format_stdout(cls.prim, [a, b], [res])) # type: ignore return cls(stack_items_added=1)
[docs]def execute_shift(prim: str, stack: MichelsonStack, stdout: List[str], shift: Callable[[Tuple[int, int]], int]): a, b = cast(Tuple[NatType, NatType], stack.pop2()) a.assert_type_equal(NatType) b.assert_type_equal(NatType) assert int(b) < 257, f'shift overflow {int(b)}, should not exceed 256' c = shift((int(a), int(b))) res = NatType.from_value(c) stack.push(res) stdout.append(format_stdout(prim, [a, b], [res]))
[docs]class LslInstruction(MichelsonInstruction, prim='LSL'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): execute_shift(cls.prim, stack, stdout, lambda x: x[0] << x[1]) # type: ignore return cls(stack_items_added=1)
[docs]class LsrInstruction(MichelsonInstruction, prim='LSR'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): execute_shift(cls.prim, stack, stdout, lambda x: x[0] >> x[1]) # type: ignore return cls(stack_items_added=1)
[docs]class MulInstruction(MichelsonInstruction, prim='MUL'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a, b = cast( Tuple[Union[IntType, NatType, MumavType, BLS12_381_FrType, BLS12_381_G1Type, BLS12_381_G2Type], ...], stack.pop2(), ) (res_type,) = dispatch_types( type(a), type(b), mapping={ (NatType, NatType): (NatType,), (NatType, IntType): (IntType,), (IntType, NatType): (IntType,), (IntType, IntType): (IntType,), (MumavType, NatType): (MumavType,), (NatType, MumavType): (MumavType,), (NatType, BLS12_381_FrType): (BLS12_381_FrType,), (IntType, BLS12_381_FrType): (BLS12_381_FrType,), (BLS12_381_FrType, NatType): (BLS12_381_FrType,), (BLS12_381_FrType, IntType): (BLS12_381_FrType,), (BLS12_381_FrType, BLS12_381_FrType): (BLS12_381_FrType,), (BLS12_381_G1Type, BLS12_381_FrType): (BLS12_381_G1Type,), (BLS12_381_G2Type, BLS12_381_FrType): (BLS12_381_G2Type,), }, ) res_type = cast( Union[ Type[IntType], Type[NatType], Type[TimestampType], Type[MumavType], Type[BLS12_381_FrType], Type[BLS12_381_G1Type], Type[BLS12_381_G2Type], ], res_type, ) if issubclass(res_type, IntType): res = res_type.from_value(int(a) * int(b)) # type: ignore else: res = res_type.from_point(bls12_381.multiply(a.to_point(), int(b))) # type: ignore stack.push(res) stdout.append(format_stdout(cls.prim, [a, b], [res])) # type: ignore return cls(stack_items_added=1)
[docs]class NegInstruction(MichelsonInstruction, prim='NEG'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a = cast(Union[IntType, NatType, BLS12_381_FrType, BLS12_381_G1Type, BLS12_381_G2Type], stack.pop1()) (res_type,) = dispatch_types( type(a), mapping={ (IntType,): (IntType,), (NatType,): (IntType,), (BLS12_381_FrType,): (BLS12_381_FrType,), (BLS12_381_G1Type,): (BLS12_381_G1Type,), (BLS12_381_G2Type,): (BLS12_381_G2Type,), }, ) if issubclass(res_type, IntType): res = IntType.from_value(-int(a)) # type: ignore else: res = res_type.from_point(bls12_381.neg(a.to_point())) # type: ignore stack.push(res) stdout.append(format_stdout(cls.prim, [a], [res])) # type: ignore return cls(stack_items_added=1)
[docs]class SubInstruction(MichelsonInstruction, prim='SUB'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a, b = cast(Tuple[Union[IntType, NatType, MumavType, TimestampType], ...], stack.pop2()) (res_type,) = dispatch_types( type(a), type(b), mapping={ # type: ignore (NatType, NatType): (IntType,), (NatType, IntType): (IntType,), (IntType, NatType): (IntType,), (IntType, IntType): (IntType,), (TimestampType, IntType): (TimestampType,), (TimestampType, TimestampType): (IntType,), (MumavType, MumavType): (MumavType,), }, ) # type: Tuple[Union[Type[IntType], Type[NatType], Type[TimestampType], Type[MumavType]]] res = res_type.from_value(int(a) - int(b)) stack.push(res) stdout.append(format_stdout(cls.prim, [a, b], [res])) # type: ignore return cls(stack_items_added=1)
[docs]class SubMumavInstruction(MichelsonInstruction, prim='SUB_MUMAV'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a, b = cast(Tuple[MumavType, MumavType], stack.pop2()) a.assert_type_equal(MumavType) b.assert_type_equal(MumavType) try: res = OptionType.from_some(MumavType.from_value(int(a) - int(b))) except OverflowError: res = OptionType.none(MumavType) stack.push(res) stdout.append(format_stdout(cls.prim, [a, b], [res])) # type: ignore return cls(stack_items_added=1)
[docs]class IntInstruction(MichelsonInstruction, prim='INT'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a = stack.pop1() if isinstance(a, BytesType): res = IntType.from_value(int.from_bytes(bytes(a), 'big', signed=True)) else: a = cast(Union[NatType, BLS12_381_FrType], a) a.assert_type_in(NatType, BLS12_381_FrType) res = IntType.from_value(int(a)) stack.push(res) stdout.append(f'{cls.prim} / {repr(a)} => {repr(res)}') return cls(stack_items_added=1)
[docs]class IsNatInstruction(MichelsonInstruction, prim='ISNAT'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a = cast(IntType, stack.pop1()) a.assert_type_equal(IntType) if int(a) >= 0: res = OptionType.from_some(NatType.from_value(int(a))) else: res = OptionType.none(NatType) stack.push(res) stdout.append(format_stdout(cls.prim, [a], [res])) # type: ignore return cls(stack_items_added=1)
[docs]class NatInstruction(MichelsonInstruction, prim='NAT'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a = cast(BytesType, stack.pop1()) a.assert_type_in(BytesType) res = NatType.from_value(int.from_bytes(bytes(a), 'big')) stack.push(res) stdout.append(f'{cls.prim} / {repr(a)} => {repr(res)}') return cls(stack_items_added=1)
[docs]class BytesInstruction(MichelsonInstruction, prim='BYTES'):
[docs] @classmethod def execute(cls, stack: MichelsonStack, stdout: List[str], context: AbstractContext): a = cast(Union[NatType, IntType], stack.pop1()) a.assert_type_in(NatType, IntType) int_val = int(a) signed = isinstance(a, IntType) if signed: length = (8 + (int_val + (int_val < 0)).bit_length()) // 8 else: length = (7 + int_val.bit_length()) // 8 # NOTE: the shortest big-endian encoding of natural number or integer n byte_val = int_val.to_bytes(length, 'big', signed=signed).lstrip(b'\x00') res = BytesType.from_value(byte_val) stack.push(res) stdout.append(f'{cls.prim} / {repr(a)} => {repr(res)}') return cls(stack_items_added=1)