Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ act_scales/
act_shifts/
pre_quantized_models/
temp.sh

.venv/
tmp.ipynb
Binary file modified imgs/omniquant.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def evaluate(lm, args, logger):

if args.eval_ppl:
# for dataset in ["wikitext2", "ptb", "c4","ptb-new",'c4-new']:
for dataset in ["wikitext2", "c4"]:
for dataset in ["wikitext2"]:
cache_testloader = f'{args.cache_dir}/testloader_{args.model_family}_{dataset}_all.cache'
if os.path.exists(cache_testloader):
testloader = torch.load(cache_testloader)
testloader = torch.load(cache_testloader, weights_only=False)
logger.info(f"load calibration from {cache_testloader}")
else:
dataloader, testloader = get_loaders(
Expand Down Expand Up @@ -326,7 +326,7 @@ def main():
# load calibration dataset
cache_dataloader = f'{args.cache_dir}/dataloader_{args.model_family}_{args.calib_dataset}_{args.nsamples}.cache'
if os.path.exists(cache_dataloader):
dataloader = torch.load(cache_dataloader)
dataloader = torch.load(cache_dataloader, weights_only=False)
logger.info(f"load calibration from {cache_dataloader}")
else:
dataloader, _ = get_loaders(
Expand All @@ -340,8 +340,8 @@ def main():
act_scales = None
act_shifts = None
if args.let:
act_scales = torch.load(args.act_scales)
act_shifts = torch.load(args.act_shifts)
act_scales = torch.load(args.act_scales, weights_only=False)
act_shifts = torch.load(args.act_shifts, weights_only=False)
omniquant(
lm,
args,
Expand Down
87 changes: 49 additions & 38 deletions models/int_llama_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,21 @@ def __init__(self,
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings

self.layer_idx = getattr(org_module, "layer_idx", None)

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

self.rotary_emb = copy.deepcopy(org_module.rotary_emb)
if hasattr(org_module, "rotary_emb"):
self.rotary_emb = copy.deepcopy(org_module.rotary_emb)
else:
# In newer versions, it might be in the model config or
# handled globally, so we might set it to None or a placeholder
self.rotary_emb = None

self.k_proj = QuantLinear(
org_module.k_proj,
Expand Down Expand Up @@ -103,12 +110,12 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
attention_mask: torch.Tensor | None = None,
past_key_values=None, # Cache
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
bsz, q_len, _ = hidden_states.size()

# query_states = self.q_proj(hidden_states)
Expand All @@ -117,22 +124,22 @@ def forward(
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states =self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# New API for rotary embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)

if past_key_values is not None:
key_states_aligned, value_states_aligned = past_key_values.update(
key_states_aligned, value_states_aligned, self.layer_idx
)
kv_seq_len = key_states_aligned.shape[-2]


# [bsz, nh, t, hd]

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
# past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down Expand Up @@ -173,10 +180,10 @@ def forward(

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
# if not output_attentions:
# attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output, attn_weights #, past_key_value

def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
# setting weight quantization here does not affect actual forward pass
Expand Down Expand Up @@ -213,12 +220,14 @@ def __init__(self,
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None, # Cache
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
Expand All @@ -238,13 +247,15 @@ def forward(


# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states

Expand All @@ -256,15 +267,15 @@ def forward(
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
# outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)
# if output_attentions:
# outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)
# if use_cache:
# outputs += (present_key_value,)

return outputs
return hidden_states

def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
# setting weight quantization here does not affect actual forward pass
Expand Down
4 changes: 3 additions & 1 deletion quantize/omniquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def forward(self, inp, **kwargs):
if args.aug_loss:
fp_inps_2[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0]
# init smooth parameters
set_quant_state(qlayer, weight_quant=False, act_quant=True) # weight will be manually quantized before forward
# Issue OpenGVLab/OmniQuant#113
if args.abits < 16:
set_quant_state(qlayer, weight_quant=False, act_quant=True) # weight will be manually quantized before forward
qlayer.let = args.let
use_shift = True
if is_llama or args.abits == 16:
Expand Down