-
Notifications
You must be signed in to change notification settings - Fork 55
A few more small fixes for upstream + CUDA #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) | ||
| ) where {TorA, S <: IndexSpace} | ||
| return Base.$fname(TorA, codomain ← domain) | ||
| ::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change slightly confuses me, I thought that passing along the array type is what is needed to just make this work for CuArray, why is this no longer the case? Can we also set up test cases to capture this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the problem was rather the reverse, that the deleted method was breaking things for ones(Float64, ...). There's still a CUDA specific specialization in the extension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is true? There are specializations for CUDA.ones, but this is a different function !== Base.ones.
I would expect something like CUDA.ones(Float64, spaces) == Base.ones(CuVector{Float64}, spaces), but I think this latter method now is no longer defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter function (Base.ones accepting a CuArray) doesn't exist, no (neither does Base.zeros). We can add support for it in the extension, I guess? It's annoying having to work around the CUDA and Base APIs and their refusal to play nicely together...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, it doesn't exist anymore because you delete it in this PR right? This is exactly why I am confused about the change, since I thought that would work previously, simply by how our generic implementation uses tensormaptype(..., TorA) which would end up with a CuTensorMap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's now back in the extension 😇
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway sorry for my bad explanation, I was running into problems earlier with
Base.$fname(TorA, codomain ← domain) when TorA == Vector{Float64} (for example) because this would try to call Base.ones(Vector{Float64}, n) where n is some integer deep down. Maybe now tensormaptype fixes this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries, I think I somehow just didn't expect it should end up calling that, since the endpoint should be this:
TensorKit.jl/src/tensors/tensor.jl
Lines 314 to 318 in 71d6c00
| function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA} | |
| t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V) | |
| fill!(t, $felt(scalartype(t))) | |
| return t | |
| end |
Of course, I am just looking at code, which is notoriously a great way of debugging/reasoning about it 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me see if I can't add a bunch of tests to trigger these paths and hopefully that will illuminate things. Worst case I make a 3rd PR later.
No description provided.