o
    2hW                     @   s  d dl Z d dlZd dlZd dlZd dlm  mZ d dlm	Z
 d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZ d
d Zdd Zdd Zdd Zdd Zdd Zdd Zdd ZdddZ dd Z!dd  Z"dd"d#Z#d$d% Z$dd'd(Z%d)d* Z&d+d, Z'dd.d/Z(d0d1 Z)dd3d4Z*dd5d6Z+dd8d9Z,d:d; Z-dd<d=Z.d>d? Z/dd@dAZ0ddBdCZ1ddDdEZ2	F	2ddGdHZ3		IddJdKZ4		I	ddLdMZ5		I	ddNdOZ6	F	PddQdRZ7	S	I		SddTdUZ8	S	I		SddVdWZ9	S	I		SddXdYZ:	S	I			SddZd[Z;dd\d]Z<dd^d_Z=dd`daZ>ddbdcZ?ddddeZ@ddfdgZA	hddidjZBddkdlZC	2	ddmdnZD	o	S	ddpdqZE	r	o	S	2	 ddsdtZFdudv ZGddwdxZHdydz ZId{d| ZJ			S	Sdd}d~ZK				P		dddZLdS )    N)lax)nn)splash_attention_kernel)splash_attention_mask)backend)+compute_conv_transpose_padding_args_for_jax)cast)convert_to_tensorc                 C      t | } t| S N)r	   jnnrelux r   S/var/www/html/chatgem/venv/lib/python3.10/site-packages/keras/src/backend/jax/nn.pyr         
r   c                 C   r
   r   )r	   r   relu6r   r   r   r   r      r   r   c                 C   r
   r   )r	   r   sigmoidr   r   r   r   r   #   r   r   c                 C   r
   r   )r	   r   sparse_sigmoidr   r   r   r   r   (   r   r   c                 C   r
   r   )r	   r   tanhr   r   r   r   r   -   r   r   c                 C   s   t | } | t|  S r   )r	   jnpr   r   r   r   r   tanh_shrink2      r   c                 C   r
   r   )r	   r   softplusr   r   r   r   r   7   r   r   c                 C   r
   r   )r	   r   	soft_signr   r   r   r   softsign<   r   r         ?c              	   C   s2   t | } t| |k| | t| | k | | dS N        r	   r   wherer   	thresholdr   r   r   soft_shrinkA   s   r$   c                 C   r
   r   )r	   r   sparse_plusr   r   r   r   r%   J   r   r%   c                 C   r
   r   )r	   r   silur   r   r   r   r&   O   r   r&      c                 C      t | } tj| |dS )N)b)r	   r   
squareplus)r   r)   r   r   r   r*   T   r   r*   c                 C   r
   r   )r	   r   log_sigmoidr   r   r   r   r+   Y   r   r+   皙?c                 C   r(   )N)negative_slope)r	   r   
leaky_relu)r   r-   r   r   r   r.   ^   r   r.   c                 C   r
   r   )r	   r   hard_sigmoidr   r   r   r   r/   c   r   r/   c                 C   r
   r   )r	   r   	hard_silur   r   r   r   r0   h   r   r0         ?c                 C   r(   N)alpha)r	   r   elur   r3   r   r   r   r4   m   r   r4   c                 C   r
   r   )r	   r   selur   r   r   r   r6   r   r   r6   Tc                 C   s   t | } t| |S r   )r	   r   gelu)r   approximater   r   r   r7   w   s   r7   c                 C   r(   r2   )r	   r   celur5   r   r   r   r9   |   r   r9   c                 C   r(   Naxis)r	   r   glur   r=   r   r   r   r>      r   r>   c                 C   r
   r   )r	   r   	hard_tanhr   r   r   r   r@      r   r@   c                 C   s    t | } tt| |k| dS r   )r	   r   r!   absr"   r   r   r   hard_shrink   s   rB   c                 C   s   t | } t| |k| |S r   r    )r   r#   default_valuer   r   r   r#      s   r#   c                 C   r(   r;   )r	   r   softmaxr?   r   r   r   rD      r   rD   c                 C   r(   r;   )r	   r   log_softmaxr?   r   r   r   rE      r   rE   c                 C   s   t | }dtj|d |d }tj||d}td|j| d }dg|j }d||< ||}||d |  dk}tj||dd}t	||d}	tj|	|ddd | }
t
||
 d}|S )	N      r<      r:   r   Tr=   keepdimsr   )r	   r   sortcumsumarangeshapendimreshapesumr!   maximum)r   r=   logitslogits_sortedlogits_cumsumrr_shapesupportklogits_cumsum_safetauoutputr   r   r   	sparsemax   s   
r\   channels_lastc                 C   sD   t | tr
| f| n| } |s| S |dkrd|  d } | S d|  } | S )Nr]   rG   )rG   rG   )
isinstanceint)r   num_spatial_dimsdata_formatinclude_batch_and_channelsr   r   r   _convert_to_spatial_operand   s   rd   validc                 C   s4   |dvrt d| d| }t| |||||S )aC  Helper function to define pooling functions.

    Args:
        inputs: input data of shape `N+2`.
        initial_value: the initial value for the reduction.
        reduce_fn: a reduce function of the form `(T, T) -> T`.
        pool_size: a sequence of `N` integers, representing the window size to
            reduce over.
        strides: a sequence of `N` integers, representing the inter-window
            strides (default: `(1, ..., 1)`).
        padding: either the string `same` or `valid`.

    Returns:
        The output of the reduction for each window slice.
    )samere   zInvalid padding 'z', must be 'same' or 'valid'.)
ValueErrorupperr   reduce_window)inputsinitial_value	reduce_fn	pool_sizestridespaddingr   r   r   _pool   s   
rp   c                 C   sT   t |}| jd }t|||}|d u r|n|}t|||}t| tj tj|||S )N   )	r   standardize_data_formatrN   rd   rp   r   infr   max)rj   rm   rn   ro   rb   ra   r   r   r   max_pool   s   

ru   c           	      C   s   t |}| jd }t|||}|d u r|n|}t|||}t| dtj|||}|dkr3|t| S dd t	| j
|D }tt|| jdtj|||}|| S )Nrq   r   re   c                 S   s    g | ]\}}|d kr|nd qS r^   r   ).0ar)   r   r   r   
<listcomp>  s    z average_pool.<locals>.<listcomp>)r   rr   rN   rd   rp   r   addmathprodziprM   r   onesdtype)	rj   rm   rn   ro   rb   ra   pooledrM   window_countsr   r   r   average_pool   s0   


r   Fc                 C   s   | d }|dkrt td|d }d|d f| }nt td|}d| }|r8|d |d ft t|d  }n|d |d ft t|d  }tj|||dS )z9Create a `lax.ConvDimensionNumbers` for the given inputs.rq   r]   rG   r   r   rG   )lhs_specrhs_specout_spec)tupleranger   ConvDimensionNumbers)ra   rb   	transposenum_dimsspatial_dims	inputs_dn	kernel_dnr   r   r   &_convert_to_lax_conv_dimension_numbers#  s   " r   rG   c              	   C   s   t |}| jd }t||dd}t|||dd}t|||dd}|dkr+| jd }n| jd }|jd }	||	 d	krFtd
| d|	 d||	 }
t|}t| |jd} t	j
j| ||||||
dS )Nrq   Fr   rc   r]   r:   rG   r   zgThe number of input channels must be evenly divisible by kernel's in_channels. Received input channels z and kernel in_channels z. r~   rhs_dilationdimension_numbersfeature_group_count)r   rr   rN   r   rd   rM   rg   r	   r~   jaxr   conv_general_dilated)rj   kernelrn   ro   rb   dilation_ratera   r   channelskernel_in_channelsr   r   r   r   conv<  sV   



r   c           	   	   C   s   t |}| jd }t||dd}t|||dd}t|||dd}|dkr*| jd n| jd }t||jd d d||jd  f }tj	j
| ||||||d	S )
Nrq   Fr   r   r]   r:   rG   r   r   )r   rr   rN   r   rd   rM   r   rO   r   r   r   )	rj   r   rn   ro   rb   r   ra   r   r   r   r   r   depthwise_convp  sD   

r   c                 C   s0   t |}t| |||||}t||dd||dS )NrG   re   )rn   ro   rb   r   )r   rr   r   r   )rj   depthwise_kernelpointwise_kernelrn   ro   rb   r   depthwise_conv_outputr   r   r   separable_conv  s"   
	r   c           
   	   C   st   t |}| jd }t| j|j||||d}t||dd}	t|||dd}t|||dd}tjj	| |||||	ddS )Nrq   )input_shapekernel_shapern   ro   output_paddingr   Fr   r   T)ro   r   r   transpose_kernel)
r   rr   rN   r   rM   r   rd   r   r   conv_transpose)
rj   r   rn   ro   r   rb   r   ra   padding_valuesr   r   r   r   r     sH   
	
r   c                    s   t | } |rn|dk r|t| j d }|d u rd}tt| d|}|jd  dd | jD }tj|ddi}||t	| d  fdd|D }tj
|dd	}t| j}||| t|}tj||f|d
d
dS tj| |||dS )Nr   rG   float32c                 S   s   g | ]}t |qS r   )r   rL   )rv   dimr   r   r   rx     s    zone_hot.<locals>.<listcomp>indexingijc                    s   g | ]}|  d dqS )rG   int32)rO   astype)rv   rw   values_countr   r   rx     s    r<   TrM   indices_sortedunique_indicesr=   r~   )r	   lenrM   r   greater_equalravelr   meshgridinsertrQ   concatenatelistr   
jax_sparseBCOOr   one_hot)r   num_classesr=   r~   sparsevaluesindicesrM   r   r   r   r     s.   

r   c                 C   s   t | } t| jdkrdnd}|r>t| ||d|d}tj||fd}t|}t|j	d
|}tj||jf|jdddS tjtt| d|||d|d	S )
NrG   r   r   )r=   r~   r   axesTr   r   r<   )r	   r   rM   r   r   bcoo_reduce_sumbcoo_sum_duplicatesr   r   datar   r   r   rt   r   )r   r   r=   r~   r   reduction_axisresultr   r   r   r   	multi_hot  s&   

r   c                 C   s   t | } t |}| j|jkrtd| j d|j t| jdk r/td| j d|j |r:tjj||d}n|t j||dd }t 	|t
 dt
  }t |}t j| | |d S )	NQArguments `target` and `output` must have the same shape. Received: target.shape=, output.shape=rG   zPArguments `target` and `output` must be at least rank 1. Received: target.shape=r<   TrI   r1   )r   arrayrM   rg   r   r   r   rE   rP   clipr   epsilonlogtargetr[   from_logitsr=   log_probr   r   r   categorical_crossentropy  s0   


r   c                 C   s
  t j| dd} t |}t| jt|jkr$| jd dkr$t j| dd} t|jdk r3td|j | j|jd d krItd| j d|j |rTtjj||d}n|t j	||d	d
 }t 
|t dt  }t |}tj| |j| |d} t j	| | |d S )Nr   r   r:   rG   r<   zBArgument `output` must be at least rank 1. Received: output.shape=zcArguments `target` and `output` must have the same shape up until the last dimension: target.shape=r   Tr   r1   )r   r   r   rM   squeezerg   r   r   rE   rP   r   r   r   r   r   r   r   r   r   r   sparse_categorical_crossentropy6  s2   
"
r   c                 C   s   t | } t |}| j|jkrtd| j d|j |r7tj|}tj| }d|  | d|  |  S t |t	 dt	  }| t 
| }|d|  t 
d|  7 }| S )Nr   r   rF   r1   )r   r   rM   rg   r   r   r+   r   r   r   r   )r   r[   r   
log_logitslog_neg_logitsbcer   r   r   binary_crossentropyR  s$   

r   c                 C   s   |rt dd}t| j}|dv rd}t| d} tj| |dd}tj| |dd}|s7t||}t||}|ret	|t
tjjt
tjj}t	|t
tjjt
tjj}t||}t||}||fS )Nz5Argument synchronized=True is not supported with JAX.F)float16bfloat16Tr   r   rH   )NotImplementedErrorr   standardize_dtyper~   r   r   meanvarr   r   finfor   minrt   )r   r   rI   synchronized	need_cast	ori_dtyper   variancer   r   r   momentsh  s0   


r   MbP?c           
      C   s   dgt | j }|jd ||< t||}t||}tj|| }|d ur1t||}|| }| | }	|d urDt||}|	| }	t| | |	S )NrG   r   )r   rM   r   rO   r   r   rsqrtry   )
r   r   r   r=   offsetscaler   rM   invresr   r   r   batch_normalization  s   
r   c                    s  t | dd} t |}t |d}t |d}|j\ }}| j\ }dt|jd}t||}dd }	|	||}
|	||}|
|j}
||j}t|}|t	j
|
ddt	j }| d d d d	f | d d dd f kt	jt	d
|d d d d ||d f }t	|d}tjj| |d}t	d||}t	|d}t	j |d f|jd }|jd d df d}t	j |f|jd }dd  fdd}|||df}tj|||f|\}\}}|d	 |d	 }|jd	 |}tjj||d d}t	d|| }|S )Nr   r   g     jr   c                 S   s<   t |d| j |f }t j| dd} || k }t |S )Nr^   r:   r<   )r   rL   rO   rN   expand_dimslogical_not)lengths
max_lengthr   
elem_validr   r   r   _lengths_to_paddings  s   

z&ctc_loss.<locals>._lengths_to_paddingsrG   r<   r:   )r   r   r   rG   r   rq   )r   zbtk,bnk->btnr   r   c                 S   s:   t j| d d d df t | d d dd f |gddS )NrG   r:   r<   )r   r   	logaddexp)phiadded_scorer   r   r   update_phi_score  s   0z"ctc_loss.<locals>.update_phi_scorec           
         s   | \}}|}||  }|\}}}t |d d d df | || }|| }	|	|| d   }	| df}|| d| |  }|| d| |	  }	|	|f|	|ffS )Nr:   r1   rG   )r   r   rO   )
prevr   prev_phi	prev_emitprev_phi_origlogprob_emitlogprob_phipad	next_emitnext_phi
batch_sizelog_epsilonrepeatr   r   r   	loop_body  s   
zctc_loss.<locals>.loop_bodyrG   r   zbn,bn->b)r	   rM   r   result_typer~   r   r   r   rE   r   rP   r   r   r  r   r   r   r   einsumr}   atsetr   scan)r   r[   target_lengthoutput_length
mask_indexmax_input_lengthr   max_label_lengthr~   r   target_paddingsoutput_paddingslogprobslabel_lengthslogprobs_phi_one_hotlogprobs_emitlogalpha_phi_initlogalpha_emit_initr
  xs_logalpha_philogalpha_emitlogalpha_phi_lastper_seq_lossr   r  r   ctc_loss  sZ   






0
r%  c                 C   sp  t | } t |dd}| j\}}}|d u r|d }tj| dd}tj| dd}t|d d d f }	|	|d d d f k}	t|	||}t|	d|}|rl|d d dd f |d d d df k}
t|
d}
t|
||}||k}t|d|}tjt|dd}t	||df}t|||}tj
|dd}tj||dd}tj|ddd d d f  }tj|dd}||fS )	Nr   r   rG   r:   r<   r   )r   r  r   )r	   rM   r   argmaxrt   rL   r!   r  r   tileargsorttake_along_axisrP   )rj   sequence_lengthsmerge_repeatedr  r  r   r   r   scoresseqlen_maskrepeat_maskinvalid_maskorderr   r   r   _ctc_greedy_decode  s2   (r1  d   c                    s  t | } t |}| j\}}t| } t|d d d f |d d d f k}d u r.d tj| dd}  d dtj|d |ftjd}t	
}	tj| d d df ddd d |	 d f }
t|
k|
}|jd d d |	df |}tj|d ftj | jdjd d d |	f tj| d d df |
dd}|d d d d df k}fddd	d
 fddfdd  fdd	fdd}t||||| |\}}t|k| d }t|g d}||fS )NrG   rq   r<   r:   r   r   c           
         s   t j| dd} t |}t |}t j|  kdd}t d  }| ||d f }t |dk |}t j  }t |d }|}| k}| ||k@ }	t |	 |}| j||f |} t |d }|| }| ||fS )Nr   r<   rG   rq   )r   r	  r&  rL   r!   r  r  r'  )
pathsr,  maskedr   path_tail_indexpaths_arange
path_tailsclassesprev_maskedmasked_repeat)_pad
beam_widthr  r   r   r   _extend_pathsY  s"   
z._ctc_beam_search_decode.<locals>._extend_pathsc                 S   s@   t |}t || }t |j|  |}t || }|S r   )r   rt   exp
zeros_liker  ry   r   )unique_inverser,  
scores_max
scores_expr   r   r   _merge_scoresr  s
   
z._ctc_beam_search_decode.<locals>._merge_scoresc                    s   t j| dd  dd\} }t|jdkrt j|dd}t |t j |}t ||t j } ||} ||}t ||}t | d  }| | } || }|| }t 	| d} t 
||g}t 
t tt tg}| ||fS )NTrq   r   return_inversesizer=   
fill_valuerG   r<   )rq   rG   )r   uniquer   rM   r   r!   rs   r   r(  r'  r   zerosboolr}   )r3  r,  r4  r@  emit_scoresmask_scorestotal_scorestop_indices)rC  r;  r<  r   r   r   _prune_pathsy  s0   




z-_ctc_beam_search_decode.<locals>._prune_pathsc                    s0    | |||\} }}| ||\} }}| ||fS r   r   r3  r,  r4  r   )r=  rO  r   r   _decode_step  s   
z-_ctc_beam_search_decode.<locals>._decode_stepc              	      s@   | \}}}|\}}t |dd  ||||\}}}|||fd fS )Nc                 S   s
   | ||fS r   r   rP  r   r   r   <lambda>  s   
 z8_ctc_beam_search_decode.<locals>._step.<locals>.<lambda>)r   cond)r   r   r3  r,  r4  r-  )rQ  r   r   _step  s   


z&_ctc_beam_search_decode.<locals>._stepc                    s   t | ||f|dd  |dd  f\\}}}}tj|dd  dd\}}	t|	jdkr7tj|	dd}	 |	|}t| d  d d d }
||
 }||
 }||fS )NrG   Trq   r   rD  r<   r:   )r   r  r   rH  r   rM   r   r(  )
init_pathsinit_scoresinit_maskedrj   r-  r3  r,  r4  r   r@  rN  )rC  r;  rT  r<  r   	top_pathsr   r   _decode_batch  s&   


z._ctc_beam_search_decode.<locals>._decode_batchr   )r	   rM   r   rE   r   rL   flipfullr   builtinsr   r(  r!   r  r  rs   r~   r)  r   vmapr   )rj   r*  r<  rX  r  r  max_seq_lenr-  rU  num_init_pathsmax_classesinit_classesrV  rW  rY  r3  r,  r   )
rQ  r=  rC  r;  rO  rT  r<  r  r   rX  r   _ctc_beam_search_decode/  sF   
&, 
rb  greedyc                 C   sb   t | } t| jd}t| |} |dkrt| |||dS |dkr)t| ||||dS td| d)Nr   rc  )r+  r  beam_search)r<  rX  r  zInvalid strategy z2. Supported values are 'greedy' and 'beam_search'.)r	   r   r  r~   r   r1  rb  rg   )rj   r*  strategyr<  rX  r+  r  r~   r   r   r   
ctc_decode  s*   	

rf  c                 C   sh   | j |j krtd| j  d|j  dt||jd}tt| | }dt| dt|  }|S )NzInput shapes z and z" must match for PSNR calculation. r      
   )rM   rg   r	   r~   r   r   squarelog10)x1x2max_valmsepsnrr   r   r   ro    s   ro  c                 C   s$  z&ddl m} ddl m} ddl m} ddl m} ddl m}	 ddlm}
 W n ty6   |r3tdY d	S w t	
 d jd
krBdS zF| }|dsNtdtt|	j }dD ]}|| qZdd |D }|	| |||fd|di| || ||d||dud	d W dS    |r Y d	S )z+Verify the availability of flash attention.r   )_normalize_layout)check_compute_capability)check_cudnn_version)check_is_flash_attention)check_layout)dot_product_attentionFlash attention is not supported in your current JAX version. Please update it by following the official guide: https://jax.readthedocs.io/en/latest/installation.htmlFtpuTz8.0z#Require at least Ampere arch to run)querykeyvaluebiaslayoutc                 S   s   i | ]}|d qS r   r   )rv   ry  r   r   r   
<dictcomp>(  s    z,_can_use_flash_attention.<locals>.<dictcomp>r|  BTNHN)is_training)(jax._src.cudnn.fused_attention_stablehlorp  rq  rr  rs  rt  jax.nnru  ImportErrorr   devicesplatformRuntimeErrorr   inspect	signature
parameterskeysremove)rx  ry  rz  r{  raise_errorrp  rq  rr  rs  rt  ru  cudnn_versioncheck_layout_paramsknown_paramkwargsr   r   r   _can_use_flash_attention  s\   	r  c                 C   s   |d u r|s| S t j| dd}|d urt ||}|rD| jd | jd }}t t j||fdd}|d d d d d d f }t ||}t jdt | jj	 | jd}t 
|| |}|S )NrJ  r   rq      gffffff)r   	ones_likelogical_andrM   trilr}   asarrayr   r~   rt   r!   )rR   mask	is_causalcombined_maskTSlarge_negative_numberpadded_logitsr   r   r   _apply_masks;  s   r  c                 C   s   t | jt j}t jd| ||d}|t j||jd9 }|d ur'|| |j}t|||}	|	t j}	tj	j
|	dd|j}
t d|
|S )NzBTNH,BSNH->BNTS)preferred_element_typer   r:   r<   zBNTS,BSNH->BTNH)r   promote_typesr~   r   r  r   r   r  r   r   rD   )rx  ry  rz  r{  r  r  r   logits_dtyperR   r  probsr   r   r   _dot_product_attention_coreP  s   r  c                 C   s   |dur| j d |jj d ksJ d|durtj|d}ntj| j d | j d fd}tj|f| j d  d}	tj|	|||d}
t	|
| |||d	S )
a  Applies a wrapped flash attention mechanism using the Splash kernel.
    This function prepares the appropriate attention mask (causal or custom),
    constructs a multi-head mask, and applies the Splash multi-head attention
    kernel to the provided query, key, and value tensors. It supports optional
    sharding and soft capping of attention logits.
    Args:
        query: jax.Array. The query tensor of shape
            (batch, num_heads, seq_len, head_dim).
        key: jax.Array. The key tensor of shape
            (batch, num_heads, seq_len, head_dim).
        value: jax.Array. The value tensor of shape
            (batch, num_heads, seq_len, head_dim).
        decoder_segment_ids: Optional. Segment IDs for the decoder, used for
            sharding or masking.
        custom_mask: Optional[jax.Array]. A custom attention mask to apply. If
            None, a causal mask is used.
        attn_logits_soft_cap: Optional[float]. If provided, applies a soft cap
            to the attention logits.
        head_shards: int, default=1. Number of shards for the attention heads.
        q_seq_shards: int, default=1. Number of shards for the query sequence
            dimension.
    Returns:
        jax.Array: The result of applying the Splash multi-head attention
            kernel to the inputs.
    Raises:
        AssertionError: If sharding along the sequence dimension is attempted
            with decoder_segment_ids.
    Nrq   rG   zESharding along sequence dimension not allowed in TPU kernel attention)r   )rM   )masks)r  head_shardsq_seq_shardsattn_logits_soft_cap)segment_ids)
rM   qr   	NumpyMask
CausalMaskMultiHeadMaskr   make_splash_mhar   r]  )rx  ry  rz  decoder_segment_idscustom_maskr  r  r  r  multi_head_masksplash_kernelr   r   r   wrap_flash_attentiond  s*   &r  c	           '         s  t | } t |}t |}t| jdks!t|jdks!t|jdkr2td| j d|j d|j dt d j}	|	dk}
|du rIt| |||}n|d	u rVt| |||d	d
 |
r\|r\z-ddlm	} ddlm
} | }|rt||r|j}d|jv r|jd}|j| }d}W n tttfy   d}d}Y nw tj| dd}tj|dd}tj|dd}|j\}}}}|dur||t|  }tj||gtjd}tj||d}d}|dur&|jtjkr|dn|}|jdkr|jd |kr|d }n|jdkr|jd |kr|d }|r&|dur&ttj||ftjd}t||}|du r;|r;ttj||ftjd}zt ||||||||d}tj|ddW S  t!y[   d}Y nw t"tj#drztj#j$| |||||||rvddW S ddW S  t!y   |rtj#j$| ||||||dd Y S  w |rt%d| j}|j\} } }!|du rdt|! n|}| j\}"}#}!  t&| |"|# |!f}  fdd}$|$|}|$|}tj't(d dd!}%|%| ||||||}&t&|&|S )"a  Computes dot-product attention given query, key, and value.

    This is the core computation of attention that is used in transformers.
    For TPU platforms, flash attention optimizations are automatically applied
    when possible, and sharding parameters are inferred from the layout map
    in the current distribution context.

    Args:
        query: Queries with shape `[batch, time, heads,
            depth_k]`.
        key: Keys with shape `[batch, time, heads,
            depth_k]`.
        value: Values with shape `[batch, time, heads,
            depth_v]`.
        bias: Optional bias with shape broadcastable to
            `[batch, heads, dest_time, source_time]`.
        mask: Optional mask with shape broadcastable to
            `[batch, heads, dest_time, source_time]`.
        scale: Float. Optional scale that is applied to the attention
            computation.
        is_causal: Boolean. Specifying whether causal masking is applied.
        flash_attention: Boolean. Whether to use flash attention optimization
            for increased performance. Default to None, which means it will
            be auto-determined based on the platform, input shapes and
            compatibility.
        attn_logits_soft_cap: Float. Optional float to softly cap attention
            logits to avoid numerical stability issues. Applied as:
            `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.

    Returns:
        JAX Array of shape `[batch, time, heads, depth_v]`.
    r'   zG`dot_product_attention` only supports 4D inputs. Received: query.shape=z, key.shape=z, value.shape=.r   rw  NT)r  )ModelParallel)distributionmodelrG   )r   rq   rG   r  r   r   )r  kvrJ  r  r   )r  r  r  r  r  Fru  cudnnxla)r{  r  r   r  implementationrv  r1   c              	      s|   | d ur<| j \}}}}|dkr+t| d d d d d d d d d f || ||f} | S |ks1J t| | ||f} | S )NrG   )rM   r   broadcast_torO   )ttBtNtTtSGKNr   r   _reshape_to_groupedf  s   4z2dot_product_attention.<locals>._reshape_to_grouped)r  NNrq   rq   NN)in_axesout_axes))r	   r   rM   rg   r   r  r  r  'keras.src.distribution.distribution_libr  r  r_   device_mesh
axis_namesindexr  AttributeErrorr   r   rz   sqrtrI  r   r   
SegmentIdsr~   bool_r   rN   r  r}   r  r  	Exceptionhasattrr   ru  r  rO   r]  r  )'rx  ry  rz  r{  r  r   r  flash_attentionr  r  is_tpur  get_distdistmeshmodel_dim_indexr  r  query_tpu_layoutkey_tpu_layoutvalue_tpu_layoutbs	num_headsq_lenhead_dimr  r  r  	mask_boolcausal_maskr[   output_shaper   HBr  r  
vmapped_fnencodedr   r  r   ru    s   +*







ru  )r   )r'   )r,   )r1   )T)r:   )r]   T)Nre   )Nre   N)r]   F)rG   re   NrG   )rG   re   NNrG   )r:   NF)Fr:   )F)FF)NNr   )r   )TN)r2  rG   N)rc  r2  rG   Tr   )NNrG   rG   )NNNFNN)Mr\  r  rz   r   jax.experimental.sparseexperimentalr   r   	jax.numpynumpyr   r   r   r   0jax.experimental.pallas.ops.tpu.splash_attentionr   r   	keras.srcr   &keras.src.backend.common.backend_utilsr   keras.src.backend.jax.corer   r	   r   r   r   r   r   r   r   r   r$   r%   r&   r*   r+   r.   r/   r0   r4   r6   r7   r9   r>   r@   rB   r#   rD   rE   r\   rd   rp   ru   r   r   r   r   r   r   r   r   r   r   r   r   r   r%  r1  rb  rf  ro  r  r  r  r  ru  r   r   r   r   <module>   s    
	











)

*

7
1


0




"

i
.
 '
#
9
G