o
    2h5j                     @   s   d dl Z d dlZd dlmZ d dlmZ d dlmZ d dlm	Z	 d dlm
Z
 d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZ edG dd deZedG dd deZdS )    N)backend)tree)keras_export)is_float_dtype)standardize_dtype)Layer)serialization_lib)	jax_utils)tracking)jaxzkeras.layers.JaxLayerc                       s~   e Zd ZdZ				d fdd	Zdd Zejdd Zd	d
 Z	dd Z
dd ZdddZ fddZe fddZ  ZS )JaxLayera   Keras Layer that wraps a JAX model.

    This layer enables the use of JAX components within Keras when using JAX as
    the backend for Keras.

    ## Model function

    This layer accepts JAX models in the form of a function, `call_fn`, which
    must take the following arguments with these exact names:

    - `params`: trainable parameters of the model.
    - `state` (*optional*): non-trainable state of the model. Can be omitted if
        the model has no non-trainable state.
    - `rng` (*optional*): a `jax.random.PRNGKey` instance. Can be omitted if the
        model does not need RNGs, neither during training nor during inference.
    - `inputs`: inputs to the model, a JAX array or a `PyTree` of arrays.
    - `training` (*optional*): an argument specifying if we're in training mode
        or inference mode, `True` is passed in training mode. Can be omitted if
        the model behaves the same in training mode and inference mode.

    The `inputs` argument is mandatory. Inputs to the model must be provided via
    a single argument. If the JAX model takes multiple inputs as separate
    arguments, they must be combined into a single structure, for instance in a
    `tuple` or a `dict`.

    ## Model weights initialization

    The initialization of the `params` and `state` of the model can be handled
    by this layer, in which case the `init_fn` argument must be provided. This
    allows the model to be initialized dynamically with the right shape.
    Alternatively, and if the shape is known, the `params` argument and
    optionally the `state` argument can be used to create an already initialized
    model.

    The `init_fn` function, if provided, must take the following arguments with
    these exact names:

    - `rng`: a `jax.random.PRNGKey` instance.
    - `inputs`: a JAX array or a `PyTree` of arrays with placeholder values to
        provide the shape of the inputs.
    - `training` (*optional*): an argument specifying if we're in training mode
        or inference mode. `True` is always passed to `init_fn`. Can be omitted
        regardless of whether `call_fn` has a `training` argument.

    ## Models with non-trainable state

    For JAX models that have non-trainable state:

    - `call_fn` must have a `state` argument
    - `call_fn` must return a `tuple` containing the outputs of the model and
        the new non-trainable state of the model
    - `init_fn` must return a `tuple` containing the initial trainable params of
        the model and the initial non-trainable state of the model.

    This code shows a possible combination of `call_fn` and `init_fn` signatures
    for a model with non-trainable state. In this example, the model has a
    `training` argument and an `rng` argument in `call_fn`.

    ```python
    def stateful_call(params, state, rng, inputs, training):
        outputs = ...
        new_state = ...
        return outputs, new_state

    def stateful_init(rng, inputs):
        initial_params = ...
        initial_state = ...
        return initial_params, initial_state
    ```

    ## Models without non-trainable state

    For JAX models with no non-trainable state:

    - `call_fn` must not have a `state` argument
    - `call_fn` must return only the outputs of the model
    - `init_fn` must return only the initial trainable params of the model.

    This code shows a possible combination of `call_fn` and `init_fn` signatures
    for a model without non-trainable state. In this example, the model does not
    have a `training` argument and does not have an `rng` argument in `call_fn`.

    ```python
    def stateless_call(params, inputs):
        outputs = ...
        return outputs

    def stateless_init(rng, inputs):
        initial_params = ...
        return initial_params
    ```

    ## Conforming to the required signature

    If a model has a different signature than the one required by `JaxLayer`,
    one can easily write a wrapper method to adapt the arguments. This example
    shows a model that has multiple inputs as separate arguments, expects
    multiple RNGs in a `dict`, and has a `deterministic` argument with the
    opposite meaning of `training`. To conform, the inputs are combined in a
    single structure using a `tuple`, the RNG is split and used the populate the
    expected `dict`, and the Boolean flag is negated:

    ```python
    def my_model_fn(params, rngs, input1, input2, deterministic):
        ...
        if not deterministic:
            dropout_rng = rngs["dropout"]
            keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape)
            x = jax.numpy.where(keep, x / dropout_rate, 0)
            ...
        ...
        return outputs

    def my_model_wrapper_fn(params, rng, inputs, training):
        input1, input2 = inputs
        rng1, rng2 = jax.random.split(rng)
        rngs = {"dropout": rng1, "preprocessing": rng2}
        deterministic = not training
        return my_model_fn(params, rngs, input1, input2, deterministic)

    keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params)
    ```

    ## Usage with Haiku modules

    `JaxLayer` enables the use of [Haiku](https://dm-haiku.readthedocs.io)
    components in the form of
    [`haiku.Module`](https://dm-haiku.readthedocs.io/en/latest/api.html#module).
    This is achieved by transforming the module per the Haiku pattern and then
    passing `module.apply` in the `call_fn` parameter and `module.init` in the
    `init_fn` parameter if needed.

    If the model has non-trainable state, it should be transformed with
    [`haiku.transform_with_state`](
      https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform_with_state).
    If the model has no non-trainable state, it should be transformed with
    [`haiku.transform`](
      https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform).
    Additionally, and optionally, if the module does not use RNGs in "apply", it
    can be transformed with
    [`haiku.without_apply_rng`](
      https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng).

    The following example shows how to create a `JaxLayer` from a Haiku module
    that uses random number generators via `hk.next_rng_key()` and takes a
    training positional argument:

    ```python
    class MyHaikuModule(hk.Module):
        def __call__(self, x, training):
            x = hk.Conv2D(32, (3, 3))(x)
            x = jax.nn.relu(x)
            x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x)
            x = hk.Flatten()(x)
            x = hk.Linear(200)(x)
            if training:
                x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x)
            x = jax.nn.relu(x)
            x = hk.Linear(10)(x)
            x = jax.nn.softmax(x)
            return x

    def my_haiku_module_fn(inputs, training):
        module = MyHaikuModule()
        return module(inputs, training)

    transformed_module = hk.transform(my_haiku_module_fn)

    keras_layer = JaxLayer(
        call_fn=transformed_module.apply,
        init_fn=transformed_module.init,
    )
    ```

    Args:
        call_fn: The function to call the model. See description above for the
            list of arguments it takes and the outputs it returns.
        init_fn: the function to call to initialize the model. See description
            above for the list of arguments it takes and the outputs it returns.
            If `None`, then `params` and/or `state` must be provided.
      params: A `PyTree` containing all the model trainable parameters. This
            allows passing trained parameters or controlling the initialization.
            If both `params` and `state` are `None`, `init_fn` is called at
            build time to initialize the trainable parameters of the model.
      state: A `PyTree` containing all the model non-trainable state. This
            allows passing learned state or controlling the initialization. If
            both `params` and `state` are `None`, and `call_fn` takes a `state`
            argument, then `init_fn` is called at build time to initialize the
            non-trainable state of the model.
      seed: Seed for random number generator. Optional.
      dtype: The dtype of the layer's computations and weights. Can also be a
            `keras.DTypePolicy`. Optional. Defaults to the default policy.
    Nc                    s   t   dkrtdt    |d u r|d u r|d u rtdt jdi | || _|| _t j|| _| j	|dd| _
| j	|dd| _| jd usO| jd urS|   | |dh dd	h| _d
| jv | _|ru| |dh dd	h| _d S d S )Nr   zBJaxLayer is only supported with the JAX backend. Current backend: z5`init_fn`, `params` and `state` cannot all be `None`.T	trainableFcall_fn>   rngstateinputsparamstrainingr   r   init_fn>   r   r   r    )r   
ValueErrorsuper__init__r   r   randomSeedGeneratorseed_generator_create_variablestracked_paramstracked_stater   r   _build_at_init_validate_signaturecall_fn_arguments	has_stateinit_fn_arguments)selfr   r   r   r   seedkwargs	__class__r   T/var/www/html/chatgem/venv/lib/python3.10/site-packages/keras/src/utils/jax_layer.pyr      s<   	
zJaxLayer.__init__c           	   
   C   s   t |j}|D ]}||vrtd| d| dqg }| D ]}|j|vr9td| d|j dd| d||j q |S )NzMissing required argument in `z`: ``zUnsupported argument in `z`, supported arguments are `z`, `)inspect	signature
parametersr   valuesnamejoinappend)	r%   fnfn_nameallowedrequiredfn_parametersparameter_nameparameter_names	parameterr   r   r*   r!      s&   
zJaxLayer._validate_signaturec                    sB    fdd}t j||}r| _n| _t j|\}}|S )a  Create a structure of variables from a structure of JAX arrays.

        `values` is traversed via JAX's `tree_map`. When a leaf is a JAX array
        or a tensor-like object, a corresponding variable is created with it as
        the initial value. The resulting structure of variables is assigned to
        `self.params` or `self.state` depending on `trainable`. Then, a
        flattened version of the variables is returned for tracking.
        `self.params` or `self.state` are intentionally not tracked because
        structures like `TrackedList` interfere with `jax.tree_utils`.
        Note that leaf objects that are not JAX arrays and not tensor-like are
        left intact as they are assumed to be configuration used by the model.

        Args:
            values: the structure of values to traverse.
            trainable: whether to create trainable variables.

        Returns:
            flat list of variables initialized with `values` for tracking.
        c                    s   t | st| tjtjfr!| j}t|rd } j| j	| |dS t| t
ttfrAtt| }t|r5d } jdt | |dS | S )N)initializerdtyper   r   )r   	is_tensor
isinstancenpndarraygenericr<   r   
add_weightshapeboolintfloatr   typeconvert_to_tensor)valuer<   r%   r   r   r*   create_variable)  s.   z3JaxLayer._create_variables.<locals>.create_variable)r   	tree_utiltree_mapr   r   tree_flatten)r%   r/   r   rK   	variablesflat_variables_r   rJ   r*   r     s   zJaxLayer._create_variablesc                 C   s
   | j  S )a  
        Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.

        By default, this returns a single `PRNGKey` retrieved by calling
        `self.seed_generator.next()`. Override this to return a different
        structure.

        Returns:
            a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
            the `rng` argument of `init_fn`.
        r   nextr%   r   r   r*   _get_init_rngN  s   
zJaxLayer._get_init_rngc                 C   s   |r| j  S dS )a  
        Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.

        By default, this returns a single `PRNGKey` retrieved by calling
        `self.seed_generator.next()` when `training` is `True`, and `None` when
        `training` is `False`. Override this to return a different structure or
        to pass RNGs in inference mode too.

        Returns:
            a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
            the `rng` argument of `call_fn`.
        NrR   r%   r   r   r   r*   _get_call_rng\  s   
zJaxLayer._get_call_rngc           	      C   s   | j d us
| jd urd S t rtddd }t||}g }| jD ]!}|dkr1|| 	  q#|dkr;|| q#|dkrD|d q#| j
| }| jrR|\}}n|d }}| j|dd| _| j|d	d| _d S )
Nz+'JaxLayer' cannot be built in tracing scopec                 S   s   dd | D } t j| S )Nc                 S   s   g | ]
}|d ur
|ndqS )N   r   ).0dr   r   r*   
<listcomp>y  s    z8JaxLayer.build.<locals>.create_input.<locals>.<listcomp>)r   numpyones)rC   r   r   r*   create_inputx  s   z$JaxLayer.build.<locals>.create_inputr   r   r   Tr   F)r   r   r	   is_in_jax_tracing_scoper   r   map_shape_structurer$   r2   rU   r   r#   r   r   r   )	r%   input_shaper^   init_inputs	init_argsargument_nameinit_resultinit_params
init_stater   r   r*   buildn  s.   




zJaxLayer.buildFc           	      C   s   dd }g }| j D ]B}|dkr|tj|| j q	|dkr+|tj|| j q	|dkr8|| | q	|dkrB|| q	|dkrK|| q	dd	 }| jre| j	| \}}tj||| j |S | j	| S )
Nc                 S   s   | d u rd S | j S N)rI   )variabler   r   r*   unwrap_variable  s   z&JaxLayer.call.<locals>.unwrap_variabler   r   r   r   r   c                 S   s    t |ds	td||  d S )NassignzStructure mismatch: the structure of the state returned by `call` does not match the structure of the state at initialization time.)hasattrr   rl   )rI   rj   r   r   r*   assign_state_to_variable  s
   
z/JaxLayer.call.<locals>.assign_state_to_variable)
r"   r2   r   rL   rM   r   r   rW   r#   r   )	r%   r   r   rk   	call_argsrd   rn   predictions	new_stater   r   r*   call  s4   



zJaxLayer.callc                    s@   t | jt | jd}t  }tt| t|  S )N)r   r   )	r   serialize_keras_objectr   r   r   
get_configdictlistitems)r%   configbase_configr(   r   r*   rt     s
   


zJaxLayer.get_configc                    s8   t |d }t |d }||d< ||d< t |S )Nr   r   )r   deserialize_keras_objectr   from_config)clsrx   r   r   r(   r   r*   r{     s
   zJaxLayer.from_config)NNNN)F)__name__
__module____qualname____doc__r   r!   r
    no_automatic_dependency_trackingr   rU   rW   rh   rr   rt   classmethodr{   __classcell__r   r   r(   r*   r      s$     F*
:
#(r   zkeras.layers.FlaxLayerc                       s^   e Zd ZdZ		d fdd	Zdd Zdd Zd	d
 Zdd Z fddZ	e
dd Z  ZS )	FlaxLayerak  Keras Layer that wraps a [Flax](https://flax.readthedocs.io) module.

    This layer enables the use of Flax components in the form of
    [`flax.linen.Module`](
        https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)
    instances within Keras when using JAX as the backend for Keras.

    The module method to use for the forward pass can be specified via the
    `method` argument and is `__call__` by default. This method must take the
    following arguments with these exact names:

    - `self` if the method is bound to the module, which is the case for the
        default of `__call__`, and `module` otherwise to pass the module.
    - `inputs`: the inputs to the model, a JAX array or a `PyTree` of arrays.
    - `training` *(optional)*: an argument specifying if we're in training mode
        or inference mode, `True` is passed in training mode.

    `FlaxLayer` handles the non-trainable state of your model and required RNGs
    automatically. Note that the `mutable` parameter of
    [`flax.linen.Module.apply()`](
        https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply)
    is set to `DenyList(["params"])`, therefore making the assumption that all
    the variables outside of the "params" collection are non-trainable weights.

    This example shows how to create a `FlaxLayer` from a Flax `Module` with
    the default `__call__` method and no training argument:

    ```python
    class MyFlaxModule(flax.linen.Module):
        @flax.linen.compact
        def __call__(self, inputs):
            x = inputs
            x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
            x = flax.linen.relu(x)
            x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
            x = x.reshape((x.shape[0], -1))  # flatten
            x = flax.linen.Dense(features=200)(x)
            x = flax.linen.relu(x)
            x = flax.linen.Dense(features=10)(x)
            x = flax.linen.softmax(x)
            return x

    flax_module = MyFlaxModule()
    keras_layer = FlaxLayer(flax_module)
    ```

    This example shows how to wrap the module method to conform to the required
    signature. This allows having multiple input arguments and a training
    argument that has a different name and values. This additionally shows how
    to use a function that is not bound to the module.

    ```python
    class MyFlaxModule(flax.linen.Module):
        @flax.linen.compact
        def forward(self, input1, input2, deterministic):
            ...
            return outputs

    def my_flax_module_wrapper(module, inputs, training):
        input1, input2 = inputs
        return module.forward(input1, input2, not training)

    flax_module = MyFlaxModule()
    keras_layer = FlaxLayer(
        module=flax_module,
        method=my_flax_module_wrapper,
    )
    ```

    Args:
        module: An instance of `flax.linen.Module` or subclass.
        method: The method to call the model. This is generally a method in the
            `Module`. If not provided, the `__call__` method is used. `method`
            can also be a function not defined in the `Module`, in which case it
            must take the `Module` as the first argument. It is used for both
            `Module.init` and `Module.apply`. Details are documented in the
            `method` argument of [`flax.linen.Module.apply()`](
              https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply).
        variables: A `dict` containing all the variables of the module in the
            same format as what is returned by [`flax.linen.Module.init()`](
              https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.init).
            It should contain a "params" key and, if applicable, other keys for
            collections of variables for non-trainable state. This allows
            passing trained parameters and learned non-trainable state or
            controlling the initialization. If `None` is passed, the module's
            `init` function is called at build time to initialize the variables
            of the model.
    Nc                    s   ddl m} t dkrtdt  |_|_|dg  fdd} fdd	}fd
d}fdd}	dt|pB|j	j
v rL||}
}n||	}
}|\}}t jd|
|||d| d S )Nr   )scoper   zCFlaxLayer is only supported with the JAX backend. Current backend: r   c                    s"   j j| |||j |dS )N)rngsmethodmutabler   moduleapply_params_and_state_to_variablesr   )r   r   r   r   r   apply_mutabler%   r   r*   apply_with_training:  s   
z/FlaxLayer.__init__.<locals>.apply_with_trainingc                    s    j j| |||j dS )N)r   r   r   r   )r   r   r   r   r   r   r*   apply_without_trainingD  s   
z2FlaxLayer.__init__.<locals>.apply_without_trainingc                    s      jj| | j|dS )N)r   r   _variables_to_params_and_stater   initr   )r   r   r   rT   r   r*   init_with_trainingM  s   z.FlaxLayer.__init__.<locals>.init_with_trainingc                    s      jj| | jdS )N)r   r   )r   r   rT   r   r*   init_without_trainingW  s   z1FlaxLayer.__init__.<locals>.init_without_trainingr   )r   r   r   r   r   )	flax.corer   r   r   r   r   DenyListr,   r-   __call__r.   r   r   r   )r%   r   r   rO   r'   
flax_scoper   r   r   r   r   r   r   r   r(   r   r*   r   %  s8   
	



zFlaxLayer.__init__c                 C   s$   |r|r
i ||S |S |r|S i S ri   r   )r%   r   r   r   r   r*   r   r  s   z(FlaxLayer._params_and_state_to_variablesc                 C   sV   |d u rdS d|vri |fS t |dkr|i fS d|d i}dd | D }||fS )NNNr   rX   c                 S   s   i | ]\}}|d kr||qS )r   r   )rY   kvr   r   r*   
<dictcomp>  s    z<FlaxLayer._variables_to_params_and_state.<locals>.<dictcomp>)lenrw   )r%   rO   r   r   r   r   r*   r   |  s   z(FlaxLayer._variables_to_params_and_statec                 C   s   | j  | j  dS )N)r   dropoutrR   rT   r   r   r*   rU     s   zFlaxLayer._get_init_rngc                 C   s   |r	d| j  iS i S )Nr   rR   rV   r   r   r*   rW     s   zFlaxLayer._get_call_rngc                    sz   | j }t| j dr| j j| jkr| j j}t| jt|d}t  }|	d |	d t
t| t|  S )N__self__)r   r   r   r   )r   rm   r   r   r}   r   rs   r   rt   popru   rv   rw   )r%   config_methodrx   ry   r(   r   r*   rt     s   




zFlaxLayer.get_configc                 C   sR   t |d }t |d }t|d trt||}||d< ||d< | di |S )Nr   r   r   )r   rz   r>   strgetattr)r|   rx   r   r   r   r   r*   r{     s   
zFlaxLayer.from_configr   )r}   r~   r   r   r   r   r   rU   rW   rt   r   r{   r   r   r   r(   r*   r     s    \M
r   )r,   r\   r?   	keras.srcr   r   keras.src.api_exportr   "keras.src.backend.common.variablesr   r   keras.src.layers.layerr   keras.src.savingr   keras.src.utilsr	   r
   keras.src.utils.module_utilsr   r   r   r   r   r   r*   <module>   s&       ;