diff --git a/.gitignore b/.gitignore index 3b9288f..dd98a8e 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ act_scales/ act_shifts/ pre_quantized_models/ temp.sh - +.venv/ +tmp.ipynb \ No newline at end of file diff --git a/imgs/omniquant.png b/imgs/omniquant.png index 84b92b6..651d0ec 100644 Binary files a/imgs/omniquant.png and b/imgs/omniquant.png differ diff --git a/main.py b/main.py index 63a8c7b..c24b45f 100644 --- a/main.py +++ b/main.py @@ -93,10 +93,10 @@ def evaluate(lm, args, logger): if args.eval_ppl: # for dataset in ["wikitext2", "ptb", "c4","ptb-new",'c4-new']: - for dataset in ["wikitext2", "c4"]: + for dataset in ["wikitext2"]: cache_testloader = f'{args.cache_dir}/testloader_{args.model_family}_{dataset}_all.cache' if os.path.exists(cache_testloader): - testloader = torch.load(cache_testloader) + testloader = torch.load(cache_testloader, weights_only=False) logger.info(f"load calibration from {cache_testloader}") else: dataloader, testloader = get_loaders( @@ -326,7 +326,7 @@ def main(): # load calibration dataset cache_dataloader = f'{args.cache_dir}/dataloader_{args.model_family}_{args.calib_dataset}_{args.nsamples}.cache' if os.path.exists(cache_dataloader): - dataloader = torch.load(cache_dataloader) + dataloader = torch.load(cache_dataloader, weights_only=False) logger.info(f"load calibration from {cache_dataloader}") else: dataloader, _ = get_loaders( @@ -340,8 +340,8 @@ def main(): act_scales = None act_shifts = None if args.let: - act_scales = torch.load(args.act_scales) - act_shifts = torch.load(args.act_shifts) + act_scales = torch.load(args.act_scales, weights_only=False) + act_shifts = torch.load(args.act_shifts, weights_only=False) omniquant( lm, args, diff --git a/models/int_llama_layer.py b/models/int_llama_layer.py index 654e61b..303c471 100644 --- a/models/int_llama_layer.py +++ b/models/int_llama_layer.py @@ -60,6 +60,8 @@ def __init__(self, self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + + self.layer_idx = getattr(org_module, "layer_idx", None) if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -67,7 +69,12 @@ def __init__(self, f" and `num_heads`: {self.num_heads})." ) - self.rotary_emb = copy.deepcopy(org_module.rotary_emb) + if hasattr(org_module, "rotary_emb"): + self.rotary_emb = copy.deepcopy(org_module.rotary_emb) + else: + # In newer versions, it might be in the model config or + # handled globally, so we might set it to None or a placeholder + self.rotary_emb = None self.k_proj = QuantLinear( org_module.k_proj, @@ -103,12 +110,12 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values=None, # Cache + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: bsz, q_len, _ = hidden_states.size() # query_states = self.q_proj(hidden_states) @@ -117,22 +124,22 @@ def forward( query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states =self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # New API for rotary embedding + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_values is not None: + key_states_aligned, value_states_aligned = past_key_values.update( + key_states_aligned, value_states_aligned, self.layer_idx + ) + kv_seq_len = key_states_aligned.shape[-2] # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + # past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -173,10 +180,10 @@ def forward( attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None + # if not output_attentions: + # attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights #, past_key_value def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): # setting weight quantization here does not affect actual forward pass @@ -213,12 +220,14 @@ def __init__(self, def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values=None, # Cache + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -238,13 +247,15 @@ def forward( # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, + past_key_values=past_key_values, use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -256,15 +267,15 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) + # outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) + # if output_attentions: + # outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) + # if use_cache: + # outputs += (present_key_value,) - return outputs + return hidden_states def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): # setting weight quantization here does not affect actual forward pass diff --git a/quantize/omniquant.py b/quantize/omniquant.py index c4af3dd..a54883b 100644 --- a/quantize/omniquant.py +++ b/quantize/omniquant.py @@ -214,7 +214,9 @@ def forward(self, inp, **kwargs): if args.aug_loss: fp_inps_2[j] = qlayer(quant_inps[j].unsqueeze(0), attention_mask=attention_mask,position_ids=position_ids)[0] # init smooth parameters - set_quant_state(qlayer, weight_quant=False, act_quant=True) # weight will be manually quantized before forward + # Issue OpenGVLab/OmniQuant#113 + if args.abits < 16: + set_quant_state(qlayer, weight_quant=False, act_quant=True) # weight will be manually quantized before forward qlayer.let = args.let use_shift = True if is_llama or args.abits == 16: