Skip to content

Conversation

@kshyatt
Copy link
Member

@kshyatt kshyatt commented Feb 11, 2026

No description provided.

@kshyatt kshyatt requested a review from lkdvos February 11, 2026 16:27
::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
) where {TorA, S <: IndexSpace}
return Base.$fname(TorA, codomain ← domain)
::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change slightly confuses me, I thought that passing along the array type is what is needed to just make this work for CuArray, why is this no longer the case? Can we also set up test cases to capture this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem was rather the reverse, that the deleted method was breaking things for ones(Float64, ...). There's still a CUDA specific specialization in the extension

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is true? There are specializations for CUDA.ones, but this is a different function !== Base.ones.

I would expect something like CUDA.ones(Float64, spaces) == Base.ones(CuVector{Float64}, spaces), but I think this latter method now is no longer defined?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter function (Base.ones accepting a CuArray) doesn't exist, no (neither does Base.zeros). We can add support for it in the extension, I guess? It's annoying having to work around the CUDA and Base APIs and their refusal to play nicely together...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, it doesn't exist anymore because you delete it in this PR right? This is exactly why I am confused about the change, since I thought that would work previously, simply by how our generic implementation uses tensormaptype(..., TorA) which would end up with a CuTensorMap

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's now back in the extension 😇

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway sorry for my bad explanation, I was running into problems earlier with

Base.$fname(TorA, codomain ← domain) when TorA == Vector{Float64} (for example) because this would try to call Base.ones(Vector{Float64}, n) where n is some integer deep down. Maybe now tensormaptype fixes this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I think I somehow just didn't expect it should end up calling that, since the endpoint should be this:

function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA}
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
fill!(t, $felt(scalartype(t)))
return t
end

Of course, I am just looking at code, which is notoriously a great way of debugging/reasoning about it 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if I can't add a bunch of tests to trigger these paths and hopefully that will illuminate things. Worst case I make a 3rd PR later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants