torchcache.torchcache._TorchCache¶
- class _TorchCache(*, memory_cache_device: str, subsample_count: int, persistent: bool, persistent_cache_dir: str, persistent_module_hash: str, max_persistent_cache_size: int, max_memory_cache_size: int, zstd_compression: bool, zstd_compression_level: int, zstd_compression_threads: int, cache_dtype: dtype, use_mmap_on_load: bool)¶
Class that implements the caching logic.
Do not initialize this class directly, use the torchcache decorator instead.
- __init__(*, memory_cache_device: str, subsample_count: int, persistent: bool, persistent_cache_dir: str, persistent_module_hash: str, max_persistent_cache_size: int, max_memory_cache_size: int, zstd_compression: bool, zstd_compression_level: int, zstd_compression_threads: int, cache_dtype: dtype, use_mmap_on_load: bool)¶
Initialize the torchcache.
Methods
__init__(*, memory_cache_device, ...)Initialize the torchcache.
cache_cleanup()forward_hook(module, inputs, outputs)Forward post-hook to replace and cache the embeddings.
forward_pre_hook(module, args, kwargs)Forward pre-hook to check the cache.
hash_tensor(tensor)Hashes a tensor.
wrap_module(module, moduleClass, *args, **kwargs)Wrap a nn.Module with the pre-hook and the post-hook.