-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefix Caching with HBM and latency test #1278
base: main
Are you sure you want to change the base?
Conversation
TESTED=unittest
Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache. Insert longer Key replace shorter key to be the longest common prefix key. The shorter key will never be returned even if longer key is erased, and should got evicted in the future. Assume Key is equal length to tokens, which can be used to slice prompt and cache Value. Should check the return key common prefix length by the caller. If erase the Key not the leaf, nothing will happen. If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted. TESTED=unittest
Clone to prevent use the same jax array. TESTED=unittest
Value will be moved to the cache, which means cannot used the same value reference after add_to_cache. The jax may modified the value even stored in another python reference. If the value need to be used after add_to_cache, make sure copy them before add_to_cache. Return value copied from cache to avoid modified value in the cache, always copied the value before return. TESTED=unittest
jax.profiler shows that tree copy and calculate size consume lots of time in clone.
Need more test dimension and review
MaxText/configs/base.yml
Outdated
@@ -499,6 +499,7 @@ vertex_tensorboard_region: "" | |||
max_checkify: False | |||
|
|||
# Inference | |||
inference_microbenchmark_cache_num: 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this number of entries in the cache? Can we make it num_entries_in_cache
? That way it will not be restricted to microbenchmarking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number was intended to facilitate inserting various suffixes into the cache easily, but I think that it's not suitable as a hard limitation.
We can optimize memory usage by saving identical prefix values shared across multiple keys. For instance, [1, 2, 3] and [1, 2, 3, 4] can share the same [1, 2, 3] value, so we only need to store one [1, 2, 3, 4] value for both entries. Adding a parameter with a maximum entries limitation could complicate the use case. I believe the optimal cache size should be determined through benchmarking under various scenarios rather than imposing a limit in the API.
MaxText/inference_microbenchmark.py
Outdated
) | ||
prefix_size_bytes_gb = value.prefix_size_bytes / 1024 / 1024 / 1024 | ||
prefix_cache_inst = prefix_cache.PrefixCache(cache_num * value.prefix_size_bytes) | ||
common_len = prefill_length // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you abstract common_len factor to the config?
MaxText/tests/prefix_cache_test.py
Outdated
assert hbm_cache.retrieve_from_cache((1)) == value1 | ||
assert hbm_cache.retrieve_from_cache((2)) == value2 | ||
|
||
def test_evict_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you elaborate this test, add multiple values and then delete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I left few comments.
MaxText/prefix_cache.py
Outdated
self._remain_size_bytes = max_size_bytes | ||
self._saved_values: dict[Key, Value] = {} | ||
|
||
def is_enough_space_remain(self, value: Value) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rename to has_enough_space()
To avoid modified value in the cache, always copied the value before return. | ||
""" | ||
if key in self._saved_values: | ||
return self._saved_values[key].clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks inefficient, you are making copy of a large JAX array.
"""Save key/value to the cache.""" | ||
logger.debug("save key=%r", key) | ||
if not self._hbm_cache.is_enough_space_remain(value): | ||
self._evict_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass value.prefix_bytes_size
into self._evict_cache()
to make check you evicted enough bytes. See below..
self._trie = PrefixCacheTrie() | ||
self._cache_strategy = LRUStrategy() | ||
|
||
def _evict_cache(self) -> Optional[Value]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to evict until you make enough space:
def _evict_cache(self, bytes_needed: int) -> Optional[List[Value]]:
evicted_bytes = 0
evicted_values = []
while evicted_bytes < bytes_needed:
key = self._cache_strategy.evict()
if key is None:
logger.debug("no key to evict")
return None
logger.debug("evict key=%r", key)
value = self._hbm_cache.evict_cache(key)
if value is None:
logger.warning("key=%r should exist in HBM cache.", key)
self._trie.erase(key)
evicted_values.append(value)
return evicted_values
# init in clear() | ||
self._hbm_cache: HBMCache = None | ||
self._trie: PrefixCacheTrie = None | ||
self._cache_strategy: LRUStrategy = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need a lock since the cache will get updated from multiple threads.
self._lock = threading.RLock()
|
||
def save(self, key: Key, value: Value) -> bool: | ||
"""Save key/value to the cache.""" | ||
logger.debug("save key=%r", key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grab lock before saving:
with self.lock:
if not self._hbm_cache.is_enough_space_remain(value):
self._evict_cache()
...
class LRUStrategy: | ||
"""Least recently used cache strategy manage key.""" | ||
|
||
def __init__(self): | ||
self._order: OrderedDict[Key, None] = OrderedDict() | ||
|
||
def evict(self) -> Optional[Key]: | ||
"""Return and pop the least recently used key.""" | ||
if len(self._order) == 0: | ||
return None | ||
return self._order.popitem(last=False)[0] | ||
|
||
def use(self, key: Key) -> None: | ||
"""Updated the usage history.""" | ||
if key not in self._order: | ||
self._order[key] = None | ||
else: | ||
self._order.move_to_end(key, last=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for abstracting into separate class. We should pass the LRUStrategy
into HBMCache
and let HBMCache
evict the key based on the strategy.
|
||
def load(self, key: Key) -> Optional[Value]: | ||
"""Returns Value stored with key or None if not found.""" | ||
logger.debug("load key=%r", key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, grab the lock...
…ive limit in long context
Description
Implements Prefix Caching in HBM and latency test in inference_microbenchmark.
Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache.
Insert longer Key replace shorter key to be the longest common prefix key.
The shorter key will never be returned even if longer key is erased, and should got evicted in the future.
Assume Key is equal length to tokens, which can be used to slice prompt and cache Value.
Should check the return key common prefix length by the caller.
If erase the Key not the leaf, nothing will happen.
If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted.
Value will be moved to the cache, which means cannot used the same value reference after add_to_cache.
The jax may modified the value even stored in another python reference.
If the value need to be used after add_to_cache, make sure copy them before add_to_cache.
Return value copied from cache to avoid modified value in the cache, always copied the value before return.
Add PrefixCaching benchmark test in inference_microbenchmark.
Using half of the prefill_length as the common prefix and save 100 prefix in the cache.
Loading the cache (including jax.array.copy) appears to be independent of the prefill_length (tested with 128 and 1024), even though the saved cache sizes are different.
Using jax.profiler shows that the copy operation consumes a similar amount of time on TPU. This might be because the sizes aren't large or different enough to see a significant impact.
Part of results below
FIXES: b/389788256
TESTED: unittest
Checklist
Before submitting this PR, please make sure (put X in square brackets):