PyTorch梯度下降反向传播
前言:
反向传播的目的是计算成本函数C对网络中任意w或b的偏导数。一旦我们有了这些偏导数,我们将通过一些常数 α的乘积和该数量相对于成本函数的偏导数来更新网络中的权重和偏差。这是流行的梯度下降算法。而偏导数给出了最大上升的方向。因此,关于反向传播算法,我们继续查看下文。
我们向相反的方向迈出了一小步——最大下降的方向,也就是将我们带到成本函数的局部最小值的方向
如题:
意思是利用这个二次模型来预测数据,减小损失函数(MSE)的值。
代码如下:
import torch import matplotlib.pyplot as plt import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # 数据集 x_data = [1.0,2.0,3.0] y_data = [2.0,4.0,6.0] # 权重参数初始值均为1 w = torch.tensor([1.0,1.0,1.0]) w.requires_grad = True # 需要计算梯度 # 前向传播 def forward(x): return w[0]*(x**2)+w[1]*x+w[2] # 计算损失 def loss(x,y): y_pred = forward(x) return (y_pred-y) ** 2 # 训练模块 print('predict (before tranining) ',4, forward(4).item()) epoch_list = [] w_list = [] loss_list = [] for epoch in range(1000): for x,y in zip(x_data,y_data): l = loss(x,y) l.backward() # 后向传播 print('\tgrad: ',x,y,w.grad.data) w.data = w.data - 0.01 * w.grad.data # 梯度下降 w.grad.data.zero_() # 梯度清零操作 print('progress: ',epoch,l.item()) epoch_list.append(epoch) w_list.append(w.data) loss_list.append(l.item()) print('predict (after tranining) ',4, forward(4).item()) # 绘图 plt.plot(epoch_list,loss_list,'b') plt.xlabel('Epoch') plt.ylabel('Loss') plt.grid() plt.show()
结果如下:
predict (before tranining) 4 21.0 grad: 1.0 2.0 tensor([2., 2., 2.]) grad: 2.0 4.0 tensor([22.8800, 11.4400, 5.7200]) grad: 3.0 6.0 tensor([77.0472, 25.6824, 8.5608]) progress: 0 18.321826934814453 grad: 1.0 2.0 tensor([-1.1466, -1.1466, -1.1466]) grad: 2.0 4.0 tensor([-15.5367, -7.7683, -3.8842]) grad: 3.0 6.0 tensor([-30.4322, -10.1441, -3.3814]) progress: 1 2.858394145965576 grad: 1.0 2.0 tensor([0.3451, 0.3451, 0.3451]) grad: 2.0 4.0 tensor([2.4273, 1.2137, 0.6068]) grad: 3.0 6.0 tensor([19.4499, 6.4833, 2.1611]) progress: 2 1.1675907373428345 grad: 1.0 2.0 tensor([-0.3224, -0.3224, -0.3224]) grad: 2.0 4.0 tensor([-5.8458, -2.9229, -1.4614]) grad: 3.0 6.0 tensor([-3.8829, -1.2943, -0.4314]) progress: 3 0.04653334245085716 grad: 1.0 2.0 tensor([0.0137, 0.0137, 0.0137]) grad: 2.0 4.0 tensor([-1.9141, -0.9570, -0.4785]) grad: 3.0 6.0 tensor([6.8557, 2.2852, 0.7617]) progress: 4 0.14506366848945618 grad: 1.0 2.0 tensor([-0.1182, -0.1182, -0.1182]) grad: 2.0 4.0 tensor([-3.6644, -1.8322, -0.9161]) grad: 3.0 6.0 tensor([1.7455, 0.5818, 0.1939]) progress: 5 0.009403289295732975 grad: 1.0 2.0 tensor([-0.0333, -0.0333, -0.0333]) grad: 2.0 4.0 tensor([-2.7739, -1.3869, -0.6935]) grad: 3.0 6.0 tensor([4.0140, 1.3380, 0.4460]) progress: 6 0.04972923547029495 grad: 1.0 2.0 tensor([-0.0501, -0.0501, -0.0501]) grad: 2.0 4.0 tensor([-3.1150, -1.5575, -0.7788]) grad: 3.0 6.0 tensor([2.8534, 0.9511, 0.3170]) progress: 7 0.025129113346338272 grad: 1.0 2.0 tensor([-0.0205, -0.0205, -0.0205]) grad: 2.0 4.0 tensor([-2.8858, -1.4429, -0.7215]) grad: 3.0 6.0 tensor([3.2924, 1.0975, 0.3658]) progress: 8 0.03345605731010437 grad: 1.0 2.0 tensor([-0.0134, -0.0134, -0.0134]) grad: 2.0 4.0 tensor([-2.9247, -1.4623, -0.7312]) grad: 3.0 6.0 tensor([2.9909, 0.9970, 0.3323]) progress: 9 0.027609655633568764 grad: 1.0 2.0 tensor([0.0033, 0.0033, 0.0033]) grad: 2.0 4.0 tensor([-2.8414, -1.4207, -0.7103]) grad: 3.0 6.0 tensor([3.0377, 1.0126, 0.3375]) progress: 10 0.02848036028444767 grad: 1.0 2.0 tensor([0.0148, 0.0148, 0.0148]) grad: 2.0 4.0 tensor([-2.8174, -1.4087, -0.7043]) grad: 3.0 6.0 tensor([2.9260, 0.9753, 0.3251]) progress: 11 0.02642466314136982 grad: 1.0 2.0 tensor([0.0280, 0.0280, 0.0280]) grad: 2.0 4.0 tensor([-2.7682, -1.3841, -0.6920]) grad: 3.0 6.0 tensor([2.8915, 0.9638, 0.3213]) progress: 12 0.025804826989769936 grad: 1.0 2.0 tensor([0.0397, 0.0397, 0.0397]) grad: 2.0 4.0 tensor([-2.7330, -1.3665, -0.6832]) grad: 3.0 6.0 tensor([2.8243, 0.9414, 0.3138]) progress: 13 0.02462013065814972 grad: 1.0 2.0 tensor([0.0514, 0.0514, 0.0514]) grad: 2.0 4.0 tensor([-2.6934, -1.3467, -0.6734]) grad: 3.0 6.0 tensor([2.7756, 0.9252, 0.3084]) progress: 14 0.023777369409799576 grad: 1.0 2.0 tensor([0.0624, 0.0624, 0.0624]) grad: 2.0 4.0 tensor([-2.6580, -1.3290, -0.6645]) grad: 3.0 6.0 tensor([2.7213, 0.9071, 0.3024]) progress: 15 0.0228563379496336 grad: 1.0 2.0 tensor([0.0731, 0.0731, 0.0731]) grad: 2.0 4.0 tensor([-2.6227, -1.3113, -0.6557]) grad: 3.0 6.0 tensor([2.6725, 0.8908, 0.2969]) progress: 16 0.022044027224183083 grad: 1.0 2.0 tensor([0.0833, 0.0833, 0.0833]) grad: 2.0 4.0 tensor([-2.5893, -1.2946, -0.6473]) grad: 3.0 6.0 tensor([2.6240, 0.8747, 0.2916]) progress: 17 0.02125072106719017 grad: 1.0 2.0 tensor([0.0931, 0.0931, 0.0931]) grad: 2.0 4.0 tensor([-2.5568, -1.2784, -0.6392]) grad: 3.0 6.0 tensor([2.5780, 0.8593, 0.2864]) progress: 18 0.020513182505965233 grad: 1.0 2.0 tensor([0.1025, 0.1025, 0.1025]) grad: 2.0 4.0 tensor([-2.5258, -1.2629, -0.6314]) grad: 3.0 6.0 tensor([2.5335, 0.8445, 0.2815]) progress: 19 0.019810274243354797 grad: 1.0 2.0 tensor([0.1116, 0.1116, 0.1116]) grad: 2.0 4.0 tensor([-2.4958, -1.2479, -0.6239]) grad: 3.0 6.0 tensor([2.4908, 0.8303, 0.2768]) progress: 20 0.019148115068674088 grad: 1.0 2.0 tensor([0.1203, 0.1203, 0.1203]) grad: 2.0 4.0 tensor([-2.4669, -1.2335, -0.6167]) grad: 3.0 6.0 tensor([2.4496, 0.8165, 0.2722]) progress: 21 0.018520694226026535 grad: 1.0 2.0 tensor([0.1286, 0.1286, 0.1286]) grad: 2.0 4.0 tensor([-2.4392, -1.2196, -0.6098]) grad: 3.0 6.0 tensor([2.4101, 0.8034, 0.2678]) progress: 22 0.017927465960383415 grad: 1.0 2.0 tensor([0.1367, 0.1367, 0.1367]) grad: 2.0 4.0 tensor([-2.4124, -1.2062, -0.6031]) grad: 3.0 6.0 tensor([2.3720, 0.7907, 0.2636]) progress: 23 0.01736525259912014 grad: 1.0 2.0 tensor([0.1444, 0.1444, 0.1444]) grad: 2.0 4.0 tensor([-2.3867, -1.1933, -0.5967]) grad: 3.0 6.0 tensor([2.3354, 0.7785, 0.2595]) progress: 24 0.016833148896694183 grad: 1.0 2.0 tensor([0.1518, 0.1518, 0.1518]) grad: 2.0 4.0 tensor([-2.3619, -1.1810, -0.5905]) grad: 3.0 6.0 tensor([2.3001, 0.7667, 0.2556]) progress: 25 0.01632905937731266 grad: 1.0 2.0 tensor([0.1589, 0.1589, 0.1589]) grad: 2.0 4.0 tensor([-2.3380, -1.1690, -0.5845]) grad: 3.0 6.0 tensor([2.2662, 0.7554, 0.2518]) progress: 26 0.01585075818002224 grad: 1.0 2.0 tensor([0.1657, 0.1657, 0.1657]) grad: 2.0 4.0 tensor([-2.3151, -1.1575, -0.5788]) grad: 3.0 6.0 tensor([2.2336, 0.7445, 0.2482]) progress: 27 0.015397666022181511 grad: 1.0 2.0 tensor([0.1723, 0.1723, 0.1723]) grad: 2.0 4.0 tensor([-2.2929, -1.1465, -0.5732]) grad: 3.0 6.0 tensor([2.2022, 0.7341, 0.2447]) progress: 28 0.014967591501772404 grad: 1.0 2.0 tensor([0.1786, 0.1786, 0.1786]) grad: 2.0 4.0 tensor([-2.2716, -1.1358, -0.5679]) grad: 3.0 6.0 tensor([2.1719, 0.7240, 0.2413]) progress: 29 0.014559715054929256 grad: 1.0 2.0 tensor([0.1846, 0.1846, 0.1846]) grad: 2.0 4.0 tensor([-2.2511, -1.1255, -0.5628]) grad: 3.0 6.0 tensor([2.1429, 0.7143, 0.2381]) progress: 30 0.014172340743243694 grad: 1.0 2.0 tensor([0.1904, 0.1904, 0.1904]) grad: 2.0 4.0 tensor([-2.2313, -1.1157, -0.5578]) grad: 3.0 6.0 tensor([2.1149, 0.7050, 0.2350]) progress: 31 0.013804304413497448 grad: 1.0 2.0 tensor([0.1960, 0.1960, 0.1960]) grad: 2.0 4.0 tensor([-2.2123, -1.1061, -0.5531]) grad: 3.0 6.0 tensor([2.0879, 0.6960, 0.2320]) progress: 32 0.013455045409500599 grad: 1.0 2.0 tensor([0.2014, 0.2014, 0.2014]) grad: 2.0 4.0 tensor([-2.1939, -1.0970, -0.5485]) grad: 3.0 6.0 tensor([2.0620, 0.6873, 0.2291]) progress: 33 0.013122711330652237 grad: 1.0 2.0 tensor([0.2065, 0.2065, 0.2065]) grad: 2.0 4.0 tensor([-2.1763, -1.0881, -0.5441]) grad: 3.0 6.0 tensor([2.0370, 0.6790, 0.2263]) progress: 34 0.01280694268643856 grad: 1.0 2.0 tensor([0.2114, 0.2114, 0.2114]) grad: 2.0 4.0 tensor([-2.1592, -1.0796, -0.5398]) grad: 3.0 6.0 tensor([2.0130, 0.6710, 0.2237]) progress: 35 0.012506747618317604 grad: 1.0 2.0 tensor([0.2162, 0.2162, 0.2162]) grad: 2.0 4.0 tensor([-2.1428, -1.0714, -0.5357]) grad: 3.0 6.0 tensor([1.9899, 0.6633, 0.2211]) progress: 36 0.012220758944749832 grad: 1.0 2.0 tensor([0.2207, 0.2207, 0.2207]) grad: 2.0 4.0 tensor([-2.1270, -1.0635, -0.5317]) grad: 3.0 6.0 tensor([1.9676, 0.6559, 0.2186]) progress: 37 0.01194891706109047 grad: 1.0 2.0 tensor([0.2251, 0.2251, 0.2251]) grad: 2.0 4.0 tensor([-2.1118, -1.0559, -0.5279]) grad: 3.0 6.0 tensor([1.9462, 0.6487, 0.2162]) progress: 38 0.011689926497638226 grad: 1.0 2.0 tensor([0.2292, 0.2292, 0.2292]) grad: 2.0 4.0 tensor([-2.0971, -1.0485, -0.5243]) grad: 3.0 6.0 tensor([1.9255, 0.6418, 0.2139]) progress: 39 0.01144315768033266 grad: 1.0 2.0 tensor([0.2333, 0.2333, 0.2333]) grad: 2.0 4.0 tensor([-2.0829, -1.0414, -0.5207]) grad: 3.0 6.0 tensor([1.9057, 0.6352, 0.2117]) progress: 40 0.011208509095013142 grad: 1.0 2.0 tensor([0.2371, 0.2371, 0.2371]) grad: 2.0 4.0 tensor([-2.0693, -1.0346, -0.5173]) grad: 3.0 6.0 tensor([1.8865, 0.6288, 0.2096]) progress: 41 0.0109840864315629 grad: 1.0 2.0 tensor([0.2408, 0.2408, 0.2408]) grad: 2.0 4.0 tensor([-2.0561, -1.0280, -0.5140]) grad: 3.0 6.0 tensor([1.8681, 0.6227, 0.2076]) progress: 42 0.010770938359200954 grad: 1.0 2.0 tensor([0.2444, 0.2444, 0.2444]) grad: 2.0 4.0 tensor([-2.0434, -1.0217, -0.5108]) grad: 3.0 6.0 tensor([1.8503, 0.6168, 0.2056]) progress: 43 0.010566935874521732 grad: 1.0 2.0 tensor([0.2478, 0.2478, 0.2478]) grad: 2.0 4.0 tensor([-2.0312, -1.0156, -0.5078]) grad: 3.0 6.0 tensor([1.8332, 0.6111, 0.2037]) progress: 44 0.010372749529778957 grad: 1.0 2.0 tensor([0.2510, 0.2510, 0.2510]) grad: 2.0 4.0 tensor([-2.0194, -1.0097, -0.5048]) grad: 3.0 6.0 tensor([1.8168, 0.6056, 0.2019]) progress: 45 0.010187389329075813 grad: 1.0 2.0 tensor([0.2542, 0.2542, 0.2542]) grad: 2.0 4.0 tensor([-2.0080, -1.0040, -0.5020]) grad: 3.0 6.0 tensor([1.8009, 0.6003, 0.2001]) progress: 46 0.010010283440351486 grad: 1.0 2.0 tensor([0.2572, 0.2572, 0.2572]) grad: 2.0 4.0 tensor([-1.9970, -0.9985, -0.4992]) grad: 3.0 6.0 tensor([1.7856, 0.5952, 0.1984]) progress: 47 0.00984097272157669 grad: 1.0 2.0 tensor([0.2600, 0.2600, 0.2600]) grad: 2.0 4.0 tensor([-1.9864, -0.9932, -0.4966]) grad: 3.0 6.0 tensor([1.7709, 0.5903, 0.1968]) progress: 48 0.009679674170911312 grad: 1.0 2.0 tensor([0.2628, 0.2628, 0.2628]) grad: 2.0 4.0 tensor([-1.9762, -0.9881, -0.4940]) grad: 3.0 6.0 tensor([1.7568, 0.5856, 0.1952]) progress: 49 0.009525291621685028 grad: 1.0 2.0 tensor([0.2655, 0.2655, 0.2655]) grad: 2.0 4.0 tensor([-1.9663, -0.9832, -0.4916]) grad: 3.0 6.0 tensor([1.7431, 0.5810, 0.1937]) progress: 50 0.00937769003212452 grad: 1.0 2.0 tensor([0.2680, 0.2680, 0.2680]) grad: 2.0 4.0 tensor([-1.9568, -0.9784, -0.4892]) grad: 3.0 6.0 tensor([1.7299, 0.5766, 0.1922]) progress: 51 0.009236648678779602 grad: 1.0 2.0 tensor([0.2704, 0.2704, 0.2704]) grad: 2.0 4.0 tensor([-1.9476, -0.9738, -0.4869]) grad: 3.0 6.0 tensor([1.7172, 0.5724, 0.1908]) progress: 52 0.00910158734768629 grad: 1.0 2.0 tensor([0.2728, 0.2728, 0.2728]) grad: 2.0 4.0 tensor([-1.9387, -0.9694, -0.4847]) grad: 3.0 6.0 tensor([1.7050, 0.5683, 0.1894]) progress: 53 0.00897257961332798 grad: 1.0 2.0 tensor([0.2750, 0.2750, 0.2750]) grad: 2.0 4.0 tensor([-1.9301, -0.9651, -0.4825]) grad: 3.0 6.0 tensor([1.6932, 0.5644, 0.1881]) progress: 54 0.008848887868225574 grad: 1.0 2.0 tensor([0.2771, 0.2771, 0.2771]) grad: 2.0 4.0 tensor([-1.9219, -0.9609, -0.4805]) grad: 3.0 6.0 tensor([1.6819, 0.5606, 0.1869]) progress: 55 0.008730598725378513 grad: 1.0 2.0 tensor([0.2792, 0.2792, 0.2792]) grad: 2.0 4.0 tensor([-1.9139, -0.9569, -0.4785]) grad: 3.0 6.0 tensor([1.6709, 0.5570, 0.1857]) progress: 56 0.00861735362559557 grad: 1.0 2.0 tensor([0.2811, 0.2811, 0.2811]) grad: 2.0 4.0 tensor([-1.9062, -0.9531, -0.4765]) grad: 3.0 6.0 tensor([1.6604, 0.5535, 0.1845]) progress: 57 0.008508718572556973 grad: 1.0 2.0 tensor([0.2830, 0.2830, 0.2830]) grad: 2.0 4.0 tensor([-1.8987, -0.9493, -0.4747]) grad: 3.0 6.0 tensor([1.6502, 0.5501, 0.1834]) progress: 58 0.008404706604778767 grad: 1.0 2.0 tensor([0.2848, 0.2848, 0.2848]) grad: 2.0 4.0 tensor([-1.8915, -0.9457, -0.4729]) grad: 3.0 6.0 tensor([1.6404, 0.5468, 0.1823]) progress: 59 0.008305158466100693 grad: 1.0 2.0 tensor([0.2865, 0.2865, 0.2865]) grad: 2.0 4.0 tensor([-1.8845, -0.9423, -0.4711]) grad: 3.0 6.0 tensor([1.6309, 0.5436, 0.1812]) progress: 60 0.00820931326597929 grad: 1.0 2.0 tensor([0.2882, 0.2882, 0.2882]) grad: 2.0 4.0 tensor([-1.8778, -0.9389, -0.4694]) grad: 3.0 6.0 tensor([1.6218, 0.5406, 0.1802]) progress: 61 0.008117804303765297 grad: 1.0 2.0 tensor([0.2898, 0.2898, 0.2898]) grad: 2.0 4.0 tensor([-1.8713, -0.9356, -0.4678]) grad: 3.0 6.0 tensor([1.6130, 0.5377, 0.1792]) progress: 62 0.008029798977077007 grad: 1.0 2.0 tensor([0.2913, 0.2913, 0.2913]) grad: 2.0 4.0 tensor([-1.8650, -0.9325, -0.4662]) grad: 3.0 6.0 tensor([1.6045, 0.5348, 0.1783]) progress: 63 0.007945418357849121 grad: 1.0 2.0 tensor([0.2927, 0.2927, 0.2927]) grad: 2.0 4.0 tensor([-1.8589, -0.9294, -0.4647]) grad: 3.0 6.0 tensor([1.5962, 0.5321, 0.1774]) progress: 64 0.007864190265536308 grad: 1.0 2.0 tensor([0.2941, 0.2941, 0.2941]) grad: 2.0 4.0 tensor([-1.8530, -0.9265, -0.4632]) grad: 3.0 6.0 tensor([1.5884, 0.5295, 0.1765]) progress: 65 0.007786744274199009 grad: 1.0 2.0 tensor([0.2954, 0.2954, 0.2954]) grad: 2.0 4.0 tensor([-1.8473, -0.9236, -0.4618]) grad: 3.0 6.0 tensor([1.5807, 0.5269, 0.1756]) progress: 66 0.007711691781878471 grad: 1.0 2.0 tensor([0.2967, 0.2967, 0.2967]) grad: 2.0 4.0 tensor([-1.8417, -0.9209, -0.4604]) grad: 3.0 6.0 tensor([1.5733, 0.5244, 0.1748]) progress: 67 0.007640169933438301 grad: 1.0 2.0 tensor([0.2979, 0.2979, 0.2979]) grad: 2.0 4.0 tensor([-1.8364, -0.9182, -0.4591]) grad: 3.0 6.0 tensor([1.5662, 0.5221, 0.1740]) progress: 68 0.007570972666144371 grad: 1.0 2.0 tensor([0.2991, 0.2991, 0.2991]) grad: 2.0 4.0 tensor([-1.8312, -0.9156, -0.4578]) grad: 3.0 6.0 tensor([1.5593, 0.5198, 0.1733]) progress: 69 0.007504733745008707 grad: 1.0 2.0 tensor([0.3002, 0.3002, 0.3002]) grad: 2.0 4.0 tensor([-1.8262, -0.9131, -0.4566]) grad: 3.0 6.0 tensor([1.5527, 0.5176, 0.1725]) progress: 70 0.007440924644470215 grad: 1.0 2.0 tensor([0.3012, 0.3012, 0.3012]) grad: 2.0 4.0 tensor([-1.8214, -0.9107, -0.4553]) grad: 3.0 6.0 tensor([1.5463, 0.5154, 0.1718]) progress: 71 0.007379599846899509 grad: 1.0 2.0 tensor([0.3022, 0.3022, 0.3022]) grad: 2.0 4.0 tensor([-1.8167, -0.9083, -0.4542]) grad: 3.0 6.0 tensor([1.5401, 0.5134, 0.1711]) progress: 72 0.007320486940443516 grad: 1.0 2.0 tensor([0.3032, 0.3032, 0.3032]) grad: 2.0 4.0 tensor([-1.8121, -0.9060, -0.4530]) grad: 3.0 6.0 tensor([1.5341, 0.5114, 0.1705]) progress: 73 0.007263725157827139 grad: 1.0 2.0 tensor([0.3041, 0.3041, 0.3041]) grad: 2.0 4.0 tensor([-1.8077, -0.9038, -0.4519]) grad: 3.0 6.0 tensor([1.5283, 0.5094, 0.1698]) progress: 74 0.007209045812487602 grad: 1.0 2.0 tensor([0.3050, 0.3050, 0.3050]) grad: 2.0 4.0 tensor([-1.8034, -0.9017, -0.4508]) grad: 3.0 6.0 tensor([1.5227, 0.5076, 0.1692]) progress: 75 0.007156429346650839 grad: 1.0 2.0 tensor([0.3058, 0.3058, 0.3058]) grad: 2.0 4.0 tensor([-1.7992, -0.8996, -0.4498]) grad: 3.0 6.0 tensor([1.5173, 0.5058, 0.1686]) progress: 76 0.007105532102286816 grad: 1.0 2.0 tensor([0.3066, 0.3066, 0.3066]) grad: 2.0 4.0 tensor([-1.7952, -0.8976, -0.4488]) grad: 3.0 6.0 tensor([1.5121, 0.5040, 0.1680]) progress: 77 0.00705681974068284 grad: 1.0 2.0 tensor([0.3073, 0.3073, 0.3073]) grad: 2.0 4.0 tensor([-1.7913, -0.8956, -0.4478]) grad: 3.0 6.0 tensor([1.5070, 0.5023, 0.1674]) progress: 78 0.007009552326053381 grad: 1.0 2.0 tensor([0.3081, 0.3081, 0.3081]) grad: 2.0 4.0 tensor([-1.7875, -0.8937, -0.4469]) grad: 3.0 6.0 tensor([1.5021, 0.5007, 0.1669]) progress: 79 0.006964194122701883 grad: 1.0 2.0 tensor([0.3087, 0.3087, 0.3087]) grad: 2.0 4.0 tensor([-1.7838, -0.8919, -0.4459]) grad: 3.0 6.0 tensor([1.4974, 0.4991, 0.1664]) progress: 80 0.006920332089066505 grad: 1.0 2.0 tensor([0.3094, 0.3094, 0.3094]) grad: 2.0 4.0 tensor([-1.7802, -0.8901, -0.4450]) grad: 3.0 6.0 tensor([1.4928, 0.4976, 0.1659]) progress: 81 0.006878111511468887 grad: 1.0 2.0 tensor([0.3100, 0.3100, 0.3100]) grad: 2.0 4.0 tensor([-1.7767, -0.8883, -0.4442]) grad: 3.0 6.0 tensor([1.4884, 0.4961, 0.1654]) progress: 82 0.006837360095232725 grad: 1.0 2.0 tensor([0.3106, 0.3106, 0.3106]) grad: 2.0 4.0 tensor([-1.7733, -0.8867, -0.4433]) grad: 3.0 6.0 tensor([1.4841, 0.4947, 0.1649]) progress: 83 0.006797831039875746 grad: 1.0 2.0 tensor([0.3111, 0.3111, 0.3111]) grad: 2.0 4.0 tensor([-1.7700, -0.8850, -0.4425]) grad: 3.0 6.0 tensor([1.4800, 0.4933, 0.1644]) progress: 84 0.006760062649846077 grad: 1.0 2.0 tensor([0.3117, 0.3117, 0.3117]) grad: 2.0 4.0 tensor([-1.7668, -0.8834, -0.4417]) grad: 3.0 6.0 tensor([1.4759, 0.4920, 0.1640]) progress: 85 0.006723103579133749 grad: 1.0 2.0 tensor([0.3122, 0.3122, 0.3122]) grad: 2.0 4.0 tensor([-1.7637, -0.8818, -0.4409]) grad: 3.0 6.0 tensor([1.4720, 0.4907, 0.1636]) progress: 86 0.00668772729113698 grad: 1.0 2.0 tensor([0.3127, 0.3127, 0.3127]) grad: 2.0 4.0 tensor([-1.7607, -0.8803, -0.4402]) grad: 3.0 6.0 tensor([1.4682, 0.4894, 0.1631]) progress: 87 0.006653300020843744 grad: 1.0 2.0 tensor([0.3131, 0.3131, 0.3131]) grad: 2.0 4.0 tensor([-1.7577, -0.8789, -0.4394]) grad: 3.0 6.0 tensor([1.4646, 0.4882, 0.1627]) progress: 88 0.0066203586757183075 grad: 1.0 2.0 tensor([0.3135, 0.3135, 0.3135]) grad: 2.0 4.0 tensor([-1.7548, -0.8774, -0.4387]) grad: 3.0 6.0 tensor([1.4610, 0.4870, 0.1623]) progress: 89 0.0065881176851689816 grad: 1.0 2.0 tensor([0.3139, 0.3139, 0.3139]) grad: 2.0 4.0 tensor([-1.7520, -0.8760, -0.4380]) grad: 3.0 6.0 tensor([1.4576, 0.4859, 0.1620]) progress: 90 0.0065572685562074184 grad: 1.0 2.0 tensor([0.3143, 0.3143, 0.3143]) grad: 2.0 4.0 tensor([-1.7493, -0.8747, -0.4373]) grad: 3.0 6.0 tensor([1.4542, 0.4847, 0.1616]) progress: 91 0.0065271081402897835 grad: 1.0 2.0 tensor([0.3147, 0.3147, 0.3147]) grad: 2.0 4.0 tensor([-1.7466, -0.8733, -0.4367]) grad: 3.0 6.0 tensor([1.4510, 0.4837, 0.1612]) progress: 92 0.00649801641702652 grad: 1.0 2.0 tensor([0.3150, 0.3150, 0.3150]) grad: 2.0 4.0 tensor([-1.7441, -0.8720, -0.4360]) grad: 3.0 6.0 tensor([1.4478, 0.4826, 0.1609]) progress: 93 0.0064699104987084866 grad: 1.0 2.0 tensor([0.3153, 0.3153, 0.3153]) grad: 2.0 4.0 tensor([-1.7415, -0.8708, -0.4354]) grad: 3.0 6.0 tensor([1.4448, 0.4816, 0.1605]) progress: 94 0.006442630663514137 grad: 1.0 2.0 tensor([0.3156, 0.3156, 0.3156]) grad: 2.0 4.0 tensor([-1.7391, -0.8695, -0.4348]) grad: 3.0 6.0 tensor([1.4418, 0.4806, 0.1602]) progress: 95 0.006416172254830599 grad: 1.0 2.0 tensor([0.3159, 0.3159, 0.3159]) grad: 2.0 4.0 tensor([-1.7366, -0.8683, -0.4342]) grad: 3.0 6.0 tensor([1.4389, 0.4796, 0.1599]) progress: 96 0.006390606984496117 grad: 1.0 2.0 tensor([0.3161, 0.3161, 0.3161]) grad: 2.0 4.0 tensor([-1.7343, -0.8671, -0.4336]) grad: 3.0 6.0 tensor([1.4361, 0.4787, 0.1596]) progress: 97 0.0063657015562057495 grad: 1.0 2.0 tensor([0.3164, 0.3164, 0.3164]) grad: 2.0 4.0 tensor([-1.7320, -0.8660, -0.4330]) grad: 3.0 6.0 tensor([1.4334, 0.4778, 0.1593]) progress: 98 0.0063416799530386925 grad: 1.0 2.0 tensor([0.3166, 0.3166, 0.3166]) grad: 2.0 4.0 tensor([-1.7297, -0.8649, -0.4324]) grad: 3.0 6.0 tensor([1.4308, 0.4769, 0.1590]) progress: 99 0.00631808303296566 predict (after tranining) 4 8.544171333312988
损失值随着迭代次数的增加呈递减趋势,如下图所示:
可以看出:x=4时的预测值约为8.5,与真实值8有所差距,可通过提高迭代次数或者调整学习率、初始参数等方法来减小差距。
参考文献:
- [1] https://www.bilibili.com/video/av93365242
到此这篇关于PyTorch反向传播的文章就介绍到这了,更多相关PyTorch反向传播内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!
赞 (0)