| # Implementat marshal.loads() in pure Python |
| |
| import ast |
| |
| from typing import Any, Tuple |
| |
| |
| class Type: |
| # Adapted from marshal.c |
| NULL = ord('0') |
| NONE = ord('N') |
| FALSE = ord('F') |
| TRUE = ord('T') |
| STOPITER = ord('S') |
| ELLIPSIS = ord('.') |
| INT = ord('i') |
| INT64 = ord('I') |
| FLOAT = ord('f') |
| BINARY_FLOAT = ord('g') |
| COMPLEX = ord('x') |
| BINARY_COMPLEX = ord('y') |
| LONG = ord('l') |
| STRING = ord('s') |
| INTERNED = ord('t') |
| REF = ord('r') |
| TUPLE = ord('(') |
| LIST = ord('[') |
| DICT = ord('{') |
| CODE = ord('c') |
| UNICODE = ord('u') |
| UNKNOWN = ord('?') |
| SET = ord('<') |
| FROZENSET = ord('>') |
| ASCII = ord('a') |
| ASCII_INTERNED = ord('A') |
| SMALL_TUPLE = ord(')') |
| SHORT_ASCII = ord('z') |
| SHORT_ASCII_INTERNED = ord('Z') |
| |
| |
| FLAG_REF = 0x80 # with a type, add obj to index |
| |
| NULL = object() # marker |
| |
| # Cell kinds |
| CO_FAST_LOCAL = 0x20 |
| CO_FAST_CELL = 0x40 |
| CO_FAST_FREE = 0x80 |
| |
| |
| class Code: |
| def __init__(self, **kwds: Any): |
| self.__dict__.update(kwds) |
| |
| def __repr__(self) -> str: |
| return f"Code(**{self.__dict__})" |
| |
| co_localsplusnames: Tuple[str] |
| co_localspluskinds: Tuple[int] |
| |
| def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]: |
| varnames: list[str] = [] |
| for name, kind in zip(self.co_localsplusnames, |
| self.co_localspluskinds): |
| if kind & select_kind: |
| varnames.append(name) |
| return tuple(varnames) |
| |
| @property |
| def co_varnames(self) -> Tuple[str, ...]: |
| return self.get_localsplus_names(CO_FAST_LOCAL) |
| |
| @property |
| def co_cellvars(self) -> Tuple[str, ...]: |
| return self.get_localsplus_names(CO_FAST_CELL) |
| |
| @property |
| def co_freevars(self) -> Tuple[str, ...]: |
| return self.get_localsplus_names(CO_FAST_FREE) |
| |
| @property |
| def co_nlocals(self) -> int: |
| return len(self.co_varnames) |
| |
| |
| class Reader: |
| # A fairly literal translation of the marshal reader. |
| |
| def __init__(self, data: bytes): |
| self.data: bytes = data |
| self.end: int = len(self.data) |
| self.pos: int = 0 |
| self.refs: list[Any] = [] |
| self.level: int = 0 |
| |
| def r_string(self, n: int) -> bytes: |
| assert 0 <= n <= self.end - self.pos |
| buf = self.data[self.pos : self.pos + n] |
| self.pos += n |
| return buf |
| |
| def r_byte(self) -> int: |
| buf = self.r_string(1) |
| return buf[0] |
| |
| def r_short(self) -> int: |
| buf = self.r_string(2) |
| x = buf[0] |
| x |= buf[1] << 8 |
| x |= -(x & (1<<15)) # Sign-extend |
| return x |
| |
| def r_long(self) -> int: |
| buf = self.r_string(4) |
| x = buf[0] |
| x |= buf[1] << 8 |
| x |= buf[2] << 16 |
| x |= buf[3] << 24 |
| x |= -(x & (1<<31)) # Sign-extend |
| return x |
| |
| def r_long64(self) -> int: |
| buf = self.r_string(8) |
| x = buf[0] |
| x |= buf[1] << 8 |
| x |= buf[2] << 16 |
| x |= buf[3] << 24 |
| x |= buf[1] << 32 |
| x |= buf[1] << 40 |
| x |= buf[1] << 48 |
| x |= buf[1] << 56 |
| x |= -(x & (1<<63)) # Sign-extend |
| return x |
| |
| def r_PyLong(self) -> int: |
| n = self.r_long() |
| size = abs(n) |
| x = 0 |
| # Pray this is right |
| for i in range(size): |
| x |= self.r_short() << i*15 |
| if n < 0: |
| x = -x |
| return x |
| |
| def r_float_bin(self) -> float: |
| buf = self.r_string(8) |
| import struct # Lazy import to avoid breaking UNIX build |
| return struct.unpack("d", buf)[0] |
| |
| def r_float_str(self) -> float: |
| n = self.r_byte() |
| buf = self.r_string(n) |
| return ast.literal_eval(buf.decode("ascii")) |
| |
| def r_ref_reserve(self, flag: int) -> int: |
| if flag: |
| idx = len(self.refs) |
| self.refs.append(None) |
| return idx |
| else: |
| return 0 |
| |
| def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any: |
| if flag: |
| self.refs[idx] = obj |
| return obj |
| |
| def r_ref(self, obj: Any, flag: int) -> Any: |
| assert flag & FLAG_REF |
| self.refs.append(obj) |
| return obj |
| |
| def r_object(self) -> Any: |
| old_level = self.level |
| try: |
| return self._r_object() |
| finally: |
| self.level = old_level |
| |
| def _r_object(self) -> Any: |
| code = self.r_byte() |
| flag = code & FLAG_REF |
| type = code & ~FLAG_REF |
| # print(" "*self.level + f"{code} {flag} {type} {chr(type)!r}") |
| self.level += 1 |
| |
| def R_REF(obj: Any) -> Any: |
| if flag: |
| obj = self.r_ref(obj, flag) |
| return obj |
| |
| if type == Type.NULL: |
| return NULL |
| elif type == Type.NONE: |
| return None |
| elif type == Type.ELLIPSIS: |
| return Ellipsis |
| elif type == Type.FALSE: |
| return False |
| elif type == Type.TRUE: |
| return True |
| elif type == Type.INT: |
| return R_REF(self.r_long()) |
| elif type == Type.INT64: |
| return R_REF(self.r_long64()) |
| elif type == Type.LONG: |
| return R_REF(self.r_PyLong()) |
| elif type == Type.FLOAT: |
| return R_REF(self.r_float_str()) |
| elif type == Type.BINARY_FLOAT: |
| return R_REF(self.r_float_bin()) |
| elif type == Type.COMPLEX: |
| return R_REF(complex(self.r_float_str(), |
| self.r_float_str())) |
| elif type == Type.BINARY_COMPLEX: |
| return R_REF(complex(self.r_float_bin(), |
| self.r_float_bin())) |
| elif type == Type.STRING: |
| n = self.r_long() |
| return R_REF(self.r_string(n)) |
| elif type == Type.ASCII_INTERNED or type == Type.ASCII: |
| n = self.r_long() |
| return R_REF(self.r_string(n).decode("ascii")) |
| elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII: |
| n = self.r_byte() |
| return R_REF(self.r_string(n).decode("ascii")) |
| elif type == Type.INTERNED or type == Type.UNICODE: |
| n = self.r_long() |
| return R_REF(self.r_string(n).decode("utf8", "surrogatepass")) |
| elif type == Type.SMALL_TUPLE: |
| n = self.r_byte() |
| idx = self.r_ref_reserve(flag) |
| retval: Any = tuple(self.r_object() for _ in range(n)) |
| self.r_ref_insert(retval, idx, flag) |
| return retval |
| elif type == Type.TUPLE: |
| n = self.r_long() |
| idx = self.r_ref_reserve(flag) |
| retval = tuple(self.r_object() for _ in range(n)) |
| self.r_ref_insert(retval, idx, flag) |
| return retval |
| elif type == Type.LIST: |
| n = self.r_long() |
| retval = R_REF([]) |
| for _ in range(n): |
| retval.append(self.r_object()) |
| return retval |
| elif type == Type.DICT: |
| retval = R_REF({}) |
| while True: |
| key = self.r_object() |
| if key == NULL: |
| break |
| val = self.r_object() |
| retval[key] = val |
| return retval |
| elif type == Type.SET: |
| n = self.r_long() |
| retval = R_REF(set()) |
| for _ in range(n): |
| v = self.r_object() |
| retval.add(v) |
| return retval |
| elif type == Type.FROZENSET: |
| n = self.r_long() |
| s: set[Any] = set() |
| idx = self.r_ref_reserve(flag) |
| for _ in range(n): |
| v = self.r_object() |
| s.add(v) |
| retval = frozenset(s) |
| self.r_ref_insert(retval, idx, flag) |
| return retval |
| elif type == Type.CODE: |
| retval = R_REF(Code()) |
| retval.co_argcount = self.r_long() |
| retval.co_posonlyargcount = self.r_long() |
| retval.co_kwonlyargcount = self.r_long() |
| retval.co_stacksize = self.r_long() |
| retval.co_flags = self.r_long() |
| retval.co_code = self.r_object() |
| retval.co_consts = self.r_object() |
| retval.co_names = self.r_object() |
| retval.co_localsplusnames = self.r_object() |
| retval.co_localspluskinds = self.r_object() |
| retval.co_filename = self.r_object() |
| retval.co_name = self.r_object() |
| retval.co_qualname = self.r_object() |
| retval.co_firstlineno = self.r_long() |
| retval.co_linetable = self.r_object() |
| retval.co_exceptiontable = self.r_object() |
| return retval |
| elif type == Type.REF: |
| n = self.r_long() |
| retval = self.refs[n] |
| assert retval is not None |
| return retval |
| else: |
| breakpoint() |
| raise AssertionError(f"Unknown type {type} {chr(type)!r}") |
| |
| |
| def loads(data: bytes) -> Any: |
| assert isinstance(data, bytes) |
| r = Reader(data) |
| return r.r_object() |
| |
| |
| def main(): |
| # Test |
| import marshal, pprint |
| sample = {'foo': {(42, "bar", 3.14)}} |
| data = marshal.dumps(sample) |
| retval = loads(data) |
| assert retval == sample, retval |
| sample = main.__code__ |
| data = marshal.dumps(sample) |
| retval = loads(data) |
| assert isinstance(retval, Code), retval |
| pprint.pprint(retval.__dict__) |
| |
| |
| if __name__ == "__main__": |
| main() |