imwithye commited on
Commit
e34abba
·
1 Parent(s): 80f5283

fix 4 steps

Browse files
Files changed (1) hide show
  1. rlcube/cube2.ipynb +96 -251
rlcube/cube2.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 127,
6
  "id": "dff864f2",
7
  "metadata": {},
8
  "outputs": [],
@@ -231,7 +231,7 @@
231
  },
232
  {
233
  "cell_type": "code",
234
- "execution_count": 128,
235
  "id": "624c83c1",
236
  "metadata": {},
237
  "outputs": [],
@@ -243,10 +243,14 @@
243
  " def state(self):\n",
244
  " return self.env.state\n",
245
  " \n",
 
 
 
246
  " def reset(self, *args, **kwargs):\n",
247
- " super().reset(*args, **kwargs)\n",
248
- " self.env.step(self.env.action_space.sample())\n",
249
- " self.env.step(self.env.action_space.sample())\n",
 
250
  " return self.env._get_obs(), {}\n",
251
  "\n",
252
  " def step(self, action):\n",
@@ -258,7 +262,7 @@
258
  },
259
  {
260
  "cell_type": "code",
261
- "execution_count": 130,
262
  "id": "f8b4d968",
263
  "metadata": {},
264
  "outputs": [
@@ -271,333 +275,183 @@
271
  "Wrapping the env in a DummyVecEnv.\n",
272
  "----------------------------------\n",
273
  "| rollout/ | |\n",
274
- "| ep_len_mean | 91.4 |\n",
275
- "| ep_rew_mean | -84.3 |\n",
276
- "| exploration_rate | 0.132 |\n",
277
  "| time/ | |\n",
278
  "| episodes | 100 |\n",
279
- "| fps | 4624 |\n",
280
  "| time_elapsed | 1 |\n",
281
- "| total_timesteps | 9136 |\n",
282
  "| train/ | |\n",
283
  "| learning_rate | 0.0001 |\n",
284
- "| loss | 0.00031 |\n",
285
- "| n_updates | 2258 |\n",
286
  "----------------------------------\n",
287
  "----------------------------------\n",
288
  "| rollout/ | |\n",
289
- "| ep_len_mean | 87.9 |\n",
290
- "| ep_rew_mean | -76.8 |\n",
291
  "| exploration_rate | 0.05 |\n",
292
  "| time/ | |\n",
293
  "| episodes | 200 |\n",
294
- "| fps | 4407 |\n",
295
  "| time_elapsed | 4 |\n",
296
- "| total_timesteps | 17928 |\n",
297
  "| train/ | |\n",
298
  "| learning_rate | 0.0001 |\n",
299
- "| loss | 0.00032 |\n",
300
- "| n_updates | 4456 |\n",
301
  "----------------------------------\n",
302
  "----------------------------------\n",
303
  "| rollout/ | |\n",
304
- "| ep_len_mean | 80.2 |\n",
305
- "| ep_rew_mean | -61 |\n",
306
  "| exploration_rate | 0.05 |\n",
307
  "| time/ | |\n",
308
  "| episodes | 300 |\n",
309
- "| fps | 4300 |\n",
310
  "| time_elapsed | 6 |\n",
311
- "| total_timesteps | 25946 |\n",
312
  "| train/ | |\n",
313
  "| learning_rate | 0.0001 |\n",
314
- "| loss | 0.000486 |\n",
315
- "| n_updates | 6461 |\n",
316
  "----------------------------------\n",
317
  "----------------------------------\n",
318
  "| rollout/ | |\n",
319
- "| ep_len_mean | 71.3 |\n",
320
- "| ep_rew_mean | -43 |\n",
321
  "| exploration_rate | 0.05 |\n",
322
  "| time/ | |\n",
323
  "| episodes | 400 |\n",
324
- "| fps | 4189 |\n",
325
- "| time_elapsed | 7 |\n",
326
- "| total_timesteps | 33072 |\n",
327
  "| train/ | |\n",
328
  "| learning_rate | 0.0001 |\n",
329
- "| loss | 0.000479 |\n",
330
- "| n_updates | 8242 |\n",
331
  "----------------------------------\n",
332
  "----------------------------------\n",
333
  "| rollout/ | |\n",
334
- "| ep_len_mean | 62.8 |\n",
335
- "| ep_rew_mean | -23.4 |\n",
336
  "| exploration_rate | 0.05 |\n",
337
  "| time/ | |\n",
338
  "| episodes | 500 |\n",
339
- "| fps | 4123 |\n",
340
- "| time_elapsed | 9 |\n",
341
- "| total_timesteps | 39348 |\n",
342
- "| train/ | |\n",
343
- "| learning_rate | 0.0001 |\n",
344
- "| loss | 0.000449 |\n",
345
- "| n_updates | 9811 |\n",
346
- "----------------------------------\n",
347
- "----------------------------------\n",
348
- "| rollout/ | |\n",
349
- "| ep_len_mean | 54.2 |\n",
350
- "| ep_rew_mean | -6.69 |\n",
351
- "| exploration_rate | 0.05 |\n",
352
- "| time/ | |\n",
353
- "| episodes | 600 |\n",
354
- "| fps | 4072 |\n",
355
  "| time_elapsed | 10 |\n",
356
- "| total_timesteps | 44764 |\n",
357
- "| train/ | |\n",
358
- "| learning_rate | 0.0001 |\n",
359
- "| loss | 0.000499 |\n",
360
- "| n_updates | 11165 |\n",
361
- "----------------------------------\n",
362
- "----------------------------------\n",
363
- "| rollout/ | |\n",
364
- "| ep_len_mean | 37.5 |\n",
365
- "| ep_rew_mean | 27.1 |\n",
366
- "| exploration_rate | 0.05 |\n",
367
- "| time/ | |\n",
368
- "| episodes | 700 |\n",
369
- "| fps | 4063 |\n",
370
- "| time_elapsed | 11 |\n",
371
- "| total_timesteps | 48514 |\n",
372
  "| train/ | |\n",
373
  "| learning_rate | 0.0001 |\n",
374
- "| loss | 0.000346 |\n",
375
- "| n_updates | 12103 |\n",
376
  "----------------------------------\n",
377
  "----------------------------------\n",
378
  "| rollout/ | |\n",
379
- "| ep_len_mean | 38.3 |\n",
380
- "| ep_rew_mean | 26.3 |\n",
381
  "| exploration_rate | 0.05 |\n",
382
  "| time/ | |\n",
383
- "| episodes | 800 |\n",
384
- "| fps | 4067 |\n",
385
  "| time_elapsed | 12 |\n",
386
- "| total_timesteps | 52346 |\n",
387
  "| train/ | |\n",
388
  "| learning_rate | 0.0001 |\n",
389
- "| loss | 0.000947 |\n",
390
- "| n_updates | 13061 |\n",
391
  "----------------------------------\n",
392
  "----------------------------------\n",
393
  "| rollout/ | |\n",
394
- "| ep_len_mean | 37.5 |\n",
395
- "| ep_rew_mean | 28.2 |\n",
396
  "| exploration_rate | 0.05 |\n",
397
  "| time/ | |\n",
398
- "| episodes | 900 |\n",
399
- "| fps | 4076 |\n",
400
- "| time_elapsed | 13 |\n",
401
- "| total_timesteps | 56094 |\n",
402
- "| train/ | |\n",
403
- "| learning_rate | 0.0001 |\n",
404
- "| loss | 0.00122 |\n",
405
- "| n_updates | 13998 |\n",
406
- "----------------------------------\n",
407
- "----------------------------------\n",
408
- "| rollout/ | |\n",
409
- "| ep_len_mean | 49.9 |\n",
410
- "| ep_rew_mean | 3.65 |\n",
411
- "| exploration_rate | 0.05 |\n",
412
- "| time/ | |\n",
413
- "| episodes | 1000 |\n",
414
- "| fps | 4092 |\n",
415
  "| time_elapsed | 14 |\n",
416
- "| total_timesteps | 61082 |\n",
417
  "| train/ | |\n",
418
  "| learning_rate | 0.0001 |\n",
419
- "| loss | 0.0014 |\n",
420
- "| n_updates | 15245 |\n",
421
  "----------------------------------\n",
422
  "----------------------------------\n",
423
  "| rollout/ | |\n",
424
- "| ep_len_mean | 42.8 |\n",
425
- "| ep_rew_mean | 16.8 |\n",
426
  "| exploration_rate | 0.05 |\n",
427
  "| time/ | |\n",
428
- "| episodes | 1100 |\n",
429
- "| fps | 4106 |\n",
430
  "| time_elapsed | 15 |\n",
431
- "| total_timesteps | 65360 |\n",
432
  "| train/ | |\n",
433
  "| learning_rate | 0.0001 |\n",
434
- "| loss | 0.0115 |\n",
435
- "| n_updates | 16314 |\n",
436
  "----------------------------------\n",
437
  "----------------------------------\n",
438
  "| rollout/ | |\n",
439
- "| ep_len_mean | 30.9 |\n",
440
- "| ep_rew_mean | 40.9 |\n",
441
  "| exploration_rate | 0.05 |\n",
442
  "| time/ | |\n",
443
- "| episodes | 1200 |\n",
444
- "| fps | 4113 |\n",
445
- "| time_elapsed | 16 |\n",
446
- "| total_timesteps | 68446 |\n",
447
- "| train/ | |\n",
448
- "| learning_rate | 0.0001 |\n",
449
- "| loss | 0.00337 |\n",
450
- "| n_updates | 17086 |\n",
451
- "----------------------------------\n",
452
- "----------------------------------\n",
453
- "| rollout/ | |\n",
454
- "| ep_len_mean | 35 |\n",
455
- "| ep_rew_mean | 33.7 |\n",
456
- "| exploration_rate | 0.05 |\n",
457
- "| time/ | |\n",
458
- "| episodes | 1300 |\n",
459
- "| fps | 4122 |\n",
460
  "| time_elapsed | 17 |\n",
461
- "| total_timesteps | 71948 |\n",
462
- "| train/ | |\n",
463
- "| learning_rate | 0.0001 |\n",
464
- "| loss | 0.039 |\n",
465
- "| n_updates | 17961 |\n",
466
- "----------------------------------\n",
467
- "----------------------------------\n",
468
- "| rollout/ | |\n",
469
- "| ep_len_mean | 31.7 |\n",
470
- "| ep_rew_mean | 39 |\n",
471
- "| exploration_rate | 0.05 |\n",
472
- "| time/ | |\n",
473
- "| episodes | 1400 |\n",
474
- "| fps | 4128 |\n",
475
- "| time_elapsed | 18 |\n",
476
- "| total_timesteps | 75122 |\n",
477
  "| train/ | |\n",
478
  "| learning_rate | 0.0001 |\n",
479
- "| loss | 0.00402 |\n",
480
- "| n_updates | 18755 |\n",
481
  "----------------------------------\n",
482
  "----------------------------------\n",
483
  "| rollout/ | |\n",
484
- "| ep_len_mean | 27.7 |\n",
485
- "| ep_rew_mean | 47 |\n",
486
  "| exploration_rate | 0.05 |\n",
487
  "| time/ | |\n",
488
- "| episodes | 1500 |\n",
489
- "| fps | 4129 |\n",
490
  "| time_elapsed | 18 |\n",
491
- "| total_timesteps | 77894 |\n",
492
- "| train/ | |\n",
493
- "| learning_rate | 0.0001 |\n",
494
- "| loss | 0.00105 |\n",
495
- "| n_updates | 19448 |\n",
496
- "----------------------------------\n",
497
- "----------------------------------\n",
498
- "| rollout/ | |\n",
499
- "| ep_len_mean | 35.5 |\n",
500
- "| ep_rew_mean | 31.2 |\n",
501
- "| exploration_rate | 0.05 |\n",
502
- "| time/ | |\n",
503
- "| episodes | 1600 |\n",
504
- "| fps | 4125 |\n",
505
- "| time_elapsed | 19 |\n",
506
- "| total_timesteps | 81440 |\n",
507
  "| train/ | |\n",
508
  "| learning_rate | 0.0001 |\n",
509
- "| loss | 0.00372 |\n",
510
- "| n_updates | 20334 |\n",
511
  "----------------------------------\n",
512
  "----------------------------------\n",
513
  "| rollout/ | |\n",
514
- "| ep_len_mean | 27.9 |\n",
515
- "| ep_rew_mean | 46.8 |\n",
516
  "| exploration_rate | 0.05 |\n",
517
  "| time/ | |\n",
518
- "| episodes | 1700 |\n",
519
- "| fps | 4122 |\n",
520
  "| time_elapsed | 20 |\n",
521
- "| total_timesteps | 84230 |\n",
522
- "| train/ | |\n",
523
- "| learning_rate | 0.0001 |\n",
524
- "| loss | 0.011 |\n",
525
- "| n_updates | 21032 |\n",
526
- "----------------------------------\n",
527
- "----------------------------------\n",
528
- "| rollout/ | |\n",
529
- "| ep_len_mean | 34.3 |\n",
530
- "| ep_rew_mean | 33.4 |\n",
531
- "| exploration_rate | 0.05 |\n",
532
- "| time/ | |\n",
533
- "| episodes | 1800 |\n",
534
- "| fps | 4122 |\n",
535
- "| time_elapsed | 21 |\n",
536
- "| total_timesteps | 87656 |\n",
537
  "| train/ | |\n",
538
  "| learning_rate | 0.0001 |\n",
539
- "| loss | 0.00412 |\n",
540
- "| n_updates | 21888 |\n",
541
  "----------------------------------\n",
542
  "----------------------------------\n",
543
  "| rollout/ | |\n",
544
- "| ep_len_mean | 27.3 |\n",
545
- "| ep_rew_mean | 48.5 |\n",
546
  "| exploration_rate | 0.05 |\n",
547
  "| time/ | |\n",
548
- "| episodes | 1900 |\n",
549
- "| fps | 4122 |\n",
550
  "| time_elapsed | 21 |\n",
551
- "| total_timesteps | 90384 |\n",
552
  "| train/ | |\n",
553
  "| learning_rate | 0.0001 |\n",
554
- "| loss | 6.7 |\n",
555
- "| n_updates | 22570 |\n",
556
- "----------------------------------\n",
557
- "----------------------------------\n",
558
- "| rollout/ | |\n",
559
- "| ep_len_mean | 35.2 |\n",
560
- "| ep_rew_mean | 31.5 |\n",
561
- "| exploration_rate | 0.05 |\n",
562
- "| time/ | |\n",
563
- "| episodes | 2000 |\n",
564
- "| fps | 4106 |\n",
565
- "| time_elapsed | 22 |\n",
566
- "| total_timesteps | 93900 |\n",
567
- "| train/ | |\n",
568
- "| learning_rate | 0.0001 |\n",
569
- "| loss | 0.0141 |\n",
570
- "| n_updates | 23449 |\n",
571
- "----------------------------------\n",
572
- "----------------------------------\n",
573
- "| rollout/ | |\n",
574
- "| ep_len_mean | 25.8 |\n",
575
- "| ep_rew_mean | 51 |\n",
576
- "| exploration_rate | 0.05 |\n",
577
- "| time/ | |\n",
578
- "| episodes | 2100 |\n",
579
- "| fps | 4105 |\n",
580
- "| time_elapsed | 23 |\n",
581
- "| total_timesteps | 96476 |\n",
582
- "| train/ | |\n",
583
- "| learning_rate | 0.0001 |\n",
584
- "| loss | 0.0174 |\n",
585
- "| n_updates | 24093 |\n",
586
- "----------------------------------\n",
587
- "----------------------------------\n",
588
- "| rollout/ | |\n",
589
- "| ep_len_mean | 27.7 |\n",
590
- "| ep_rew_mean | 47 |\n",
591
- "| exploration_rate | 0.05 |\n",
592
- "| time/ | |\n",
593
- "| episodes | 2200 |\n",
594
- "| fps | 4110 |\n",
595
- "| time_elapsed | 24 |\n",
596
- "| total_timesteps | 99250 |\n",
597
- "| train/ | |\n",
598
- "| learning_rate | 0.0001 |\n",
599
- "| loss | 6.27 |\n",
600
- "| n_updates | 24787 |\n",
601
  "----------------------------------\n"
602
  ]
603
  }
@@ -608,13 +462,12 @@
608
  "env = Cube2()\n",
609
  "env = RewardWrapper(env)\n",
610
  "model = DQN(\"MlpPolicy\", env, verbose=1)\n",
611
- "model.learn(total_timesteps=100000, log_interval=100)\n",
612
- "model.save(\"dqn_cube2.pkl\")"
613
  ]
614
  },
615
  {
616
  "cell_type": "code",
617
- "execution_count": 148,
618
  "id": "24132717",
619
  "metadata": {},
620
  "outputs": [
@@ -622,10 +475,10 @@
622
  "name": "stdout",
623
  "output_type": "stream",
624
  "text": [
625
- "rotationController.setState([[4.0, 4.0, 0.0, 0.0], [1.0, 5.0, 1.0, 5.0], [4.0, 2.0, 1.0, 2.0], [5.0, 3.0, 0.0, 3.0], [3.0, 1.0, 3.0, 4.0], [2.0, 0.0, 2.0, 5.0]])\n",
626
- "rotationController.addRotationStepCode(...[7, 10, 2, 5, 7, 4, 10, 3, 11, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4])\n",
627
  "\n",
628
- "Solved in 98 steps\n"
629
  ]
630
  }
631
  ],
@@ -650,14 +503,6 @@
650
  "print()\n",
651
  "print(f\"Solved in {len(solved_actions)} steps\")\n"
652
  ]
653
- },
654
- {
655
- "cell_type": "code",
656
- "execution_count": null,
657
- "id": "39924b6b",
658
- "metadata": {},
659
- "outputs": [],
660
- "source": []
661
  }
662
  ],
663
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 55,
6
  "id": "dff864f2",
7
  "metadata": {},
8
  "outputs": [],
 
231
  },
232
  {
233
  "cell_type": "code",
234
+ "execution_count": 56,
235
  "id": "624c83c1",
236
  "metadata": {},
237
  "outputs": [],
 
243
  " def state(self):\n",
244
  " return self.env.state\n",
245
  " \n",
246
+ " def step_count(self):\n",
247
+ " return self.env.step_count\n",
248
+ " \n",
249
  " def reset(self, *args, **kwargs):\n",
250
+ " self.env.reset(*args, **kwargs)\n",
251
+ " for _ in range(4):\n",
252
+ " self.env.step(self.env.action_space.sample())\n",
253
+ " self.env.step_count = 0\n",
254
  " return self.env._get_obs(), {}\n",
255
  "\n",
256
  " def step(self, action):\n",
 
262
  },
263
  {
264
  "cell_type": "code",
265
+ "execution_count": null,
266
  "id": "f8b4d968",
267
  "metadata": {},
268
  "outputs": [
 
275
  "Wrapping the env in a DummyVecEnv.\n",
276
  "----------------------------------\n",
277
  "| rollout/ | |\n",
278
+ "| ep_len_mean | 94.2 |\n",
279
+ "| ep_rew_mean | -88.2 |\n",
280
+ "| exploration_rate | 0.105 |\n",
281
  "| time/ | |\n",
282
  "| episodes | 100 |\n",
283
+ "| fps | 4943 |\n",
284
  "| time_elapsed | 1 |\n",
285
+ "| total_timesteps | 9424 |\n",
286
  "| train/ | |\n",
287
  "| learning_rate | 0.0001 |\n",
288
+ "| loss | 0.0004 |\n",
289
+ "| n_updates | 2330 |\n",
290
  "----------------------------------\n",
291
  "----------------------------------\n",
292
  "| rollout/ | |\n",
293
+ "| ep_len_mean | 98.1 |\n",
294
+ "| ep_rew_mean | -96.1 |\n",
295
  "| exploration_rate | 0.05 |\n",
296
  "| time/ | |\n",
297
  "| episodes | 200 |\n",
298
+ "| fps | 4426 |\n",
299
  "| time_elapsed | 4 |\n",
300
+ "| total_timesteps | 19236 |\n",
301
  "| train/ | |\n",
302
  "| learning_rate | 0.0001 |\n",
303
+ "| loss | 0.000292 |\n",
304
+ "| n_updates | 4783 |\n",
305
  "----------------------------------\n",
306
  "----------------------------------\n",
307
  "| rollout/ | |\n",
308
+ "| ep_len_mean | 95.2 |\n",
309
+ "| ep_rew_mean | -90.1 |\n",
310
  "| exploration_rate | 0.05 |\n",
311
  "| time/ | |\n",
312
  "| episodes | 300 |\n",
313
+ "| fps | 4349 |\n",
314
  "| time_elapsed | 6 |\n",
315
+ "| total_timesteps | 28754 |\n",
316
  "| train/ | |\n",
317
  "| learning_rate | 0.0001 |\n",
318
+ "| loss | 0.000103 |\n",
319
+ "| n_updates | 7163 |\n",
320
  "----------------------------------\n",
321
  "----------------------------------\n",
322
  "| rollout/ | |\n",
323
+ "| ep_len_mean | 88.4 |\n",
324
+ "| ep_rew_mean | -76.3 |\n",
325
  "| exploration_rate | 0.05 |\n",
326
  "| time/ | |\n",
327
  "| episodes | 400 |\n",
328
+ "| fps | 4391 |\n",
329
+ "| time_elapsed | 8 |\n",
330
+ "| total_timesteps | 37598 |\n",
331
  "| train/ | |\n",
332
  "| learning_rate | 0.0001 |\n",
333
+ "| loss | 0.000121 |\n",
334
+ "| n_updates | 9374 |\n",
335
  "----------------------------------\n",
336
  "----------------------------------\n",
337
  "| rollout/ | |\n",
338
+ "| ep_len_mean | 86.6 |\n",
339
+ "| ep_rew_mean | -72.5 |\n",
340
  "| exploration_rate | 0.05 |\n",
341
  "| time/ | |\n",
342
  "| episodes | 500 |\n",
343
+ "| fps | 4417 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  "| time_elapsed | 10 |\n",
345
+ "| total_timesteps | 46260 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  "| train/ | |\n",
347
  "| learning_rate | 0.0001 |\n",
348
+ "| loss | 0.000169 |\n",
349
+ "| n_updates | 11539 |\n",
350
  "----------------------------------\n",
351
  "----------------------------------\n",
352
  "| rollout/ | |\n",
353
+ "| ep_len_mean | 82.6 |\n",
354
+ "| ep_rew_mean | -64.4 |\n",
355
  "| exploration_rate | 0.05 |\n",
356
  "| time/ | |\n",
357
+ "| episodes | 600 |\n",
358
+ "| fps | 4436 |\n",
359
  "| time_elapsed | 12 |\n",
360
+ "| total_timesteps | 54520 |\n",
361
  "| train/ | |\n",
362
  "| learning_rate | 0.0001 |\n",
363
+ "| loss | 9.72e-05 |\n",
364
+ "| n_updates | 13604 |\n",
365
  "----------------------------------\n",
366
  "----------------------------------\n",
367
  "| rollout/ | |\n",
368
+ "| ep_len_mean | 79.4 |\n",
369
+ "| ep_rew_mean | -57.2 |\n",
370
  "| exploration_rate | 0.05 |\n",
371
  "| time/ | |\n",
372
+ "| episodes | 700 |\n",
373
+ "| fps | 4445 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  "| time_elapsed | 14 |\n",
375
+ "| total_timesteps | 62462 |\n",
376
  "| train/ | |\n",
377
  "| learning_rate | 0.0001 |\n",
378
+ "| loss | 6.99e-05 |\n",
379
+ "| n_updates | 15590 |\n",
380
  "----------------------------------\n",
381
  "----------------------------------\n",
382
  "| rollout/ | |\n",
383
+ "| ep_len_mean | 75.5 |\n",
384
+ "| ep_rew_mean | -49.2 |\n",
385
  "| exploration_rate | 0.05 |\n",
386
  "| time/ | |\n",
387
+ "| episodes | 800 |\n",
388
+ "| fps | 4456 |\n",
389
  "| time_elapsed | 15 |\n",
390
+ "| total_timesteps | 70012 |\n",
391
  "| train/ | |\n",
392
  "| learning_rate | 0.0001 |\n",
393
+ "| loss | 0.264 |\n",
394
+ "| n_updates | 17477 |\n",
395
  "----------------------------------\n",
396
  "----------------------------------\n",
397
  "| rollout/ | |\n",
398
+ "| ep_len_mean | 70.5 |\n",
399
+ "| ep_rew_mean | -39.2 |\n",
400
  "| exploration_rate | 0.05 |\n",
401
  "| time/ | |\n",
402
+ "| episodes | 900 |\n",
403
+ "| fps | 4471 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  "| time_elapsed | 17 |\n",
405
+ "| total_timesteps | 77066 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  "| train/ | |\n",
407
  "| learning_rate | 0.0001 |\n",
408
+ "| loss | 0.000102 |\n",
409
+ "| n_updates | 19241 |\n",
410
  "----------------------------------\n",
411
  "----------------------------------\n",
412
  "| rollout/ | |\n",
413
+ "| ep_len_mean | 66.1 |\n",
414
+ "| ep_rew_mean | -28.8 |\n",
415
  "| exploration_rate | 0.05 |\n",
416
  "| time/ | |\n",
417
+ "| episodes | 1000 |\n",
418
+ "| fps | 4489 |\n",
419
  "| time_elapsed | 18 |\n",
420
+ "| total_timesteps | 83678 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  "| train/ | |\n",
422
  "| learning_rate | 0.0001 |\n",
423
+ "| loss | 0.000145 |\n",
424
+ "| n_updates | 20894 |\n",
425
  "----------------------------------\n",
426
  "----------------------------------\n",
427
  "| rollout/ | |\n",
428
+ "| ep_len_mean | 66.9 |\n",
429
+ "| ep_rew_mean | -31.6 |\n",
430
  "| exploration_rate | 0.05 |\n",
431
  "| time/ | |\n",
432
+ "| episodes | 1100 |\n",
433
+ "| fps | 4504 |\n",
434
  "| time_elapsed | 20 |\n",
435
+ "| total_timesteps | 90370 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  "| train/ | |\n",
437
  "| learning_rate | 0.0001 |\n",
438
+ "| loss | 0.000488 |\n",
439
+ "| n_updates | 22567 |\n",
440
  "----------------------------------\n",
441
  "----------------------------------\n",
442
  "| rollout/ | |\n",
443
+ "| ep_len_mean | 68.6 |\n",
444
+ "| ep_rew_mean | -34.3 |\n",
445
  "| exploration_rate | 0.05 |\n",
446
  "| time/ | |\n",
447
+ "| episodes | 1200 |\n",
448
+ "| fps | 4517 |\n",
449
  "| time_elapsed | 21 |\n",
450
+ "| total_timesteps | 97230 |\n",
451
  "| train/ | |\n",
452
  "| learning_rate | 0.0001 |\n",
453
+ "| loss | 0.00045 |\n",
454
+ "| n_updates | 24282 |\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  "----------------------------------\n"
456
  ]
457
  }
 
462
  "env = Cube2()\n",
463
  "env = RewardWrapper(env)\n",
464
  "model = DQN(\"MlpPolicy\", env, verbose=1)\n",
465
+ "model.learn(total_timesteps=100000, log_interval=100)"
 
466
  ]
467
  },
468
  {
469
  "cell_type": "code",
470
+ "execution_count": 75,
471
  "id": "24132717",
472
  "metadata": {},
473
  "outputs": [
 
475
  "name": "stdout",
476
  "output_type": "stream",
477
  "text": [
478
+ "rotationController.setState([[0.0, 0.0, 3.0, 4.0], [5.0, 2.0, 1.0, 1.0], [3.0, 4.0, 3.0, 2.0], [2.0, 5.0, 4.0, 5.0], [0.0, 3.0, 5.0, 1.0], [1.0, 2.0, 4.0, 0.0]])\n",
479
+ "rotationController.addRotationStepCode(...[3, 1, 8, 3])\n",
480
  "\n",
481
+ "Solved in 4 steps\n"
482
  ]
483
  }
484
  ],
 
503
  "print()\n",
504
  "print(f\"Solved in {len(solved_actions)} steps\")\n"
505
  ]
 
 
 
 
 
 
 
 
506
  }
507
  ],
508
  "metadata": {