[mtl canvas, wip] implemented backprop

This commit is contained in:
Martin Fouilleul 2023-03-29 14:27:05 +02:00
parent c4415aaeea
commit a6c53649bd
3 changed files with 260 additions and 55 deletions

View File

@ -59,18 +59,26 @@ typedef struct mg_mtl_path_queue
using namespace metal;
#endif
typedef enum { MG_MTL_OP_SEGMENT } mg_mtl_tile_op_kind;
typedef enum { MG_MTL_OP_START,
MG_MTL_OP_SEGMENT } mg_mtl_tile_op_kind;
typedef struct mg_mtl_tile_op
{
mg_mtl_tile_op_kind kind;
int index;
int next;
union
{
bool crossRight;
int windingOffset;
};
} mg_mtl_tile_op;
typedef struct mg_mtl_tile_queue
{
atomic_int windingOffset;
atomic_int first;
int last;
} mg_mtl_tile_queue;

View File

@ -27,13 +27,18 @@ typedef struct mg_mtl_canvas_backend
id<MTLComputePipelineState> pathPipeline;
id<MTLComputePipelineState> segmentPipeline;
id<MTLComputePipelineState> backpropPipeline;
id<MTLComputePipelineState> rasterPipeline;
id<MTLRenderPipelineState> blitPipeline;
id<MTLTexture> outTexture;
id<MTLBuffer> pathBuffer;
id<MTLBuffer> elementBuffer;
int bufferIndex;
dispatch_semaphore_t bufferSemaphore;
id<MTLBuffer> pathBuffer[MG_MTL_INPUT_BUFFERS_COUNT];
id<MTLBuffer> elementBuffer[MG_MTL_INPUT_BUFFERS_COUNT];
id<MTLBuffer> segmentCountBuffer;
id<MTLBuffer> segmentBuffer;
id<MTLBuffer> pathQueueBuffer;
@ -63,8 +68,11 @@ void mg_mtl_canvas_render(mg_canvas_backend* interface,
//TODO: update rolling buffers
mg_mtl_path_elt* elementBufferData = (mg_mtl_path_elt*)[backend->elementBuffer contents];
mg_mtl_path* pathBufferData = (mg_mtl_path*)[backend->pathBuffer contents];
dispatch_semaphore_wait(backend->bufferSemaphore, DISPATCH_TIME_FOREVER);
backend->bufferIndex = (backend->bufferIndex + 1) % MG_MTL_INPUT_BUFFERS_COUNT;
mg_mtl_path_elt* elementBufferData = (mg_mtl_path_elt*)[backend->elementBuffer[backend->bufferIndex] contents];
mg_mtl_path* pathBufferData = (mg_mtl_path*)[backend->pathBuffer[backend->bufferIndex] contents];
//NOTE: fill renderer input buffers
int pathCount = 0;
@ -153,14 +161,14 @@ void mg_mtl_canvas_render(mg_canvas_backend* interface,
[pathEncoder setComputePipelineState: backend->pathPipeline];
[pathEncoder setBytes:&pathCount length:sizeof(int) atIndex:0];
[pathEncoder setBuffer:backend->pathBuffer offset:0 atIndex:1];
[pathEncoder setBuffer:backend->pathBuffer[backend->bufferIndex] offset:0 atIndex:1];
[pathEncoder setBuffer:backend->pathQueueBuffer offset:0 atIndex:2];
[pathEncoder setBuffer:backend->tileQueueBuffer offset:0 atIndex:3];
[pathEncoder setBuffer:backend->tileQueueCountBuffer offset:0 atIndex:4];
[pathEncoder setBytes:&tileSize length:sizeof(int) atIndex:5];
MTLSize pathGridSize = MTLSizeMake(pathCount, 1, 1);
MTLSize pathGroupSize = MTLSizeMake(64, 1, 1);
MTLSize pathGroupSize = MTLSizeMake([backend->pathPipeline maxTotalThreadsPerThreadgroup], 1, 1);
[pathEncoder dispatchThreads: pathGridSize threadsPerThreadgroup: pathGroupSize];
[pathEncoder endEncoding];
@ -171,7 +179,7 @@ void mg_mtl_canvas_render(mg_canvas_backend* interface,
[segmentEncoder setComputePipelineState: backend->segmentPipeline];
[segmentEncoder setBytes:&eltCount length:sizeof(int) atIndex:0];
[segmentEncoder setBuffer:backend->elementBuffer offset:0 atIndex:1];
[segmentEncoder setBuffer:backend->elementBuffer[backend->bufferIndex] offset:0 atIndex:1];
[segmentEncoder setBuffer:backend->segmentCountBuffer offset:0 atIndex:2];
[segmentEncoder setBuffer:backend->segmentBuffer offset:0 atIndex:3];
[segmentEncoder setBuffer:backend->pathQueueBuffer offset:0 atIndex:4];
@ -181,18 +189,32 @@ void mg_mtl_canvas_render(mg_canvas_backend* interface,
[segmentEncoder setBytes:&tileSize length:sizeof(int) atIndex:8];
MTLSize segmentGridSize = MTLSizeMake(mtlEltCount, 1, 1);
MTLSize segmentGroupSize = MTLSizeMake(64, 1, 1);
MTLSize segmentGroupSize = MTLSizeMake([backend->segmentPipeline maxTotalThreadsPerThreadgroup], 1, 1);
[segmentEncoder dispatchThreads: segmentGridSize threadsPerThreadgroup: segmentGroupSize];
[segmentEncoder endEncoding];
//NOTE: backprop pass
id<MTLComputeCommandEncoder> backpropEncoder = [surface->commandBuffer computeCommandEncoder];
backpropEncoder.label = @"backprop pass";
[backpropEncoder setComputePipelineState: backend->backpropPipeline];
[backpropEncoder setBuffer:backend->pathQueueBuffer offset:0 atIndex:0];
[backpropEncoder setBuffer:backend->tileQueueBuffer offset:0 atIndex:1];
MTLSize backpropGroupSize = MTLSizeMake([backend->backpropPipeline maxTotalThreadsPerThreadgroup], 1, 1);
MTLSize backpropGridSize = MTLSizeMake(pathCount*backpropGroupSize.width, 1, 1);
[backpropEncoder dispatchThreads: backpropGridSize threadsPerThreadgroup: backpropGroupSize];
[backpropEncoder endEncoding];
//NOTE: raster pass
id<MTLComputeCommandEncoder> rasterEncoder = [surface->commandBuffer computeCommandEncoder];
rasterEncoder.label = @"raster pass";
[rasterEncoder setComputePipelineState: backend->rasterPipeline];
[rasterEncoder setBytes:&pathCount length:sizeof(int) atIndex:0];
[rasterEncoder setBuffer:backend->pathBuffer offset:0 atIndex:1];
[rasterEncoder setBuffer:backend->pathBuffer[backend->bufferIndex] offset:0 atIndex:1];
[rasterEncoder setBuffer:backend->segmentCountBuffer offset:0 atIndex:2];
[rasterEncoder setBuffer:backend->segmentBuffer offset:0 atIndex:3];
[rasterEncoder setBuffer:backend->pathQueueBuffer offset:0 atIndex:4];
@ -230,6 +252,12 @@ void mg_mtl_canvas_render(mg_canvas_backend* interface,
vertexCount: 3 ];
[renderEncoder endEncoding];
}
//NOTE: finalize
[surface->commandBuffer addCompletedHandler:^(id<MTLCommandBuffer> commandBuffer)
{
dispatch_semaphore_signal(backend->bufferSemaphore);
}];
}
}
@ -241,11 +269,15 @@ void mg_mtl_canvas_destroy(mg_canvas_backend* interface)
{
[backend->pathPipeline release];
[backend->segmentPipeline release];
[backend->backpropPipeline release];
[backend->rasterPipeline release];
[backend->blitPipeline release];
[backend->pathBuffer release];
[backend->elementBuffer release];
for(int i=0; i<MG_MTL_INPUT_BUFFERS_COUNT; i++)
{
[backend->pathBuffer[i] release];
[backend->elementBuffer[i] release];
}
[backend->segmentCountBuffer release];
[backend->segmentBuffer release];
[backend->tileQueueBuffer release];
@ -296,6 +328,7 @@ mg_canvas_backend* mg_mtl_canvas_create(mg_surface surface)
}
id<MTLFunction> pathFunction = [library newFunctionWithName:@"mtl_path_setup"];
id<MTLFunction> segmentFunction = [library newFunctionWithName:@"mtl_segment_setup"];
id<MTLFunction> backpropFunction = [library newFunctionWithName:@"mtl_backprop"];
id<MTLFunction> rasterFunction = [library newFunctionWithName:@"mtl_raster"];
id<MTLFunction> vertexFunction = [library newFunctionWithName:@"mtl_vertex_shader"];
id<MTLFunction> fragmentFunction = [library newFunctionWithName:@"mtl_fragment_shader"];
@ -309,6 +342,9 @@ mg_canvas_backend* mg_mtl_canvas_create(mg_surface surface)
backend->segmentPipeline = [metalSurface->device newComputePipelineStateWithFunction: segmentFunction
error:&error];
backend->backpropPipeline = [metalSurface->device newComputePipelineStateWithFunction: backpropFunction
error:&error];
backend->rasterPipeline = [metalSurface->device newComputePipelineStateWithFunction: rasterFunction
error:&error];
@ -343,14 +379,21 @@ mg_canvas_backend* mg_mtl_canvas_create(mg_surface surface)
backend->outTexture = [metalSurface->device newTextureWithDescriptor:texDesc];
//NOTE: create buffers
backend->bufferSemaphore = dispatch_semaphore_create(MG_MTL_INPUT_BUFFERS_COUNT);
backend->bufferIndex = 0;
MTLResourceOptions bufferOptions = MTLResourceCPUCacheModeWriteCombined
| MTLResourceStorageModeShared;
backend->pathBuffer = [metalSurface->device newBufferWithLength: MG_MTL_PATH_BUFFER_SIZE
options: bufferOptions];
for(int i=0; i<MG_MTL_INPUT_BUFFERS_COUNT; i++)
{
backend->pathBuffer[i] = [metalSurface->device newBufferWithLength: MG_MTL_PATH_BUFFER_SIZE
options: bufferOptions];
backend->elementBuffer = [metalSurface->device newBufferWithLength: MG_MTL_ELEMENT_BUFFER_SIZE
options: bufferOptions];
backend->elementBuffer[i] = [metalSurface->device newBufferWithLength: MG_MTL_ELEMENT_BUFFER_SIZE
options: bufferOptions];
}
bufferOptions = MTLResourceStorageModePrivate;
backend->segmentBuffer = [metalSurface->device newBufferWithLength: MG_MTL_SEGMENT_BUFFER_SIZE

View File

@ -33,6 +33,8 @@ kernel void mtl_path_setup(constant int* pathCount [[buffer(0)]],
for(int i=0; i<tileCount; i++)
{
atomic_store_explicit(&tileQueues[i].first, -1, memory_order_relaxed);
tileQueues[i].last = -1;
atomic_store_explicit(&tileQueues[i].windingOffset, 0, memory_order_relaxed);
}
}
@ -67,12 +69,12 @@ bool mtl_is_left_of_segment(float2 p, const device mg_mtl_segment* seg)
float dy = p.y - seg->box.y;
if( (seg->config == MG_MTL_BR && dy > alpha*dx)
||(seg->config == MG_MTL_TR && dy < ofs - alpha*dx))
||(seg->config == MG_MTL_TR && dy < ofs - alpha*dx))
{
isLeft = true;
}
else if( !(seg->config == MG_MTL_TL && dy < alpha*dx)
&& !(seg->config == MG_MTL_BL && dy > ofs - alpha*dx))
&& !(seg->config == MG_MTL_BL && dy > ofs - alpha*dx))
{
//Need implicit test, but for lines, we only have config BR or TR, so the test is always negative for now
}
@ -95,7 +97,7 @@ kernel void mtl_segment_setup(constant int* elementCount [[buffer(0)]],
float2 p0 = elt->p[0];
float2 p3 = elt->p[3];
if(elt->kind == MG_MTL_LINE && p0.y != p3.y)
if(elt->kind == MG_MTL_LINE)
{
int segIndex = atomic_fetch_add_explicit(segmentCount, 1, memory_order_relaxed);
device mg_mtl_segment* seg = &segmentBuffer[segIndex];
@ -107,13 +109,15 @@ kernel void mtl_segment_setup(constant int* elementCount [[buffer(0)]],
max(p0.y, p3.y)};
if( (p3.x > p0.x && p3.y < p0.y)
||(p3.x <= p0.x && p3.y > p0.y))
||(p3.x <= p0.x && p3.y > p0.y))
{
seg->config = MG_MTL_TR;
}
else if( (p3.x > p0.x && p3.y > p0.y)
||(p3.x <= p0.x && p3.y < p0.y))
else if( (p3.x > p0.x && p3.y >= p0.y)
||(p3.x <= p0.x && p3.y <= p0.y))
{
//NOTE: it is important to include horizontal segments here, so that the mtl_is_left_of_segment() test
// becomes x > seg->box.x, in order to correctly detect right-crossing horizontal segments
seg->config = MG_MTL_BR;
}
@ -139,15 +143,20 @@ kernel void mtl_segment_setup(constant int* elementCount [[buffer(0)]],
float(y + pathQueue->area.y + 1)} * float(tileSize[0]);
//NOTE: select two corners of tile box to test against the curve
float2 testPoint[2] = {{tileBox.x, tileBox.y},
{tileBox.z, tileBox.w}};
if(seg->config == MG_MTL_BR || seg->config == MG_MTL_TL)
float2 testPoint0;
float2 testPoint1;
if(seg->config == MG_MTL_BL || seg->config == MG_MTL_TR)
{
testPoint[0] = (float2){tileBox.x, tileBox.w};
testPoint[1] = (float2){tileBox.z, tileBox.y};
testPoint0 = (float2){tileBox.x, tileBox.y},
testPoint1 = (float2){tileBox.z, tileBox.w};
}
bool test0 = mtl_is_left_of_segment(testPoint[0], seg);
bool test1 = mtl_is_left_of_segment(testPoint[1], seg);
else
{
testPoint0 = (float2){tileBox.z, tileBox.y};
testPoint1 = (float2){tileBox.x, tileBox.w};
}
bool test0 = mtl_is_left_of_segment(testPoint0, seg);
bool test1 = mtl_is_left_of_segment(testPoint1, seg);
//NOTE: the curve overlaps the tile only if test points are on opposite sides of segment
if(test0 != test1)
@ -159,13 +168,160 @@ kernel void mtl_segment_setup(constant int* elementCount [[buffer(0)]],
op->index = segIndex;
int tileIndex = y*pathQueue->area.z + x;
op->next = atomic_exchange_explicit(&tileQueues[tileIndex].first, tileOpIndex, memory_order_relaxed);
device mg_mtl_tile_queue* tile = &tileQueues[tileIndex];
op->next = atomic_exchange_explicit(&tile->first, tileOpIndex, memory_order_relaxed);
if(op->next == -1)
{
tile->last = tileOpIndex;
}
//NOTE: if the segment crosses the tile's bottom boundary, update the tile's winding offset
// testPoint0 is always a bottom point. We select the other one and check if they are on
// opposite sides of the curve.
// We also need to check that the endpoints of the curve are on opposite sides of the bottom
// boundary.
float2 testPoint3;
if(seg->config == MG_MTL_BL || seg->config == MG_MTL_TR)
{
testPoint3 = (float2){tileBox.z, tileBox.y};
}
else
{
testPoint3 = (float2){tileBox.x, tileBox.y};
}
bool test3 = mtl_is_left_of_segment(testPoint3, seg);
if( test0 != test3
&& seg->box.y <= testPoint0.y
&& seg->box.w > testPoint0.y)
{
atomic_fetch_add_explicit(&tile->windingOffset, seg->windingIncrement, memory_order_relaxed);
}
//NOTE: if the segment crosses the right boundary, mark it. We reuse one of the previous tests
float2 top = {tileBox.z, tileBox.w};
bool testTop = mtl_is_left_of_segment(top, seg);
bool testBottom = (seg->config == MG_MTL_BL || seg->config == MG_MTL_TR)? test3 : test0;
if(testTop != testBottom
&& seg->box.x <= top.x
&& seg->box.z > top.x)
{
op->crossRight = true;
}
else
{
op->crossRight = false;
}
}
}
}
}
}
kernel void mtl_backprop(const device mg_mtl_path_queue* pathQueueBuffer [[buffer(0)]],
device mg_mtl_tile_queue* tileQueueBuffer [[buffer(1)]],
uint pathIndex [[threadgroup_position_in_grid]],
uint localID [[thread_position_in_threadgroup]])
{
threadgroup atomic_int nextRowIndex;
if(localID == 0)
{
atomic_store_explicit(&nextRowIndex, 0, memory_order_relaxed);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
int rowIndex = 0;
const device mg_mtl_path_queue* pathQueue = &pathQueueBuffer[pathIndex];
device mg_mtl_tile_queue* tiles = &tileQueueBuffer[pathQueue->tileQueues];
int rowSize = pathQueue->area.z;
int rowCount = pathQueue->area.w;
rowIndex = atomic_fetch_add_explicit(&nextRowIndex, 1, memory_order_relaxed);
while(rowIndex < rowCount)
{
device mg_mtl_tile_queue* row = &tiles[rowIndex * rowSize];
int sum = 0;
for(int x = rowSize-1; x >= 0; x--)
{
device mg_mtl_tile_queue* tile = &row[x];
int offset = *(device int*)&tile->windingOffset;
*(device int*)(&tile->windingOffset) = sum;
sum += offset;
}
rowIndex = atomic_fetch_add_explicit(&nextRowIndex, 1, memory_order_relaxed);
}
}
/*
kernel void mtl_merge(constant int* pathCount [[buffer(0)]],
const device mg_mtl_path* pathBuffer [[buffer(1)]],
const device mg_mtl_path_queue* pathQueueBuffer [[buffer(2)]],
const device mg_mtl_tile_queue* tileQueueBuffer [[buffer(3)]],
device mg_mtl_tile_op* tileOpBuffer [[buffer(4)]],
device atomic_int* tileOpCount [[buffer(5)]],
device int* screenTilesBuffer [[buffer(6)]],
uint2 threadCoord [[thread_position_in_grid]],
uint2 gridSize [[threads_per_grid]])
{
int2 tileCoord = int2(threadCoord);
int tileIndex = tileCoord.y * gridSize.x + tileCoord.x;
device int* nextLink = &screenTilesBuffer[tileIndex];
for(int pathIndex = 0; pathIndex < pathCount[0]; pathIndex++)
{
const device mg_mtl_path_queue* pathQueue = &pathQueueBuffer[pathIndex];
int2 pathTileCoord = tileCoord - pathQueue->area.xy;
if( pathTileCoord.x >= 0
&& pathTileCoord.x < pathQueue->area.z
&& pathTileCoord.y >= 0
&& pathTileCoord.y < pathQueue->area.w)
{
int pathTileIndex = pathTileCoord.y * pathQueue->area.z + pathTileCoord.x;
const device mg_mtl_tile_queue* tileQueue = &tileQueueBuffer[pathQueue->tileQueues + pathTileIndex];
int windingOffset = atomic_load_explicit(&tileQueue->windingOffset, memory_order_relaxed);
int opIndex = atomic_load_explicit(&tileQueue->first, memory_order_relaxed);
if((opIndex != -1) || (windingOffset & 1))
{
//NOTE: add path start op (with winding offset)
int startOpIndex = atomic_fetch_add_explicit(tileOpCount, 1, memory_order_relaxed);
device mg_mtl_tile_op* startOp = &tileOpBuffer[startOpIndex];
startOp->kind = MG_MTL_OP_START;
startOp->index = pathIndex;
startOp->windingOffset = windingOffset;
if(opIndex == -1)
{
//NOTE: the tile is fully covered by path fill. Insert start op,
// and if the fill color is opaque, trim tile list.
if(pathBuffer[pathIndex].color.a == 1)
{
screenTilesBuffer[tileIndex] = startOpIndex;
}
else
{
*nextLink = startOpIndex;
}
nextLink = &startOp->next;
}
else
{
//NOTE: add start op
*nextLink = startOpIndex;
nextLink = &startOp->next;
//NOTE: chain path ops to end of tile list
device mg_mtl_tile_op* lastOp = &tileOpBuffer[opIndex];
*nextLink = opIndex;
nextLink = &lastOp->next;
}
}
}
}
}
*/
kernel void mtl_raster(constant int* pathCount [[buffer(0)]],
const device mg_mtl_path* pathBuffer [[buffer(1)]],
constant int* segCount [[buffer(2)]],
@ -182,8 +338,6 @@ kernel void mtl_raster(constant int* pathCount [[buffer(0)]],
int2 tileCoord = pixelCoord / tileSize[0];
float4 color = float4(0, 0, 0, 0);
int currentPath = 0;
int winding = 0;
if( (pixelCoord.x % tileSize[0] == 0)
||(pixelCoord.y % tileSize[0] == 0))
@ -205,31 +359,17 @@ kernel void mtl_raster(constant int* pathCount [[buffer(0)]],
int pathTileIndex = pathTileCoord.y * pathQueue->area.z + pathTileCoord.x;
const device mg_mtl_tile_queue* tileQueue = &tileQueueBuffer[pathQueue->tileQueues + pathTileIndex];
int winding = atomic_load_explicit(&tileQueue->windingOffset, memory_order_relaxed);
int opIndex = atomic_load_explicit(&tileQueue->first, memory_order_relaxed);
while(opIndex != -1)
{
//outTexture.write(float4(0, 0, 1, 1), uint2(pixelCoord));
//return;
const device mg_mtl_tile_op* op = &tileOpBuffer[opIndex];
if(op->kind == MG_MTL_OP_SEGMENT)
{
const device mg_mtl_segment* seg = &segmentBuffer[op->index];
if(seg->pathIndex != currentPath)
{
//depending on winding number, update color
if(winding & 1)
{
float4 pathColor = pathBuffer[currentPath].color;
pathColor.rgb *= pathColor.a;
color = color*(1-pathColor.a) + pathColor;
}
currentPath = seg->pathIndex;
winding = 0;
}
if(pixelCoord.y >= seg->box.y && pixelCoord.y < seg->box.w)
{
if(pixelCoord.x < seg->box.x)
@ -258,17 +398,31 @@ kernel void mtl_raster(constant int* pathCount [[buffer(0)]],
}
}
}
if(op->crossRight)
{
if( (seg->config == MG_MTL_BR || seg->config == MG_MTL_TL)
&&(pixelCoord.y >= seg->box.w))
{
winding += seg->windingIncrement;
}
else if( (seg->config == MG_MTL_BL || seg->config == MG_MTL_TR)
&&(pixelCoord.y >= seg->box.y))
{
winding -= seg->windingIncrement;
}
}
}
opIndex = op->next;
}
}
}
if(winding & 1)
{
float4 pathColor = pathBuffer[currentPath].color;
pathColor.rgb *= pathColor.a;
color = color*(1-pathColor.a) + pathColor;
if(winding & 1)
{
float4 pathColor = pathBuffer[pathIndex].color;
pathColor.rgb *= pathColor.a;
color = color*(1-pathColor.a) + pathColor;
}
}
}
outTexture.write(color, uint2(pixelCoord));