imwithye commited on
Commit
314d0a6
·
1 Parent(s): 11d757e

implement dataset

Browse files
rlcube/pyproject.toml CHANGED
@@ -10,5 +10,6 @@ dependencies = [
10
  "ipykernel>=6.30.1",
11
  "lightning>=2.5.5",
12
  "numpy>=2.3.2",
 
13
  "torch>=2.8.0",
14
  ]
 
10
  "ipykernel>=6.30.1",
11
  "lightning>=2.5.5",
12
  "numpy>=2.3.2",
13
+ "tensordict>=0.10.0",
14
  "torch>=2.8.0",
15
  ]
rlcube/rlcube/envs/cube2.py CHANGED
@@ -16,7 +16,7 @@ class Cube2(gym.Env):
16
  self.observation_space = gym.spaces.Box(
17
  low=0, high=1, shape=(24, 6), dtype=np.int8
18
  )
19
- self.state = np.zeros((6, 4))
20
  self.step_count = 0
21
 
22
  def reset(self, seed=None, options=None, state: np.ndarray = None):
@@ -216,6 +216,15 @@ class Cube2(gym.Env):
216
  {},
217
  )
218
 
 
 
 
 
 
 
 
 
 
219
  def _get_obs(self):
220
  one_hots = []
221
  for i in range(6):
 
16
  self.observation_space = gym.spaces.Box(
17
  low=0, high=1, shape=(24, 6), dtype=np.int8
18
  )
19
+ self.state = np.zeros((6, 4), dtype=np.int8)
20
  self.step_count = 0
21
 
22
  def reset(self, seed=None, options=None, state: np.ndarray = None):
 
216
  {},
217
  )
218
 
219
+ def neighbors(self):
220
+ neighbors = []
221
+ for i in range(12):
222
+ env = Cube2()
223
+ env.reset(state=self.state)
224
+ obs, _, _, _, _ = env.step(i)
225
+ neighbors.append(obs)
226
+ return np.array(neighbors)
227
+
228
  def _get_obs(self):
229
  one_hots = []
230
  for i in range(6):
rlcube/rlcube/models/dataset.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from rlcube.envs.cube2 import Cube2
3
+ import numpy as np
4
+
5
+
6
+ class Cube2Dataset(Dataset):
7
+ def __init__(self, num_envs: int = 1000, num_steps: int = 20):
8
+ self.num_envs = num_envs
9
+ self.num_steps = num_steps
10
+ self.states = []
11
+ self.D = []
12
+ for _ in range(num_envs):
13
+ env = Cube2()
14
+ obs, _ = env.reset()
15
+ for _ in range(num_steps):
16
+ action = env.action_space.sample()
17
+ obs, _, _, _, _ = env.step(action)
18
+ self.states.append(obs)
19
+ self.D.append(env.step_count)
20
+ self.states = np.array(self.states)
21
+ self.D = np.array(self.D)
22
+
23
+ def __len__(self):
24
+ return len(self.states)
25
+
26
+ def __getitem__(self, idx):
27
+ return self.states[idx], self.D[idx]
rlcube/rlcube/models/models.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
  import torch
 
4
 
5
 
6
  class ResidualBlock(nn.Module):
@@ -32,7 +33,14 @@ class DNN(nn.Module):
32
  ]
33
  )
34
 
35
- self.fc_out = nn.Linear(hidden_dim, 1)
 
 
 
 
 
 
 
36
 
37
  def forward(self, x):
38
  batch_size = x.size(0)
@@ -40,8 +48,9 @@ class DNN(nn.Module):
40
  out = F.relu(self.fc_in(x))
41
  for block in self.residual_blocks:
42
  out = block(out)
43
- out = self.fc_out(out)
44
- return out
 
45
 
46
 
47
  if __name__ == "__main__":
@@ -52,7 +61,9 @@ if __name__ == "__main__":
52
 
53
  print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
54
  x = torch.randn(4, 24, 6)
55
- net = DNN()
56
  print("Input shape:", x.shape)
57
- print("Output shape:", net(x).shape)
 
 
 
58
  print("Number of parameters:", sum(p.numel() for p in net.parameters()))
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
  import torch
4
+ from tensordict import TensorDict
5
 
6
 
7
  class ResidualBlock(nn.Module):
 
33
  ]
34
  )
35
 
36
+ # Value head
37
+ self.fc_value = nn.Sequential(
38
+ nn.Linear(hidden_dim, 64), nn.ReLU(), nn.Linear(64, 1)
39
+ )
40
+ # Policy head
41
+ self.fc_policy = nn.Sequential(
42
+ nn.Linear(hidden_dim, 64), nn.ReLU(), nn.Linear(64, 12)
43
+ )
44
 
45
  def forward(self, x):
46
  batch_size = x.size(0)
 
48
  out = F.relu(self.fc_in(x))
49
  for block in self.residual_blocks:
50
  out = block(out)
51
+ value = self.fc_value(out)
52
+ policy = self.fc_policy(out)
53
+ return TensorDict({"value": value, "policy": policy}, batch_size=batch_size)
54
 
55
 
56
  if __name__ == "__main__":
 
61
 
62
  print("Testing Cube2VNetwork, input_dim=24, num_residual_blocks=4")
63
  x = torch.randn(4, 24, 6)
 
64
  print("Input shape:", x.shape)
65
+ net = DNN()
66
+ y = net(x)
67
+ print("Output value shape:", y["value"].shape)
68
+ print("Output policy shape:", y["policy"].shape)
69
  print("Number of parameters:", sum(p.numel() for p in net.parameters()))
rlcube/rlcube/train/train.py CHANGED
@@ -1,18 +1,6 @@
1
- import sys
2
- import os
3
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
4
-
5
- from rlcube.envs.cube2 import Cube2
6
-
7
-
8
- def generate_train_data(num_envs: int = 1000, num_steps: int = 20):
9
- for _ in range(num_envs):
10
- env = Cube2()
11
- obs, _ = env.reset()
12
-
13
- print(obs)
14
- break
15
 
16
 
17
  if __name__ == "__main__":
18
- generate_train_data()
 
 
1
+ from rlcube.models.dataset import Cube2Dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  if __name__ == "__main__":
5
+ dataset = Cube2Dataset(num_envs=10, num_steps=20)
6
+ print(dataset[10])
rlcube/uv.lock CHANGED
@@ -528,6 +528,18 @@ wheels = [
528
  { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
529
  ]
530
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  [[package]]
532
  name = "ipykernel"
533
  version = "6.30.1"
@@ -1023,6 +1035,55 @@ wheels = [
1023
  { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
1024
  ]
1025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
  [[package]]
1027
  name = "packaging"
1028
  version = "25.0"
@@ -1294,6 +1355,17 @@ wheels = [
1294
  { url = "https://files.pythonhosted.org/packages/04/f6/99a5c66478f469598dee25b0e29b302b5bddd4e03ed0da79608ac964056e/pytorch_lightning-2.5.5-py3-none-any.whl", hash = "sha256:0b533991df2353c0c6ea9ca10a7d0728b73631fd61f5a15511b19bee2aef8af0", size = 832431, upload-time = "2025-09-05T16:01:16.234Z" },
1295
  ]
1296
 
 
 
 
 
 
 
 
 
 
 
 
1297
  [[package]]
1298
  name = "pywin32"
1299
  version = "311"
@@ -1462,6 +1534,7 @@ dependencies = [
1462
  { name = "ipykernel" },
1463
  { name = "lightning" },
1464
  { name = "numpy" },
 
1465
  { name = "torch" },
1466
  ]
1467
 
@@ -1472,6 +1545,7 @@ requires-dist = [
1472
  { name = "ipykernel", specifier = ">=6.30.1" },
1473
  { name = "lightning", specifier = ">=2.5.5" },
1474
  { name = "numpy", specifier = ">=2.3.2" },
 
1475
  { name = "torch", specifier = ">=2.8.0" },
1476
  ]
1477
 
@@ -1563,6 +1637,34 @@ wheels = [
1563
  { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" },
1564
  ]
1565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1566
  [[package]]
1567
  name = "torch"
1568
  version = "2.8.0"
@@ -1934,3 +2036,12 @@ wheels = [
1934
  { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
1935
  { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
1936
  ]
 
 
 
 
 
 
 
 
 
 
528
  { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" },
529
  ]
530
 
531
+ [[package]]
532
+ name = "importlib-metadata"
533
+ version = "8.7.0"
534
+ source = { registry = "https://pypi.org/simple" }
535
+ dependencies = [
536
+ { name = "zipp" },
537
+ ]
538
+ sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" }
539
+ wheels = [
540
+ { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" },
541
+ ]
542
+
543
  [[package]]
544
  name = "ipykernel"
545
  version = "6.30.1"
 
1035
  { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" },
1036
  ]
1037
 
1038
+ [[package]]
1039
+ name = "orjson"
1040
+ version = "3.11.3"
1041
+ source = { registry = "https://pypi.org/simple" }
1042
+ sdist = { url = "https://files.pythonhosted.org/packages/be/4d/8df5f83256a809c22c4d6792ce8d43bb503be0fb7a8e4da9025754b09658/orjson-3.11.3.tar.gz", hash = "sha256:1c0603b1d2ffcd43a411d64797a19556ef76958aef1c182f22dc30860152a98a", size = 5482394, upload-time = "2025-08-26T17:46:43.171Z" }
1043
+ wheels = [
1044
+ { url = "https://files.pythonhosted.org/packages/3d/b0/a7edab2a00cdcb2688e1c943401cb3236323e7bfd2839815c6131a3742f4/orjson-3.11.3-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8c752089db84333e36d754c4baf19c0e1437012242048439c7e80eb0e6426e3b", size = 238259, upload-time = "2025-08-26T17:45:15.093Z" },
1045
+ { url = "https://files.pythonhosted.org/packages/e1/c6/ff4865a9cc398a07a83342713b5932e4dc3cb4bf4bc04e8f83dedfc0d736/orjson-3.11.3-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:9b8761b6cf04a856eb544acdd82fc594b978f12ac3602d6374a7edb9d86fd2c2", size = 127633, upload-time = "2025-08-26T17:45:16.417Z" },
1046
+ { url = "https://files.pythonhosted.org/packages/6e/e6/e00bea2d9472f44fe8794f523e548ce0ad51eb9693cf538a753a27b8bda4/orjson-3.11.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b13974dc8ac6ba22feaa867fc19135a3e01a134b4f7c9c28162fed4d615008a", size = 123061, upload-time = "2025-08-26T17:45:17.673Z" },
1047
+ { url = "https://files.pythonhosted.org/packages/54/31/9fbb78b8e1eb3ac605467cb846e1c08d0588506028b37f4ee21f978a51d4/orjson-3.11.3-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f83abab5bacb76d9c821fd5c07728ff224ed0e52d7a71b7b3de822f3df04e15c", size = 127956, upload-time = "2025-08-26T17:45:19.172Z" },
1048
+ { url = "https://files.pythonhosted.org/packages/36/88/b0604c22af1eed9f98d709a96302006915cfd724a7ebd27d6dd11c22d80b/orjson-3.11.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e6fbaf48a744b94091a56c62897b27c31ee2da93d826aa5b207131a1e13d4064", size = 130790, upload-time = "2025-08-26T17:45:20.586Z" },
1049
+ { url = "https://files.pythonhosted.org/packages/0e/9d/1c1238ae9fffbfed51ba1e507731b3faaf6b846126a47e9649222b0fd06f/orjson-3.11.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc779b4f4bba2847d0d2940081a7b6f7b5877e05408ffbb74fa1faf4a136c424", size = 132385, upload-time = "2025-08-26T17:45:22.036Z" },
1050
+ { url = "https://files.pythonhosted.org/packages/a3/b5/c06f1b090a1c875f337e21dd71943bc9d84087f7cdf8c6e9086902c34e42/orjson-3.11.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd4b909ce4c50faa2192da6bb684d9848d4510b736b0611b6ab4020ea6fd2d23", size = 135305, upload-time = "2025-08-26T17:45:23.4Z" },
1051
+ { url = "https://files.pythonhosted.org/packages/a0/26/5f028c7d81ad2ebbf84414ba6d6c9cac03f22f5cd0d01eb40fb2d6a06b07/orjson-3.11.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:524b765ad888dc5518bbce12c77c2e83dee1ed6b0992c1790cc5fb49bb4b6667", size = 132875, upload-time = "2025-08-26T17:45:25.182Z" },
1052
+ { url = "https://files.pythonhosted.org/packages/fe/d4/b8df70d9cfb56e385bf39b4e915298f9ae6c61454c8154a0f5fd7efcd42e/orjson-3.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:84fd82870b97ae3cdcea9d8746e592b6d40e1e4d4527835fc520c588d2ded04f", size = 130940, upload-time = "2025-08-26T17:45:27.209Z" },
1053
+ { url = "https://files.pythonhosted.org/packages/da/5e/afe6a052ebc1a4741c792dd96e9f65bf3939d2094e8b356503b68d48f9f5/orjson-3.11.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:fbecb9709111be913ae6879b07bafd4b0785b44c1eb5cac8ac76da048b3885a1", size = 403852, upload-time = "2025-08-26T17:45:28.478Z" },
1054
+ { url = "https://files.pythonhosted.org/packages/f8/90/7bbabafeb2ce65915e9247f14a56b29c9334003536009ef5b122783fe67e/orjson-3.11.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:9dba358d55aee552bd868de348f4736ca5a4086d9a62e2bfbbeeb5629fe8b0cc", size = 146293, upload-time = "2025-08-26T17:45:29.86Z" },
1055
+ { url = "https://files.pythonhosted.org/packages/27/b3/2d703946447da8b093350570644a663df69448c9d9330e5f1d9cce997f20/orjson-3.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eabcf2e84f1d7105f84580e03012270c7e97ecb1fb1618bda395061b2a84a049", size = 135470, upload-time = "2025-08-26T17:45:31.243Z" },
1056
+ { url = "https://files.pythonhosted.org/packages/38/70/b14dcfae7aff0e379b0119c8a812f8396678919c431efccc8e8a0263e4d9/orjson-3.11.3-cp312-cp312-win32.whl", hash = "sha256:3782d2c60b8116772aea8d9b7905221437fdf53e7277282e8d8b07c220f96cca", size = 136248, upload-time = "2025-08-26T17:45:32.567Z" },
1057
+ { url = "https://files.pythonhosted.org/packages/35/b8/9e3127d65de7fff243f7f3e53f59a531bf6bb295ebe5db024c2503cc0726/orjson-3.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:79b44319268af2eaa3e315b92298de9a0067ade6e6003ddaef72f8e0bedb94f1", size = 131437, upload-time = "2025-08-26T17:45:34.949Z" },
1058
+ { url = "https://files.pythonhosted.org/packages/51/92/a946e737d4d8a7fd84a606aba96220043dcc7d6988b9e7551f7f6d5ba5ad/orjson-3.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:0e92a4e83341ef79d835ca21b8bd13e27c859e4e9e4d7b63defc6e58462a3710", size = 125978, upload-time = "2025-08-26T17:45:36.422Z" },
1059
+ { url = "https://files.pythonhosted.org/packages/fc/79/8932b27293ad35919571f77cb3693b5906cf14f206ef17546052a241fdf6/orjson-3.11.3-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:af40c6612fd2a4b00de648aa26d18186cd1322330bd3a3cc52f87c699e995810", size = 238127, upload-time = "2025-08-26T17:45:38.146Z" },
1060
+ { url = "https://files.pythonhosted.org/packages/1c/82/cb93cd8cf132cd7643b30b6c5a56a26c4e780c7a145db6f83de977b540ce/orjson-3.11.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:9f1587f26c235894c09e8b5b7636a38091a9e6e7fe4531937534749c04face43", size = 127494, upload-time = "2025-08-26T17:45:39.57Z" },
1061
+ { url = "https://files.pythonhosted.org/packages/a4/b8/2d9eb181a9b6bb71463a78882bcac1027fd29cf62c38a40cc02fc11d3495/orjson-3.11.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61dcdad16da5bb486d7227a37a2e789c429397793a6955227cedbd7252eb5a27", size = 123017, upload-time = "2025-08-26T17:45:40.876Z" },
1062
+ { url = "https://files.pythonhosted.org/packages/b4/14/a0e971e72d03b509190232356d54c0f34507a05050bd026b8db2bf2c192c/orjson-3.11.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:11c6d71478e2cbea0a709e8a06365fa63da81da6498a53e4c4f065881d21ae8f", size = 127898, upload-time = "2025-08-26T17:45:42.188Z" },
1063
+ { url = "https://files.pythonhosted.org/packages/8e/af/dc74536722b03d65e17042cc30ae586161093e5b1f29bccda24765a6ae47/orjson-3.11.3-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff94112e0098470b665cb0ed06efb187154b63649403b8d5e9aedeb482b4548c", size = 130742, upload-time = "2025-08-26T17:45:43.511Z" },
1064
+ { url = "https://files.pythonhosted.org/packages/62/e6/7a3b63b6677bce089fe939353cda24a7679825c43a24e49f757805fc0d8a/orjson-3.11.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae8b756575aaa2a855a75192f356bbda11a89169830e1439cfb1a3e1a6dde7be", size = 132377, upload-time = "2025-08-26T17:45:45.525Z" },
1065
+ { url = "https://files.pythonhosted.org/packages/fc/cd/ce2ab93e2e7eaf518f0fd15e3068b8c43216c8a44ed82ac2b79ce5cef72d/orjson-3.11.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c9416cc19a349c167ef76135b2fe40d03cea93680428efee8771f3e9fb66079d", size = 135313, upload-time = "2025-08-26T17:45:46.821Z" },
1066
+ { url = "https://files.pythonhosted.org/packages/d0/b4/f98355eff0bd1a38454209bbc73372ce351ba29933cb3e2eba16c04b9448/orjson-3.11.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b822caf5b9752bc6f246eb08124c3d12bf2175b66ab74bac2ef3bbf9221ce1b2", size = 132908, upload-time = "2025-08-26T17:45:48.126Z" },
1067
+ { url = "https://files.pythonhosted.org/packages/eb/92/8f5182d7bc2a1bed46ed960b61a39af8389f0ad476120cd99e67182bfb6d/orjson-3.11.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:414f71e3bdd5573893bf5ecdf35c32b213ed20aa15536fe2f588f946c318824f", size = 130905, upload-time = "2025-08-26T17:45:49.414Z" },
1068
+ { url = "https://files.pythonhosted.org/packages/1a/60/c41ca753ce9ffe3d0f67b9b4c093bdd6e5fdb1bc53064f992f66bb99954d/orjson-3.11.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:828e3149ad8815dc14468f36ab2a4b819237c155ee1370341b91ea4c8672d2ee", size = 403812, upload-time = "2025-08-26T17:45:51.085Z" },
1069
+ { url = "https://files.pythonhosted.org/packages/dd/13/e4a4f16d71ce1868860db59092e78782c67082a8f1dc06a3788aef2b41bc/orjson-3.11.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ac9e05f25627ffc714c21f8dfe3a579445a5c392a9c8ae7ba1d0e9fb5333f56e", size = 146277, upload-time = "2025-08-26T17:45:52.851Z" },
1070
+ { url = "https://files.pythonhosted.org/packages/8d/8b/bafb7f0afef9344754a3a0597a12442f1b85a048b82108ef2c956f53babd/orjson-3.11.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e44fbe4000bd321d9f3b648ae46e0196d21577cf66ae684a96ff90b1f7c93633", size = 135418, upload-time = "2025-08-26T17:45:54.806Z" },
1071
+ { url = "https://files.pythonhosted.org/packages/60/d4/bae8e4f26afb2c23bea69d2f6d566132584d1c3a5fe89ee8c17b718cab67/orjson-3.11.3-cp313-cp313-win32.whl", hash = "sha256:2039b7847ba3eec1f5886e75e6763a16e18c68a63efc4b029ddf994821e2e66b", size = 136216, upload-time = "2025-08-26T17:45:57.182Z" },
1072
+ { url = "https://files.pythonhosted.org/packages/88/76/224985d9f127e121c8cad882cea55f0ebe39f97925de040b75ccd4b33999/orjson-3.11.3-cp313-cp313-win_amd64.whl", hash = "sha256:29be5ac4164aa8bdcba5fa0700a3c9c316b411d8ed9d39ef8a882541bd452fae", size = 131362, upload-time = "2025-08-26T17:45:58.56Z" },
1073
+ { url = "https://files.pythonhosted.org/packages/e2/cf/0dce7a0be94bd36d1346be5067ed65ded6adb795fdbe3abd234c8d576d01/orjson-3.11.3-cp313-cp313-win_arm64.whl", hash = "sha256:18bd1435cb1f2857ceb59cfb7de6f92593ef7b831ccd1b9bfb28ca530e539dce", size = 125989, upload-time = "2025-08-26T17:45:59.95Z" },
1074
+ { url = "https://files.pythonhosted.org/packages/ef/77/d3b1fef1fc6aaeed4cbf3be2b480114035f4df8fa1a99d2dac1d40d6e924/orjson-3.11.3-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cf4b81227ec86935568c7edd78352a92e97af8da7bd70bdfdaa0d2e0011a1ab4", size = 238115, upload-time = "2025-08-26T17:46:01.669Z" },
1075
+ { url = "https://files.pythonhosted.org/packages/e4/6d/468d21d49bb12f900052edcfbf52c292022d0a323d7828dc6376e6319703/orjson-3.11.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:bc8bc85b81b6ac9fc4dae393a8c159b817f4c2c9dee5d12b773bddb3b95fc07e", size = 127493, upload-time = "2025-08-26T17:46:03.466Z" },
1076
+ { url = "https://files.pythonhosted.org/packages/67/46/1e2588700d354aacdf9e12cc2d98131fb8ac6f31ca65997bef3863edb8ff/orjson-3.11.3-cp314-cp314-manylinux_2_34_aarch64.whl", hash = "sha256:88dcfc514cfd1b0de038443c7b3e6a9797ffb1b3674ef1fd14f701a13397f82d", size = 122998, upload-time = "2025-08-26T17:46:04.803Z" },
1077
+ { url = "https://files.pythonhosted.org/packages/3b/94/11137c9b6adb3779f1b34fd98be51608a14b430dbc02c6d41134fbba484c/orjson-3.11.3-cp314-cp314-manylinux_2_34_x86_64.whl", hash = "sha256:d61cd543d69715d5fc0a690c7c6f8dcc307bc23abef9738957981885f5f38229", size = 132915, upload-time = "2025-08-26T17:46:06.237Z" },
1078
+ { url = "https://files.pythonhosted.org/packages/10/61/dccedcf9e9bcaac09fdabe9eaee0311ca92115699500efbd31950d878833/orjson-3.11.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2b7b153ed90ababadbef5c3eb39549f9476890d339cf47af563aea7e07db2451", size = 130907, upload-time = "2025-08-26T17:46:07.581Z" },
1079
+ { url = "https://files.pythonhosted.org/packages/0e/fd/0e935539aa7b08b3ca0f817d73034f7eb506792aae5ecc3b7c6e679cdf5f/orjson-3.11.3-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:7909ae2460f5f494fecbcd10613beafe40381fd0316e35d6acb5f3a05bfda167", size = 403852, upload-time = "2025-08-26T17:46:08.982Z" },
1080
+ { url = "https://files.pythonhosted.org/packages/4a/2b/50ae1a5505cd1043379132fdb2adb8a05f37b3e1ebffe94a5073321966fd/orjson-3.11.3-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:2030c01cbf77bc67bee7eef1e7e31ecf28649353987775e3583062c752da0077", size = 146309, upload-time = "2025-08-26T17:46:10.576Z" },
1081
+ { url = "https://files.pythonhosted.org/packages/cd/1d/a473c158e380ef6f32753b5f39a69028b25ec5be331c2049a2201bde2e19/orjson-3.11.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a0169ebd1cbd94b26c7a7ad282cf5c2744fce054133f959e02eb5265deae1872", size = 135424, upload-time = "2025-08-26T17:46:12.386Z" },
1082
+ { url = "https://files.pythonhosted.org/packages/da/09/17d9d2b60592890ff7382e591aa1d9afb202a266b180c3d4049b1ec70e4a/orjson-3.11.3-cp314-cp314-win32.whl", hash = "sha256:0c6d7328c200c349e3a4c6d8c83e0a5ad029bdc2d417f234152bf34842d0fc8d", size = 136266, upload-time = "2025-08-26T17:46:13.853Z" },
1083
+ { url = "https://files.pythonhosted.org/packages/15/58/358f6846410a6b4958b74734727e582ed971e13d335d6c7ce3e47730493e/orjson-3.11.3-cp314-cp314-win_amd64.whl", hash = "sha256:317bbe2c069bbc757b1a2e4105b64aacd3bc78279b66a6b9e51e846e4809f804", size = 131351, upload-time = "2025-08-26T17:46:15.27Z" },
1084
+ { url = "https://files.pythonhosted.org/packages/28/01/d6b274a0635be0468d4dbd9cafe80c47105937a0d42434e805e67cd2ed8b/orjson-3.11.3-cp314-cp314-win_arm64.whl", hash = "sha256:e8f6a7a27d7b7bec81bd5924163e9af03d49bbb63013f107b48eb5d16db711bc", size = 125985, upload-time = "2025-08-26T17:46:16.67Z" },
1085
+ ]
1086
+
1087
  [[package]]
1088
  name = "packaging"
1089
  version = "25.0"
 
1355
  { url = "https://files.pythonhosted.org/packages/04/f6/99a5c66478f469598dee25b0e29b302b5bddd4e03ed0da79608ac964056e/pytorch_lightning-2.5.5-py3-none-any.whl", hash = "sha256:0b533991df2353c0c6ea9ca10a7d0728b73631fd61f5a15511b19bee2aef8af0", size = 832431, upload-time = "2025-09-05T16:01:16.234Z" },
1356
  ]
1357
 
1358
+ [[package]]
1359
+ name = "pyvers"
1360
+ version = "0.1.0"
1361
+ source = { registry = "https://pypi.org/simple" }
1362
+ dependencies = [
1363
+ { name = "packaging" },
1364
+ ]
1365
+ wheels = [
1366
+ { url = "https://files.pythonhosted.org/packages/7f/39/c5432f541e6ea1d616dfd6ef42ce02792f7eb42dd44f5ed4439dbe17a58b/pyvers-0.1.0-py3-none-any.whl", hash = "sha256:065249805ae537ddf9a2d1a8dffc6d0a12474a347d2eaa2f35ebdae92c0c8199", size = 10092, upload-time = "2025-06-08T23:46:46.219Z" },
1367
+ ]
1368
+
1369
  [[package]]
1370
  name = "pywin32"
1371
  version = "311"
 
1534
  { name = "ipykernel" },
1535
  { name = "lightning" },
1536
  { name = "numpy" },
1537
+ { name = "tensordict" },
1538
  { name = "torch" },
1539
  ]
1540
 
 
1545
  { name = "ipykernel", specifier = ">=6.30.1" },
1546
  { name = "lightning", specifier = ">=2.5.5" },
1547
  { name = "numpy", specifier = ">=2.3.2" },
1548
+ { name = "tensordict", specifier = ">=0.10.0" },
1549
  { name = "torch", specifier = ">=2.8.0" },
1550
  ]
1551
 
 
1637
  { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" },
1638
  ]
1639
 
1640
+ [[package]]
1641
+ name = "tensordict"
1642
+ version = "0.10.0"
1643
+ source = { registry = "https://pypi.org/simple" }
1644
+ dependencies = [
1645
+ { name = "cloudpickle" },
1646
+ { name = "importlib-metadata" },
1647
+ { name = "numpy" },
1648
+ { name = "orjson", marker = "python_full_version < '3.13'" },
1649
+ { name = "packaging" },
1650
+ { name = "pyvers" },
1651
+ { name = "torch" },
1652
+ ]
1653
+ wheels = [
1654
+ { url = "https://files.pythonhosted.org/packages/e6/89/2914b6d2796bdbe64ba8d42b568bf02c25673f187079e8795fc668c609fa/tensordict-0.10.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f52321ddec5da5acb3b90785c68b4fd4cce805aab6eb700ec59352e9f9e6214f", size = 801519, upload-time = "2025-09-08T10:07:18.954Z" },
1655
+ { url = "https://files.pythonhosted.org/packages/9e/88/2c1bf6c1abdc4d0bfcbdda2d1a5b19c7a9540f67ff6d20fe08d328d78305/tensordict-0.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:6a6ba462cc4c04299eb3746e8994df4a625706e60e2dfb88cc5f9513b6cbad2f", size = 445427, upload-time = "2025-09-08T10:07:20.148Z" },
1656
+ { url = "https://files.pythonhosted.org/packages/5e/05/6e7d130c5e9af947fad25fb7d40a3aa2fd9ef9d37c9c7ddc94ba11853d23/tensordict-0.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6f0a52524c7c46778bf250444f1cd508f055735667b8d596a1a7e2fb38824e8c", size = 449961, upload-time = "2025-09-08T10:07:21.977Z" },
1657
+ { url = "https://files.pythonhosted.org/packages/f4/40/877fd0453c9c79a14063ecd21102a23460d033e3760e83eae7fd6c09b3ef/tensordict-0.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:76e7c1d6604addd08e026141c3943fa15fbe36db537f9ff311af5d2caee25daf", size = 494502, upload-time = "2025-09-08T10:07:23.2Z" },
1658
+ { url = "https://files.pythonhosted.org/packages/f6/28/95cabf70c3e6a44476ff03649f847178603d713a3d754ddc9245416c48df/tensordict-0.10.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:7074663d59bb42586f7ee9859377299cac8882bd28cfb43800d56976b893db1a", size = 801382, upload-time = "2025-09-08T10:07:24.683Z" },
1659
+ { url = "https://files.pythonhosted.org/packages/3f/c4/bbee035a2330c856bb5437a368e8e696c7bf0929b2cbf413999305b5d055/tensordict-0.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c3c9c16f6601b396155701b05ce40f328edae4a7d4d3069240fbee66a3645fc3", size = 445382, upload-time = "2025-09-08T10:07:26.266Z" },
1660
+ { url = "https://files.pythonhosted.org/packages/64/70/7aabb69d9dc760e4053187631e2af793a449a5a352e94c0c739d091e66ef/tensordict-0.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:102776f9e3ea17bcc17fd45896f22c5b045fb61b6b502a1a6a6ff627eacb9b68", size = 449861, upload-time = "2025-09-08T10:07:27.477Z" },
1661
+ { url = "https://files.pythonhosted.org/packages/44/f3/c108125d0c5bfd9ca17904b90f4f9c1a6cbcf831389e24e2bb5ac2f90019/tensordict-0.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:61dd9a7cd6229d1d21b492f6d920312029447c183dd41abc05d3a3671bc089e6", size = 494476, upload-time = "2025-09-08T10:07:28.961Z" },
1662
+ { url = "https://files.pythonhosted.org/packages/cf/ed/5f25c5ffcd7a7d30d2672f1cd513622b6df6869392664b087cfe6e0bfdc1/tensordict-0.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3ef1d115c907f82c6b2bd7c2f1fcabb1b5541d4a327f6e53a60c95c56dd3323d", size = 807091, upload-time = "2025-09-08T10:07:30.609Z" },
1663
+ { url = "https://files.pythonhosted.org/packages/a4/da/092ff3ee460a97892f9294871908fd2536b31f891fc9eb006527064c27ad/tensordict-0.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:3aeb85359ae11142f9c29cb55d4e46e516381d772271aeff001e3c310715d5fa", size = 446841, upload-time = "2025-09-08T10:07:31.788Z" },
1664
+ { url = "https://files.pythonhosted.org/packages/9d/2a/3a1f103cf0da8b01ef1dc5605cbc5a3242bf6cd6151797dd32b15eaef278/tensordict-0.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:194adb6cac77e8e4f32b4b10b6a373d3389133e58b6eef184c598908d7cadc93", size = 449578, upload-time = "2025-09-08T10:07:32.933Z" },
1665
+ { url = "https://files.pythonhosted.org/packages/e7/7a/d7eb7472e7134b7cc9557e061e27061c3b9e1942426462f8ce93d1a9f13e/tensordict-0.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:7e5ba5d9eb21d76f1fc8fd5b06453cc497b2283f2317cbb2cb1bd848a92f2e2a", size = 504441, upload-time = "2025-09-08T10:07:34.114Z" },
1666
+ ]
1667
+
1668
  [[package]]
1669
  name = "torch"
1670
  version = "2.8.0"
 
2036
  { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
2037
  { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
2038
  ]
2039
+
2040
+ [[package]]
2041
+ name = "zipp"
2042
+ version = "3.23.0"
2043
+ source = { registry = "https://pypi.org/simple" }
2044
+ sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" }
2045
+ wheels = [
2046
+ { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" },
2047
+ ]