-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNNProcessor.compute
More file actions
56 lines (50 loc) · 1.94 KB
/
NNProcessor.compute
File metadata and controls
56 lines (50 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// Each #kernel tells which function to compile; you can have many kernels
#pragma kernel NNAddWeights1024
#pragma kernel NNSumWeights256
#pragma kernel NNCreateOutput
#pragma kernel NNProcessVisuals
// Create a RenderTexture with enableRandomWrite flag and set it
// with cs.SetTexture
RWTexture2D<float4> Result; //starts 32x32, ends 32x32...
RWTexture2D<float4> weightValues; //1024x1024 tex
RWTexture2D<float4> weightedValues; //1024x1024 texture.
RWStructuredBuffer<float4> output;
//dispatch with 1,4,64
[numthreads(8, 8, 16)]
void NNAddWeights1024(uint3 id : SV_DispatchThreadID) {
weightedValues[uint2(id.x * id.y, id.z)] = ((Result[id.xy] + weightValues[uint2(id.x * id.y, id.z)]) + (Result[uint2(id.x + 8, id.y)] + weightValues[uint2((id.x + 8) * id.y, id.z)]) + (Result[uint2(id.x + 16, id.y)] + weightValues[uint2((id.x + 16) * id.y, id.z)]) + (Result[uint2(id.x + 24,id.y)] + weightValues[uint2((id.x + 24) * id.y, id.z)])) / float4(4, 4, 4, 4);
}
// id.x, id.y and id.z, respectively:
///dispatch with 16,64,1
[numthreads(8, 16, 8)]
void NNSumWeights256(uint3 id : SV_DispatchThreadID)
{
if (id.x < (256 >> id.z))
{
weightedValues[id.xy] = ((weightedValues[uint2(id.x, id.y)] + weightedValues[uint2(id.x + (256 >> id.z), id.y)])) / 2;
GroupMemoryBarrierWithGroupSync();
}
}
//DISPATCH WITH 1,1,1
[numthreads(32,32,1)]
void NNCreateOutput(uint3 id : SV_DispatchThreadID)
{
Result[id.xy] = weightedValues[uint2(0,id.x * id.y)];
}
//DISPATCH WITH 1,1,1
[numthreads(1024,1,1)]
void NNProcessVisuals(uint3 id : SV_DispatchThreadID)
{
//so, get 32x32 input.
//Result[id.xy] = weightedValues[uint2(0,id.x * id.y)];
//we are at 1x1...
//start by adding the weights to each of the 1024x1024 tex...
float4 sum;
for (int x=0;x<32;x++){
for (int y=0;y<32;y++){
sum += Result[uint2(x,y)] + weightValues[uint2(x*y,id.x)];
}
}
sum = sum/1024;
output[id.x%32 + ((id.x-id.x%32)*32))] = sum;
}