diff --git a/src/graphics/mtl_renderer.m b/src/graphics/mtl_renderer.m index 61e8daa..1ad20f8 100644 --- a/src/graphics/mtl_renderer.m +++ b/src/graphics/mtl_renderer.m @@ -1049,6 +1049,8 @@ void oc_mtl_render_batch(oc_mtl_canvas_backend* backend, [backpropEncoder setBuffer:backend->logBuffer[backend->bufferIndex] offset:0 atIndex:2]; [backpropEncoder setBuffer:backend->logOffsetBuffer[backend->bufferIndex] offset:0 atIndex:3]; + [backpropEncoder setBuffer:backend->segmentCountBuffer offset:0 atIndex:4]; + MTLSize backpropGroupSize = MTLSizeMake([backend->backpropPipeline maxTotalThreadsPerThreadgroup], 1, 1); MTLSize backpropGridSize = MTLSizeMake(pathCount * backpropGroupSize.width, 1, 1); @@ -1563,7 +1565,7 @@ oc_canvas_backend* oc_mtl_canvas_backend_create(oc_mtl_surface* surface) options:bufferOptions]; backend->segmentCountBuffer = [surface->device newBufferWithLength:sizeof(int) - options:bufferOptions]; + options:MTLResourceStorageModeShared]; backend->pathQueueBuffer = [surface->device newBufferWithLength:OC_MTL_DEFAULT_PATH_QUEUE_BUFFER_LEN * sizeof(oc_mtl_path_queue) options:bufferOptions]; diff --git a/src/graphics/mtl_renderer.metal b/src/graphics/mtl_renderer.metal index f043932..bcb39f1 100644 --- a/src/graphics/mtl_renderer.metal +++ b/src/graphics/mtl_renderer.metal @@ -83,7 +83,8 @@ int mtl_itoa_right_aligned(int bufSize, thread char* buffer, int64_t value, bool buffer[index] = '0' + (value % 10); index--; value /= 10; - } while(value != 0 && index >= stop); + } + while(value != 0 && index >= stop); if(zeroPad) { @@ -1259,8 +1260,8 @@ kernel void mtl_segment_setup(constant int* elementCount [[buffer(0)]], device oc_mtl_tile_queue* tileQueueBuffer [[buffer(5)]], device oc_mtl_tile_op* tileOpBuffer [[buffer(6)]], device atomic_int* tileOpCount [[buffer(7)]], - constant int* segmentMax [[buffer(8)]], - constant int* tileOpMax [[buffer(9)]], + constant int* tileOpMax [[buffer(8)]], + constant int* segmentMax [[buffer(9)]], constant int* tileSize [[buffer(10)]], constant float* scale [[buffer(11)]], @@ -1323,10 +1324,19 @@ kernel void mtl_backprop(const device oc_mtl_path_queue* pathQueueBuffer [[buffe device oc_mtl_tile_queue* tileQueueBuffer [[buffer(1)]], device char* logBuffer [[buffer(2)]], device atomic_int* logOffsetBuffer [[buffer(3)]], + device int* segmentCount [[buffer(4)]], uint pathIndex [[threadgroup_position_in_grid]], - uint localID [[thread_position_in_threadgroup]]) + uint localID [[thread_position_in_threadgroup]], + uint uid [[thread_position_in_grid]]) { - // mtl_log_context log = {.buffer = logBuffer, .offset = logOffsetBuffer, .enabled = false}; + mtl_log_context log = { .buffer = logBuffer, .offset = logOffsetBuffer, .enabled = false }; + + if(uid == 0) + { + mtl_log(log, "segmentCount = "); + mtl_log_i32(log, segmentCount[0]); + mtl_log(log, "\n"); + } threadgroup atomic_int nextRowIndex; if(localID == 0)