893 lines
42 KiB
Python
893 lines
42 KiB
Python
![]() |
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import os
|
||
|
from contextlib import contextmanager, nullcontext
|
||
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||
|
|
||
|
import safetensors.torch
|
||
|
import torch
|
||
|
|
||
|
from ..utils import get_logger, is_accelerate_available
|
||
|
from .hooks import HookRegistry, ModelHook
|
||
|
|
||
|
|
||
|
if is_accelerate_available():
|
||
|
from accelerate.hooks import AlignDevicesHook, CpuOffload
|
||
|
from accelerate.utils import send_to_device
|
||
|
|
||
|
|
||
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||
|
|
||
|
|
||
|
# fmt: off
|
||
|
_GROUP_OFFLOADING = "group_offloading"
|
||
|
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||
|
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||
|
|
||
|
_SUPPORTED_PYTORCH_LAYERS = (
|
||
|
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||
|
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||
|
torch.nn.Linear,
|
||
|
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
||
|
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
||
|
)
|
||
|
# fmt: on
|
||
|
|
||
|
|
||
|
class ModuleGroup:
|
||
|
def __init__(
|
||
|
self,
|
||
|
modules: List[torch.nn.Module],
|
||
|
offload_device: torch.device,
|
||
|
onload_device: torch.device,
|
||
|
offload_leader: torch.nn.Module,
|
||
|
onload_leader: Optional[torch.nn.Module] = None,
|
||
|
parameters: Optional[List[torch.nn.Parameter]] = None,
|
||
|
buffers: Optional[List[torch.Tensor]] = None,
|
||
|
non_blocking: bool = False,
|
||
|
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||
|
record_stream: Optional[bool] = False,
|
||
|
low_cpu_mem_usage: bool = False,
|
||
|
onload_self: bool = True,
|
||
|
offload_to_disk_path: Optional[str] = None,
|
||
|
) -> None:
|
||
|
self.modules = modules
|
||
|
self.offload_device = offload_device
|
||
|
self.onload_device = onload_device
|
||
|
self.offload_leader = offload_leader
|
||
|
self.onload_leader = onload_leader
|
||
|
self.parameters = parameters or []
|
||
|
self.buffers = buffers or []
|
||
|
self.non_blocking = non_blocking or stream is not None
|
||
|
self.stream = stream
|
||
|
self.record_stream = record_stream
|
||
|
self.onload_self = onload_self
|
||
|
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||
|
|
||
|
self.offload_to_disk_path = offload_to_disk_path
|
||
|
self._is_offloaded_to_disk = False
|
||
|
|
||
|
if self.offload_to_disk_path:
|
||
|
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
|
||
|
|
||
|
all_tensors = []
|
||
|
for module in self.modules:
|
||
|
all_tensors.extend(list(module.parameters()))
|
||
|
all_tensors.extend(list(module.buffers()))
|
||
|
all_tensors.extend(self.parameters)
|
||
|
all_tensors.extend(self.buffers)
|
||
|
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
|
||
|
|
||
|
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
|
||
|
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
|
||
|
self.cpu_param_dict = {}
|
||
|
else:
|
||
|
self.cpu_param_dict = self._init_cpu_param_dict()
|
||
|
|
||
|
def _init_cpu_param_dict(self):
|
||
|
cpu_param_dict = {}
|
||
|
if self.stream is None:
|
||
|
return cpu_param_dict
|
||
|
|
||
|
for module in self.modules:
|
||
|
for param in module.parameters():
|
||
|
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||
|
for buffer in module.buffers():
|
||
|
cpu_param_dict[buffer] = (
|
||
|
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||
|
)
|
||
|
|
||
|
for param in self.parameters:
|
||
|
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||
|
|
||
|
for buffer in self.buffers:
|
||
|
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||
|
|
||
|
return cpu_param_dict
|
||
|
|
||
|
@contextmanager
|
||
|
def _pinned_memory_tensors(self):
|
||
|
pinned_dict = {}
|
||
|
try:
|
||
|
for param, tensor in self.cpu_param_dict.items():
|
||
|
if not tensor.is_pinned():
|
||
|
pinned_dict[param] = tensor.pin_memory()
|
||
|
else:
|
||
|
pinned_dict[param] = tensor
|
||
|
|
||
|
yield pinned_dict
|
||
|
|
||
|
finally:
|
||
|
pinned_dict = None
|
||
|
|
||
|
@torch.compiler.disable()
|
||
|
def onload_(self):
|
||
|
r"""Onloads the group of modules to the onload_device."""
|
||
|
torch_accelerator_module = (
|
||
|
getattr(torch, torch.accelerator.current_accelerator().type)
|
||
|
if hasattr(torch, "accelerator")
|
||
|
else torch.cuda
|
||
|
)
|
||
|
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||
|
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||
|
|
||
|
if self.offload_to_disk_path:
|
||
|
if self.stream is not None:
|
||
|
# Wait for previous Host->Device transfer to complete
|
||
|
self.stream.synchronize()
|
||
|
|
||
|
with context:
|
||
|
if self.stream is not None:
|
||
|
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
|
||
|
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||
|
for key, tensor_obj in self.key_to_tensor.items():
|
||
|
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
|
||
|
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
if self.record_stream:
|
||
|
tensor_obj.data.record_stream(current_stream)
|
||
|
else:
|
||
|
# Load directly to the target device (synchronous)
|
||
|
onload_device = (
|
||
|
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||
|
)
|
||
|
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||
|
for key, tensor_obj in self.key_to_tensor.items():
|
||
|
tensor_obj.data = loaded_tensors[key]
|
||
|
return
|
||
|
|
||
|
if self.stream is not None:
|
||
|
# Wait for previous Host->Device transfer to complete
|
||
|
self.stream.synchronize()
|
||
|
|
||
|
with context:
|
||
|
if self.stream is not None:
|
||
|
with self._pinned_memory_tensors() as pinned_memory:
|
||
|
for group_module in self.modules:
|
||
|
for param in group_module.parameters():
|
||
|
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
if self.record_stream:
|
||
|
param.data.record_stream(current_stream)
|
||
|
for buffer in group_module.buffers():
|
||
|
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
if self.record_stream:
|
||
|
buffer.data.record_stream(current_stream)
|
||
|
|
||
|
for param in self.parameters:
|
||
|
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
if self.record_stream:
|
||
|
param.data.record_stream(current_stream)
|
||
|
|
||
|
for buffer in self.buffers:
|
||
|
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
if self.record_stream:
|
||
|
buffer.data.record_stream(current_stream)
|
||
|
|
||
|
else:
|
||
|
for group_module in self.modules:
|
||
|
for param in group_module.parameters():
|
||
|
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
for buffer in group_module.buffers():
|
||
|
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
|
||
|
for param in self.parameters:
|
||
|
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
|
||
|
for buffer in self.buffers:
|
||
|
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||
|
if self.record_stream:
|
||
|
buffer.data.record_stream(current_stream)
|
||
|
|
||
|
@torch.compiler.disable()
|
||
|
def offload_(self):
|
||
|
r"""Offloads the group of modules to the offload_device."""
|
||
|
if self.offload_to_disk_path:
|
||
|
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||
|
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||
|
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||
|
# we perform a write.
|
||
|
# Check if the file has been saved in this session or if it already exists on disk.
|
||
|
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
||
|
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||
|
tensors_to_save = {
|
||
|
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
|
||
|
}
|
||
|
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||
|
|
||
|
# The group is now considered offloaded to disk for the rest of the session.
|
||
|
self._is_offloaded_to_disk = True
|
||
|
|
||
|
# We do this to free up the RAM which is still holding the up tensor data.
|
||
|
for tensor_obj in self.tensor_to_key.keys():
|
||
|
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||
|
return
|
||
|
|
||
|
torch_accelerator_module = (
|
||
|
getattr(torch, torch.accelerator.current_accelerator().type)
|
||
|
if hasattr(torch, "accelerator")
|
||
|
else torch.cuda
|
||
|
)
|
||
|
if self.stream is not None:
|
||
|
if not self.record_stream:
|
||
|
torch_accelerator_module.current_stream().synchronize()
|
||
|
for group_module in self.modules:
|
||
|
for param in group_module.parameters():
|
||
|
param.data = self.cpu_param_dict[param]
|
||
|
for param in self.parameters:
|
||
|
param.data = self.cpu_param_dict[param]
|
||
|
for buffer in self.buffers:
|
||
|
buffer.data = self.cpu_param_dict[buffer]
|
||
|
|
||
|
else:
|
||
|
for group_module in self.modules:
|
||
|
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||
|
for param in self.parameters:
|
||
|
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||
|
for buffer in self.buffers:
|
||
|
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||
|
|
||
|
|
||
|
class GroupOffloadingHook(ModelHook):
|
||
|
r"""
|
||
|
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
|
||
|
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
|
||
|
module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
|
||
|
group is responsible for onloading the current module group.
|
||
|
"""
|
||
|
|
||
|
_is_stateful = False
|
||
|
|
||
|
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
|
||
|
self.group = group
|
||
|
self.next_group = next_group
|
||
|
|
||
|
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||
|
if self.group.offload_leader == module:
|
||
|
self.group.offload_()
|
||
|
return module
|
||
|
|
||
|
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||
|
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
|
||
|
# method is the onload_leader of the group.
|
||
|
if self.group.onload_leader is None:
|
||
|
self.group.onload_leader = module
|
||
|
|
||
|
# If the current module is the onload_leader of the group, we onload the group if it is supposed
|
||
|
# to onload itself. In the case of using prefetching with streams, we onload the next group if
|
||
|
# it is not supposed to onload itself.
|
||
|
if self.group.onload_leader == module:
|
||
|
if self.group.onload_self:
|
||
|
self.group.onload_()
|
||
|
if self.next_group is not None and not self.next_group.onload_self:
|
||
|
self.next_group.onload_()
|
||
|
|
||
|
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||
|
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
||
|
return args, kwargs
|
||
|
|
||
|
def post_forward(self, module: torch.nn.Module, output):
|
||
|
if self.group.offload_leader == module:
|
||
|
self.group.offload_()
|
||
|
return output
|
||
|
|
||
|
|
||
|
class LazyPrefetchGroupOffloadingHook(ModelHook):
|
||
|
r"""
|
||
|
A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
|
||
|
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
|
||
|
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
|
||
|
prefetching groups in the correct order.
|
||
|
"""
|
||
|
|
||
|
_is_stateful = False
|
||
|
|
||
|
def __init__(self):
|
||
|
self.execution_order: List[Tuple[str, torch.nn.Module]] = []
|
||
|
self._layer_execution_tracker_module_names = set()
|
||
|
|
||
|
def initialize_hook(self, module):
|
||
|
def make_execution_order_update_callback(current_name, current_submodule):
|
||
|
def callback():
|
||
|
logger.debug(f"Adding {current_name} to the execution order")
|
||
|
self.execution_order.append((current_name, current_submodule))
|
||
|
|
||
|
return callback
|
||
|
|
||
|
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
|
||
|
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
|
||
|
# layers are executed during the forward pass.
|
||
|
for name, submodule in module.named_modules():
|
||
|
if name == "" or not hasattr(submodule, "_diffusers_hook"):
|
||
|
continue
|
||
|
|
||
|
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||
|
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
|
||
|
|
||
|
if group_offloading_hook is not None:
|
||
|
# For the first forward pass, we have to load in a blocking manner
|
||
|
group_offloading_hook.group.non_blocking = False
|
||
|
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
|
||
|
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
|
||
|
self._layer_execution_tracker_module_names.add(name)
|
||
|
|
||
|
return module
|
||
|
|
||
|
def post_forward(self, module, output):
|
||
|
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now
|
||
|
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each
|
||
|
# group offloading hook.
|
||
|
num_executed = len(self.execution_order)
|
||
|
execution_order_module_names = {name for name, _ in self.execution_order}
|
||
|
|
||
|
# It may be possible that some layers were not executed during the forward pass. This can happen if the layer
|
||
|
# is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we
|
||
|
# may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors
|
||
|
# if the missing layers end up being executed in the future.
|
||
|
if execution_order_module_names != self._layer_execution_tracker_module_names:
|
||
|
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
|
||
|
logger.warning(
|
||
|
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
|
||
|
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
|
||
|
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
|
||
|
f"{unexecuted_layers=}"
|
||
|
)
|
||
|
|
||
|
# Remove the layer execution tracker hooks from the submodules
|
||
|
base_module_registry = module._diffusers_hook
|
||
|
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
|
||
|
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
|
||
|
|
||
|
for i in range(num_executed):
|
||
|
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
|
||
|
|
||
|
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
|
||
|
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
|
||
|
|
||
|
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
|
||
|
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
|
||
|
# see the benefits of prefetching.
|
||
|
for hook in group_offloading_hooks:
|
||
|
hook.group.non_blocking = True
|
||
|
|
||
|
# Set required attributes for prefetching
|
||
|
if num_executed > 0:
|
||
|
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
|
||
|
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
|
||
|
base_module_group_offloading_hook.next_group.onload_self = False
|
||
|
|
||
|
for i in range(num_executed - 1):
|
||
|
name1, _ = self.execution_order[i]
|
||
|
name2, _ = self.execution_order[i + 1]
|
||
|
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
|
||
|
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
|
||
|
group_offloading_hooks[i].next_group.onload_self = False
|
||
|
|
||
|
return output
|
||
|
|
||
|
|
||
|
class LayerExecutionTrackerHook(ModelHook):
|
||
|
r"""
|
||
|
A hook that tracks the order in which the layers are executed during the forward pass by calling back to the
|
||
|
LazyPrefetchGroupOffloadingHook to update the execution order.
|
||
|
"""
|
||
|
|
||
|
_is_stateful = False
|
||
|
|
||
|
def __init__(self, execution_order_update_callback):
|
||
|
self.execution_order_update_callback = execution_order_update_callback
|
||
|
|
||
|
def pre_forward(self, module, *args, **kwargs):
|
||
|
self.execution_order_update_callback()
|
||
|
return args, kwargs
|
||
|
|
||
|
|
||
|
def apply_group_offloading(
|
||
|
module: torch.nn.Module,
|
||
|
onload_device: torch.device,
|
||
|
offload_device: torch.device = torch.device("cpu"),
|
||
|
offload_type: str = "block_level",
|
||
|
num_blocks_per_group: Optional[int] = None,
|
||
|
non_blocking: bool = False,
|
||
|
use_stream: bool = False,
|
||
|
record_stream: bool = False,
|
||
|
low_cpu_mem_usage: bool = False,
|
||
|
offload_to_disk_path: Optional[str] = None,
|
||
|
) -> None:
|
||
|
r"""
|
||
|
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
||
|
where it is beneficial, we need to first provide some context on how other supported offloading methods work.
|
||
|
|
||
|
Typically, offloading is done at two levels:
|
||
|
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
|
||
|
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
|
||
|
when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
|
||
|
but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
|
||
|
the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
|
||
|
pass.
|
||
|
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
|
||
|
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
|
||
|
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
|
||
|
memory, but can be slower due to the excessive number of device synchronizations.
|
||
|
|
||
|
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
|
||
|
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
|
||
|
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is
|
||
|
reduced.
|
||
|
|
||
|
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
|
||
|
overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This
|
||
|
is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to
|
||
|
the accelerator device while the current layer is being executed - this increases the memory requirements slightly.
|
||
|
Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
|
||
|
|
||
|
Args:
|
||
|
module (`torch.nn.Module`):
|
||
|
The module to which group offloading is applied.
|
||
|
onload_device (`torch.device`):
|
||
|
The device to which the group of modules are onloaded.
|
||
|
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
||
|
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
||
|
offload_type (`str`, defaults to "block_level"):
|
||
|
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||
|
"block_level".
|
||
|
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||
|
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||
|
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||
|
num_blocks_per_group (`int`, *optional*):
|
||
|
The number of blocks per group when using offload_type="block_level". This is required when using
|
||
|
offload_type="block_level".
|
||
|
non_blocking (`bool`, defaults to `False`):
|
||
|
If True, offloading and onloading is done with non-blocking data transfer.
|
||
|
use_stream (`bool`, defaults to `False`):
|
||
|
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||
|
overlapping computation and data transfer.
|
||
|
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||
|
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||
|
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||
|
details.
|
||
|
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||
|
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||
|
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||
|
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||
|
|
||
|
Example:
|
||
|
```python
|
||
|
>>> from diffusers import CogVideoXTransformer3DModel
|
||
|
>>> from diffusers.hooks import apply_group_offloading
|
||
|
|
||
|
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||
|
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
|
||
|
... )
|
||
|
|
||
|
>>> apply_group_offloading(
|
||
|
... transformer,
|
||
|
... onload_device=torch.device("cuda"),
|
||
|
... offload_device=torch.device("cpu"),
|
||
|
... offload_type="block_level",
|
||
|
... num_blocks_per_group=2,
|
||
|
... use_stream=True,
|
||
|
... )
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
stream = None
|
||
|
if use_stream:
|
||
|
if torch.cuda.is_available():
|
||
|
stream = torch.cuda.Stream()
|
||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||
|
stream = torch.Stream()
|
||
|
else:
|
||
|
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
||
|
|
||
|
if not use_stream and record_stream:
|
||
|
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
||
|
|
||
|
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||
|
|
||
|
if offload_type == "block_level":
|
||
|
if num_blocks_per_group is None:
|
||
|
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||
|
|
||
|
_apply_group_offloading_block_level(
|
||
|
module=module,
|
||
|
num_blocks_per_group=num_blocks_per_group,
|
||
|
offload_device=offload_device,
|
||
|
onload_device=onload_device,
|
||
|
offload_to_disk_path=offload_to_disk_path,
|
||
|
non_blocking=non_blocking,
|
||
|
stream=stream,
|
||
|
record_stream=record_stream,
|
||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||
|
)
|
||
|
elif offload_type == "leaf_level":
|
||
|
_apply_group_offloading_leaf_level(
|
||
|
module=module,
|
||
|
offload_device=offload_device,
|
||
|
onload_device=onload_device,
|
||
|
offload_to_disk_path=offload_to_disk_path,
|
||
|
non_blocking=non_blocking,
|
||
|
stream=stream,
|
||
|
record_stream=record_stream,
|
||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||
|
|
||
|
|
||
|
def _apply_group_offloading_block_level(
|
||
|
module: torch.nn.Module,
|
||
|
num_blocks_per_group: int,
|
||
|
offload_device: torch.device,
|
||
|
onload_device: torch.device,
|
||
|
non_blocking: bool,
|
||
|
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||
|
record_stream: Optional[bool] = False,
|
||
|
low_cpu_mem_usage: bool = False,
|
||
|
offload_to_disk_path: Optional[str] = None,
|
||
|
) -> None:
|
||
|
r"""
|
||
|
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||
|
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||
|
|
||
|
Args:
|
||
|
module (`torch.nn.Module`):
|
||
|
The module to which group offloading is applied.
|
||
|
offload_device (`torch.device`):
|
||
|
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||
|
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||
|
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||
|
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||
|
onload_device (`torch.device`):
|
||
|
The device to which the group of modules are onloaded.
|
||
|
non_blocking (`bool`):
|
||
|
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||
|
and data transfer.
|
||
|
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||
|
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||
|
for overlapping computation and data transfer.
|
||
|
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||
|
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||
|
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||
|
details.
|
||
|
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||
|
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||
|
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||
|
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||
|
"""
|
||
|
if stream is not None and num_blocks_per_group != 1:
|
||
|
logger.warning(
|
||
|
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
||
|
)
|
||
|
num_blocks_per_group = 1
|
||
|
|
||
|
# Create module groups for ModuleList and Sequential blocks
|
||
|
modules_with_group_offloading = set()
|
||
|
unmatched_modules = []
|
||
|
matched_module_groups = []
|
||
|
for name, submodule in module.named_children():
|
||
|
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||
|
unmatched_modules.append((name, submodule))
|
||
|
modules_with_group_offloading.add(name)
|
||
|
continue
|
||
|
|
||
|
for i in range(0, len(submodule), num_blocks_per_group):
|
||
|
current_modules = submodule[i : i + num_blocks_per_group]
|
||
|
group = ModuleGroup(
|
||
|
modules=current_modules,
|
||
|
offload_device=offload_device,
|
||
|
onload_device=onload_device,
|
||
|
offload_to_disk_path=offload_to_disk_path,
|
||
|
offload_leader=current_modules[-1],
|
||
|
onload_leader=current_modules[0],
|
||
|
non_blocking=non_blocking,
|
||
|
stream=stream,
|
||
|
record_stream=record_stream,
|
||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||
|
onload_self=True,
|
||
|
)
|
||
|
matched_module_groups.append(group)
|
||
|
for j in range(i, i + len(current_modules)):
|
||
|
modules_with_group_offloading.add(f"{name}.{j}")
|
||
|
|
||
|
# Apply group offloading hooks to the module groups
|
||
|
for i, group in enumerate(matched_module_groups):
|
||
|
for group_module in group.modules:
|
||
|
_apply_group_offloading_hook(group_module, group, None)
|
||
|
|
||
|
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||
|
# when the forward pass of this module is called. This is because the top-level module is not
|
||
|
# part of any group (as doing so would lead to no VRAM savings).
|
||
|
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||
|
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||
|
parameters = [param for _, param in parameters]
|
||
|
buffers = [buffer for _, buffer in buffers]
|
||
|
|
||
|
# Create a group for the unmatched submodules of the top-level module so that they are on the correct
|
||
|
# device when the forward pass is called.
|
||
|
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||
|
unmatched_group = ModuleGroup(
|
||
|
modules=unmatched_modules,
|
||
|
offload_device=offload_device,
|
||
|
onload_device=onload_device,
|
||
|
offload_to_disk_path=offload_to_disk_path,
|
||
|
offload_leader=module,
|
||
|
onload_leader=module,
|
||
|
parameters=parameters,
|
||
|
buffers=buffers,
|
||
|
non_blocking=False,
|
||
|
stream=None,
|
||
|
record_stream=False,
|
||
|
onload_self=True,
|
||
|
)
|
||
|
if stream is None:
|
||
|
_apply_group_offloading_hook(module, unmatched_group, None)
|
||
|
else:
|
||
|
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||
|
|
||
|
|
||
|
def _apply_group_offloading_leaf_level(
|
||
|
module: torch.nn.Module,
|
||
|
offload_device: torch.device,
|
||
|
onload_device: torch.device,
|
||
|
non_blocking: bool,
|
||
|
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||
|
record_stream: Optional[bool] = False,
|
||
|
low_cpu_mem_usage: bool = False,
|
||
|
offload_to_disk_path: Optional[str] = None,
|
||
|
) -> None:
|
||
|
r"""
|
||
|
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||
|
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
||
|
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
||
|
reduce memory usage without any performance degradation.
|
||
|
|
||
|
Args:
|
||
|
module (`torch.nn.Module`):
|
||
|
The module to which group offloading is applied.
|
||
|
offload_device (`torch.device`):
|
||
|
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||
|
onload_device (`torch.device`):
|
||
|
The device to which the group of modules are onloaded.
|
||
|
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||
|
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||
|
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||
|
non_blocking (`bool`):
|
||
|
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||
|
and data transfer.
|
||
|
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||
|
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||
|
for overlapping computation and data transfer.
|
||
|
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||
|
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||
|
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||
|
details.
|
||
|
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||
|
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||
|
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||
|
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||
|
"""
|
||
|
|
||
|
# Create module groups for leaf modules and apply group offloading hooks
|
||
|
modules_with_group_offloading = set()
|
||
|
for name, submodule in module.named_modules():
|
||
|
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
|
||
|
continue
|
||
|
group = ModuleGroup(
|
||
|
modules=[submodule],
|
||
|
offload_device=offload_device,
|
||
|
onload_device=onload_device,
|
||
|
offload_to_disk_path=offload_to_disk_path,
|
||
|
offload_leader=submodule,
|
||
|
onload_leader=submodule,
|
||
|
non_blocking=non_blocking,
|
||
|
stream=stream,
|
||
|
record_stream=record_stream,
|
||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||
|
onload_self=True,
|
||
|
)
|
||
|
_apply_group_offloading_hook(submodule, group, None)
|
||
|
modules_with_group_offloading.add(name)
|
||
|
|
||
|
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||
|
# of the module is called
|
||
|
module_dict = dict(module.named_modules())
|
||
|
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||
|
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
||
|
|
||
|
# Find closest module parent for each parameter and buffer, and attach group hooks
|
||
|
parent_to_parameters = {}
|
||
|
for name, param in parameters:
|
||
|
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
||
|
if parent_name in parent_to_parameters:
|
||
|
parent_to_parameters[parent_name].append(param)
|
||
|
else:
|
||
|
parent_to_parameters[parent_name] = [param]
|
||
|
|
||
|
parent_to_buffers = {}
|
||
|
for name, buffer in buffers:
|
||
|
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
||
|
if parent_name in parent_to_buffers:
|
||
|
parent_to_buffers[parent_name].append(buffer)
|
||
|
else:
|
||
|
parent_to_buffers[parent_name] = [buffer]
|
||
|
|
||
|
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
|
||
|
for name in parent_names:
|
||
|
parameters = parent_to_parameters.get(name, [])
|
||
|
buffers = parent_to_buffers.get(name, [])
|
||
|
parent_module = module_dict[name]
|
||
|
assert getattr(parent_module, "_diffusers_hook", None) is None
|
||
|
group = ModuleGroup(
|
||
|
modules=[],
|
||
|
offload_device=offload_device,
|
||
|
onload_device=onload_device,
|
||
|
offload_leader=parent_module,
|
||
|
onload_leader=parent_module,
|
||
|
offload_to_disk_path=offload_to_disk_path,
|
||
|
parameters=parameters,
|
||
|
buffers=buffers,
|
||
|
non_blocking=non_blocking,
|
||
|
stream=stream,
|
||
|
record_stream=record_stream,
|
||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||
|
onload_self=True,
|
||
|
)
|
||
|
_apply_group_offloading_hook(parent_module, group, None)
|
||
|
|
||
|
if stream is not None:
|
||
|
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||
|
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
||
|
# execution order and apply prefetching in the correct order.
|
||
|
unmatched_group = ModuleGroup(
|
||
|
modules=[],
|
||
|
offload_device=offload_device,
|
||
|
onload_device=onload_device,
|
||
|
offload_to_disk_path=offload_to_disk_path,
|
||
|
offload_leader=module,
|
||
|
onload_leader=module,
|
||
|
parameters=None,
|
||
|
buffers=None,
|
||
|
non_blocking=False,
|
||
|
stream=None,
|
||
|
record_stream=False,
|
||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||
|
onload_self=True,
|
||
|
)
|
||
|
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||
|
|
||
|
|
||
|
def _apply_group_offloading_hook(
|
||
|
module: torch.nn.Module,
|
||
|
group: ModuleGroup,
|
||
|
next_group: Optional[ModuleGroup] = None,
|
||
|
) -> None:
|
||
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||
|
|
||
|
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||
|
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||
|
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||
|
hook = GroupOffloadingHook(group, next_group)
|
||
|
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||
|
|
||
|
|
||
|
def _apply_lazy_group_offloading_hook(
|
||
|
module: torch.nn.Module,
|
||
|
group: ModuleGroup,
|
||
|
next_group: Optional[ModuleGroup] = None,
|
||
|
) -> None:
|
||
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||
|
|
||
|
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||
|
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||
|
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||
|
hook = GroupOffloadingHook(group, next_group)
|
||
|
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||
|
|
||
|
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||
|
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
|
||
|
|
||
|
|
||
|
def _gather_parameters_with_no_group_offloading_parent(
|
||
|
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||
|
) -> List[torch.nn.Parameter]:
|
||
|
parameters = []
|
||
|
for name, parameter in module.named_parameters():
|
||
|
has_parent_with_group_offloading = False
|
||
|
atoms = name.split(".")
|
||
|
while len(atoms) > 0:
|
||
|
parent_name = ".".join(atoms)
|
||
|
if parent_name in modules_with_group_offloading:
|
||
|
has_parent_with_group_offloading = True
|
||
|
break
|
||
|
atoms.pop()
|
||
|
if not has_parent_with_group_offloading:
|
||
|
parameters.append((name, parameter))
|
||
|
return parameters
|
||
|
|
||
|
|
||
|
def _gather_buffers_with_no_group_offloading_parent(
|
||
|
module: torch.nn.Module, modules_with_group_offloading: Set[str]
|
||
|
) -> List[torch.Tensor]:
|
||
|
buffers = []
|
||
|
for name, buffer in module.named_buffers():
|
||
|
has_parent_with_group_offloading = False
|
||
|
atoms = name.split(".")
|
||
|
while len(atoms) > 0:
|
||
|
parent_name = ".".join(atoms)
|
||
|
if parent_name in modules_with_group_offloading:
|
||
|
has_parent_with_group_offloading = True
|
||
|
break
|
||
|
atoms.pop()
|
||
|
if not has_parent_with_group_offloading:
|
||
|
buffers.append((name, buffer))
|
||
|
return buffers
|
||
|
|
||
|
|
||
|
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str:
|
||
|
atoms = name.split(".")
|
||
|
while len(atoms) > 0:
|
||
|
parent_name = ".".join(atoms)
|
||
|
if parent_name in module_dict:
|
||
|
return parent_name
|
||
|
atoms.pop()
|
||
|
return ""
|
||
|
|
||
|
|
||
|
def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
|
||
|
if not is_accelerate_available():
|
||
|
return
|
||
|
for name, submodule in module.named_modules():
|
||
|
if not hasattr(submodule, "_hf_hook"):
|
||
|
continue
|
||
|
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
|
||
|
raise ValueError(
|
||
|
f"Cannot apply group offloading to a module that is already applying an alternative "
|
||
|
f"offloading strategy from Accelerate. If you want to apply group offloading, please "
|
||
|
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
|
||
|
)
|
||
|
|
||
|
|
||
|
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||
|
for submodule in module.modules():
|
||
|
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||
|
for submodule in module.modules():
|
||
|
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||
|
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||
|
raise ValueError("Group offloading is not enabled for the provided module.")
|