UnfilteredStackTrace Traceback (most recent call last) <ipython-input-2-0e20e3adf861> in <module>() 2 ----> 3 image = generate_image_from_text("alien life", seed=7) 4 display(image)
67 frames UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last) /content/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py in __call__(self, decoder_state, keys_state, values_state, attention_mask, state_index) 38 keys_state, 39 self.k_proj(decoder_state).reshape(shape_split), ---> 40 state_index 41 ) 42 values_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.
pip3 install flax==0.4.2
UnfilteredStackTrace Traceback (most recent call last) <ipython-input-2-0e20e3adf861> in <module>() 2 ----> 3 image = generate_image_from_text("alien life", seed=7) 4 display(image)
67 frames UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last) /content/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py in __call__(self, decoder_state, keys_state, values_state, attention_mask, state_index) 38 keys_state, 39 self.k_proj(decoder_state).reshape(shape_split), ---> 40 state_index 41 ) 42 values_state = lax.dynamic_update_slice(
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.