#!/usr/bin/env python3
from __future__ import annotations
# Masin RPG (Fullbirth) save-code encoder/decoder in Python 3
# -----------------------------------------------------------

import math, sys
import numpy
numpy.seterr(all='ignore')
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Iterable, Optional
from pyjassprng import JASSPrng


# --- Alphabet ---------------------------------------------------------------
DEFAULT_ALPHABET = (
    "abcdefghijklmnopqrstuvwxyz"
    "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    "0987654321"
    "_!@#<>?/$%^&*`'()[]{}~=+:;,.|"
)


CODE_LEN_TBL = (0, 0, 1.0, 1.5849625, 2.0, 2.32192809, 2.5849625, 2.80735492, 3.0, 3.169925, 3.32192809, 3.45943162, 3.5849625, 3.70043972, 3.80735492, 3.9068906, 4.0, 4.08746284, 4.169925, 4.24792751, 4.32192809, 4.39231742, 4.45943162, 4.52356196, 4.5849625, 4.64385619, 4.70043972, 4.7548875, 4.80735492, 4.857981, 4.9068906, 4.95419631, 5.0, 5.04439412, 5.08746284, 5.12928302, 5.169925, 5.20945337, 5.24792751, 5.28540222, 5.32192809, 5.357552, 5.39231742, 5.42626475, 5.45943162, 5.4918531, 5.52356196, 5.55458885, 5.5849625, 5.61470984, 5.64385619, 5.67242534, 5.70043972, 5.72792045, 5.7548875, 5.78135971, 5.80735492, 5.83289001, 5.857981, 5.88264305, 5.9068906, 5.93073734, 5.95419631, 5.97727992, 6.0, 6.02236781, 6.04439412, 6.06608919, 6.08746284, 6.10852446, 6.12928302, 6.14974712, 6.169925, 6.18982456, 6.20945337, 6.22881869, 6.24792751, 6.26678654, 6.28540222, 6.30378075, 6.32192809, 6.33985, 6.357552, 6.37503943, 6.39231742, 6.40939094, 6.42626475, 6.4429435, 6.45943162, 6.47573343, 6.4918531, 6.50779464, 6.52356196, 6.53915881, 6.55458885, 6.56985561, 6.5849625, 6.59991284, 6.61470984, 6.62935662, 6.64385619, 6.65821148, 6.67242534, 6.68650053, 6.70043972, 6.71424552, 6.72792045, 6.74146699, 6.7548875, 6.76818432, 6.78135971, 6.79441587, 6.80735492, 6.82017896, 6.83289001, 6.84549005, 6.857981, 6.87036472, 6.88264305, 6.89481776, 6.9068906, 6.91886324, 6.93073734, 6.94251451, 6.95419631, 6.96578429, 6.97727992, 6.98868469, 7.0, 7.01122726, 7.02236781, 7.033423, 7.04439412, 7.05528244, 7.06608919, 7.0768156, 7.08746284, 7.09803208)

def _per_char_bits(base_len: int) -> float:
    # Use table if present; fall back to math.log2 otherwise
    try:
        return CODE_LEN_TBL[base_len]
    except:
        return math.log2(base_len)

# --- Bit utilities ----------------------------------------------------------

def int_to_bits(n: int, length: int) -> List[int]:
    """LSB-first bit list of exactly 'length' bits (zero-padded)."""
    n = int(n)
    return [(n >> i) & 1 for i in range(int(length))]

def bits_to_int(bits: Iterable[int]) -> int:
    """LSB-first bit list -> Python int (arbitrary precision)."""
    v = 0
    for i, b in enumerate(bits):
        if (b & 1) == 1:
            v |= (1 << i)
    return v  # IMPORTANT: Python int, not numpy.int32

def normalize_display_name(raw: Optional[str], members_disp: Optional[List[str]] = None, gmax_member_id: Optional[int] = None) -> Optional[str]:
    ''' Strips BattleNet "#" and checks members array '''
    if raw is None:
        return None
    lb = raw.rfind('#')
    if lb != -1:
        raw = raw[:lb]
    if members_disp:
        limit = (gmax_member_id + 1) if (gmax_member_id is not None) else len(members_disp)
        low = raw.lower()
        for disp in members_disp[:limit]:
            if disp is None:
                continue
            if low == disp.lower():
                return disp
    return raw

def calc_player_name_checksum(name: Optional[str], members_disp: Optional[List[str]] = None, gmax_member_id: Optional[int] = None) -> int:
    name = normalize_display_name(name, members_disp, gmax_member_id)
    if name is None:
        return numpy.int32(12345)

    PNAME_ALPHA = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_()"
    chksum = numpy.int32(0)
    for ch in name + '\0':  # include JASS null terminator
        try:
            char_val = PNAME_ALPHA.index(ch)
        except ValueError:
            char_val = len(PNAME_ALPHA)
        if char_val >= 26:
            char_val -= 26
        char_val = numpy.int32(char_val)
        chksum = (chksum + numpy.int32(0xBCECE9)) * numpy.int32(0x63A105F) \
                 * numpy.mod(char_val * char_val * chksum + numpy.int32(123), numpy.int32(50))
    return numpy.abs(chksum)


@dataclass
class DecoderState:
    # Core bit arrays (LSB-first), mirroring gCodecBitArray__Mk and gCodecKeyBitArray__Mt
    bits: List[int]
    bit_len: int
    rbits: List[int]
    rbit_len: int

    # (metadata, optional)
    alphabet: str
    ml: int

    # --- JASS helpers on bit arrays ---

    def _get_bit(self, idx: int) -> int:
        if idx >= self.bit_len or idx < 0:
            return 0
        return self.bits[idx]
    
    def _set_bit(self, idx: int, val: int) -> None:
        v = 1 if val else 0
        if idx >= self.bit_len:
            if v == 1:
                while self.bit_len <= idx:
                    self.bits.append(0)
                    self.bit_len += 1
                self.bits[idx] = 1
            return  # writing 0 past end is a no-op in JASS
        if idx == self.bit_len - 1 and v == 0:
            self.bits[idx] = 0
            while self.bit_len > 0 and self.bits[self.bit_len - 1] == 0:
                self.bit_len -= 1
                self.bits.pop()
        else:
            self.bits[idx] = v

    def _get_rbit(self, idx: int) -> int:
        if idx >= self.rbit_len or idx < 0:
            return 0
        return self.rbits[idx]


    def _set_rbit(self, idx: int, val: int) -> None:
        v = 1 if val else 0
        if idx >= self.rbit_len:
            if v == 1:
                while self.rbit_len <= idx:
                    self.rbits.append(0)
                    self.rbit_len += 1
                self.rbits[idx] = 1
            else:
                return
        else:
            if idx == self.rbit_len - 1 and v == 0:
                self.rbits[idx] = 0
                while self.rbit_len > 0 and self.rbits[self.rbit_len - 1] == 0:
                    self.rbit_len -= 1
                    self.rbits.pop()
            else:
                self.rbits[idx] = v

    @staticmethod
    def _bit_length(n: int) -> int:
        n = int(n)
        i = 0
        while n != 0:
            n //= 2
            i += 1
        return i

    def _read_bits_be(self, lo: int, hi: int) -> int:
        """Read bits in [lo, hi) as big-endian (BDd)."""
        out = 0
        for i in range(hi - lo):
            out = out * 2
            out = out + self._get_bit(hi -i -1)
        return out

    def _insert_value_between(self, value: int, start: int, end_excl: int) -> None:
        """CodecInsertValueBetween__BDg: write LSB-first across [start, end_excl]."""
        v = int(value)
        for i in range(end_excl - start):
            self._set_bit(start + i, v % 2)
            v = v // 2

    def _mv_rbits_reverse_to_bits(self) -> None:
        """CodecMvKeyBitsReverseToBits__BDY."""
        # ensure capacity
        if len(self.bits) < self.rbit_len:
            self.bits += [0] * (self.rbit_len - len(self.bits))
        for i in range(self.rbit_len):
            self.bits[i] = self.rbits[self.rbit_len - 1 - i]
        self.bit_len = self.rbit_len
        self.rbit_len = 0
        self.rbits.clear()

    # --- Public API: JASS-accurate pop ---

    def pop(self, radix: int) -> int:
        """Faithful to CodecPopValue_inner__BDm."""
        Ln = int(radix)
        if Ln <= 0:
            raise ValueError("radix must be positive")

        i = 0
        bit_arr_len_initial = self.bit_len
        Ma = self._bit_length(Ln - 1)

        # Slide a Ma-wide window from the end
        while bit_arr_len_initial >= i + Ma:
            start = bit_arr_len_initial - i - Ma
            Mq    = self._read_bits_be(start, self.bit_len)

            # --- ensure rbits capacity up to index i (JASS arrays are implicitly 0) ---
            while len(self.rbits) <= i:
                self.rbits.append(0)   # fill with 0s

            if Mq >= Ln:
                # insert Mq - Ln back into [start .. current_bit_len] (inclusive)
                self._insert_value_between(Mq - Ln, start, self.bit_len)
                self.rbits[i] = 1
            else:
                self.rbits[i] = 0

            i += 1

        # JASS sets length explicitly after the loop
        self.rbit_len = i

        # Output is the whole bitarray as big-endian int
        out = self._read_bits_be(0, self.bit_len)

        # Move reversed key bits to bits (CodecMvKeyBitsReverseToBits__BDY)
        self._mv_rbits_reverse_to_bits()

        return out

    def pop_sequence(self, radices: Iterable[int]) -> List[int]:
        return [self.pop(r) for r in radices]


# --- Core codec -------------------------------------------------------------

class MasinCodec:
    def __init__(self, *, alphabet: str = DEFAULT_ALPHABET, ml: int = 3, min_extra_bits: int = 9,
                 members_disp: Optional[List[str]] = None, gmax_member_id: Optional[int] = None):
        # Use Python ints for sizes/bases; keep numpy for checksum math where desired.
        self.alphabet = alphabet
        self.base = int(len(alphabet))
        self.ml = int(ml)
        self.min_extra_bits = int(min_extra_bits)
        self.members_disp = members_disp
        self.gmax_member_id = gmax_member_id

    @staticmethod
    def push_acc(n: int, value: int, radix: int) -> int:
        if value < 0 or value >= radix:
            raise ValueError(f"value {value} not in [0,{radix})")
        return n * radix + value

    def encode(self, values: Iterable[Tuple[int, int]], player_name: Optional[str]) -> str:
        n = 0
        for val, radix in values:
            n = self.push_acc(int(n), int(val), int(radix))

        bits_needed = max(int(n).bit_length(), 0)
        per_char_bits = math.log2(self.base)
        Mf = math.ceil((bits_needed + self.min_extra_bits) / per_char_bits)
        out_bit_len = int(per_char_bits * Mf)  # floor like R2I

        bits = int_to_bits(n, out_bit_len)
        bits = self._mix_bits_encode(bits, player_name)
        n_mixed = bits_to_int(bits)

        code_chars = []
        for _ in range(Mf):
            code_chars.append(self.alphabet[n_mixed % self.base])
            n_mixed //= self.base
        return "".join(reversed(code_chars))

    def _decode_init_bits(self, code: str) -> tuple[list[int], int, int]:
        bits: list[int] = []
        bit_len = 0
        rbits: list[int] = []
        rbit_len = 0
    
        bit_len_list = [bit_len]
        rbit_len_list = [rbit_len]
    
        def get_bit(idx: int) -> int:
            bl = bit_len_list[0]
            return bits[idx] if 0 <= idx < bl else 0
    
        def set_bit(idx: int, val: int) -> None:
            v = 1 if val else 0
            bl = bit_len_list[0]
            if idx >= bl:
                if v == 1:
                    while bit_len_list[0] <= idx:
                        bits.append(0)
                        bit_len_list[0] += 1
                    bits[idx] = 1
                return  # writing 0 past end is a no-op
            if idx == bit_len_list[0] - 1 and v == 0:
                bits[idx] = 0
                while bit_len_list[0] > 0 and bits[bit_len_list[0] - 1] == 0:
                    bit_len_list[0] -= 1
                    bits.pop()
            else:
                bits[idx] = v
    
        def get_rbit(idx: int) -> int:
            rbl = rbit_len_list[0]
            return rbits[idx] if 0 <= idx < rbl else 0
    
        def set_rbit(idx: int, val: int) -> None:
            if idx >= rbit_len_list[0]:
                while rbit_len_list[0] <= idx:
                    rbits.append(0)
                    rbit_len_list[0] += 1
            rbits[idx] = 1 if val else 0
    
        def rbits_to_bits_copy() -> None:
            if len(bits) < rbit_len_list[0]:
                bits.extend([0] * (rbit_len_list[0] - len(bits)))
            for i in range(rbit_len_list[0]):
                bits[i] = rbits[i]
            bit_len_list[0] = rbit_len_list[0]
            rbits.clear()
            rbit_len_list[0] = 0
    
        def roll_by_const(Ln: int) -> None:
            # multiply current 'bits' by constant Ln (add shifted copies into rbits)
            i = 0
            n = Ln
            while n != 0:
                if (n & 1) == 1:
                    carry = 0
                    j = 0
                    bl = bit_len_list[0]
                    while j < bl or carry != 0:
                        tmp = carry + get_bit(j) + get_rbit(i + j)
                        set_rbit(i + j, tmp & 1)
                        carry = tmp >> 1
                        j += 1
                n >>= 1
                i += 1
            rbits_to_bits_copy()
    
        def append_small(value: int) -> None:
            v = int(value)
            carry = 0
            i = 0
            while not (v == 0 and carry == 0):
                tmp = carry + (v & 1) + get_bit(i)
                set_bit(i, tmp & 1)
                carry = tmp >> 1
                v >>= 1
                i += 1
    
        Mf = 0
        for ch in code:
            if ch == '-':
                continue
            idx = self.alphabet.find(ch)
            if idx == -1:
                idx = "abcdefghijklmnopqrstuvwxyz".find(ch)
                if idx == -1:
                    raise ValueError(f"Character '{ch}' not in alphabet")
            roll_by_const(self.base)   # multiply by |alphabet|
            append_small(idx)          # add the digit
            Mf += 1
    
        return bits, bit_len_list[0], Mf

    # --- decode (returns DecoderState holding the bitarray) ------------------
    def decode(self, code: str, player_name: Optional[str]) -> DecoderState:
        # 1) JASS-accurate decode-init (roll + append) to build the bitarray
        bits, bit_len, Mf = self._decode_init_bits(code)
    
        # 2) R2I( table[base] * Mf )
        per_char = _per_char_bits(self.base)
        out_bit_len = int(per_char * Mf)  # truncate like R2I
    
        # 3) pad zeros up to out_bit_len (JASS CodecDecode__BDt)
        while bit_len < out_bit_len:
            bits.append(0)
            bit_len += 1
    
        # 4) unscramble (reverse twiddle + keyed shuffle undo), in-place
        bits = self._unmix_bits_decode(bits, player_name)
         
        # 5) return a DecoderState with the bitarray
        return DecoderState(bits=list(bits), bit_len=out_bit_len, rbits=[], rbit_len=0,
                        alphabet=self.alphabet, ml=self.ml)

    # --- internal mixing ----------------------------------------
    def _mix_bits_encode(self, bits: List[int], player_name: Optional[str]) -> List[int]:
        L = len(bits)
        if L == 0:
            return bits
        seed = int(calc_player_name_checksum(player_name, self.members_disp, self.gmax_member_id))
        for _ in range(self.ml):
            rng = JASSPrng(seed)
            for i in range(0, L - 1):
                j = rng.GetRandomInt(i, L - 1)
                bits[i], bits[j] = bits[j], bits[i]
            for j in range(L):
                if bits[j] == 0:
                    bits[j] = 1
                    k = (j + 1) % L
                    bits[k] = 1 - bits[k]
                else:
                    bits[j] = 0
            rng = JASSPrng(seed)
            seed = rng.GetRandomInt(0, 1_000_000)
        return bits

    def _unmix_bits_decode(self, bits: List[int], player_name: Optional[str]) -> List[int]:
        L = len(bits)
        if L == 0:
            return bits
    
        # Build [s0, s1, ..., s{Ml-1}]
        seeds: List[int] = []
        seed = int(calc_player_name_checksum(player_name, self.members_disp, self.gmax_member_id))
        for _ in range(self.ml):
            seeds.append(seed)
            rng_tmp = JASSPrng(seed)
            seed = rng_tmp.GetRandomInt(0, 1_000_000)
    
        # IMPORTANT: undo from last seed back to first
        for pass_idx in range(self.ml):
            # reverse twiddle
            for j in range(L - 1, -1, -1):
                if bits[j] == 1:
                    bits[j] = 0
                    k = (j + 1) % L
                    if bits[k] == 1:
                        bits[k] = 0
                    else:
                        bits[k] = 1
                else:
                    bits[j] = 1
    
            # undo shuffle with the seed used on that pass during encode
            use_seed = seeds[self.ml - 1 - pass_idx]   # <— key change
            rng = JASSPrng(use_seed)
            idxs = [rng.GetRandomInt(i, L - 1) for i in range(0, L - 1)]
            for i in range(L - 2, -1, -1):
                Mo = bits[idxs[i]]
                bits[idxs[i]] = bits[i]
                bits[i] = Mo
        return bits

    @staticmethod
    def trailing_zero_bits_from_msb(st: DecoderState) -> int:
        tz = 0
        while tz < st.bit_len and st._get_bit(st.bit_len - 1 - tz) == 0:
            tz += 1
        return tz
# /CLASS
 
# === Helpers to build field pieces ===
def _last_digit(n: int) -> int:
    return abs(int(n)) % 10

def _split_stat(val: int) -> tuple[int, int, int]:
    """
    Split a stat value into (Fd, MD, Ll) where:
      stat = Fd*10000 + MD*100 + Ll   with 0<=Fd<=300, 0<=MD<100, 0<=Ll<100
    Clamp at 3_000_000 like the map does.
    """
    if val < 0:
        val = 0
    if val > 3_000_000:
        val = 3_000_000
    Fd = val // 10000
    rem = val - Fd * 10000
    MD = rem // 100
    Ll = rem - MD * 100
    # enforce radix ranges
    Fd = min(Fd, 300)
    MD %= 100
    Ll %= 100
    return Fd, MD, Ll

def _stat_checksum_from_lls(lls: list[int]) -> int:
    """
    Map’s checksum over the three Ll digits:
      if Ll >= 10, add Ll//10, else add Ll.
    Range: 0..27 (fits radix 28).
    """
    n9 = 0
    for ll in lls:
        n9 += (ll // 10) if ll >= 10 else ll
    return n9

def _gold_to_code(gold: int) -> int:
    if gold < 0:
        raise ValueError("gold must be non-negative")

    # 100k tier
    if gold % 100_000 == 0:
        code = gold // 100_000 + 10
        if 11 <= code <= 110:
            return int(code)

    # 20k tier
    if gold % 20_000 == 0:
        steps = gold // 20_000
        code = steps + 6
        if 7 <= code <= 10:
            return int(code)

    # 2k tier
    if gold % 2_000 == 0:
        steps = gold // 2_000
        code = steps + 1
        if 2 <= code <= 6:
            return int(code)

    # 1k tier
    if gold % 1_000 == 0:
        code = gold // 1_000
        if 0 <= code <= 1:
            return int(code)

    raise ValueError("gold must be a multiple of 1k/2k/20k/100k within the map limits")
    #raise ValueError(
    #    "gold not representable by the map's scheme. "
    #    "Allowed values: {0,1000, 2000..10000 step 2000, 20000..80000 step 20000, 100000..10000000 step 100000}."
    #)

def _level_to_pieces(level: int) -> tuple[int, int]:
    """
    Split level into (hi, lo) with hi in [0..20], lo in [0..99], since radix(hi)=21.
    """
    if level < 0:
        level = 0
    hi = level // 100
    lo = level % 100
    if hi > 20:
        raise ValueError("level too large for (hi radix 21, lo radix 100)")
    return hi, lo

def _lvl_ck(desc, level):
    lvl_hi, lvl_lo = _level_to_pieces(level)
    return _last_digit(lvl_hi) + _last_digit(lvl_lo) + _last_digit(desc)

def _bed_pop_block(st, Ly: int, Gd_init: int) -> list[int]:
    """Decode 1 block of length Ly with BED constraints."""
    L5 = 0
    Gd = Gd_init
    out = []
    for Gf in range(Ly):
        v = st.pop(Gd - L5) + L5
        out.append(v)
        if (Gf % 2) == 0:
            L5 = v
        else:
            Gd = v + 1
    return out

def _bek_canonical_order(block: list[int]) -> list[int]:
    """Return [min, max, next_min, next_max, ...] for a 6-slot block."""
    a = sorted(int(v) for v in block)
    out = []
    low, high = 0, len(a) - 1
    for k in range(len(a)):
        if k % 2 == 0:        # k = 0,2,4 -> take from low and advance
            out.append(a[low])
            low += 1
        else:                  # k = 1,3,5 -> take from high and retreat
            out.append(a[high])
            high -= 1
    return out


def _bek_push_block(block: list[int], Gd_init: int) -> list[tuple[int,int]]:
    """
    Encode 1 block (inverse of BED). Returns [(rel_val, rel_radix), ...]
    such that pop(Gd-L5)+L5 reproduces the original block.
    """
    block = _bek_canonical_order(block)
    L5 = 0
    Gd = Gd_init
    pushes = []
    for Gf, x in enumerate(block):
        if not (0 <= x < Gd_init):
            raise ValueError(f"item {x} out of range [0,{Gd_init-1}]")
        rel_radix = Gd - L5
        rel_val   = x  - L5
        if not (0 <= rel_val < rel_radix):
            raise ValueError(
                f"item {x} violates BED bounds at idx {Gf}: "
                f"must be in [{L5},{Gd-1}]"
            )
        pushes.append((rel_val, rel_radix))
        if (Gf % 2) == 0:
            L5 = x
        else:
            Gd = x + 1
    return pushes[::-1] # fix ordering

def dump_items(
    st,
    *,
    item_radix: int = 650,
    num_stash_pages: int = 6,
    grims_radix: int = 71,
):
    # 1) Four skill picks (LP, LO, LN, LM) — popped first in loader
    LP = st.pop(grims_radix)
    LO = st.pop(grims_radix)
    LN = st.pop(grims_radix)
    LM = st.pop(grims_radix)
    grims = (LP, LO, LN, LM)

    # 2) Stash pages: loader pops pages from LJ-1..0
    stash_pages = [None] * num_stash_pages

    for page in range(num_stash_pages - 1, -1, -1):
        stash_pages[page] = _bed_pop_block(st, 6, item_radix)

    # 3) Bag block (BED)
    bag_items = _bed_pop_block(st, 6, item_radix)

    # 4) Hero items: 6 × pop(Lv), slots 5..0 in loader
    hero_items = [0] * 6
    for slot in range(5, -1, -1):
        hero_items[slot] = st.pop(item_radix)

    return {
        "grims": grims,
        "stash_pages": stash_pages,  # page 0..LJ-1, each [slot0..slot5]
        "bag_items": bag_items,      # [slot0..slot5]
        "hero_items": hero_items,    # [slot0..slot5]
    }

def encode_character_code(
    player_name: str,
    *,
    hero_id: int,
    bag_id: int,
    gold: int,
    level: int,
    str_stat: int,
    agi_stat: int,
    int_stat: int,
    version: int = 1,
    desc: int = 0,
    lvl_ck: Optional[int] = None,
    codec: Optional[MasinCodec] = None,
    item_radix: int = 650,
    grims_radix: int = 71,
    full_build: bool = False, # should encode items and grims
    # All arguments after here are gated behind full_build == True
    hero_items: list[int] | None = None,              # 6 ints (0 if empty)
    bag_items: list[int] | None = None,               # 6 ints
    stash_pages: list[list[int]] | None = None,       # each page: 6 ints
    # optional “skill pick” indices (the 4 single-choice categories)
    # Pushed in the same order the map does: LP, LO, LN, then LM.
    grims: tuple[int,int,int,int] | None = None #ex: (0,0,0,0)

) -> str:
    """
    Build a valid save code from high-level fields.

    Notes:
    - hero_id in [0..74]  (radix 75)
    - bag_id  in [0..99]  (radix 100)
    - desc    in [0..1000] (radix 1001) — 'Descensions' count
    - level   in [0..2099] (hi radix 21, lo radix 100)
    - gold    must fit one of the map’s bands (1k/2k/20k/100k multiples)
    - version usually small (map-specific; ≤ 9 is typical)

    lvl_ck: if None, we use a simple deterministic placeholder (level % 28).
            If you know the exact formula from the map, pass it in explicitly.
    """
    if codec is None:
        codec = MasinCodec()

    # Validate simple ranges
    if not (0 <= hero_id < 75):
        raise ValueError("hero_id out of range for radix 75")
    if not (0 <= bag_id < 100):
        raise ValueError("bag_id out of range for radix 100")
    if not (0 <= desc < 1001):
        raise ValueError("desc out of range for radix 1001")
    if not (0 <= version < 16):
        raise ValueError("version out of range for radix 16")

    # Split stats -> (Fd, MD, Ll)
    sFd_STR, sMD_STR, sLl_STR = _split_stat(str_stat)
    sFd_AGI, sMD_AGI, sLl_AGI = _split_stat(agi_stat)
    sFd_INT, sMD_INT, sLl_INT = _split_stat(int_stat)

    # Stat checksum (radix 28)
    stat_ck = _stat_checksum_from_lls([sLl_INT, sLl_AGI, sLl_STR])

    # Gold compaction code (radix 111)
    gold_c = _gold_to_code(gold)

    # Level pieces (radix 21 for hi, 100 for lo)
    lvl_hi, lvl_lo = _level_to_pieces(level)

    # Level checksum (radix 28) — placeholder if not provided
    if lvl_ck is None:
        #lvl_ck = (level % 28)
        lvl_ck = _last_digit(lvl_hi) + _last_digit(lvl_lo) + _last_digit(desc)
    if not (0 <= lvl_ck < 28):
        raise ValueError("lvl_ck out of range for radix 28")

    if full_build: # Experimental
        # normalize inputs
        hero_items  = (hero_items  or [0]*6)[:6]
        bag_items   = (bag_items   or [0]*6)[:6]
        stash_pages = stash_pages or []
        for page in stash_pages:
            if len(page) != 6:
                raise ValueError("each stash page must have exactly 6 item ids")
    
        equip_vals: list[tuple[int,int]] = []
   
        # 1) HERO items (still 6 direct Lv values; loader pops them direct)
        for v in hero_items:
            equip_vals.append((int(v), item_radix))
        
        # 2) BAG via BEk(6, Lv, false)
        equip_vals += _bek_push_block(bag_items, item_radix)
        
        # 3) STASH PAGES via BEk(6, Lv, false) for each page
        for page in stash_pages:
            equip_vals += _bek_push_block(page, item_radix)
 
        # 4) grims  (LIFO: push reverse of pop order)
        LP, LO, LN, LM = (grims or (0,0,0,0))
        equip_vals += [
            (int(LM), grims_radix),
            (int(LN), grims_radix),
            (int(LO), grims_radix),
            (int(LP), grims_radix),
        ]
        '''
        # 4) grims
        LP, LO, LN, LM = (grims or (0,0,0,0))
        equip_vals += [(int(LP), grims_radix),
                       (int(LO), grims_radix),
                       (int(LN), grims_radix),
                       (int(LM), grims_radix)]
        '''
    # --- Compose mixed-radix values in the exact reverse of pop order ---
    # Pop order in decoder:
    #   version(16), stat_ck(28), hero_id(75), bag_id(100),
    #   desc(1001), lvl_lo(100), lvl_hi(21), lvl_ck(28), gold_c(111),
    #   [3x: Fd(301), MD(100), Ll(100)]
    #
    # Therefore push order must be:
    values: list[tuple[int, int]] = [
        # STR (last popped triplet -> push first): Ll, MD, Fd
        (sLl_STR, 100), (sMD_STR, 100), (sFd_STR, 301),
        # AGI
        (sLl_AGI, 100), (sMD_AGI, 100), (sFd_AGI, 301),
        # INT (first popped triplet -> push last of the three)
        (sLl_INT, 100), (sMD_INT, 100), (sFd_INT, 301),

        # then the scalar fields, still in reverse pop order
        (gold_c, 111),
        (lvl_ck, 28),
        (lvl_hi, 21),
        (lvl_lo, 100),
        (desc, 1001),
        (bag_id, 100),
        (hero_id, 75),
        (stat_ck, 28),
        (version, 16),
    ]
    if full_build: # prepend extra values
        values = equip_vals + values
    # Encode via bit-mix + base-N alphabet
    return codec.encode(values, player_name)

def encode_shared_code(
    player_name: str,
    *,
    # 1000/1001 blocks (these are the ones checked by the 3 sums):
    mx: int = 0, mw: int = 0, mv: int = 0,
    mu: int = 0, mt: int = 0, ms: int = 0,
    mr: int = 0, mq: int = 0,
    lq: int = 0,                 # radix 1001
    # toggles and values popped next:
    lr: int = 0, lt: int = 0, mp: int = 0, mo: int = 0, mn: int = 0, mm: int = 0,  # each radix 2
    ep: int = 50,               # radix 151 (camera/zoom scalar), 0..150; 50 is the map’s default-ish
    # optional per-ability arrays; if omitted, loader will read zeros
    mh: list[int] | None = None,
    mi: list[int] | None = None,
    mj: list[int] | None = None,
    mk: list[int] | None = None,
    codec: MasinCodec | None = None,
) -> str:
    """
    Builds a Shared.pld-style code that passes LoadSharedSave__BEW’s checks.
    If mh/mi/mj/mk are omitted, the loader will pop zeros for those arrays.
    """
    if codec is None:
        codec = MasinCodec()

    # clamp ranges the loader expects
    def r2(x: int) -> int: return 1 if x else 0
    lr, lt, mp, mo, mn, mm = map(r2, [lr, lt, mp, mo, mn, mm])

    # checksum fields (exactly as loader compares)
    OS = _last_digit(mr) + _last_digit(mq) + _last_digit(lq)   # pushed last -> popped first (as OS -> becomes third pop)
    N8 = _last_digit(mu) + _last_digit(mt) + _last_digit(ms)
    OR = _last_digit(mx) + _last_digit(mw) + _last_digit(mv)

    # Build payload in exact reverse of loader pop order:

    values: list[tuple[int, int]] = []

    # Arrays (reverse of pop ⇒ push MH, MI, MJ, MK first) — optional
    def push_bits_array(arr: list[int] | None):
        if not arr:
            return
        for v in arr:
            values.append((r2(v), 2))

    push_bits_array(mh)  # MH[1..LL-1]
    push_bits_array(mi)  # MI[1..LL-1]
    push_bits_array(mj)  # MJ[1..LL-1]
    push_bits_array(mk)  # MK[1..LL-1]

    # Then the scalar pops in reverse:
    values.extend([
        (ep, 151),
        (mm, 2), (mn, 2), (mo, 2), (mp, 2), (lt, 2), (lr, 2),

        (lq, 1001),
        (mq, 1000), (mr, 1000), (ms, 1000), (mt, 1000), (mu, 1000), (mv, 1000), (mw, 1000), (mx, 1000),

        (OS, 28), (N8, 28), (OR, 28),
    ])

    return codec.encode(values, player_name) 


# --- Convenience demo (optional) --------------------------------------------
def dump_player(st):
    p = {}
    # Versioning / identity
    p['version']     = st.pop(16)
    p['stat_chksum'] = st.pop(28)     # checksum over stat low-digits
    p['hero_id']     = st.pop(75)
    p['bag_id']      = st.pop(100)

    p['desc']        = st.pop(1001)
    lvl_lo           = st.pop(100)
    lvl_hi           = st.pop(21)
    p['lvl_ck']      = st.pop(28)
    p['level']       = lvl_lo + 100 * lvl_hi    
    # Gold
    gold_c  = st.pop(111)
    if   gold_c >= 11: gold = (gold_c - 10) * 100000
    elif gold_c >  6: gold = (gold_c -  6) *  20000
    elif gold_c >  1: gold = (gold_c -  1) *   2000
    else:             gold =  gold_c       *   1000
    p['gold'] = gold
    
    # Three stats (loop 3x): each stat = Fd*10000 + MD*100 + Ll (clamped 3,000,000)
    stats = []
    n9 = 0
    for _ in range(3):
        sFd = st.pop(301)
        sMD = st.pop(100)
        sLl = st.pop(100)
        # contribution to ME:
        if   sLl >= 100: n9 += sLl // 100
        elif sLl >= 10:  n9 += sLl // 10
        else:            n9 += sLl
        val = sFd * 10000 + sMD * 100 + sLl
        if val > 3000000: val = 3000000
        stats.append(val)
    p['STR'], p['AGI'], p['INT'] = stats[2], stats[1], stats[0]  # note the order in the map

    for k, v in p.items():
        print(f'{k}: {v}')
    #print(p)
    print(f"Stats valid: {n9 == p['stat_chksum']}")

def dump_shared(st: DecoderState, ll=71) -> dict:
    """
    Decode a Shared.pld stream from a DecoderState `st`.

    Args:
      st: DecoderState from mc.decode(shared_code, player_name)
      ll: gConst_71__LL from the map (default 100). Arrays are length (ll-1).

    Returns:
      dict with fields:
        OR, N8, OS,
        MX, MW, MV, MU, MT, MS, MR, MQ, LQ,
        valid_checksums (bool),
        LR, LT, MP, MO, MN, MM, Ep,
        MK, MJ, MI, MH  (each a list of length ll-1; index 0 == JASS index 1),
        H8 (the small bonus computed the same way the map does)
    """
    def last_digit(n: int) -> int:
        return abs(int(n)) % 10

    out: dict = {}

    # Pop in the exact order used by LoadSharedSave__BEW
    out["OR"] = st.pop(28)
    out["N8"] = st.pop(28)
    out["OS"] = st.pop(28)

    out["MX"] = st.pop(1000)
    out["MW"] = st.pop(1000)
    out["MV"] = st.pop(1000)
    out["MU"] = st.pop(1000)
    out["MT"] = st.pop(1000)
    out["MS"] = st.pop(1000)
    out["MR"] = st.pop(1000)
    out["MQ"] = st.pop(1000)

    out["LQ"] = st.pop(1001)

    # Verify the three checksum guards (same expressions as the JASS)
    os_calc = last_digit(out["MR"]) + last_digit(out["MQ"]) + last_digit(out["LQ"])
    n8_calc = last_digit(out["MU"]) + last_digit(out["MT"]) + last_digit(out["MS"])
    or_calc = last_digit(out["MX"]) + last_digit(out["MW"]) + last_digit(out["MV"])
    out["valid_checksums"] = (out["OS"] == os_calc and
                              out["N8"] == n8_calc and
                              out["OR"] == or_calc)

    # If checksums fail, the map zeroes the rest and bails. We still return what we have:
    if not out["valid_checksums"]:
        # The loader would also set some globals to 0 and Ep to 50.
        # We expose what *would* be defaulted if you want to mirror it.
        out.update(dict(LR=0, LT=0, MP=0, MO=0, MN=0, MM=0, Ep=50,
                        MK=[], MJ=[], MI=[], MH=[], H8=0.0))
        return out

    # Otherwise continue popping in map order
    out["LR"] = st.pop(2)
    out["LT"] = st.pop(2)
    out["MP"] = st.pop(2)
    out["MO"] = st.pop(2)
    out["MN"] = st.pop(2)
    out["MM"] = st.pop(2)
    out["Ep"] = st.pop(151)

    # Arrays are popped from (ll-1) down to 1; we return them as Python lists [0..ll-2]
    def pop_bits_array(desc: str) -> list[int]:
        tmp = []
        for _i in range(ll - 1, 0, -1):  # descending in JASS
            tmp.append(st.pop(2))
        tmp.reverse()  # index 0 corresponds to JASS index 1
        return tmp

    out["MK"] = pop_bits_array("MK")
    out["MJ"] = pop_bits_array("MJ")
    out["MI"] = pop_bits_array("MI")
    out["MH"] = pop_bits_array("MH")

    # Compute the H8 bonus the same way the map does
    if out["LQ"] >= 11:
        # I2R(R2I(3.963 * Pow(LQ, .5))) / 120. then * .5
        h8 = (int(3.963 * (out["LQ"] ** 0.5)) / 120.0) * 0.5
    else:
        h8 = (out["LQ"] * 0.01) * 0.5
    out["H8"] = h8

    return out

def get_code(args):
    if args.code:
        return args.code
    elif args.code_hex:
        return bytes.fromhex(args.code_hex).decode()
    elif args.code_file:
        with open(args.code_file, 'r') as f:
            return f.read().strip()
    else:
        print('--code or --code-hex required')
        sys.exit(-1)

# [id, ideal_stats[str,agi,int]]
INT_MAIN = [1_000_000, 0, 2_000_000]
STR_MAIN = [3_000_000, 0, 0]
AGI_MAIN = [1_000_000, 2_000_000, 0]
STR_AGI = [2_000_000, 1_000_000, 0]
HERO_ID_MAP = {
    "Slaughter": 8,
    "Death": 11,
    "Sitael" : 20,
    "Angra" : 35, # floating clone
    "AngraNightmare" : 34, # Nightmare
    "Satan" : 36, # ?
    "Lilith" : 25,
    "Blood Vampire": 26, # STR_AGI
}
ITEM_ID_MAP = {
    "Prophecy's Hand": 1,
    "Marble's Combat Gloves": 8,
    "Bandit King's Soul": 6, # item
    "Tribal Chief's Soul": 10, # item
    "Kameal's Soul": 42, # item
    "Satan Costume": 200,
    "Great Sword of Chaos": 251,
    "Skeletal Ribcage": 252,
    "Sloth's Magical Power": 116, # 7 deadly sins
    "Abyssal Godess's Hand": 222,
    "True Leather Globs of Assasination": 90,
    "Lilith's Magical Orb": 111,
    "Lilith's Soul": 112, # item
    "Evil God Stone of Mystic": 115,
    "Evil God Stone of Might": 113,
    "Evil God Stone of Endurance": 114,
    "Power's Second Axe": 250,
    "Wicked Robe": 230,
}

BAG_ID_MAP = {
    "default": 1,
    "cactus": 2,
    "tonberry": 3,
    "yoshi": 4,
}

def convert_code(name_in, code_in, name_out,
    item_radix: int = 650,
    num_stash_pages: int = 6,
    grims_radix: int = 71,
):
    
    values = []
    # decode + pop all values first
    mc = MasinCodec()
    st = mc.decode(code, name_in)

    values = [
        (st.pop(16), 16),
        (st.pop(28), 28),
        (st.pop(75), 75),
        (st.pop(100), 100),
        (st.pop(1001), 1001),
        (st.pop(100), 100),
        (st.pop(21), 21),
        (st.pop(28), 28),
        (st.pop(111), 111)
    ]
    # pop stats
    for i in range(3):
        values.append((st.pop(301), 301))
        values.append((st.pop(100), 100))
        values.append((st.pop(100), 100))

    # pop grims
    for i in range(4):
        values.append((st.pop(grims_radix), grims_radix))

    # pop stashes
    for page in range(num_stash_pages - 1, -1, -1):
        block = _bed_pop_block(st, 6, item_radix)
        values += _bek_push_block(block, item_radix)[::-1]

    # pop bags
    block = _bed_pop_block(st, 6, item_radix)
    values += _bek_push_block(block, item_radix)[::-1]

    # pop hero items
    for i in range(5, -1, -1):
        values.append((st.pop(item_radix), item_radix))

    # now push and encode w/ new name
    mc = MasinCodec()
    return mc.encode(values[::-1], name_out)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Wc3 Masin RPG Fullbirth save/load coder")

    subparsers = parser.add_subparsers(help='Operation to run [decode/encode]', dest="operation", required=True)

    # Decode
    decode_p = subparsers.add_parser('decode', help='Decode a save code')
    decode_p.add_argument("--code", "-c", type=str, help="Save code to decode")
    decode_p.add_argument("--code-hex", "-e", type=str, help="Save code to decode (as hex string)")
    decode_p.add_argument("--code-file", "-f", type=str, help="File w/ save code")

    decode_s = subparsers.add_parser('decode-shared', help='Decode a shared save code')
    decode_s.add_argument("--code", "-c", type=str, help="Save code to decode")
    decode_s.add_argument("--code-hex", "-e", type=str, help="Save code to decode (as hex string)")
    decode_s.add_argument("--code-file", "-f", type=str, help="File w/ save code")

    # Encode
    encode_p = subparsers.add_parser('encode', help='Encode a save code')

    encode_s = subparsers.add_parser('encode-shared', help='Encode a shared save code')

    # Convert
    convert_p = subparsers.add_parser('convert', help='Convert a save from 1 name to another')
    convert_p.add_argument("--code-file", "-f", type=str, help="File w/ save code INPUT")
    convert_p.add_argument("--code", "-c", type=str, help="save code INPUT")
    convert_p.add_argument("--code-hex", "-e", type=str, help="save code (hex) INPUT")
    convert_p.add_argument("--name-out", "-a", type=str, help="Name to convert the file to")

    parser.add_argument("--name", "-n", type=str, default="Swag",
                        help="Player name")
    parser.add_argument("--verbose", "-v", action="store_true",
                        help="Enable verbose output.")
    
    args = parser.parse_args()

    # === Map constants ===
    LV = 650        # item radix (Lv) = max(Lu[]) + 1
    LJ = 6          # number of stash pages (LJ)
    LL = 71         # gConst_71__LL (skill radix)

    if args.operation == 'decode':
        name = args.name
        code = get_code(args)

        print('==== CHARACTER SLOT ====')
        print(f'Name: {name}, code: {code}')
        mc = MasinCodec()
        st = mc.decode(code, name)
        left = mc.trailing_zero_bits_from_msb(st)
        print(f'Trailing bits: {left}') # want >= 9
        dump_player(st)
        items = dump_items(
            st,
            item_radix = LV, #Lv - qF()/Lu[], max_item_id + 1, or global bn
            num_stash_pages = LJ, # LJ - stash page count
        )
        print(items)
        
    elif args.operation == 'encode':
        name = args.name
        mc = MasinCodec()

        # --- Character basics ---
        version      = 9          # MF; MUST be <= 9 for this map to accept it

        # See HERO_ID_MAP dict
        #hero_id      = HERO_ID_MAP['Death'] # 11
        #hero_id      = HERO_ID_MAP['Sitael'] # 20
        hero_id      = HERO_ID_MAP['Angra']  # 35
        #hero_id      = HERO_ID_MAP['AngraNightmare']  # 34
        #hero_id      = HERO_ID_MAP['Blood Vampire']  # 26
        #bag_id       = BAG_ID_MAP['cactus']          # example bag index  (0..It-1)
        bag_id       = 4
        descensions  = 100          # Lh[G7]
        level        = 2_000        # level = hi*100 + lo (encoder will split/validate)
        gold         = 1_600_000     # encoder will bucketize to the compressed code (111)
        
        # --- Stats (totals). The encoder splits into Fd, MD, Ll with clamping. ---
        # max 3_000_000 total

        # set each one manually
        #str_total = 1_000_000
        #agi_total = 2_000_000
        #int_total = 0

        # or use a constant array from above
        str_total, agi_total, int_total = STR_AGI

        if str_total + agi_total + int_total > 3_000_000:
            print('!!! Stats over 3M -- this is impossible in game!')

        #hero_only = True
        hero_only = False
        if hero_only:
            # Just Hero
            code = encode_character_code(
                name,
                hero_id=hero_id,
                bag_id=bag_id,
                gold=gold,
                level=level,
                str_stat=str_total,
                agi_stat=agi_total,
                int_stat=int_total,
                version=version,
                desc=descensions,
                codec=mc,
            )
            print("Player name:", name)
            print("ENCODED:", code)
        else:
            # with items
            # --- Grims: four choices, each in [0, LL) ---
            # grimoir volumes in order: oblivion, spirit, salv, awakening
            grims = (0, 0, 0, 19)
            # --- Items ---
            # orders will be canonicalized and swapped around for the map loader
            # 6 hero inventory slots (IDs in [0, LV))
            # See ITEM_ID_MAP dict
            hero_items = [0, 0, 0, 0, 0, 0]
            hero_items[0]= ITEM_ID_MAP["Great Sword of Chaos"]
            hero_items[1]= ITEM_ID_MAP["Lilith's Soul"]
            hero_items[2]= ITEM_ID_MAP["Sloth's Magical Power"]
            hero_items[3]= ITEM_ID_MAP["Evil God Stone of Might"]
            
            # 6 bag slots on the bag unit (IDs in [0, LV))
            # these are the first page of
            # bag inventory 
            bag_items = [0, 0, 0, 0, 0, 0]
            bag_items[0] = 0 
            # LJ stash pages × 6 slots each (IDs in [0, LV))
            stash_pages = [[0]*6 for _ in range(LJ)]
            # e.g., put a couple items on page 0:
            stash_pages[0] = [111, 112, 113, 114, 115, 116]
            stash_pages[1] = [261, 262, 263, 264, 265, 267]
            #stash_pages[1][2] = 250
            #stash_pages[2][3] = 251
            #stash_pages[3][4] = 252
            #stash_pages[0][5] = 500
            
            # --- Encode ---
            code = encode_character_code(
                name,
                hero_id=hero_id,
                bag_id=bag_id,
                gold=gold,
                level=level,
                str_stat=str_total,
                agi_stat=agi_total,
                int_stat=int_total,
                version=version,
                desc=descensions,
                codec=mc,
                full_build = True,
                grims=grims,
                hero_items=hero_items,
                bag_items=bag_items,
                stash_pages=stash_pages,
                item_radix=LV,
                grims_radix=LL,
            )
            print("Player name:", name)
            print("ENCODED:", code)
        # Decode & inspect with existing dumper
        st = mc.decode(code, name)
        dump_player(st)
        if not hero_only:
            items = dump_items(
                st,
                item_radix = LV, #Lv - qF()/Lu[], max_item_id + 1, or global bn
                num_stash_pages = LJ, # LJ - stash page count
            )
            print(items)

    elif args.operation == 'encode-shared':
        name = args.name
        mc = MasinCodec()
        shared_code = encode_shared_code(name, codec=mc)
        print("SHARED CODE:", shared_code)
    elif args.operation == 'decode-shared':
        name = args.name
        code = get_code(args)

        print('==== SHARED SLOT ====')
        print(f'Name: {name}, code: {code}')
        mc = MasinCodec()
        st = mc.decode(code, name)
        print(dump_shared(st, ll=LL))
    elif args.operation == 'convert':
        name_in = args.name
        #name_in = MackyDip#11212
        code = get_code(args)
        name_out = args.name_out
        if not name_out:
            print('--name-out, -a -- required')
            sys.exit(-1)

        print('=== CODE IN DUMP ===')
        mc = MasinCodec()
        st = mc.decode(code, name_in)
        dump_player(st)
        print(dump_items(st, item_radix=LV, num_stash_pages=LJ))

        code_out = convert_code(
            name_in, 
            code,
            name_out,
            item_radix=LV,
            num_stash_pages=LJ,
            grims_radix=LL
        )
        print('=== CODE OUT DUMP ===')
        mc = MasinCodec()
        st = mc.decode(code_out, name_out)
        dump_player(st)
        print(dump_items(st, item_radix=LV, num_stash_pages=LJ))

        print('Converted Code:', code_out)
