67 lines
No EOL
1.9 KiB
Nix
67 lines
No EOL
1.9 KiB
Nix
{
|
|
description = "A Nix-flake-based PyTorch development environment";
|
|
|
|
# CUDA binaries are cached by the community.
|
|
nixConfig = {
|
|
extra-substituters = [
|
|
"https://nix-community.cachix.org"
|
|
];
|
|
extra-trusted-public-keys = [
|
|
"nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs="
|
|
];
|
|
};
|
|
|
|
inputs.nixpkgs.url = "https://flakehub.com/f/NixOS/nixpkgs/0.1.*.tar.gz";
|
|
|
|
outputs = {
|
|
self,
|
|
nixpkgs,
|
|
}: let
|
|
supportedSystems = ["x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin"];
|
|
forEachSupportedSystem = f:
|
|
nixpkgs.lib.genAttrs supportedSystems (system:
|
|
f {
|
|
pkgs = import nixpkgs {
|
|
inherit system;
|
|
config.allowUnfree = true;
|
|
};
|
|
});
|
|
in {
|
|
devShells = forEachSupportedSystem ({pkgs}: let
|
|
libs = [
|
|
# PyTorch and Numpy depends on the following libraries.
|
|
pkgs.cudaPackages.cudatoolkit
|
|
pkgs.cudaPackages.cudnn
|
|
pkgs.stdenv.cc.cc.lib
|
|
pkgs.zlib
|
|
|
|
# PyTorch also needs to know where your local "lib/libcuda.so" lives.
|
|
# If you're not on NixOS, you should provide the right path (likely
|
|
# another one).
|
|
"/run/opengl-driver"
|
|
];
|
|
in {
|
|
default = pkgs.mkShell {
|
|
packages = [
|
|
pkgs.python312
|
|
pkgs.python312Packages.venvShellHook
|
|
];
|
|
|
|
env = {
|
|
CC = "${pkgs.gcc}/bin/gcc"; # For `torch.compile`.
|
|
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath libs;
|
|
};
|
|
|
|
venvDir = ".venv";
|
|
postVenvCreation = ''
|
|
# This is run only when creating the virtual environment.
|
|
pip install torch==2.5.1 numpy==2.2.2
|
|
'';
|
|
postShellHook = ''
|
|
# This is run every time you enter the devShell.
|
|
python3 -c "import torch; print('CUDA available' if torch.cuda.is_available() else 'CPU only')"
|
|
'';
|
|
};
|
|
});
|
|
};
|
|
} |