Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new feature: INT4 projection matrix #64

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion galore_torch/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def step(self, closure=None):
if "rank" in group:
if "projector" not in state:
if group['dim'] <=2:
state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"], proj_quant=group["proj_quant"])
else:
state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])

Expand Down
2 changes: 1 addition & 1 deletion galore_torch/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def step(self, closure: Callable = None):
if "rank" in group:
if "projector" not in state:
if group['dim'] <=2:
state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"], proj_quant=group["proj_quant"])
else:
state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
grad = state["projector"].project(grad, state["step"])
Expand Down
2 changes: 1 addition & 1 deletion galore_torch/adamw8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def step(self, closure=None):
if "rank" in group:
if "projector" not in state:
if group['dim'] <= 2:
state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"], proj_quant=group["proj_quant"])
else:
state["projector"] = GaLoreProjectorTensor(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
if 'weight_decay' in group and group['weight_decay'] > 0:
Expand Down
111 changes: 103 additions & 8 deletions galore_torch/galore_projector.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
import torch

class GaLoreProjector:
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std', proj_quant=False):
self.rank = rank
self.verbose = verbose
self.update_proj_gap = update_proj_gap
self.scale = scale

self.ortho_matrix = None
self.ortho_matrix_scales = None
self.ortho_matrix_zeros = None
self.ortho_matrix_shape = None

self.proj_type = proj_type

self.proj_quant = proj_quant
self.quant_group_size = 256
self.quant_n_bit = 4

def project(self, full_rank_grad, iter):
# TODO: implement the quantizated projection for other proj_type, currently only support std

if self.proj_type == 'std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t().to(full_rank_grad.device.type))
self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')

if self.proj_quant:
float_ortho_matrix = self.unpack_int4_projection()
else:
float_ortho_matrix = self.ortho_matrix

low_rank_grad = torch.matmul(full_rank_grad, float_ortho_matrix.t().to(full_rank_grad.device.type))
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
low_rank_grad = torch.matmul(self.ortho_matrix.t().to(full_rank_grad.device.type), full_rank_grad)
self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')

if self.proj_quant:
float_ortho_matrix = self.unpack_int4_projection()
else:
float_ortho_matrix = self.ortho_matrix

low_rank_grad = torch.matmul(float_ortho_matrix.t().to(full_rank_grad.device.type), full_rank_grad)

elif self.proj_type == 'reverse_std':
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
Expand All @@ -44,11 +68,21 @@ def project(self, full_rank_grad, iter):
return low_rank_grad

def project_back(self, low_rank_grad):
# TODO: implement the quantizated projection for other proj_type, currently only support std
if self.proj_type == 'std':
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix.to(low_rank_grad.device.type))
if self.proj_quant:
float_ortho_matrix = self.unpack_int4_projection()
else:
float_ortho_matrix = self.ortho_matrix
full_rank_grad = torch.matmul(low_rank_grad, float_ortho_matrix.to(low_rank_grad.device.type))
else:
full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad)
if self.proj_quant:
float_ortho_matrix = self.unpack_int4_projection()
else:
float_ortho_matrix = self.ortho_matrix
full_rank_grad = torch.matmul(float_ortho_matrix.to(low_rank_grad.device.type), low_rank_grad)

elif self.proj_type == 'reverse_std':
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix.to(low_rank_grad.device.type), low_rank_grad)
Expand All @@ -64,7 +98,6 @@ def project_back(self, low_rank_grad):

return full_rank_grad * self.scale


# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
Expand All @@ -85,11 +118,24 @@ def get_orthogonal_matrix(self, weights, rank, type):
B = Vh[:rank, :]
if not float_data:
B = B.to(original_device).type(original_type)

if self.proj_quant:
self._quantize(B, q_group_size=self.quant_group_size, n_bit=self.quant_n_bit)
else:
self.ortho_matrix = B

return B

elif type=='left':
A = U[:, :rank]
if not float_data:
A = A.to(original_device).type(original_type)

if self.proj_quant:
self._quantize(A, q_group_size=self.quant_group_size, n_bit=self.quant_n_bit)
else:
self.ortho_matrix = A

return A
elif type=='full':
A = U[:, :rank]
Expand All @@ -100,3 +146,52 @@ def get_orthogonal_matrix(self, weights, rank, type):
return [A, B]
else:
raise ValueError('type should be left, right or full')

def _quantize(self, w, q_group_size=-1, n_bit=4):
org_w_shape = w.shape
if q_group_size > 0:
assert w.nelement() % q_group_size == 0
w = w.reshape(-1, q_group_size)

assert w.dim() == 2

max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0

w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int).to(torch.uint8)
packed_w = self.pack_uint8_to_int4(w)

self.ortho_matrix = packed_w
self.ortho_matrix_scales = scales
self.ortho_matrix_zeros = zeros
self.ortho_matrix_shape = org_w_shape

def pack_uint8_to_int4(self,tensor):
reshaped = tensor.view(tensor.shape[0], -1, 2)
packed = (reshaped[:, :, 0] & 0x0F) | ((reshaped[:, :, 1] & 0x0F) << 4)
return packed

def unpack_int4_projection(self):
packed_tensor = self.ortho_matrix
unpacked_low = packed_tensor & 0x0F
unpacked_high = (packed_tensor >> 4) & 0x0F
unpacked = torch.stack([unpacked_low, unpacked_high], dim=-1).view(packed_tensor.shape[0], -1)

float_ortho_matrix = self.ortho_matrix_scales * (unpacked.to(self.ortho_matrix_scales.dtype) - self.ortho_matrix_zeros)
float_ortho_matrix = float_ortho_matrix.reshape(self.ortho_matrix_shape)
return float_ortho_matrix








16 changes: 16 additions & 0 deletions scripts/benchmark_c4/llama_130m_quant_proj.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# LLaMA-130M, GaLore-Adam, 1 A100, 1 Node
torchrun --standalone --nproc_per_node 1 torchrun_main.py \
--model_config configs/llama_130m.json \
--lr 0.01 \
--galore_scale 0.25 \
--rank 256 \
--update_proj_gap 200 \
--batch_size 256 \
--total_batch_size 512 \
--num_training_steps 20000 \
--warmup_steps 2000 \
--weight_decay 0 \
--dtype bfloat16 \
--eval_every 1000 \
--optimizer galore_adamw \
--proj_quant
5 changes: 3 additions & 2 deletions torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def parse_args(args):
parser.add_argument("--update_proj_gap", type=int, default=50)
parser.add_argument("--galore_scale", type=float, default=1.0)
parser.add_argument("--proj_type", type=str, default="std")

parser.add_argument("--proj_quant", default=False, action="store_true")

# disable ddp, single_gpu
parser.add_argument("--single_gpu", default=False, action="store_true")

Expand Down Expand Up @@ -278,7 +279,7 @@ def preprocess_batched(batch):
regular_params = [p for p in model.parameters() if id(p) not in id_galore_params]
# then call galore_adamw
param_groups = [{'params': regular_params},
{'params': galore_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type}]
{'params': galore_params, 'rank': args.rank, 'update_proj_gap': args.update_proj_gap, 'scale': args.galore_scale, 'proj_type': args.proj_type, 'proj_quant': args.proj_quant}]

# print params and trainable params
logger.info(f"\n{model}\n")
Expand Down