import argparse
import io
from pygltflib import GLTF2, BufferView, Buffer
from PIL import Image

# Standardwerte
DEFAULT_MAXSIZE = 1024
DEFAULT_JPEGQUALITY = 90

def get_image_bytes(gltf, img_idx):
    """extract image data from bufferView."""
    img = gltf.images[img_idx]
    if img.bufferView is not None and img.bufferView < len(gltf.bufferViews):
        bv = gltf.bufferViews[img.bufferView]
        blob = gltf.binary_blob()
        start = bv.byteOffset or 0
        end = start + bv.byteLength
        return blob[start:end], img.mimeType
    return None, None


def resize_and_compress(data: bytes, name: str, maxsize: int, jpegquality: int):
    """
    scales to maxsize px, converts:
    - PNG (if Alpha)
    - JPEG (if no alpha)
    """
    with Image.open(io.BytesIO(data)) as img:
        orig_size = img.size
        w, h = img.size

        # scale if greater than maxsize
        if max(w, h) > maxsize:
            scale = maxsize / max(w, h)
            new_size = (int(w * scale), int(h * scale))
            img = img.resize(new_size, Image.LANCZOS)
        else:
            new_size = orig_size

        buf = io.BytesIO()
        if "A" in img.getbands():
            img.save(buf, format="PNG", optimize=True)
            return buf.getvalue(), (name or "image") + ".png", "image/png", orig_size, new_size
        else:
            img.convert("RGB").save(buf, format="JPEG", quality=jpegquality)
            return buf.getvalue(), (name or "image") + ".jpg", "image/jpeg", orig_size, new_size


def replace_image_bytes(gltf: GLTF2, img_idx: int, data: bytes, mime: str):
    """replace embedded image data"""
    img = gltf.images[img_idx]
    blob = bytearray(gltf.binary_blob() or b"")

    if img.bufferView is not None and img.bufferView < len(gltf.bufferViews):
        bv = gltf.bufferViews[img.bufferView]
        start = bv.byteOffset or 0
        end = start + bv.byteLength

        if len(data) <= bv.byteLength:
            # overwrite direct (with padding if smaller)
            blob[start:start+len(data)] = data
            if len(data) < bv.byteLength:
                blob[start+len(data):end] = b"\x00" * (bv.byteLength - len(data))
            bv.byteLength = len(data)
        else:
            # new bufferview
            offset = len(blob)
            blob.extend(data)
            bv = BufferView(buffer=0, byteOffset=offset, byteLength=len(data))
            gltf.bufferViews.append(bv)
            img.bufferView = len(gltf.bufferViews) - 1
    else:
        # no valid bufferview - create new
        offset = len(blob)
        blob.extend(data)
        bv = BufferView(buffer=0, byteOffset=offset, byteLength=len(data))
        gltf.bufferViews.append(bv)
        img.bufferView = len(gltf.bufferViews) - 1

    img.uri = None
    img.mimeType = mime

    if not gltf.buffers:
        gltf.buffers = [Buffer(byteLength=len(blob))]
    else:
        gltf.buffers[0].byteLength = len(blob)

    gltf.set_binary_blob(bytes(blob))


def main():
    parser = argparse.ArgumentParser(
        description="Compress GLB textures: JPEG (no alpha), PNG (with alpha). Overwrites original!"
    )
    parser.add_argument("inputfile", type=str, help="Input GLB file")
    args = parser.parse_args()

    outputfile = str(args.inputfile)

    gltf = GLTF2().load(args.inputfile)

    if gltf.images:
        for idx, img in enumerate(gltf.images):
            data, mime = get_image_bytes(gltf, idx)
            if not data:
                print(f"Skip {img.name or f'image_{idx}'} (no data found)")
                continue

            try:
                new_bytes, new_name, new_mime, old_size, new_size = resize_and_compress(
                    data, img.name or f"image_{idx}", DEFAULT_MAXSIZE, DEFAULT_JPEGQUALITY
                )
                replace_image_bytes(gltf, idx, new_bytes, new_mime)
                img.name = new_name
            except Exception as e:
                print(f"Error at {img.name or f'image_{idx}'}: {e}")

    # save compressed GLB
    gltf.save_binary(outputfile)
    print(f"- Compressed GLB saved as {outputfile}")


if __name__ == "__main__":
    main()
