]> git.sesse.net Git - narabu/blob - encoder.shader
Fix the DCT scaling (I believe).
[narabu] / encoder.shader
1 #version 440
2 #extension GL_ARB_shader_clock : enable
3
4 layout(local_size_x = 8) in;
5
6 layout(r16ui) uniform restrict writeonly uimage2D dc_ac7_tex;
7 layout(r16ui) uniform restrict writeonly uimage2D ac1_ac6_tex;
8 layout(r16ui) uniform restrict writeonly uimage2D ac2_ac5_tex;
9 layout(r8i) uniform restrict writeonly iimage2D ac3_tex;
10 layout(r8i) uniform restrict writeonly iimage2D ac4_tex;
11 layout(r8ui) uniform restrict readonly uimage2D image_tex;
12
13 shared float temp[64];
14
15 // Scale factors; 1.0 / (sqrt(2.0) * cos(k * M_PI / 16.0)), except for the first which is 1.
16 const float sf[8] = {
17         1.0, 0.7209598220069479, 0.765366864730180, 0.8504300947672564,
18         1.0, 1.2727585805728336, 1.847759065022573, 3.6245097854115502
19 };
20
21 const float W[64] = {
22          8, 16, 19, 22, 26, 27, 29, 34,
23         16, 16, 22, 24, 27, 29, 34, 37,
24         19, 22, 26, 27, 29, 34, 34, 38,
25         22, 22, 26, 27, 29, 34, 37, 40,
26         22, 26, 27, 29, 32, 35, 40, 48,
27         26, 27, 29, 32, 35, 40, 48, 58,
28         26, 27, 29, 34, 38, 46, 56, 69,
29         27, 29, 35, 38, 46, 56, 69, 83
30 };
31 const float S = 4.0 * 0.5;  // whatever?
32
33 // NOTE: Contains factors to counteract the scaling in the DCT implementation.
34 const float quant_matrix[64] = {
35         sf[0] * sf[0] / 64.0,         sf[1] * sf[0] / (W[ 1] * S),  sf[2] * sf[0] / (W[ 2] * S),  sf[3] * sf[0] / (W[ 3] * S),  sf[4] * sf[0] / (W[ 4] * S),  sf[5] * sf[0] / (W[ 5] * S),  sf[6] * sf[0] / (W[ 6] * S),  sf[7] * sf[0] / (W[ 7] * S),
36         sf[0] * sf[1] / (W[ 8] * S),  sf[1] * sf[1] / (W[ 9] * S),  sf[2] * sf[1] / (W[10] * S),  sf[3] * sf[1] / (W[11] * S),  sf[4] * sf[1] / (W[12] * S),  sf[5] * sf[1] / (W[13] * S),  sf[6] * sf[1] / (W[14] * S),  sf[7] * sf[1] / (W[15] * S),
37         sf[0] * sf[2] / (W[16] * S),  sf[1] * sf[2] / (W[17] * S),  sf[2] * sf[2] / (W[18] * S),  sf[3] * sf[2] / (W[19] * S),  sf[4] * sf[2] / (W[20] * S),  sf[5] * sf[2] / (W[21] * S),  sf[6] * sf[2] / (W[22] * S),  sf[7] * sf[2] / (W[23] * S),
38         sf[0] * sf[3] / (W[24] * S),  sf[1] * sf[3] / (W[25] * S),  sf[2] * sf[3] / (W[26] * S),  sf[3] * sf[3] / (W[27] * S),  sf[4] * sf[3] / (W[28] * S),  sf[5] * sf[3] / (W[29] * S),  sf[6] * sf[3] / (W[30] * S),  sf[7] * sf[3] / (W[31] * S),
39         sf[0] * sf[4] / (W[32] * S),  sf[1] * sf[4] / (W[33] * S),  sf[2] * sf[4] / (W[34] * S),  sf[3] * sf[4] / (W[35] * S),  sf[4] * sf[4] / (W[36] * S),  sf[5] * sf[4] / (W[37] * S),  sf[6] * sf[4] / (W[38] * S),  sf[7] * sf[4] / (W[39] * S),
40         sf[0] * sf[5] / (W[40] * S),  sf[1] * sf[5] / (W[41] * S),  sf[2] * sf[5] / (W[42] * S),  sf[3] * sf[5] / (W[43] * S),  sf[4] * sf[5] / (W[44] * S),  sf[5] * sf[5] / (W[45] * S),  sf[6] * sf[5] / (W[46] * S),  sf[7] * sf[5] / (W[47] * S),
41         sf[0] * sf[6] / (W[48] * S),  sf[1] * sf[6] / (W[49] * S),  sf[2] * sf[6] / (W[50] * S),  sf[3] * sf[6] / (W[51] * S),  sf[4] * sf[6] / (W[52] * S),  sf[5] * sf[6] / (W[53] * S),  sf[6] * sf[6] / (W[54] * S),  sf[7] * sf[6] / (W[55] * S),
42         sf[0] * sf[7] / (W[56] * S),  sf[1] * sf[7] / (W[57] * S),  sf[2] * sf[7] / (W[58] * S),  sf[3] * sf[7] / (W[59] * S),  sf[4] * sf[7] / (W[60] * S),  sf[5] * sf[7] / (W[61] * S),  sf[6] * sf[7] / (W[62] * S),  sf[7] * sf[7] / (W[63] * S)
43 };
44
45 // Clamp and pack a 9-bit and a 7-bit signed value into a 16-bit word.
46 uint pack_9_7(int v9, int v7)
47 {
48         return (uint(clamp(v9, -256, 255)) & 0x1ffu) | ((uint(clamp(v7, -64, 63)) & 0x7fu) << 9);
49 }
50
51 // Scaled 1D DCT (AA&N). y0 is correctly scaled, all other y_k are scaled by sqrt(2) cos(k * Pi / 16).
52 void dct_1d(inout float y0, inout float y1, inout float y2, inout float y3, inout float y4, inout float y5, inout float y6, inout float y7)
53 {
54         const float a1 = 0.7071067811865474;   // sqrt(2)
55         const float a2 = 0.5411961001461971;   // cos(3/8 pi) * sqrt(2)
56         const float a4 = 1.3065629648763766;   // cos(pi/8) * sqrt(2)
57         // static const float a5 = 0.5 * (a4 - a2);
58         const float a5 = 0.3826834323650897;
59
60         // phase 1
61         const float p1_0 = y0 + y7;
62         const float p1_1 = y1 + y6;
63         const float p1_2 = y2 + y5;
64         const float p1_3 = y3 + y4;
65         const float p1_4 = y3 - y4;
66         const float p1_5 = y2 - y5;
67         const float p1_6 = y1 - y6;
68         const float p1_7 = y0 - y7;
69
70         // phase 2
71         const float p2_0 = p1_0 + p1_3;
72         const float p2_1 = p1_1 + p1_2;
73         const float p2_2 = p1_1 - p1_2;
74         const float p2_3 = p1_0 - p1_3;
75         const float p2_4 = p1_4 + p1_5;  // Inverted.
76         const float p2_5 = p1_5 + p1_6;
77         const float p2_6 = p1_6 + p1_7;
78
79         // phase 3
80         const float p3_0 = p2_0 + p2_1;
81         const float p3_1 = p2_0 - p2_1;
82         const float p3_2 = p2_2 + p2_3;
83         
84         // phase 4
85         const float p4_2 = p3_2 * a1;
86         const float p4_4 = p2_4 * a2 + (p2_4 - p2_6) * a5;
87         const float p4_5 = p2_5 * a1;
88         const float p4_6 = p2_6 * a4 + (p2_4 - p2_6) * a5;
89
90         // phase 5
91         const float p5_2 = p2_3 + p4_2;
92         const float p5_3 = p2_3 - p4_2;
93         const float p5_5 = p1_7 + p4_5;
94         const float p5_7 = p1_7 - p4_5;
95         
96         // phase 6
97         y0 = p3_0;
98         y4 = p3_1;
99         y2 = p5_2;
100         y6 = p5_3;
101         y5 = p4_4 + p5_7;
102         y1 = p5_5 + p4_6;
103         y7 = p5_5 - p4_6;
104         y3 = p5_7 - p4_4;
105 }
106 void main()
107 {
108         uint x = 8 * gl_WorkGroupID.x;
109         uint y = 8 * gl_WorkGroupID.y;
110         uint n = gl_LocalInvocationID.x;
111
112         // Load column.
113         float y0 = imageLoad(image_tex, ivec2(x + n, y + 0)).x;
114         float y1 = imageLoad(image_tex, ivec2(x + n, y + 1)).x;
115         float y2 = imageLoad(image_tex, ivec2(x + n, y + 2)).x;
116         float y3 = imageLoad(image_tex, ivec2(x + n, y + 3)).x;
117         float y4 = imageLoad(image_tex, ivec2(x + n, y + 4)).x;
118         float y5 = imageLoad(image_tex, ivec2(x + n, y + 5)).x;
119         float y6 = imageLoad(image_tex, ivec2(x + n, y + 6)).x;
120         float y7 = imageLoad(image_tex, ivec2(x + n, y + 7)).x;
121
122         // Vertical DCT.
123         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
124
125         // Communicate with the other shaders in the group.
126         temp[n + 0 * 8] = y0;
127         temp[n + 1 * 8] = y1;
128         temp[n + 2 * 8] = y2;
129         temp[n + 3 * 8] = y3;
130         temp[n + 4 * 8] = y4;
131         temp[n + 5 * 8] = y5;
132         temp[n + 6 * 8] = y6;
133         temp[n + 7 * 8] = y7;
134
135         memoryBarrierShared();
136         barrier();
137
138         // Load row (so transpose, in a sense).
139         y0 = temp[n * 8 + 0];
140         y1 = temp[n * 8 + 1];
141         y2 = temp[n * 8 + 2];
142         y3 = temp[n * 8 + 3];
143         y4 = temp[n * 8 + 4];
144         y5 = temp[n * 8 + 5];
145         y6 = temp[n * 8 + 6];
146         y7 = temp[n * 8 + 7];
147
148         // Horizontal DCT.
149         dct_1d(y0, y1, y2, y3, y4, y5, y6, y7);
150
151         // Quantize.
152         int c0 = int(round(y0 * quant_matrix[n * 8 + 0]));
153         int c1 = int(round(y1 * quant_matrix[n * 8 + 1]));
154         int c2 = int(round(y2 * quant_matrix[n * 8 + 2]));
155         int c3 = int(round(y3 * quant_matrix[n * 8 + 3]));
156         int c4 = int(round(y4 * quant_matrix[n * 8 + 4]));
157         int c5 = int(round(y5 * quant_matrix[n * 8 + 5]));
158         int c6 = int(round(y6 * quant_matrix[n * 8 + 6]));
159         int c7 = int(round(y7 * quant_matrix[n * 8 + 7]));
160
161         // Clamp, pack and store.
162         uint sx = gl_WorkGroupID.x;
163         imageStore(dc_ac7_tex,  ivec2(sx, y + n), uvec4(pack_9_7(c0, c7), 0, 0, 0));
164         imageStore(ac1_ac6_tex, ivec2(sx, y + n), uvec4(pack_9_7(c1, c6), 0, 0, 0));
165         imageStore(ac2_ac5_tex, ivec2(sx, y + n), uvec4(pack_9_7(c2, c5), 0, 0, 0));
166         imageStore(ac3_tex,     ivec2(sx, y + n), ivec4(c3, 0, 0, 0));
167         imageStore(ac4_tex,     ivec2(sx, y + n), ivec4(c4, 0, 0, 0));
168 }
169