Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

The updated version gives me this (after successful setup with the example alien thing):

UnfilteredStackTrace Traceback (most recent call last) <ipython-input-2-0e20e3adf861> in <module>() 2 ----> 3 image = generate_image_from_text("alien life", seed=7) 4 display(image)

67 frames UnfilteredStackTrace: TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError Traceback (most recent call last) /content/min-dalle/min_dalle/models/dalle_bart_decoder_flax.py in __call__(self, decoder_state, keys_state, values_state, attention_mask, state_index) 38 keys_state, 39 self.k_proj(decoder_state).reshape(shape_split), ---> 40 state_index 41 ) 42 values_state = lax.dynamic_update_slice(

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got float16, float32.



You need to install flax 0.4.2. If you're using collab you just open a terminal (icon in the bottom left of the screen) and run:

    pip3 install flax==0.4.2


Yes this is the fix for now. I need to address what is actually causing the dtype mismatch




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: