Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions cuda_bindings/build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,32 @@ def __init__(self, name, members):
self._name = name
self._member_names = []
self._member_types = []
self._member_declarators = []
for var_name, var_type, _ in members:
var_type = var_type[0]
var_type = var_type.removeprefix("struct ")
var_type = var_type.removeprefix("union ")
base_type = var_type[0]
base_type = base_type.removeprefix("struct ")
base_type = base_type.removeprefix("union ")

self._member_names += [var_name]
self._member_types += [var_type]
self._member_types += [base_type]
self._member_declarators += [tuple(var_type[1:])]

def member_type(self, member_name):
try:
return self._member_types[self._member_names.index(member_name)]
except ValueError:
return None

def member_array_length(self, member_name):
try:
declarators = self._member_declarators[self._member_names.index(member_name)]
except ValueError:
return None

for declarator in declarators:
if isinstance(declarator, list) and len(declarator) == 1:
return declarator[0]
return None

def discoverMembers(self, memberDict, prefix, seen=None):
if seen is None:
Expand Down Expand Up @@ -161,6 +180,9 @@ def _parse_headers(header_dict, include_path_list, parser_caching):
# Since we only support 64 bit architectures, we can inline the sizeof(T*) to 8 and then compute the
# result in Python. The arithmetic expression is preserved to help with clarity and understanding
r"char reserved\[52 - sizeof\(CUcheckpointGpuPair \*\)\];": rf"char reserved[{52 - 8}];",
r"char reserved\[64 - sizeof\(CUcheckpointGpuPair \*\) - sizeof\(unsigned int\)\];": (
rf"char reserved[{64 - 8 - 4}];"
),
}

print(f'Parsing headers in "{include_path_list}" (Caching = {parser_caching})', flush=True)
Expand Down Expand Up @@ -310,6 +332,13 @@ def _build_cuda_bindings(strip=False):
found_types, found_functions, found_values, found_struct, struct_list = _parse_headers(
header_dict, include_path_list, parser_caching
)
struct_field_types = {}
struct_field_array_lengths = {}
for struct_name, struct in struct_list.items():
for member_name in struct._member_names:
key = f"{struct_name}.{member_name}"
struct_field_types[key] = struct.member_type(member_name)
struct_field_array_lengths[key] = struct.member_array_length(member_name)

# Generate code from .in templates
path_list = [
Expand All @@ -332,6 +361,8 @@ def _build_cuda_bindings(strip=False):
"found_values": found_values,
"found_struct": found_struct,
"struct_list": struct_list,
"struct_field_types": struct_field_types,
"struct_field_array_lengths": struct_field_array_lengths,
"os": os,
"sys": sys,
"platform": platform,
Expand Down
2 changes: 1 addition & 1 deletion cuda_bindings/cuda/bindings/_bindings/cydriver.pxd.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

# This code was automatically generated with version 12.9.0, generator version 49a8141. Do not modify it directly.
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why this wasn't updated in the previous refresh...

from cuda.bindings.cydriver cimport *

{{if 'cuGetErrorString' in found_functions}}
Expand Down
2 changes: 1 addition & 1 deletion cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1630+gadce055ea.d20260422. Do not modify it directly.
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
{{if 'Windows' == platform.system()}}
import os
cimport cuda.bindings._lib.windll as windll
Expand Down
9 changes: 7 additions & 2 deletions cuda_bindings/cuda/bindings/cydriver.pxd.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

# This code was automatically generated with version 12.9.0, generator version 49a8141. Do not modify it directly.
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.

from libc.stdint cimport uint32_t, uint64_t

Expand Down Expand Up @@ -2311,7 +2311,12 @@ cdef extern from "cuda.h":
ctypedef CUcheckpointCheckpointArgs_st CUcheckpointCheckpointArgs

cdef struct CUcheckpointRestoreArgs_st:
cuuint64_t reserved[8]
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'char'}}
char reserved[{{struct_field_array_lengths['CUcheckpointRestoreArgs_st.reserved']}}]
{{endif}}
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'cuuint64_t'}}
cuuint64_t reserved[{{struct_field_array_lengths['CUcheckpointRestoreArgs_st.reserved']}}]
{{endif}}
Comment on lines 2313 to +2319
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern as in here: #2144 (comment)


ctypedef CUcheckpointRestoreArgs_st CUcheckpointRestoreArgs

Expand Down
2 changes: 1 addition & 1 deletion cuda_bindings/cuda/bindings/cydriver.pyx.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

# This code was automatically generated with version 12.9.0, generator version 49a8141. Do not modify it directly.
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
cimport cuda.bindings._bindings.cydriver as cydriver

{{if 'cuGetErrorString' in found_functions}}
Expand Down
14 changes: 11 additions & 3 deletions cuda_bindings/cuda/bindings/driver.pxd.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1588+g61faef43a. Do not modify it directly.
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
cimport cuda.bindings.cydriver as cydriver

include "_lib/utils.pxd"
Expand Down Expand Up @@ -5097,7 +5097,11 @@ cdef class CUcheckpointRestoreArgs_st:

Attributes
----------
{{if 'CUcheckpointRestoreArgs_st.reserved' in found_struct}}
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'char'}}
reserved : bytes
Reserved for future use, must be zeroed
{{endif}}
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'cuuint64_t'}}
reserved : list[cuuint64_t]
Reserved for future use, must be zeroed
{{endif}}
Expand Down Expand Up @@ -10560,7 +10564,11 @@ cdef class CUcheckpointRestoreArgs(CUcheckpointRestoreArgs_st):

Attributes
----------
{{if 'CUcheckpointRestoreArgs_st.reserved' in found_struct}}
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'char'}}
reserved : bytes
Reserved for future use, must be zeroed
{{endif}}
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'cuuint64_t'}}
reserved : list[cuuint64_t]
Reserved for future use, must be zeroed
{{endif}}
Expand Down
Loading
Loading