Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Caching with HBM and latency test #1278

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

yuyanpeng-google
Copy link
Collaborator

@yuyanpeng-google yuyanpeng-google commented Feb 17, 2025

Description

Implements Prefix Caching in HBM and latency test in inference_microbenchmark.

Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache.

Insert longer Key replace shorter key to be the longest common prefix key.
The shorter key will never be returned even if longer key is erased, and should got evicted in the future.

Assume Key is equal length to tokens, which can be used to slice prompt and cache Value.
Should check the return key common prefix length by the caller.

If erase the Key not the leaf, nothing will happen.
If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted.

Value will be moved to the cache, which means cannot used the same value reference after add_to_cache.

The jax may modified the value even stored in another python reference.
If the value need to be used after add_to_cache, make sure copy them before add_to_cache.

Return value copied from cache to avoid modified value in the cache, always copied the value before return.

Add PrefixCaching benchmark test in inference_microbenchmark.

Using half of the prefill_length as the common prefix and save 100 prefix in the cache.

Loading the cache (including jax.array.copy) appears to be independent of the prefill_length (tested with 128 and 1024), even though the saved cache sizes are different.

Using jax.profiler shows that the copy operation consumes a similar amount of time on TPU. This might be because the sizes aren't large or different enough to see a significant impact.

Part of results below

Prefix Cache benchmark results for prefill length 128:

PrefixCaching results:
	Per prefix size bytes: 0.124 GB
	Average save cache time: 12.142 ms
	Average fetch longest prefix time: 0.029 ms
	Average load cache time: 5.589 ms


Prefix Cache benchmark results for prefill length 1024:

PrefixCaching results:
	Per prefix size bytes: 0.220 GB
	Average save cache time: 12.987 ms
	Average fetch longest prefix time: 0.218 ms
	Average load cache time: 5.143 ms

FIXES: b/389788256
TESTED: unittest

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Stores prefix tokens as a trie for fast lookup index of PrefixCache store in cache.

Insert longer Key replace shorter key to be the longest common prefix key.
The shorter key will never be returned even if longer key is erased, and should got evicted in the future.

Assume Key is equal length to tokens, which can be used to slice prompt and cache Value.
Should check the return key common prefix length by the caller.

If erase the Key not the leaf, nothing will happen.
If erased key match at a leaf, delete the node and ancestors would be the leaf after deleted.

TESTED=unittest
Clone to prevent use the same jax array.

TESTED=unittest
Value will be moved to the cache, which means cannot used the same value reference after add_to_cache.

The jax may modified the value even stored in another python reference.
If the value need to be used after add_to_cache, make sure copy them before add_to_cache.

Return value copied from cache to avoid modified value in the cache, always copied the value before return.

TESTED=unittest
jax.profiler shows that tree copy and calculate size consume lots of time in clone.
Need more test dimension and review
@@ -499,6 +499,7 @@ vertex_tensorboard_region: ""
max_checkify: False

# Inference
inference_microbenchmark_cache_num: 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this number of entries in the cache? Can we make it num_entries_in_cache? That way it will not be restricted to microbenchmarking

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number was intended to facilitate inserting various suffixes into the cache easily, but I think that it's not suitable as a hard limitation.

We can optimize memory usage by saving identical prefix values shared across multiple keys. For instance, [1, 2, 3] and [1, 2, 3, 4] can share the same [1, 2, 3] value, so we only need to store one [1, 2, 3, 4] value for both entries. Adding a parameter with a maximum entries limitation could complicate the use case. I believe the optimal cache size should be determined through benchmarking under various scenarios rather than imposing a limit in the API.

)
prefix_size_bytes_gb = value.prefix_size_bytes / 1024 / 1024 / 1024
prefix_cache_inst = prefix_cache.PrefixCache(cache_num * value.prefix_size_bytes)
common_len = prefill_length // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you abstract common_len factor to the config?

assert hbm_cache.retrieve_from_cache((1)) == value1
assert hbm_cache.retrieve_from_cache((2)) == value2

def test_evict_cache(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate this test, add multiple values and then delete?

@yuyanpeng-google yuyanpeng-google marked this pull request as ready for review February 19, 2025 00:28
@yuyanpeng-google yuyanpeng-google changed the title [WIP] Prefix Caching Prefix Caching with HBM and latency test Feb 19, 2025
Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I left few comments.

self._remain_size_bytes = max_size_bytes
self._saved_values: dict[Key, Value] = {}

def is_enough_space_remain(self, value: Value) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename to has_enough_space()

To avoid modified value in the cache, always copied the value before return.
"""
if key in self._saved_values:
return self._saved_values[key].clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks inefficient, you are making copy of a large JAX array.

"""Save key/value to the cache."""
logger.debug("save key=%r", key)
if not self._hbm_cache.is_enough_space_remain(value):
self._evict_cache()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass value.prefix_bytes_size into self._evict_cache() to make check you evicted enough bytes. See below..

self._trie = PrefixCacheTrie()
self._cache_strategy = LRUStrategy()

def _evict_cache(self) -> Optional[Value]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to evict until you make enough space:

def _evict_cache(self, bytes_needed: int) -> Optional[List[Value]]:
  evicted_bytes = 0
  evicted_values = []
  while evicted_bytes < bytes_needed:
     key = self._cache_strategy.evict()
     if key is None:
       logger.debug("no key to evict")
       return None
     logger.debug("evict key=%r", key)
     value = self._hbm_cache.evict_cache(key)
     if value is None:
       logger.warning("key=%r should exist in HBM cache.", key)
     self._trie.erase(key)
     evicted_values.append(value)
     
  return evicted_values

# init in clear()
self._hbm_cache: HBMCache = None
self._trie: PrefixCacheTrie = None
self._cache_strategy: LRUStrategy = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need a lock since the cache will get updated from multiple threads.

  self._lock = threading.RLock()


def save(self, key: Key, value: Value) -> bool:
"""Save key/value to the cache."""
logger.debug("save key=%r", key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grab lock before saving:

  with self.lock:
    if not self._hbm_cache.is_enough_space_remain(value):
      self._evict_cache()
     ...

Comment on lines +314 to +331
class LRUStrategy:
"""Least recently used cache strategy manage key."""

def __init__(self):
self._order: OrderedDict[Key, None] = OrderedDict()

def evict(self) -> Optional[Key]:
"""Return and pop the least recently used key."""
if len(self._order) == 0:
return None
return self._order.popitem(last=False)[0]

def use(self, key: Key) -> None:
"""Updated the usage history."""
if key not in self._order:
self._order[key] = None
else:
self._order.move_to_end(key, last=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for abstracting into separate class. We should pass the LRUStrategy into HBMCache and let HBMCache evict the key based on the strategy.


def load(self, key: Key) -> Optional[Value]:
"""Returns Value stored with key or None if not found."""
logger.debug("load key=%r", key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, grab the lock...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants