Skip to content

Commit

Permalink
Update nx to v0.9.2 - use explicit type for tensors
Browse files Browse the repository at this point in the history
The integers tensors are now of the type s32 by default. So we need to
specify the desired type - Explorer uses s64 for integer series by
default.
  • Loading branch information
philss committed Jan 10, 2025
1 parent d4f5197 commit 657e8e6
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 36 deletions.
18 changes: 9 additions & 9 deletions lib/explorer/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1747,9 +1747,9 @@ defmodule Explorer.DataFrame do
...> ]))
#Explorer.DataFrame<
Polars[2 x 3]
x1 s64 [1, 4]
x2 s64 [2, 5]
x3 s64 [3, 6]
x1 s32 [1, 4]
x2 s32 [2, 5]
x3 s32 [3, 6]
>
Explorer expects tensors to have certain types, so you may need to cast
Expand All @@ -1764,14 +1764,14 @@ defmodule Explorer.DataFrame do
#Explorer.DataFrame<
Polars[2 x 2]
floats f64 [1.0, 2.0]
ints s64 [3, 4]
ints s32 [3, 4]
>
Use dtypes to force a particular representation:
iex> Explorer.DataFrame.new([
...> floats: Nx.tensor([1.0, 2.0], type: :f64),
...> times: Nx.tensor([3_000, 4_000])
...> times: Nx.tensor([3_000, 4_000], type: :s64)
...> ], dtypes: [times: :time])
#Explorer.DataFrame<
Polars[2 x 2]
Expand Down Expand Up @@ -3158,7 +3158,7 @@ defmodule Explorer.DataFrame do
iex> Explorer.DataFrame.put(df, :a, Nx.tensor([1, 2, 3]))
#Explorer.DataFrame<
Polars[3 x 1]
a s64 [1, 2, 3]
a s32 [1, 2, 3]
>
You can specify which dtype the tensor represents.
Expand All @@ -3167,7 +3167,7 @@ defmodule Explorer.DataFrame do
in microseconds from the Unix epoch:
iex> df = Explorer.DataFrame.new([])
iex> Explorer.DataFrame.put(df, :a, Nx.tensor([1, 2, 3]), dtype: {:naive_datetime, :microsecond})
iex> Explorer.DataFrame.put(df, :a, Nx.tensor([1, 2, 3], type: :s64), dtype: {:naive_datetime, :microsecond})
#Explorer.DataFrame<
Polars[3 x 1]
a naive_datetime[μs] [1970-01-01 00:00:00.000001, 1970-01-01 00:00:00.000002, 1970-01-01 00:00:00.000003]
Expand All @@ -3179,7 +3179,7 @@ defmodule Explorer.DataFrame do
straight-forward:
iex> df = Explorer.DataFrame.new(a: [~N[1970-01-01 00:00:00]])
iex> Explorer.DataFrame.put(df, :a, Nx.tensor(529550625987654))
iex> Explorer.DataFrame.put(df, :a, Nx.tensor(529550625987654, type: :s64))
#Explorer.DataFrame<
Polars[1 x 1]
a naive_datetime[μs] [1986-10-13 01:23:45.987654]
Expand All @@ -3189,7 +3189,7 @@ defmodule Explorer.DataFrame do
iex> cat = Explorer.Series.from_list(["foo", "bar", "baz"], dtype: :category)
iex> df = Explorer.DataFrame.new(a: cat)
iex> Explorer.DataFrame.put(df, :a, Nx.tensor([2, 1, 0]))
iex> Explorer.DataFrame.put(df, :a, Nx.tensor([2, 1, 0], type: :s64))
#Explorer.DataFrame<
Polars[3 x 1]
a category ["baz", "bar", "foo"]
Expand Down
14 changes: 7 additions & 7 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ defmodule Explorer.Series do
iex> Explorer.Series.from_tensor(tensor)
#Explorer.Series<
Polars[3]
s64 [1, 2, 3]
s32 [1, 2, 3]
>
iex> tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f64)
Expand All @@ -672,11 +672,11 @@ defmodule Explorer.Series do
u8 [1, 0, 1]
>
iex> tensor = Nx.tensor([-719162, 0, 6129], type: :s32)
iex> tensor = Nx.tensor([-719162, 0, 6129], type: :s64)
iex> Explorer.Series.from_tensor(tensor)
#Explorer.Series<
Polars[3]
s32 [-719162, 0, 6129]
s64 [-719162, 0, 6129]
>
Booleans can be read from a tensor of `{:u, 8}` type if the dtype is explicitly given:
Expand All @@ -691,7 +691,7 @@ defmodule Explorer.Series do
Times are signed 64-bit representing nanoseconds from midnight and
therefore must have their dtype explicitly given:
iex> tensor = Nx.tensor([0, 86399999999000])
iex> tensor = Nx.tensor([0, 86399999999000], type: :s64)
iex> Explorer.Series.from_tensor(tensor, dtype: :time)
#Explorer.Series<
Polars[2]
Expand All @@ -700,7 +700,7 @@ defmodule Explorer.Series do
Datetimes are signed 64-bit and therefore must have their dtype explicitly given:
iex> tensor = Nx.tensor([0, 529550625987654])
iex> tensor = Nx.tensor([0, 529550625987654], type: :s64)
iex> Explorer.Series.from_tensor(tensor, dtype: {:naive_datetime, :microsecond})
#Explorer.Series<
Polars[2]
Expand Down Expand Up @@ -736,7 +736,7 @@ defmodule Explorer.Series do
## Tensor examples
iex> s = Explorer.Series.from_list([0, 1, 2])
iex> Explorer.Series.replace(s, Nx.tensor([1, 2, 3]))
iex> Explorer.Series.replace(s, Nx.tensor([1, 2, 3], type: :s64))
#Explorer.Series<
Polars[3]
s64 [1, 2, 3]
Expand All @@ -745,7 +745,7 @@ defmodule Explorer.Series do
This is particularly useful for categorical columns:
iex> s = Explorer.Series.from_list(["foo", "bar", "baz"], dtype: :category)
iex> Explorer.Series.replace(s, Nx.tensor([2, 1, 0]))
iex> Explorer.Series.replace(s, Nx.tensor([2, 1, 0], type: :s64))
#Explorer.Series<
Polars[3]
category ["baz", "bar", "foo"]
Expand Down
4 changes: 2 additions & 2 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"},
"castore": {:hex, :castore, "1.0.11", "4bbd584741601eb658007339ea730b082cc61f3554cf2e8f39bf693a11b49073", [:mix], [], "hexpm", "e03990b4db988df56262852f20de0f659871c35154691427a5047f4967a16a62"},
"cc_precompiler": {:hex, :cc_precompiler, "0.1.10", "47c9c08d8869cf09b41da36538f62bc1abd3e19e41701c2cea2675b53c704258", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f6e046254e53cd6b41c6bacd70ae728011aa82b2742a80d6e2214855c6e06b22"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"},
"cowboy": {:hex, :cowboy, "2.10.0", "ff9ffeff91dae4ae270dd975642997afe2a1179d94b1887863e43f681a203e26", [:make, :rebar3], [{:cowlib, "2.12.1", [hex: :cowlib, repo: "hexpm", optional: false]}, {:ranch, "1.8.0", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "3afdccb7183cc6f143cb14d3cf51fa00e53db9ec80cdcd525482f5e99bc41d6b"},
"cowboy_telemetry": {:hex, :cowboy_telemetry, "0.4.0", "f239f68b588efa7707abce16a84d0d2acf3a0f50571f8bb7f56a15865aae820c", [:rebar3], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "7d98bac1ee4565d31b62d59f8823dfd8356a169e7fcbb83831b8a5397404c9de"},
"cowlib": {:hex, :cowlib, "2.12.1", "a9fa9a625f1d2025fe6b462cb865881329b5caff8f1854d1cbc9f9533f00e1e1", [:make, :rebar3], [], "hexpm", "163b73f6367a7341b33c794c4e88e7dbfe6498ac42dcd69ef44c5bc5507c8db0"},
Expand All @@ -29,7 +29,7 @@
"nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"},
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
"nx": {:hex, :nx, "0.8.0", "81d801773cbcee654b8f6a41ccb7c2716d25073f2d64fec3d62950c9db98cf99", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e61943143f1719dcceb8c5004286b719f3ab230058eafeea39799a4da3f0d754"},
"nx": {:hex, :nx, "0.9.2", "17563029c01bf749aad3c31234326d7665abd0acc33ee2acbe531a4759f29a8a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "914d74741617d8103de8ab1f8c880353e555263e1c397b8a1109f79a3716557f"},
"plug": {:hex, :plug, "1.16.1", "40c74619c12f82736d2214557dedec2e9762029b2438d6d175c5074c933edc9d", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "a13ff6b9006b03d7e33874945b2755253841b238c34071ed85b0e86057f8cddc"},
"plug_cowboy": {:hex, :plug_cowboy, "2.6.1", "9a3bbfceeb65eff5f39dab529e5cd79137ac36e913c02067dba3963a26efe9b2", [:mix], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:cowboy_telemetry, "~> 0.3", [hex: :cowboy_telemetry, repo: "hexpm", optional: false]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}], "hexpm", "de36e1a21f451a18b790f37765db198075c25875c64834bcc82d90b309eb6613"},
"plug_crypto": {:hex, :plug_crypto, "2.1.0", "f44309c2b06d249c27c8d3f65cfe08158ade08418cf540fd4f72d4d6863abb7b", [:mix], [], "hexpm", "131216a4b030b8f8ce0f26038bc4421ae60e4bb95c5cf5395e1421437824c4fa"},
Expand Down
16 changes: 10 additions & 6 deletions test/explorer/data_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3868,10 +3868,14 @@ defmodule Explorer.DataFrameTest do

df =
DF.new(
a: [1, 2, 3],
b: [4.0, 5.0, 6.0],
c: ["a", "b", "c"],
d: [~D[1970-01-01], ~D[1980-01-01], ~D[1990-01-01]]
%{
a: [1, 2, 3],
b: [4.0, 5.0, 6.0],
c: ["a", "b", "c"],
d: [~D[1970-01-01], ~D[1980-01-01], ~D[1990-01-01]]
},
# The dtype of `a` must match the default dtype of integer tensors, which is s32.
dtypes: [a: :s32]
)

assert DF.put(df, :a, i)[:a] |> Series.to_list() == [1, 1, 1]
Expand All @@ -3883,7 +3887,7 @@ defmodule Explorer.DataFrameTest do
~D[1970-01-02]
]

assert DF.put(df, :c, i, dtype: :integer)[:c] |> Series.to_list() == [1, 1, 1]
assert DF.put(df, :c, i, dtype: :s32)[:c] |> Series.to_list() == [1, 1, 1]
assert DF.put(df, :c, f, dtype: {:f, 64})[:c] |> Series.to_list() == [1.0, 1.0, 1.0]

assert DF.put(df, :c, d, dtype: :date)[:c] |> Series.to_list() == [
Expand All @@ -3892,7 +3896,7 @@ defmodule Explorer.DataFrameTest do
~D[1970-01-02]
]

assert DF.put(df, :d, i, dtype: :integer)[:d] |> Series.to_list() == [1, 1, 1]
assert DF.put(df, :d, i, dtype: :s32)[:d] |> Series.to_list() == [1, 1, 1]
assert DF.put(df, :d, f, dtype: {:f, 64})[:d] |> Series.to_list() == [1.0, 1.0, 1.0]

assert DF.put(df, :d, d)[:d] |> Series.to_list() == [
Expand Down
24 changes: 12 additions & 12 deletions test/explorer/tensor_frame_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ defmodule Explorer.TensorFrameTest do

test "handles deftransform functions" do
tf = put_column(tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"]))
assert tf[:a] == Nx.tensor([1, 2, 3])
assert tf["a"] == Nx.tensor([1, 2, 3])
assert tf[:a] == Nx.tensor([1, 2, 3], type: :s64)
assert tf["a"] == Nx.tensor([1, 2, 3], type: :s64)
assert tf[:b] == Nx.tensor([4.0, 5.0, 6.0], type: :f64)
assert tf["b"] == Nx.tensor([4.0, 5.0, 6.0], type: :f64)
assert tf[:d] == Nx.tensor([5.0, 7.0, 9.0], type: :f64)
Expand All @@ -73,13 +73,13 @@ defmodule Explorer.TensorFrameTest do
i = 1
f = Nx.tensor([1.0], type: :f64)
tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"])
assert TF.put(tf, :a, i)[:a] == Nx.tensor([1, 1, 1])
assert TF.put(tf, :a, i)[:a] == Nx.tensor([1, 1, 1], type: :s32)
assert TF.put(tf, :a, f)[:a] == Nx.tensor([1.0, 1.0, 1.0], type: :f64)

assert TF.put(tf, :c, i)[:c] == Nx.tensor([1, 1, 1])
assert TF.put(tf, :c, i)[:c] == Nx.tensor([1, 1, 1], type: :s32)
assert TF.put(tf, :c, f)[:c] == Nx.tensor([1.0, 1.0, 1.0], type: :f64)

assert TF.put(tf, :d, i)[:d] == Nx.tensor([1, 1, 1])
assert TF.put(tf, :d, i)[:d] == Nx.tensor([1, 1, 1], type: :s32)
assert TF.put(tf, :d, f)[:d] == Nx.tensor([1.0, 1.0, 1.0], type: :f64)
end
end
Expand All @@ -105,8 +105,8 @@ defmodule Explorer.TensorFrameTest do
describe "access" do
test "get" do
tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"])
assert tf[:a] == Nx.tensor([1, 2, 3])
assert tf["a"] == Nx.tensor([1, 2, 3])
assert tf[:a] == Nx.tensor([1, 2, 3], type: :s64)
assert tf["a"] == Nx.tensor([1, 2, 3], type: :s64)
assert tf[:b] == Nx.tensor([4.0, 5.0, 6.0], type: :f64)
assert tf["b"] == Nx.tensor([4.0, 5.0, 6.0], type: :f64)

Expand All @@ -122,12 +122,12 @@ defmodule Explorer.TensorFrameTest do
test "get_and_update" do
tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"])
{get, update} = Access.get_and_update(tf, :a, fn a -> {a, Nx.multiply(a, 2)} end)
assert get == Nx.tensor([1, 2, 3])
assert update[:a] == Nx.tensor([2, 4, 6])
assert get == Nx.tensor([1, 2, 3], type: :s64)
assert update[:a] == Nx.tensor([2, 4, 6], type: :s64)

{get, update} = Access.get_and_update(tf, :a, fn a -> {a, 123} end)
assert get == Nx.tensor([1, 2, 3])
assert update[:a] == Nx.tensor([123, 123, 123])
assert get == Nx.tensor([1, 2, 3], type: :s64)
assert update[:a] == Nx.tensor([123, 123, 123], type: :s32)

assert_raise ArgumentError,
~r"cannot add tensor that does not match the frame size. Expected a tensor of shape \{3\}",
Expand All @@ -139,7 +139,7 @@ defmodule Explorer.TensorFrameTest do
test "pop" do
tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"])
{tensor, popped} = Access.pop(tf, :a)
assert tensor == Nx.tensor([1, 2, 3])
assert tensor == Nx.tensor([1, 2, 3], type: :s64)
assert popped.data[:a] == nil
end
end
Expand Down

0 comments on commit 657e8e6

Please sign in to comment.