def cached_function(inputs, outputs): import theano with Message("Hashing theano fn"): if hasattr(outputs, '__len__'): hash_content = tuple(map(theano.pp, outputs)) else: hash_content = theano.pp(outputs) cache_key = hex(hash(hash_content) & (2 ** 64 - 1))[:-1] cache_dir = Path('~/.hierctrl_cache') cache_dir = cache_dir.expanduser() cache_dir.mkdir_p() cache_file = cache_dir / ('%s.pkl' % cache_key) if cache_file.exists(): with Message("unpickling"): with open(cache_file, "rb") as f: try: return pickle.load(f) except Exception: pass with Message("compiling"): fun = compile_function(inputs, outputs) with Message("picking"): with open(cache_file, "wb") as f: pickle.dump(fun, f, protocol=pickle.HIGHEST_PROTOCOL) return fun # Immutable, lazily evaluated dict
def run(self, epochs, steps, api_key, rollouts_per_epoch = 20, updateTargetNetwork = defaultRunSettings['updateTargetNetwork'], explorationRate = defaultRunSettings['explorationRate'], miniBatchSize = defaultRunSettings['miniBatchSize'], learnStart = defaultRunSettings['learnStart'], renderPerXEpochs = defaultRunSettings['renderPerXEpochs'], shouldRender = defaultRunSettings['shouldRender'], experimentId = defaultRunSettings['experimentId'], force = defaultRunSettings['force'], upload = defaultRunSettings['upload']): last100Scores = [0] * 100 last100ScoresIndex = 0 last100Filled = False stepCounter = 0 if not experimentId == None: self.env.monitor.start('tmp/'+experimentId, force = force) for epoch in xrange(epochs): I = 1 observation = self.env.reset(); for t in xrange(steps): policyValues = self.runModel(self.policyModel, observation) action = self.selectActionByProbability(policyValues) newObservation, reward, done, info = self.env.step(action) cost, grads = self.get_cost_grads(self.policyModel); print (theano.pp(grads[1][0])); if done: delta = reward + self.discountFactor * self.runModel(self.valueModel, newObservation) - self.runModel(self.valueModel, observation) else : delta = reward - self.runModel(self.valueModel, observation) # because the value for new obs is 0 self.env.monitor.close() if upload: gym.upload('/tmp/'+experimentId, api_key=api_key)