/*
 * ProRes RAW decoder
 *
 * Copyright (c) 2025 Lynne <dev@lynne.ee>
 *
 * This file is part of FFmpeg.
 *
 * FFmpeg is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * FFmpeg is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with FFmpeg; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#version 460
#pragma shader_stage(compute)
#extension GL_GOOGLE_include_directive : require

#include "common.comp"
#include "dct.glsl"

struct TileData {
   ivec2 pos;
   uint offset;
   uint size;
};

layout (set = 0, binding = 0) uniform uimage2D dst;
layout (set = 0, binding = 1, scalar) readonly buffer frame_data_buf {
    TileData tile_data[];
};

layout (push_constant, scalar) uniform pushConstants {
   u8buf pkt_data;
   ivec2 tile_size;
   uint8_t qmat[64];
};

#define COMP_ID (gl_LocalInvocationID.z)
#define BLOCK_ID (gl_LocalInvocationID.y)
#define ROW_ID (gl_LocalInvocationID.x)

const u8vec2 scan[64] = {
    u8vec2( 0,  0), u8vec2( 4,  0), u8vec2( 0,  2), u8vec2( 4,  2),
    u8vec2( 0,  8), u8vec2( 4,  8), u8vec2( 6,  8), u8vec2( 2, 10),
    u8vec2( 2,  0), u8vec2( 6,  0), u8vec2( 2,  2), u8vec2( 6,  2),
    u8vec2( 2,  8), u8vec2( 8,  8), u8vec2( 0, 10), u8vec2( 4, 10),
    u8vec2( 8,  0), u8vec2(12,  0), u8vec2( 8,  2), u8vec2(12,  2),
    u8vec2(10,  8), u8vec2(14,  8), u8vec2( 6, 10), u8vec2( 2, 12),
    u8vec2(10,  0), u8vec2(14,  0), u8vec2(10,  2), u8vec2(14,  2),
    u8vec2(12,  8), u8vec2( 8, 10), u8vec2( 0, 12), u8vec2( 4, 12),
    u8vec2( 0,  4), u8vec2( 4,  4), u8vec2( 6,  4), u8vec2( 2,  6),
    u8vec2(10, 10), u8vec2(14, 10), u8vec2( 6, 12), u8vec2( 2, 14),
    u8vec2( 2,  4), u8vec2( 8,  4), u8vec2( 0,  6), u8vec2( 4,  6),
    u8vec2(12, 10), u8vec2( 8, 12), u8vec2( 0, 14), u8vec2( 4, 14),
    u8vec2(10,  4), u8vec2(14,  4), u8vec2( 6,  6), u8vec2(12,  6),
    u8vec2(10, 12), u8vec2(14, 12), u8vec2( 6, 14), u8vec2(12, 14),
    u8vec2(12,  4), u8vec2( 8,  6), u8vec2(10,  6), u8vec2(14,  6),
    u8vec2(12, 12), u8vec2( 8, 14), u8vec2(10, 14), u8vec2(14, 14),
};

void main(void)
{
    const uint tile_idx = gl_WorkGroupID.y*gl_NumWorkGroups.x + gl_WorkGroupID.x;
    TileData td = tile_data[tile_idx];

    int width = imageSize(dst).x;
    if (expectEXT(td.pos.x >= width, false))
        return;

    uint64_t pkt_offset = uint64_t(pkt_data) + td.offset;
    u8vec2buf hdr_data = u8vec2buf(pkt_offset);
    int qscale = pack16(hdr_data[0].v.yx);

    const ivec2 offs = td.pos + ivec2(COMP_ID & 1, COMP_ID >> 1);
    const uint w = min(tile_size.x, width - td.pos.x) >> 1;
    const uint nb_blocks = w >> 3;

    /* We have to do non-uniform access, so copy it */
    uint8_t qmat_buf[64] = qmat;

    [[unroll]]
    for (uint y = 0; y < 8; y++) {
        uint block_off = y*8 + ROW_ID;
        int v = int(imageLoad(dst, offs + 2*ivec2(BLOCK_ID*8, 0) + scan[block_off])[0]);
        float vf = float(sign_extend(v, 16)) / 32768.0;
        vf *= qmat_buf[block_off] * qscale;
        blocks[BLOCK_ID][COMP_ID*72 + y*9 + ROW_ID] = (vf / (64*4.56)) *
                                                      idct_scale[block_off];
    }

    /* Column-wise iDCT */
    idct8(BLOCK_ID, COMP_ID*72 + ROW_ID, 9);
    barrier();

    blocks[BLOCK_ID][COMP_ID*72 + ROW_ID * 9] += 0.5f;

    /* Row-wise iDCT */
    idct8(BLOCK_ID, COMP_ID*72 + ROW_ID * 9, 1);
    barrier();

    [[unroll]]
    for (uint y = 0; y < 8; y++) {
        int v = int(round(blocks[BLOCK_ID][COMP_ID*72 + y*9 + ROW_ID]*4095.0));
        v = clamp(v, 0, 4095);
        v <<= 4;
        imageStore(dst,
                   offs + 2*ivec2(BLOCK_ID*8 + ROW_ID, y),
                   ivec4(v));
    }
}
