-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
123 lines (108 loc) · 3.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
'''
Built on on SAC implementation from
https://github.com/pranz24/pytorch-soft-actor-critic
except for video processing uitls, which are built on
Goal-Aware Prediction: Learning to Model What Matters (ICML 2020)
'''
import os
import cv2
import numpy as np
import plotly
from plotly.graph_objs import Scatter
from plotly.graph_objs.scatter import Line
import math
import torch
# Plots min, max and mean + standard deviation bars of a population over time
def lineplot(xs, ys_population, title, path='', xaxis='episode'):
max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)'
if isinstance(ys_population[0], list) or isinstance(
ys_population[0], tuple):
ys = np.asarray(ys_population, dtype=np.float32)
ys_min, ys_max, ys_mean, ys_std, ys_median = ys.min(1), ys.max(
1), ys.mean(1), ys.std(1), np.median(ys, 1)
ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std
trace_max = Scatter(
x=xs,
y=ys_max,
line=Line(color=max_colour, dash='dash'),
name='Max')
trace_upper = Scatter(
x=xs,
y=ys_upper,
line=Line(color=transparent),
name='+1 Std. Dev.',
showlegend=False)
trace_mean = Scatter(
x=xs,
y=ys_mean,
fill='tonexty',
fillcolor=std_colour,
line=Line(color=mean_colour),
name='Mean')
trace_lower = Scatter(
x=xs,
y=ys_lower,
fill='tonexty',
fillcolor=std_colour,
line=Line(color=transparent),
name='-1 Std. Dev.',
showlegend=False)
trace_min = Scatter(
x=xs,
y=ys_min,
line=Line(color=max_colour, dash='dash'),
name='Min')
trace_median = Scatter(
x=xs, y=ys_median, line=Line(color=max_colour), name='Median')
data = [
trace_upper, trace_mean, trace_lower, trace_min, trace_max,
trace_median
]
else:
data = [Scatter(x=xs, y=ys_population, line=Line(color=mean_colour))]
plotly.offline.plot(
{
'data':
data,
'layout':
dict(
title=title,
xaxis={'title': xaxis},
yaxis={'title': title})
},
filename=os.path.join(path, title + '.html'),
auto_open=False)
def write_video(frames, title, path=''):
frames = np.multiply(np.stack(frames, axis=0).transpose(
0, 2, 3, 1), 255).clip(0, 255).astype(
np.uint8)[:, :, :, ::-1] # VideoWrite expects H x W x C in BGR
_, H, W, _ = frames.shape
writer = cv2.VideoWriter(
os.path.join(path, '%s.mp4' % title), cv2.VideoWriter_fourcc(*'mp4v'),
30., (W, H), True)
for frame in frames:
writer.write(frame)
writer.release()
def create_log_gaussian(mean, log_std, t):
quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
l = mean.shape
log_z = log_std
z = l[-1] * math.log(2 * math.pi)
log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
return log_p
def logsumexp(inputs, dim=None, keepdim=False):
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) +
param.data * tau)
def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)