Skip to content
Snippets Groups Projects
Commit 62d5a8a2 authored by Florian Kirchen's avatar Florian Kirchen
Browse files

rename and stuff

parent 3cf76e6c
No related branches found
No related tags found
1 merge request!1Removed old Files
.idea/misc.xml
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.11 (Ion_ai1)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (Ion_ai1)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -46,7 +46,7 @@ import ionen_env ...@@ -46,7 +46,7 @@ import ionen_env
""" """
Global Variables: Global Variables:
""" """
MODELPATH = "model.pth" MODELPATH = "model1.pth"
# Neuer versuch mit pythorchrl # Neuer versuch mit pythorchrl
def torchrlAlg1(): def torchrlAlg1():
...@@ -62,7 +62,7 @@ def torchrlAlg1(): ...@@ -62,7 +62,7 @@ def torchrlAlg1():
frames_per_batch = 1000 frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M # For a complete training, bring the number of frames up to 1M
total_frames = 1_000_000 total_frames = 10_000
sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop sub_batch_size = 64 # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10 # optimization steps per batch of data collected num_epochs = 10 # optimization steps per batch of data collected
...@@ -259,12 +259,12 @@ def torchrlAlg1(): ...@@ -259,12 +259,12 @@ def torchrlAlg1():
# We evaluate the policy once every 10 batches of data. # We evaluate the policy once every 10 batches of data.
# Evaluation is rather simple: execute the policy without exploration # Evaluation is rather simple: execute the policy without exploration
# (take the expected value of the action distribution) for a given # (take the expected value of the action distribution) for a given
# number of steps (1000, which is our ``env`` horizon). # number of steps (1000, which is our ``env`` horizon). ****500 for me
# The ``rollout`` method of the ``env`` can take a policy as argument: # The ``rollout`` method of the ``env`` can take a policy as argument:
# it will then execute this policy at each step. # it will then execute this policy at each step.
with set_exploration_type(ExplorationType.MEAN), torch.no_grad(): with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
# execute a rollout with the trained policy # execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module) eval_rollout = env.rollout(500, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item()) logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
logs["eval reward (sum)"].append( logs["eval reward (sum)"].append(
eval_rollout["next", "reward"].sum().item() eval_rollout["next", "reward"].sum().item()
...@@ -282,10 +282,10 @@ def torchrlAlg1(): ...@@ -282,10 +282,10 @@ def torchrlAlg1():
# this is a nice-to-have but nothing necessary for PPO to work. # this is a nice-to-have but nothing necessary for PPO to work.
scheduler.step() scheduler.step()
torch.save({'policy_state_dict': policy_module.state_dict(), # torch.save({'policy_state_dict': policy_module.state_dict(),
'value_state_dict': value_module.state_dict(), # 'value_state_dict': value_module.state_dict(),
'loss_state_dict': loss_module.state_dict()}, # 'loss_state_dict': loss_module.state_dict()},
MODELPATH) # MODELPATH)
# These modules are not good to load # These modules are not good to load
# 'collector_state_dict': collector.state_dict(), # 'collector_state_dict': collector.state_dict(),
......
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment