Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,31 @@ Download or fetch datasets locally:
{images_binary, tensor_type, shape} = train_images
```

You can also pass transform functions to `download/1`:
Most often you will convert those results to `Nx` tensors:

```elixir
transform_images = fn {binary, type, shape} ->
binary
|> Nx.from_binary(type)
|> Nx.reshape(shape)
{train_images, train_labels} =
Scidata.MNIST.download(transform_images: transform_images)

# Normalize and batch images
{images_binary, images_type, images_shape} = train_images

batched_images =
images_binary
|> Nx.from_binary(images_type)
|> Nx.reshape(images_shape)
|> Nx.divide(255)
|> Nx.to_batched_list(32)
end

{train_images, train_labels} =
Scidata.MNIST.download(transform_images: transform_images)
# One-hot-encode and batch labels
{labels_binary, labels_type, _shape} = train_labels

# Transform labels as well, e.g. get one-hot encoding
transform_labels = fn {labels_binary, type, _} ->
batchd_labels =
labels_binary
|> Nx.from_binary(type)
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
end

{images, labels} =
Scidata.MNIST.download(
transform_images: transform_images,
transform_labels: transform_labels
)
```

## Installation
Expand Down
47 changes: 15 additions & 32 deletions lib/scidata/cifar10.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,17 @@ defmodule Scidata.CIFAR10 do
@doc """
Downloads the CIFAR10 training dataset or fetches it locally.

## Options
Returns a tuple of format:

* `:transform_images` - A function that transforms images, defaults to
`& &1`.
{{images_binary, images_type, images_shape},
{labels_binary, labels_type, labels_shape}}

It accepts a tuple like `{binary_data, tensor_type, data_shape}` which
can be used for converting the `binary_data` to a tensor with a function
like:
If you want to one-hot encode the labels, you can:

fn {labels_binary, type, _shape} ->
labels_binary
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
end

* `:transform_labels` - similar to `:transform_images` but applied to
dataset labels
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

## Examples

Expand All @@ -48,17 +40,17 @@ defmodule Scidata.CIFAR10 do
{:u, 8}, {50000}}}

"""
def download(opts \\ []) do
download_dataset(:train, opts)
def download() do
download_dataset(:train)
end

@doc """
Downloads the CIFAR10 test dataset or fetches it locally.

Accepts the same options as `download/1`.
"""
def download_test(opts \\ []) do
download_dataset(:test, opts)
def download_test() do
download_dataset(:test)
end

defp parse_images(content) do
Expand All @@ -70,10 +62,7 @@ defmodule Scidata.CIFAR10 do
end
end

defp download_dataset(dataset_type, opts) do
transform_images = opts[:transform_images] || (& &1)
transform_labels = opts[:transform_labels] || (& &1)

defp download_dataset(dataset_type) do
files = Utils.get!(@base_url <> @dataset_file).body

{imgs, labels} =
Expand All @@ -93,13 +82,7 @@ defmodule Scidata.CIFAR10 do
{image_acc <> image, label_acc <> label}
end)

{transform_images.(
{imgs, {:u, 8},
if(dataset_type == :test, do: @test_images_shape, else: @train_images_shape)}
),
transform_labels.(
{labels, {:u, 8},
if(dataset_type == :test, do: @test_labels_shape, else: @train_labels_shape)}
)}
{{imgs, {:u, 8}, if(dataset_type == :test, do: @test_images_shape, else: @train_images_shape)},
{labels, {:u, 8}, if(dataset_type == :test, do: @test_labels_shape, else: @train_labels_shape)}}
end
end
47 changes: 15 additions & 32 deletions lib/scidata/cifar100.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,17 @@ defmodule Scidata.CIFAR100 do
@doc """
Downloads the CIFAR100 training dataset or fetches it locally.

## Options
Returns a tuple of format:

* `:transform_images` - A function that transforms images, defaults to
`& &1`.
{{images_binary, images_type, images_shape},
{labels_binary, labels_type, labels_shape}}

It accepts a tuple like `{binary_data, tensor_type, data_shape}` which
can be used for converting the `binary_data` to a tensor with a function
like:
If you want to one-hot encode the labels, you can:

fn {labels_binary, type, _shape} ->
labels_binary
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
end

* `:transform_labels` - similar to `:transform_images` but applied to
dataset labels
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

## Examples

Expand All @@ -48,17 +40,17 @@ defmodule Scidata.CIFAR100 do
{:u, 8}, {50000, 2}}}

"""
def download(opts \\ []) do
download_dataset(:train, opts)
def download() do
download_dataset(:train)
end

@doc """
Downloads the CIFAR100 test dataset or fetches it locally.

Accepts the same options as `download/1`.
"""
def download_test(opts \\ []) do
download_dataset(:test, opts)
def download_test() do
download_dataset(:test)
end

defp parse_images(content) do
Expand All @@ -70,10 +62,7 @@ defmodule Scidata.CIFAR100 do
end
end

defp download_dataset(dataset_type, opts) do
transform_images = opts[:transform_images] || (& &1)
transform_labels = opts[:transform_labels] || (& &1)

defp download_dataset(dataset_type) do
files = Utils.get!(@base_url <> @dataset_file).body

{imgs, labels} =
Expand All @@ -93,13 +82,7 @@ defmodule Scidata.CIFAR100 do
{image_acc <> image, label_acc <> label}
end)

{transform_images.(
{imgs, {:u, 8},
if(dataset_type == :test, do: @test_images_shape, else: @train_images_shape)}
),
transform_labels.(
{labels, {:u, 8},
if(dataset_type == :test, do: @test_labels_shape, else: @train_labels_shape)}
)}
{{imgs, {:u, 8}, if(dataset_type == :test, do: @test_images_shape, else: @train_images_shape)},
{labels, {:u, 8}, if(dataset_type == :test, do: @test_labels_shape, else: @train_labels_shape)}}
end
end
50 changes: 16 additions & 34 deletions lib/scidata/fashionmnist.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,17 @@ defmodule Scidata.FashionMNIST do
@doc """
Downloads the FashionMNIST training dataset or fetches it locally.

## Options
Returns a tuple of format:

* `:transform_images` - A function that transforms images, defaults to
`& &1`.
{{images_binary, images_type, images_shape},
{labels_binary, labels_type, labels_shape}}

It accepts a tuple like `{binary_data, tensor_type, data_shape}` which
can be used for converting the `binary_data` to a tensor with a function
like:
If you want to one-hot encode the labels, you can:

fn {labels_binary, type, _shape} ->
labels_binary
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
end

* `:transform_labels` - similar to `:transform_images` but applied to
dataset labels
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

## Examples

Expand All @@ -48,38 +40,28 @@ defmodule Scidata.FashionMNIST do
{3739854681}}}

"""
def download(opts \\ []) do
transform_images = opts[:transform_images] || (& &1)
transform_labels = opts[:transform_labels] || (& &1)

{download_images(@train_image_file, transform_images),
download_labels(@train_label_file, transform_labels)}
def download() do
{download_images(@train_image_file), download_labels(@train_label_file)}
end

@doc """
Downloads the FashionMNIST test dataset or fetches it locally.

Accepts the same options as `download/1`.
"""
def download_test(opts \\ []) do
transform_images = opts[:transform_images] || (& &1)
transform_labels = opts[:transform_labels] || (& &1)

{download_images(@test_image_file, transform_images),
download_labels(@test_label_file, transform_labels)}
def download_test() do
{download_images(@test_image_file), download_labels(@test_label_file)}
end

defp download_images(image_file, transform) do
defp download_images(image_file) do
data = Utils.get!(@base_url <> image_file).body
<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = data

transform.({images, {:u, 8}, {n_images, 1, n_rows, n_cols}})
{images, {:u, 8}, {n_images, 1, n_rows, n_cols}}
end

defp download_labels(label_file, transform) do
defp download_labels(label_file) do
data = Utils.get!(@base_url <> label_file).body
<<_::32, n_labels::32, labels::binary>> = data

transform.({labels, {:u, 8}, {n_labels}})
{labels, {:u, 8}, {n_labels}}
end
end
50 changes: 16 additions & 34 deletions lib/scidata/mnist.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,41 @@ defmodule Scidata.MNIST do
@doc """
Downloads the MNIST training dataset or fetches it locally.

## Options
Returns a tuple of format:

* `:transform_images` - A function that transforms images, defaults to
`& &1`.
{{images_binary, images_type, images_shape},
{labels_binary, labels_type, labels_shape}}

It accepts a tuple like `{binary_data, tensor_type, data_shape}` which
can be used for converting the `binary_data` to a tensor with a function
like:
If you want to one-hot encode the labels, you can:

fn {labels_binary, type, _shape} ->
labels_binary
|> Nx.from_binary(type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
end

* `:transform_labels` - similar to `:transform_images` but applied to
dataset labels
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))

"""
def download(opts \\ []) do
transform_images = opts[:transform_images] || (& &1)
transform_labels = opts[:transform_labels] || (& &1)

{download_images(@train_image_file, transform_images),
download_labels(@train_label_file, transform_labels)}
def download() do
{download_images(@train_image_file), download_labels(@train_label_file)}
end

@doc """
Downloads the MNIST test dataset or fetches it locally.

Accepts the same options as `download/1`.
"""
def download_test(opts \\ []) do
transform_images = opts[:transform_images] || (& &1)
transform_labels = opts[:transform_labels] || (& &1)

{download_images(@test_image_file, transform_images),
download_labels(@test_label_file, transform_labels)}
def download_test() do
{download_images(@test_image_file), download_labels(@test_label_file)}
end

defp download_images(image_file, transform) do
defp download_images(image_file) do
data = Utils.get!(@base_url <> image_file).body
<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = data

transform.({images, {:u, 8}, {n_images, 1, n_rows, n_cols}})
{images, {:u, 8}, {n_images, 1, n_rows, n_cols}}
end

defp download_labels(label_file, transform) do
defp download_labels(label_file) do
data = Utils.get!(@base_url <> label_file).body
<<_::32, n_labels::32, labels::binary>> = data

transform.({labels, {:u, 8}, {n_labels}})
{labels, {:u, 8}, {n_labels}}
end
end