样例-快速入门Pytorch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
import torch.optim as optim

class modellogit(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.fc1 = nn.Linear(2,16)
self.fc2 = nn.Linear(16,8)
self.fc3 = nn.Linear(8,1)

def forward(self,x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x

class hello_model(object):
def __init__(self,max_iter=200,learning_rate=0.01,l1=1e-4):
self.model = modellogit()
self.max_iter = max_iter
self.learning_rate = learning_rate
self.l1 = l1

def fit(self,train_x,train_y):
optimizer = optim.Adam(self.model.parameters(),lr = self.learning_rate)
criterion = nn.MSELoss()
print('训练开始')
for epoch in range(self.max_iter):
input_x = torch.from_numpy(train_x).float()
target = torch.from_numpy(train_y).float()
#梯度置零
optimizer.zero_grad()
#正向传播
output = self.model(input_x)
#反向传播
loss = criterion(output, target)
regular_loss = 0
for param in self.model.parameters():
regular_loss += torch.sum(torch.abs(param))
loss += self.l1*regular_loss
loss.backward()
#优化
optimizer.step()
if epoch%int(self.max_iter/5) == 0:
print('[%d, %5d] loss: %.3f'%(epoch+1, epoch+1, loss.data))
print('训练结束')
def predict(self,test_x):
y_hat = self.model(torch.from_numpy(test_x).float())
y_hat = y_hat.detach()
y_hat = y_hat.numpy()
return(y_hat)

data_x = np.random.normal(size = (30000,2))
data_y = data_x[:,[0]] + data_x[:,[0]]**2 + data_x[:,[1]]**2+\
data_x[:,[1]]*4+np.random.normal(size = (30000,1))*0.1
train_x,test_x,train_y,test_y = train_test_split(data_x,data_y,test_size=0.33, random_state=42)

model = hello_model(max_iter=5000,learning_rate=0.001)
model.fit(train_x,train_y)
model.predict(test_x)
作者

周江峰

发布于

2022-06-16

更新于

2022-06-22

许可协议

You need to set install_url to use ShareThis. Please set it in _config.yml.
You forgot to set the business or currency_code for Paypal. Please set it in _config.yml.

评论

You forgot to set the shortname for Disqus. Please set it in _config.yml.