torchcache.torchcache._TorchCache

class _TorchCache(*, memory_cache_device: str, subsample_count: int, persistent: bool, persistent_cache_dir: str, persistent_module_hash: str, max_persistent_cache_size: int, max_memory_cache_size: int, zstd_compression: bool, zstd_compression_level: int, zstd_compression_threads: int, cache_dtype: dtype, use_mmap_on_load: bool)

Class that implements the caching logic.

Do not initialize this class directly, use the torchcache decorator instead.

__init__(*, memory_cache_device: str, subsample_count: int, persistent: bool, persistent_cache_dir: str, persistent_module_hash: str, max_persistent_cache_size: int, max_memory_cache_size: int, zstd_compression: bool, zstd_compression_level: int, zstd_compression_threads: int, cache_dtype: dtype, use_mmap_on_load: bool)

Initialize the torchcache.

Methods

__init__(*, memory_cache_device, ...)

Initialize the torchcache.

cache_cleanup()

forward_hook(module, inputs, outputs)

Forward post-hook to replace and cache the embeddings.

forward_pre_hook(module, args, kwargs)

Forward pre-hook to check the cache.

hash_tensor(tensor)

Hashes a tensor.

wrap_module(module, moduleClass, *args, **kwargs)

Wrap a nn.Module with the pre-hook and the post-hook.