Conversation
| # Assign to a variable to prevent garbage collection before sync. | ||
| logits = model_tpu(input_ids).logits | ||
|
|
||
| torch_xla.sync() # Wait for the computation to complete. |
There was a problem hiding this comment.
This doesn't actually wait for computation to complete. It just launches the kernel on TPU and proceeds. I think that the right api is wait_device_ops
There was a problem hiding this comment.
I see, thank your for the info. This mean even for eager mode, we also need to call wait_device_ops each time we want to measure time, correct? (so we can wait until computation completes)
There was a problem hiding this comment.
I noticed that if I use wait_device_ops for preheat timing, it became 35ms where using torch_xla.sync() gives me 3000ms. It feels like using wait_device_ops is not including the compilation time for initial run. Since we want to compare compilation time as well for first run, I will use torch_xla.sync() for preheat time.
No description provided.