From 657e8e6d21c41ecf38a37680cda924e42ace5307 Mon Sep 17 00:00:00 2001 From: Philip Sampaio Date: Thu, 9 Jan 2025 22:14:57 -0300 Subject: [PATCH] Update nx to v0.9.2 - use explicit type for tensors The integers tensors are now of the type s32 by default. So we need to specify the desired type - Explorer uses s64 for integer series by default. --- lib/explorer/data_frame.ex | 18 +++++++++--------- lib/explorer/series.ex | 14 +++++++------- mix.lock | 4 ++-- test/explorer/data_frame_test.exs | 16 ++++++++++------ test/explorer/tensor_frame_test.exs | 24 ++++++++++++------------ 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/lib/explorer/data_frame.ex b/lib/explorer/data_frame.ex index 329f04998..fe1c99ab5 100644 --- a/lib/explorer/data_frame.ex +++ b/lib/explorer/data_frame.ex @@ -1747,9 +1747,9 @@ defmodule Explorer.DataFrame do ...> ])) #Explorer.DataFrame< Polars[2 x 3] - x1 s64 [1, 4] - x2 s64 [2, 5] - x3 s64 [3, 6] + x1 s32 [1, 4] + x2 s32 [2, 5] + x3 s32 [3, 6] > Explorer expects tensors to have certain types, so you may need to cast @@ -1764,14 +1764,14 @@ defmodule Explorer.DataFrame do #Explorer.DataFrame< Polars[2 x 2] floats f64 [1.0, 2.0] - ints s64 [3, 4] + ints s32 [3, 4] > Use dtypes to force a particular representation: iex> Explorer.DataFrame.new([ ...> floats: Nx.tensor([1.0, 2.0], type: :f64), - ...> times: Nx.tensor([3_000, 4_000]) + ...> times: Nx.tensor([3_000, 4_000], type: :s64) ...> ], dtypes: [times: :time]) #Explorer.DataFrame< Polars[2 x 2] @@ -3158,7 +3158,7 @@ defmodule Explorer.DataFrame do iex> Explorer.DataFrame.put(df, :a, Nx.tensor([1, 2, 3])) #Explorer.DataFrame< Polars[3 x 1] - a s64 [1, 2, 3] + a s32 [1, 2, 3] > You can specify which dtype the tensor represents. @@ -3167,7 +3167,7 @@ defmodule Explorer.DataFrame do in microseconds from the Unix epoch: iex> df = Explorer.DataFrame.new([]) - iex> Explorer.DataFrame.put(df, :a, Nx.tensor([1, 2, 3]), dtype: {:naive_datetime, :microsecond}) + iex> Explorer.DataFrame.put(df, :a, Nx.tensor([1, 2, 3], type: :s64), dtype: {:naive_datetime, :microsecond}) #Explorer.DataFrame< Polars[3 x 1] a naive_datetime[μs] [1970-01-01 00:00:00.000001, 1970-01-01 00:00:00.000002, 1970-01-01 00:00:00.000003] @@ -3179,7 +3179,7 @@ defmodule Explorer.DataFrame do straight-forward: iex> df = Explorer.DataFrame.new(a: [~N[1970-01-01 00:00:00]]) - iex> Explorer.DataFrame.put(df, :a, Nx.tensor(529550625987654)) + iex> Explorer.DataFrame.put(df, :a, Nx.tensor(529550625987654, type: :s64)) #Explorer.DataFrame< Polars[1 x 1] a naive_datetime[μs] [1986-10-13 01:23:45.987654] @@ -3189,7 +3189,7 @@ defmodule Explorer.DataFrame do iex> cat = Explorer.Series.from_list(["foo", "bar", "baz"], dtype: :category) iex> df = Explorer.DataFrame.new(a: cat) - iex> Explorer.DataFrame.put(df, :a, Nx.tensor([2, 1, 0])) + iex> Explorer.DataFrame.put(df, :a, Nx.tensor([2, 1, 0], type: :s64)) #Explorer.DataFrame< Polars[3 x 1] a category ["baz", "bar", "foo"] diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index c88bdc301..ed91fb2cf 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -655,7 +655,7 @@ defmodule Explorer.Series do iex> Explorer.Series.from_tensor(tensor) #Explorer.Series< Polars[3] - s64 [1, 2, 3] + s32 [1, 2, 3] > iex> tensor = Nx.tensor([1.0, 2.0, 3.0], type: :f64) @@ -672,11 +672,11 @@ defmodule Explorer.Series do u8 [1, 0, 1] > - iex> tensor = Nx.tensor([-719162, 0, 6129], type: :s32) + iex> tensor = Nx.tensor([-719162, 0, 6129], type: :s64) iex> Explorer.Series.from_tensor(tensor) #Explorer.Series< Polars[3] - s32 [-719162, 0, 6129] + s64 [-719162, 0, 6129] > Booleans can be read from a tensor of `{:u, 8}` type if the dtype is explicitly given: @@ -691,7 +691,7 @@ defmodule Explorer.Series do Times are signed 64-bit representing nanoseconds from midnight and therefore must have their dtype explicitly given: - iex> tensor = Nx.tensor([0, 86399999999000]) + iex> tensor = Nx.tensor([0, 86399999999000], type: :s64) iex> Explorer.Series.from_tensor(tensor, dtype: :time) #Explorer.Series< Polars[2] @@ -700,7 +700,7 @@ defmodule Explorer.Series do Datetimes are signed 64-bit and therefore must have their dtype explicitly given: - iex> tensor = Nx.tensor([0, 529550625987654]) + iex> tensor = Nx.tensor([0, 529550625987654], type: :s64) iex> Explorer.Series.from_tensor(tensor, dtype: {:naive_datetime, :microsecond}) #Explorer.Series< Polars[2] @@ -736,7 +736,7 @@ defmodule Explorer.Series do ## Tensor examples iex> s = Explorer.Series.from_list([0, 1, 2]) - iex> Explorer.Series.replace(s, Nx.tensor([1, 2, 3])) + iex> Explorer.Series.replace(s, Nx.tensor([1, 2, 3], type: :s64)) #Explorer.Series< Polars[3] s64 [1, 2, 3] @@ -745,7 +745,7 @@ defmodule Explorer.Series do This is particularly useful for categorical columns: iex> s = Explorer.Series.from_list(["foo", "bar", "baz"], dtype: :category) - iex> Explorer.Series.replace(s, Nx.tensor([2, 1, 0])) + iex> Explorer.Series.replace(s, Nx.tensor([2, 1, 0], type: :s64)) #Explorer.Series< Polars[3] category ["baz", "bar", "foo"] diff --git a/mix.lock b/mix.lock index 978465fc0..9a1de57e7 100644 --- a/mix.lock +++ b/mix.lock @@ -5,7 +5,7 @@ "bypass": {:hex, :bypass, "2.1.0", "909782781bf8e20ee86a9cabde36b259d44af8b9f38756173e8f5e2e1fabb9b1", [:mix], [{:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.0", [hex: :plug_cowboy, repo: "hexpm", optional: false]}, {:ranch, "~> 1.3", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "d9b5df8fa5b7a6efa08384e9bbecfe4ce61c77d28a4282f79e02f1ef78d96b80"}, "castore": {:hex, :castore, "1.0.11", "4bbd584741601eb658007339ea730b082cc61f3554cf2e8f39bf693a11b49073", [:mix], [], "hexpm", "e03990b4db988df56262852f20de0f659871c35154691427a5047f4967a16a62"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.10", "47c9c08d8869cf09b41da36538f62bc1abd3e19e41701c2cea2675b53c704258", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f6e046254e53cd6b41c6bacd70ae728011aa82b2742a80d6e2214855c6e06b22"}, - "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "cowboy": {:hex, :cowboy, "2.10.0", "ff9ffeff91dae4ae270dd975642997afe2a1179d94b1887863e43f681a203e26", [:make, :rebar3], [{:cowlib, "2.12.1", [hex: :cowlib, repo: "hexpm", optional: false]}, {:ranch, "1.8.0", [hex: :ranch, repo: "hexpm", optional: false]}], "hexpm", "3afdccb7183cc6f143cb14d3cf51fa00e53db9ec80cdcd525482f5e99bc41d6b"}, "cowboy_telemetry": {:hex, :cowboy_telemetry, "0.4.0", "f239f68b588efa7707abce16a84d0d2acf3a0f50571f8bb7f56a15865aae820c", [:rebar3], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "7d98bac1ee4565d31b62d59f8823dfd8356a169e7fcbb83831b8a5397404c9de"}, "cowlib": {:hex, :cowlib, "2.12.1", "a9fa9a625f1d2025fe6b462cb865881329b5caff8f1854d1cbc9f9533f00e1e1", [:make, :rebar3], [], "hexpm", "163b73f6367a7341b33c794c4e88e7dbfe6498ac42dcd69ef44c5bc5507c8db0"}, @@ -29,7 +29,7 @@ "nimble_options": {:hex, :nimble_options, "1.1.1", "e3a492d54d85fc3fd7c5baf411d9d2852922f66e69476317787a7b2bb000a61b", [:mix], [], "hexpm", "821b2470ca9442c4b6984882fe9bb0389371b8ddec4d45a9504f00a66f650b44"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, "nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"}, - "nx": {:hex, :nx, "0.8.0", "81d801773cbcee654b8f6a41ccb7c2716d25073f2d64fec3d62950c9db98cf99", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e61943143f1719dcceb8c5004286b719f3ab230058eafeea39799a4da3f0d754"}, + "nx": {:hex, :nx, "0.9.2", "17563029c01bf749aad3c31234326d7665abd0acc33ee2acbe531a4759f29a8a", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "914d74741617d8103de8ab1f8c880353e555263e1c397b8a1109f79a3716557f"}, "plug": {:hex, :plug, "1.16.1", "40c74619c12f82736d2214557dedec2e9762029b2438d6d175c5074c933edc9d", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "a13ff6b9006b03d7e33874945b2755253841b238c34071ed85b0e86057f8cddc"}, "plug_cowboy": {:hex, :plug_cowboy, "2.6.1", "9a3bbfceeb65eff5f39dab529e5cd79137ac36e913c02067dba3963a26efe9b2", [:mix], [{:cowboy, "~> 2.7", [hex: :cowboy, repo: "hexpm", optional: false]}, {:cowboy_telemetry, "~> 0.3", [hex: :cowboy_telemetry, repo: "hexpm", optional: false]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}], "hexpm", "de36e1a21f451a18b790f37765db198075c25875c64834bcc82d90b309eb6613"}, "plug_crypto": {:hex, :plug_crypto, "2.1.0", "f44309c2b06d249c27c8d3f65cfe08158ade08418cf540fd4f72d4d6863abb7b", [:mix], [], "hexpm", "131216a4b030b8f8ce0f26038bc4421ae60e4bb95c5cf5395e1421437824c4fa"}, diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index 368b6b227..d1966f5f0 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -3868,10 +3868,14 @@ defmodule Explorer.DataFrameTest do df = DF.new( - a: [1, 2, 3], - b: [4.0, 5.0, 6.0], - c: ["a", "b", "c"], - d: [~D[1970-01-01], ~D[1980-01-01], ~D[1990-01-01]] + %{ + a: [1, 2, 3], + b: [4.0, 5.0, 6.0], + c: ["a", "b", "c"], + d: [~D[1970-01-01], ~D[1980-01-01], ~D[1990-01-01]] + }, + # The dtype of `a` must match the default dtype of integer tensors, which is s32. + dtypes: [a: :s32] ) assert DF.put(df, :a, i)[:a] |> Series.to_list() == [1, 1, 1] @@ -3883,7 +3887,7 @@ defmodule Explorer.DataFrameTest do ~D[1970-01-02] ] - assert DF.put(df, :c, i, dtype: :integer)[:c] |> Series.to_list() == [1, 1, 1] + assert DF.put(df, :c, i, dtype: :s32)[:c] |> Series.to_list() == [1, 1, 1] assert DF.put(df, :c, f, dtype: {:f, 64})[:c] |> Series.to_list() == [1.0, 1.0, 1.0] assert DF.put(df, :c, d, dtype: :date)[:c] |> Series.to_list() == [ @@ -3892,7 +3896,7 @@ defmodule Explorer.DataFrameTest do ~D[1970-01-02] ] - assert DF.put(df, :d, i, dtype: :integer)[:d] |> Series.to_list() == [1, 1, 1] + assert DF.put(df, :d, i, dtype: :s32)[:d] |> Series.to_list() == [1, 1, 1] assert DF.put(df, :d, f, dtype: {:f, 64})[:d] |> Series.to_list() == [1.0, 1.0, 1.0] assert DF.put(df, :d, d)[:d] |> Series.to_list() == [ diff --git a/test/explorer/tensor_frame_test.exs b/test/explorer/tensor_frame_test.exs index 83610771f..366d545ca 100644 --- a/test/explorer/tensor_frame_test.exs +++ b/test/explorer/tensor_frame_test.exs @@ -58,8 +58,8 @@ defmodule Explorer.TensorFrameTest do test "handles deftransform functions" do tf = put_column(tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"])) - assert tf[:a] == Nx.tensor([1, 2, 3]) - assert tf["a"] == Nx.tensor([1, 2, 3]) + assert tf[:a] == Nx.tensor([1, 2, 3], type: :s64) + assert tf["a"] == Nx.tensor([1, 2, 3], type: :s64) assert tf[:b] == Nx.tensor([4.0, 5.0, 6.0], type: :f64) assert tf["b"] == Nx.tensor([4.0, 5.0, 6.0], type: :f64) assert tf[:d] == Nx.tensor([5.0, 7.0, 9.0], type: :f64) @@ -73,13 +73,13 @@ defmodule Explorer.TensorFrameTest do i = 1 f = Nx.tensor([1.0], type: :f64) tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"]) - assert TF.put(tf, :a, i)[:a] == Nx.tensor([1, 1, 1]) + assert TF.put(tf, :a, i)[:a] == Nx.tensor([1, 1, 1], type: :s32) assert TF.put(tf, :a, f)[:a] == Nx.tensor([1.0, 1.0, 1.0], type: :f64) - assert TF.put(tf, :c, i)[:c] == Nx.tensor([1, 1, 1]) + assert TF.put(tf, :c, i)[:c] == Nx.tensor([1, 1, 1], type: :s32) assert TF.put(tf, :c, f)[:c] == Nx.tensor([1.0, 1.0, 1.0], type: :f64) - assert TF.put(tf, :d, i)[:d] == Nx.tensor([1, 1, 1]) + assert TF.put(tf, :d, i)[:d] == Nx.tensor([1, 1, 1], type: :s32) assert TF.put(tf, :d, f)[:d] == Nx.tensor([1.0, 1.0, 1.0], type: :f64) end end @@ -105,8 +105,8 @@ defmodule Explorer.TensorFrameTest do describe "access" do test "get" do tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"]) - assert tf[:a] == Nx.tensor([1, 2, 3]) - assert tf["a"] == Nx.tensor([1, 2, 3]) + assert tf[:a] == Nx.tensor([1, 2, 3], type: :s64) + assert tf["a"] == Nx.tensor([1, 2, 3], type: :s64) assert tf[:b] == Nx.tensor([4.0, 5.0, 6.0], type: :f64) assert tf["b"] == Nx.tensor([4.0, 5.0, 6.0], type: :f64) @@ -122,12 +122,12 @@ defmodule Explorer.TensorFrameTest do test "get_and_update" do tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"]) {get, update} = Access.get_and_update(tf, :a, fn a -> {a, Nx.multiply(a, 2)} end) - assert get == Nx.tensor([1, 2, 3]) - assert update[:a] == Nx.tensor([2, 4, 6]) + assert get == Nx.tensor([1, 2, 3], type: :s64) + assert update[:a] == Nx.tensor([2, 4, 6], type: :s64) {get, update} = Access.get_and_update(tf, :a, fn a -> {a, 123} end) - assert get == Nx.tensor([1, 2, 3]) - assert update[:a] == Nx.tensor([123, 123, 123]) + assert get == Nx.tensor([1, 2, 3], type: :s64) + assert update[:a] == Nx.tensor([123, 123, 123], type: :s32) assert_raise ArgumentError, ~r"cannot add tensor that does not match the frame size. Expected a tensor of shape \{3\}", @@ -139,7 +139,7 @@ defmodule Explorer.TensorFrameTest do test "pop" do tf = tf(a: [1, 2, 3], b: [4.0, 5.0, 6.0], c: ["a", "b", "c"]) {tensor, popped} = Access.pop(tf, :a) - assert tensor == Nx.tensor([1, 2, 3]) + assert tensor == Nx.tensor([1, 2, 3], type: :s64) assert popped.data[:a] == nil end end