imwithye commited on
Commit
80f5283
·
1 Parent(s): b908f51

learn 2 steps

Browse files
Files changed (1) hide show
  1. rlcube/cube2.ipynb +386 -21
rlcube/cube2.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "id": "dff864f2",
7
  "metadata": {},
8
  "outputs": [],
@@ -231,7 +231,7 @@
231
  },
232
  {
233
  "cell_type": "code",
234
- "execution_count": 6,
235
  "id": "624c83c1",
236
  "metadata": {},
237
  "outputs": [],
@@ -245,54 +245,419 @@
245
  " \n",
246
  " def reset(self, *args, **kwargs):\n",
247
  " super().reset(*args, **kwargs)\n",
248
- " actions = [self.env.action_space.sample() for _ in range(20)]\n",
249
- " for action in actions:\n",
250
- " self.env.step(action)\n",
251
  " return self.env._get_obs(), {}\n",
252
  "\n",
253
  " def step(self, action):\n",
254
  " obs, reward, terminated, truncated, _ = super().step(action)\n",
 
 
255
  " return obs, reward, terminated, truncated, _"
256
  ]
257
  },
258
  {
259
  "cell_type": "code",
260
- "execution_count": 7,
261
- "id": "639f54c6",
262
  "metadata": {},
263
  "outputs": [
264
  {
265
  "name": "stdout",
266
  "output_type": "stream",
267
  "text": [
268
- "[[1. 1. 0. 3.]\n",
269
- " [5. 4. 4. 2.]\n",
270
- " [3. 4. 5. 5.]\n",
271
- " [1. 2. 2. 4.]\n",
272
- " [1. 3. 0. 0.]\n",
273
- " [3. 5. 0. 2.]]\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  ]
275
  }
276
  ],
277
  "source": [
 
 
278
  "env = Cube2()\n",
279
  "env = RewardWrapper(env)\n",
280
- "obs, _ = env.reset()\n",
281
- "print(env.state())"
 
282
  ]
283
  },
284
  {
285
  "cell_type": "code",
286
- "execution_count": null,
287
- "id": "f8b4d968",
288
  "metadata": {},
289
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
290
  "source": [
291
- "from stable_baselines3 import DQN\n",
 
292
  "\n",
293
- "model = DQN(\"MlpPolicy\", env, verbose=1)\n",
294
- "model.learn(total_timesteps=10000, log_interval=10)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  ]
 
 
 
 
 
 
 
 
296
  }
297
  ],
298
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 127,
6
  "id": "dff864f2",
7
  "metadata": {},
8
  "outputs": [],
 
231
  },
232
  {
233
  "cell_type": "code",
234
+ "execution_count": 128,
235
  "id": "624c83c1",
236
  "metadata": {},
237
  "outputs": [],
 
245
  " \n",
246
  " def reset(self, *args, **kwargs):\n",
247
  " super().reset(*args, **kwargs)\n",
248
+ " self.env.step(self.env.action_space.sample())\n",
249
+ " self.env.step(self.env.action_space.sample())\n",
 
250
  " return self.env._get_obs(), {}\n",
251
  "\n",
252
  " def step(self, action):\n",
253
  " obs, reward, terminated, truncated, _ = super().step(action)\n",
254
+ " if terminated:\n",
255
+ " reward = 100\n",
256
  " return obs, reward, terminated, truncated, _"
257
  ]
258
  },
259
  {
260
  "cell_type": "code",
261
+ "execution_count": 130,
262
+ "id": "f8b4d968",
263
  "metadata": {},
264
  "outputs": [
265
  {
266
  "name": "stdout",
267
  "output_type": "stream",
268
  "text": [
269
+ "Using cpu device\n",
270
+ "Wrapping the env with a `Monitor` wrapper\n",
271
+ "Wrapping the env in a DummyVecEnv.\n",
272
+ "----------------------------------\n",
273
+ "| rollout/ | |\n",
274
+ "| ep_len_mean | 91.4 |\n",
275
+ "| ep_rew_mean | -84.3 |\n",
276
+ "| exploration_rate | 0.132 |\n",
277
+ "| time/ | |\n",
278
+ "| episodes | 100 |\n",
279
+ "| fps | 4624 |\n",
280
+ "| time_elapsed | 1 |\n",
281
+ "| total_timesteps | 9136 |\n",
282
+ "| train/ | |\n",
283
+ "| learning_rate | 0.0001 |\n",
284
+ "| loss | 0.00031 |\n",
285
+ "| n_updates | 2258 |\n",
286
+ "----------------------------------\n",
287
+ "----------------------------------\n",
288
+ "| rollout/ | |\n",
289
+ "| ep_len_mean | 87.9 |\n",
290
+ "| ep_rew_mean | -76.8 |\n",
291
+ "| exploration_rate | 0.05 |\n",
292
+ "| time/ | |\n",
293
+ "| episodes | 200 |\n",
294
+ "| fps | 4407 |\n",
295
+ "| time_elapsed | 4 |\n",
296
+ "| total_timesteps | 17928 |\n",
297
+ "| train/ | |\n",
298
+ "| learning_rate | 0.0001 |\n",
299
+ "| loss | 0.00032 |\n",
300
+ "| n_updates | 4456 |\n",
301
+ "----------------------------------\n",
302
+ "----------------------------------\n",
303
+ "| rollout/ | |\n",
304
+ "| ep_len_mean | 80.2 |\n",
305
+ "| ep_rew_mean | -61 |\n",
306
+ "| exploration_rate | 0.05 |\n",
307
+ "| time/ | |\n",
308
+ "| episodes | 300 |\n",
309
+ "| fps | 4300 |\n",
310
+ "| time_elapsed | 6 |\n",
311
+ "| total_timesteps | 25946 |\n",
312
+ "| train/ | |\n",
313
+ "| learning_rate | 0.0001 |\n",
314
+ "| loss | 0.000486 |\n",
315
+ "| n_updates | 6461 |\n",
316
+ "----------------------------------\n",
317
+ "----------------------------------\n",
318
+ "| rollout/ | |\n",
319
+ "| ep_len_mean | 71.3 |\n",
320
+ "| ep_rew_mean | -43 |\n",
321
+ "| exploration_rate | 0.05 |\n",
322
+ "| time/ | |\n",
323
+ "| episodes | 400 |\n",
324
+ "| fps | 4189 |\n",
325
+ "| time_elapsed | 7 |\n",
326
+ "| total_timesteps | 33072 |\n",
327
+ "| train/ | |\n",
328
+ "| learning_rate | 0.0001 |\n",
329
+ "| loss | 0.000479 |\n",
330
+ "| n_updates | 8242 |\n",
331
+ "----------------------------------\n",
332
+ "----------------------------------\n",
333
+ "| rollout/ | |\n",
334
+ "| ep_len_mean | 62.8 |\n",
335
+ "| ep_rew_mean | -23.4 |\n",
336
+ "| exploration_rate | 0.05 |\n",
337
+ "| time/ | |\n",
338
+ "| episodes | 500 |\n",
339
+ "| fps | 4123 |\n",
340
+ "| time_elapsed | 9 |\n",
341
+ "| total_timesteps | 39348 |\n",
342
+ "| train/ | |\n",
343
+ "| learning_rate | 0.0001 |\n",
344
+ "| loss | 0.000449 |\n",
345
+ "| n_updates | 9811 |\n",
346
+ "----------------------------------\n",
347
+ "----------------------------------\n",
348
+ "| rollout/ | |\n",
349
+ "| ep_len_mean | 54.2 |\n",
350
+ "| ep_rew_mean | -6.69 |\n",
351
+ "| exploration_rate | 0.05 |\n",
352
+ "| time/ | |\n",
353
+ "| episodes | 600 |\n",
354
+ "| fps | 4072 |\n",
355
+ "| time_elapsed | 10 |\n",
356
+ "| total_timesteps | 44764 |\n",
357
+ "| train/ | |\n",
358
+ "| learning_rate | 0.0001 |\n",
359
+ "| loss | 0.000499 |\n",
360
+ "| n_updates | 11165 |\n",
361
+ "----------------------------------\n",
362
+ "----------------------------------\n",
363
+ "| rollout/ | |\n",
364
+ "| ep_len_mean | 37.5 |\n",
365
+ "| ep_rew_mean | 27.1 |\n",
366
+ "| exploration_rate | 0.05 |\n",
367
+ "| time/ | |\n",
368
+ "| episodes | 700 |\n",
369
+ "| fps | 4063 |\n",
370
+ "| time_elapsed | 11 |\n",
371
+ "| total_timesteps | 48514 |\n",
372
+ "| train/ | |\n",
373
+ "| learning_rate | 0.0001 |\n",
374
+ "| loss | 0.000346 |\n",
375
+ "| n_updates | 12103 |\n",
376
+ "----------------------------------\n",
377
+ "----------------------------------\n",
378
+ "| rollout/ | |\n",
379
+ "| ep_len_mean | 38.3 |\n",
380
+ "| ep_rew_mean | 26.3 |\n",
381
+ "| exploration_rate | 0.05 |\n",
382
+ "| time/ | |\n",
383
+ "| episodes | 800 |\n",
384
+ "| fps | 4067 |\n",
385
+ "| time_elapsed | 12 |\n",
386
+ "| total_timesteps | 52346 |\n",
387
+ "| train/ | |\n",
388
+ "| learning_rate | 0.0001 |\n",
389
+ "| loss | 0.000947 |\n",
390
+ "| n_updates | 13061 |\n",
391
+ "----------------------------------\n",
392
+ "----------------------------------\n",
393
+ "| rollout/ | |\n",
394
+ "| ep_len_mean | 37.5 |\n",
395
+ "| ep_rew_mean | 28.2 |\n",
396
+ "| exploration_rate | 0.05 |\n",
397
+ "| time/ | |\n",
398
+ "| episodes | 900 |\n",
399
+ "| fps | 4076 |\n",
400
+ "| time_elapsed | 13 |\n",
401
+ "| total_timesteps | 56094 |\n",
402
+ "| train/ | |\n",
403
+ "| learning_rate | 0.0001 |\n",
404
+ "| loss | 0.00122 |\n",
405
+ "| n_updates | 13998 |\n",
406
+ "----------------------------------\n",
407
+ "----------------------------------\n",
408
+ "| rollout/ | |\n",
409
+ "| ep_len_mean | 49.9 |\n",
410
+ "| ep_rew_mean | 3.65 |\n",
411
+ "| exploration_rate | 0.05 |\n",
412
+ "| time/ | |\n",
413
+ "| episodes | 1000 |\n",
414
+ "| fps | 4092 |\n",
415
+ "| time_elapsed | 14 |\n",
416
+ "| total_timesteps | 61082 |\n",
417
+ "| train/ | |\n",
418
+ "| learning_rate | 0.0001 |\n",
419
+ "| loss | 0.0014 |\n",
420
+ "| n_updates | 15245 |\n",
421
+ "----------------------------------\n",
422
+ "----------------------------------\n",
423
+ "| rollout/ | |\n",
424
+ "| ep_len_mean | 42.8 |\n",
425
+ "| ep_rew_mean | 16.8 |\n",
426
+ "| exploration_rate | 0.05 |\n",
427
+ "| time/ | |\n",
428
+ "| episodes | 1100 |\n",
429
+ "| fps | 4106 |\n",
430
+ "| time_elapsed | 15 |\n",
431
+ "| total_timesteps | 65360 |\n",
432
+ "| train/ | |\n",
433
+ "| learning_rate | 0.0001 |\n",
434
+ "| loss | 0.0115 |\n",
435
+ "| n_updates | 16314 |\n",
436
+ "----------------------------------\n",
437
+ "----------------------------------\n",
438
+ "| rollout/ | |\n",
439
+ "| ep_len_mean | 30.9 |\n",
440
+ "| ep_rew_mean | 40.9 |\n",
441
+ "| exploration_rate | 0.05 |\n",
442
+ "| time/ | |\n",
443
+ "| episodes | 1200 |\n",
444
+ "| fps | 4113 |\n",
445
+ "| time_elapsed | 16 |\n",
446
+ "| total_timesteps | 68446 |\n",
447
+ "| train/ | |\n",
448
+ "| learning_rate | 0.0001 |\n",
449
+ "| loss | 0.00337 |\n",
450
+ "| n_updates | 17086 |\n",
451
+ "----------------------------------\n",
452
+ "----------------------------------\n",
453
+ "| rollout/ | |\n",
454
+ "| ep_len_mean | 35 |\n",
455
+ "| ep_rew_mean | 33.7 |\n",
456
+ "| exploration_rate | 0.05 |\n",
457
+ "| time/ | |\n",
458
+ "| episodes | 1300 |\n",
459
+ "| fps | 4122 |\n",
460
+ "| time_elapsed | 17 |\n",
461
+ "| total_timesteps | 71948 |\n",
462
+ "| train/ | |\n",
463
+ "| learning_rate | 0.0001 |\n",
464
+ "| loss | 0.039 |\n",
465
+ "| n_updates | 17961 |\n",
466
+ "----------------------------------\n",
467
+ "----------------------------------\n",
468
+ "| rollout/ | |\n",
469
+ "| ep_len_mean | 31.7 |\n",
470
+ "| ep_rew_mean | 39 |\n",
471
+ "| exploration_rate | 0.05 |\n",
472
+ "| time/ | |\n",
473
+ "| episodes | 1400 |\n",
474
+ "| fps | 4128 |\n",
475
+ "| time_elapsed | 18 |\n",
476
+ "| total_timesteps | 75122 |\n",
477
+ "| train/ | |\n",
478
+ "| learning_rate | 0.0001 |\n",
479
+ "| loss | 0.00402 |\n",
480
+ "| n_updates | 18755 |\n",
481
+ "----------------------------------\n",
482
+ "----------------------------------\n",
483
+ "| rollout/ | |\n",
484
+ "| ep_len_mean | 27.7 |\n",
485
+ "| ep_rew_mean | 47 |\n",
486
+ "| exploration_rate | 0.05 |\n",
487
+ "| time/ | |\n",
488
+ "| episodes | 1500 |\n",
489
+ "| fps | 4129 |\n",
490
+ "| time_elapsed | 18 |\n",
491
+ "| total_timesteps | 77894 |\n",
492
+ "| train/ | |\n",
493
+ "| learning_rate | 0.0001 |\n",
494
+ "| loss | 0.00105 |\n",
495
+ "| n_updates | 19448 |\n",
496
+ "----------------------------------\n",
497
+ "----------------------------------\n",
498
+ "| rollout/ | |\n",
499
+ "| ep_len_mean | 35.5 |\n",
500
+ "| ep_rew_mean | 31.2 |\n",
501
+ "| exploration_rate | 0.05 |\n",
502
+ "| time/ | |\n",
503
+ "| episodes | 1600 |\n",
504
+ "| fps | 4125 |\n",
505
+ "| time_elapsed | 19 |\n",
506
+ "| total_timesteps | 81440 |\n",
507
+ "| train/ | |\n",
508
+ "| learning_rate | 0.0001 |\n",
509
+ "| loss | 0.00372 |\n",
510
+ "| n_updates | 20334 |\n",
511
+ "----------------------------------\n",
512
+ "----------------------------------\n",
513
+ "| rollout/ | |\n",
514
+ "| ep_len_mean | 27.9 |\n",
515
+ "| ep_rew_mean | 46.8 |\n",
516
+ "| exploration_rate | 0.05 |\n",
517
+ "| time/ | |\n",
518
+ "| episodes | 1700 |\n",
519
+ "| fps | 4122 |\n",
520
+ "| time_elapsed | 20 |\n",
521
+ "| total_timesteps | 84230 |\n",
522
+ "| train/ | |\n",
523
+ "| learning_rate | 0.0001 |\n",
524
+ "| loss | 0.011 |\n",
525
+ "| n_updates | 21032 |\n",
526
+ "----------------------------------\n",
527
+ "----------------------------------\n",
528
+ "| rollout/ | |\n",
529
+ "| ep_len_mean | 34.3 |\n",
530
+ "| ep_rew_mean | 33.4 |\n",
531
+ "| exploration_rate | 0.05 |\n",
532
+ "| time/ | |\n",
533
+ "| episodes | 1800 |\n",
534
+ "| fps | 4122 |\n",
535
+ "| time_elapsed | 21 |\n",
536
+ "| total_timesteps | 87656 |\n",
537
+ "| train/ | |\n",
538
+ "| learning_rate | 0.0001 |\n",
539
+ "| loss | 0.00412 |\n",
540
+ "| n_updates | 21888 |\n",
541
+ "----------------------------------\n",
542
+ "----------------------------------\n",
543
+ "| rollout/ | |\n",
544
+ "| ep_len_mean | 27.3 |\n",
545
+ "| ep_rew_mean | 48.5 |\n",
546
+ "| exploration_rate | 0.05 |\n",
547
+ "| time/ | |\n",
548
+ "| episodes | 1900 |\n",
549
+ "| fps | 4122 |\n",
550
+ "| time_elapsed | 21 |\n",
551
+ "| total_timesteps | 90384 |\n",
552
+ "| train/ | |\n",
553
+ "| learning_rate | 0.0001 |\n",
554
+ "| loss | 6.7 |\n",
555
+ "| n_updates | 22570 |\n",
556
+ "----------------------------------\n",
557
+ "----------------------------------\n",
558
+ "| rollout/ | |\n",
559
+ "| ep_len_mean | 35.2 |\n",
560
+ "| ep_rew_mean | 31.5 |\n",
561
+ "| exploration_rate | 0.05 |\n",
562
+ "| time/ | |\n",
563
+ "| episodes | 2000 |\n",
564
+ "| fps | 4106 |\n",
565
+ "| time_elapsed | 22 |\n",
566
+ "| total_timesteps | 93900 |\n",
567
+ "| train/ | |\n",
568
+ "| learning_rate | 0.0001 |\n",
569
+ "| loss | 0.0141 |\n",
570
+ "| n_updates | 23449 |\n",
571
+ "----------------------------------\n",
572
+ "----------------------------------\n",
573
+ "| rollout/ | |\n",
574
+ "| ep_len_mean | 25.8 |\n",
575
+ "| ep_rew_mean | 51 |\n",
576
+ "| exploration_rate | 0.05 |\n",
577
+ "| time/ | |\n",
578
+ "| episodes | 2100 |\n",
579
+ "| fps | 4105 |\n",
580
+ "| time_elapsed | 23 |\n",
581
+ "| total_timesteps | 96476 |\n",
582
+ "| train/ | |\n",
583
+ "| learning_rate | 0.0001 |\n",
584
+ "| loss | 0.0174 |\n",
585
+ "| n_updates | 24093 |\n",
586
+ "----------------------------------\n",
587
+ "----------------------------------\n",
588
+ "| rollout/ | |\n",
589
+ "| ep_len_mean | 27.7 |\n",
590
+ "| ep_rew_mean | 47 |\n",
591
+ "| exploration_rate | 0.05 |\n",
592
+ "| time/ | |\n",
593
+ "| episodes | 2200 |\n",
594
+ "| fps | 4110 |\n",
595
+ "| time_elapsed | 24 |\n",
596
+ "| total_timesteps | 99250 |\n",
597
+ "| train/ | |\n",
598
+ "| learning_rate | 0.0001 |\n",
599
+ "| loss | 6.27 |\n",
600
+ "| n_updates | 24787 |\n",
601
+ "----------------------------------\n"
602
  ]
603
  }
604
  ],
605
  "source": [
606
+ "from stable_baselines3 import DQN\n",
607
+ "\n",
608
  "env = Cube2()\n",
609
  "env = RewardWrapper(env)\n",
610
+ "model = DQN(\"MlpPolicy\", env, verbose=1)\n",
611
+ "model.learn(total_timesteps=100000, log_interval=100)\n",
612
+ "model.save(\"dqn_cube2.pkl\")"
613
  ]
614
  },
615
  {
616
  "cell_type": "code",
617
+ "execution_count": 148,
618
+ "id": "24132717",
619
  "metadata": {},
620
+ "outputs": [
621
+ {
622
+ "name": "stdout",
623
+ "output_type": "stream",
624
+ "text": [
625
+ "rotationController.setState([[4.0, 4.0, 0.0, 0.0], [1.0, 5.0, 1.0, 5.0], [4.0, 2.0, 1.0, 2.0], [5.0, 3.0, 0.0, 3.0], [3.0, 1.0, 3.0, 4.0], [2.0, 0.0, 2.0, 5.0]])\n",
626
+ "rotationController.addRotationStepCode(...[7, 10, 2, 5, 7, 4, 10, 3, 11, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4])\n",
627
+ "\n",
628
+ "Solved in 98 steps\n"
629
+ ]
630
+ }
631
+ ],
632
  "source": [
633
+ "# model = DQN.load(\"dqn_cube2.pkl\")\n",
634
+ "import json\n",
635
  "\n",
636
+ "env = Cube2()\n",
637
+ "env = RewardWrapper(env)\n",
638
+ "obs, _ = env.reset()\n",
639
+ "print(f\"rotationController.setState({json.dumps(env.state().tolist())})\")\n",
640
+ "\n",
641
+ "solved_actions = []\n",
642
+ "for i in range(100):\n",
643
+ " action, _ = model.predict(obs, deterministic=True)\n",
644
+ " solved_actions.append(action.item())\n",
645
+ " obs, reward, terminated, truncated, _ = env.step(action)\n",
646
+ " if terminated or truncated:\n",
647
+ " break\n",
648
+ "print(f\"rotationController.addRotationStepCode(...{json.dumps(solved_actions)})\")\n",
649
+ "\n",
650
+ "print()\n",
651
+ "print(f\"Solved in {len(solved_actions)} steps\")\n"
652
  ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": null,
657
+ "id": "39924b6b",
658
+ "metadata": {},
659
+ "outputs": [],
660
+ "source": []
661
  }
662
  ],
663
  "metadata": {