diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 07f1b8b8..6237dc61 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -568,6 +568,13 @@ defmodule Bumblebee do * `:params_filename` - the file with the model parameters to be loaded + * `:safetensors_reader` - a 1-arity function used to read `.safetensors` + parameter files. Receives the file path and must return a map from + tensor name to an `Nx.Tensor` or any term implementing + `Nx.LazyContainer`. Defaults to `&Safetensors.read!(&1, lazy: true)`. + Override to plug in a custom reader, for example one that + memory-maps the file for zero-copy loading + * `:log_params_diff` - whether to log missing, mismatched and unused parameters. By default diff is logged only if some parameters cannot be loaded @@ -617,6 +624,7 @@ defmodule Bumblebee do :architecture, :params_variant, :params_filename, + :safetensors_reader, :log_params_diff, :backend, :type @@ -659,7 +667,7 @@ defmodule Bumblebee do filename |> String.replace_suffix(".index.json", "") |> Path.extname() - |> params_file_loader_fun() + |> params_file_loader_fun(opts) with {:ok, paths} <- download_params_files(repository, repo_files, filename, sharded?) do opts = @@ -768,8 +776,11 @@ defmodule Bumblebee do end end - defp params_file_loader_fun(".safetensors"), do: &Safetensors.read!(&1, lazy: true) - defp params_file_loader_fun(_), do: &Bumblebee.Conversion.PyTorchLoader.load!/1 + defp params_file_loader_fun(".safetensors", opts) do + opts[:safetensors_reader] || (&Safetensors.read!(&1, lazy: true)) + end + + defp params_file_loader_fun(_, _opts), do: &Bumblebee.Conversion.PyTorchLoader.load!/1 @doc """ Featurizes `input` with the given featurizer. diff --git a/test/bumblebee_test.exs b/test/bumblebee_test.exs index 765e9d82..2df4a4b3 100644 --- a/test/bumblebee_test.exs +++ b/test/bumblebee_test.exs @@ -84,5 +84,30 @@ defmodule BumblebeeTest do assert Nx.type(params["decoder.blocks.0.ffn.output"]["kernel"]) == {:f, 16} assert Nx.type(params["decoder.blocks.0.ffn.output"]["bias"]) == {:f, 16} end + + test "uses :safetensors_reader to read .safetensors files" do + test_pid = self() + + reader = fn path -> + send(test_pid, {:reader_called, path}) + Safetensors.read!(path, lazy: true) + end + + assert {:ok, %{params: params}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-GPT2Model-safetensors-only"}, + safetensors_reader: reader + ) + + assert_received {:reader_called, path} + assert File.exists?(path) + + assert {:ok, %{params: default_params}} = + Bumblebee.load_model( + {:hf, "bumblebee-testing/tiny-random-GPT2Model-safetensors-only"} + ) + + assert Enum.sort(Map.keys(params.data)) == Enum.sort(Map.keys(default_params.data)) + end end end