-
Notifications
You must be signed in to change notification settings - Fork 43
Transp2d + Transp1d refactor + minor syntax updates #36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pcs
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,39 +28,47 @@ void k2c_conv1d_transpose(k2c_tensor *output, const k2c_tensor *input, | |||||||
|
|
||||||||
| const size_t ker_dim12 = n_channels * n_filters; | ||||||||
|
|
||||||||
| size_t cs = 0; | ||||||||
| size_t ce = 0; | ||||||||
| size_t ts = 0; | ||||||||
| size_t ks = 0; | ||||||||
| // changed some names for refactor clarity | ||||||||
| size_t output_start_idx = 0; // cs | ||||||||
| size_t output_end_idx = 0; // ce | ||||||||
| size_t output_raw_idx = 0; // ts | ||||||||
| size_t kernel_offset = 0; // ks | ||||||||
|
|
||||||||
| for (size_t f = 0; f < n_filters; ++f) | ||||||||
| { | ||||||||
| for (size_t ch = 0; ch < n_channels; ++ch) | ||||||||
| { | ||||||||
| for (size_t t = 0; t < n_height; ++t) | ||||||||
| { | ||||||||
| ts = t * stride; | ||||||||
| if (ts > start_crop) | ||||||||
| output_raw_idx = t * stride; | ||||||||
|
|
||||||||
| // start index | ||||||||
| if (output_raw_idx > start_crop) | ||||||||
| { | ||||||||
| cs = ts - start_crop; | ||||||||
| output_start_idx = output_raw_idx - start_crop; | ||||||||
| } | ||||||||
| else | ||||||||
| { | ||||||||
| cs = 0; | ||||||||
| output_start_idx = 0; | ||||||||
| } | ||||||||
| if (ts + k_size - start_crop > out_height) | ||||||||
|
|
||||||||
| // end index | ||||||||
| if (output_raw_idx + k_size - start_crop > out_height) | ||||||||
| { | ||||||||
| ce = out_height; | ||||||||
| output_end_idx = out_height; | ||||||||
| } | ||||||||
| else | ||||||||
| { | ||||||||
| ce = ts + k_size - start_crop; | ||||||||
| output_end_idx = output_raw_idx + k_size - start_crop; | ||||||||
| } | ||||||||
| ks = cs - (ts - start_crop); | ||||||||
| for (size_t i = 0; i < ce - cs; ++i) | ||||||||
|
|
||||||||
| kernel_offset = output_start_idx - (output_raw_idx - start_crop); | ||||||||
|
|
||||||||
| // convolution | ||||||||
| for (size_t i = 0; i < output_end_idx - output_start_idx; ++i) | ||||||||
| { | ||||||||
| output->array[(i + cs) * n_filters + f] += | ||||||||
| kernel->array[(i + ks) * ker_dim12 + f * n_channels + ch] * | ||||||||
| output->array[(i + output_start_idx) * n_filters + f] += | ||||||||
| kernel->array[(i + kernel_offset) * ker_dim12 + f * n_channels + ch] * | ||||||||
| input->array[t * n_channels + ch]; | ||||||||
| } | ||||||||
| } | ||||||||
|
|
@@ -71,3 +79,152 @@ void k2c_conv1d_transpose(k2c_tensor *output, const k2c_tensor *input, | |||||||
| k2c_bias_add(output, bias); | ||||||||
| activation(output->array, output->numel); | ||||||||
| } | ||||||||
|
|
||||||||
| /** | ||||||||
| * 2D Transposed Convolution (Deconvolution). | ||||||||
| * Assumes a "channels last" structure. | ||||||||
| * | ||||||||
| * :param output: output tensor. | ||||||||
| * :param input: input tensor. | ||||||||
| * :param kernel: kernel tensor. | ||||||||
| * :param bias: bias tensor. | ||||||||
| * :param stride: array[2] {stride_height, stride_width}. | ||||||||
| * :param dilation: array[2] {dilation_height, dilation_width}. | ||||||||
| * (Note: Logic below assumes dilation is 1 for the optimized bounds check). | ||||||||
| * :param padding: array[2] {crop_top, crop_left}. | ||||||||
| * Amount to crop from the output (inverse of padding). | ||||||||
| * :param activation: activation function to apply to output. | ||||||||
| */ | ||||||||
| void k2c_conv2d_transpose(k2c_tensor *output, const k2c_tensor *input, | ||||||||
| const k2c_tensor *kernel, const k2c_tensor *bias, | ||||||||
| const size_t *stride, const size_t *dilation, | ||||||||
| const size_t *padding, k2c_activationType *activation) | ||||||||
|
Comment on lines
+100
to
+101
|
||||||||
| { | ||||||||
| // Initialize output memory to zero | ||||||||
| memset(output->array, 0, output->numel * sizeof(output->array[0])); | ||||||||
|
|
||||||||
| // --- Dimensions --- | ||||||||
| const size_t in_rows = input->shape[0]; | ||||||||
| const size_t in_cols = input->shape[1]; | ||||||||
| const size_t in_channels = input->shape[2]; | ||||||||
|
|
||||||||
| // Kernel Shape: {Rows, Cols, InChannels, OutChannels} based on reference | ||||||||
| const size_t k_rows = kernel->shape[0]; | ||||||||
| const size_t k_cols = kernel->shape[1]; | ||||||||
| const size_t n_filters = kernel->shape[3]; | ||||||||
|
|
||||||||
| const size_t out_rows = output->shape[0]; | ||||||||
| const size_t out_cols = output->shape[1]; | ||||||||
|
|
||||||||
| // Access strides/padding from arrays | ||||||||
| const size_t stride_h = stride[0]; | ||||||||
| const size_t stride_w = stride[1]; | ||||||||
| const size_t crop_h = padding[0]; | ||||||||
| const size_t crop_w = padding[1]; | ||||||||
|
|
||||||||
| // Pre-calculate dimensional steps for Kernel | ||||||||
| // Kernel index math: z0 * (cols*in*out) + z1 * (in*out) + q * (out) + k | ||||||||
| // Note: This matches the "Out-Channel Last" memory layout of the reference. | ||||||||
|
||||||||
| // Note: This matches the "Out-Channel Last" memory layout of the reference. | |
| // Note: This matches a kernel layout of (rows, cols, in_channels, out_channels), | |
| // i.e., with output channels in the last dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the comment on lines 31-35, the variable names are being renamed for clarity (e.g., 'cs' to 'output_start_idx'). However, these comments show the old variable names in a format that suggests they are being defined (e.g., "// cs"). This notation is somewhat ambiguous. Consider clarifying this with a format like "// was: cs" or "// renamed from: cs" to make it clear these are the old names being replaced.