pytorch使用gpu
安装环境
一般nvidia与cuda是安装好的,没好在管网查。
再就是安装anaconda,里面添加一个py38的环境吧。
最后安装pytorch不能像其他组件一样在pycharm里安装,得在这个页面,选相应版本安装。https://pytorch.org/get-started/locally/
比如我的。先conda activate py38
再用页面的pip命令安装。
应该选conda也行。
运行结果:
一块gpu。
代码:
import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib import torch.nn.functional as F from torchmetrics import R2Score from torchmetrics import MeanAbsoluteError from torchmetrics import MeanSquaredError import torch import os os.environ['KMP_DUPLICATE_LIB_OK']='True' print("cuda is available:"+str(torch.cuda.is_available())) print("gpu:"+str(torch.cuda.device_count())) # 检查是否有 GPU 支持 device = "cuda:0" if torch.cuda.is_available() else "cpu" print(matplotlib.matplotlib_fname()) matplotlib.rcParams['font.sans-serif'] = ['SimHei'] matplotlib.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题 #data_csv = pd.read_csv('data/rnn.spider.csv', usecols=[2]) data_csv = pd.read_csv('data/rnn_spider_9.csv', usecols=[2]) plt.plot(data_csv) #plt.show() # 数据预处理 data_csv = data_csv.dropna() #݄去掉 na dataset = data_csv.values olddataset = dataset.astype('float32') max_value = np.max(dataset) min_value = np.min(dataset) scalar = max_value - min_value dataset = list(map(lambda x: (x-min_value) / scalar, olddataset)) #dataset = list(map(lambda x: x / scalar, olddataset)) #X每两行形成一组,如[284,1425],[1425,2674],形成[0.1,0.5], #Y从下标2开始每行形成一组,如[2674]形成[1.0] def create_dataset(dataset, look_back=2): dataX, dataY = [], [] for i in range(len(dataset) - look_back): a = dataset[i:(i + look_back)] dataX.append(a) dataY.append(dataset[i + look_back]) return np.array(dataX), np.array(dataY) # 创建好输入输出 data_X, data_Y = create_dataset(dataset) # 划分训练集和测试集, 70% 作为训练集 train_size = int(len(data_X) * 0.7) test_size = len(data_X) - train_size train_X = data_X[:train_size] train_Y = data_Y[:train_size] test_X = data_X[train_size:] test_Y = data_Y[train_size:] train_max_value = np.max(train_X) train_min_value = np.min(train_X) # 三维数组reshape函数中 第一个参数2代表深度,第二个参数3代表行,第三个参数4代表列。 train_X = train_X.reshape(-1, 1, 2) train_Y = train_Y.reshape(-1, 1, 1) test_X = test_X.reshape(-1, 1, 2) train_x = torch.from_numpy(train_X).to(device) train_y = torch.from_numpy(train_Y).to(device) test_x = torch.from_numpy(test_X).to(device) from torch import nn from torch.autograd import Variable #torch.device('cpu') # 定义模型 class lstm_reg(nn.Module): def __init__(self, input_size, hidden_size, output_size=1, num_layers=2): super(lstm_reg, self).__init__() self.rnn = nn.LSTM(input_size, hidden_size, num_layers) # rnn self.reg = nn.Linear(hidden_size, output_size) #回归 def forward(self, x): x, _ = self.rnn(x) # (seq, batch, hidden) s, b, h = x.shape x = x.view(s*b, h) # 转换成线性层的输入格式 x = self.reg(x) x = x.view(s, b, -1) return x #2行,4个隐藏层 net = lstm_reg(2, 4).to(device) criterion = nn.MSELoss() #学习率真为0.01 optimizer = torch.optim.Adam(net.parameters(), lr=0.01) print(net) # 开始训练 for e in range(1000): var_x = Variable(train_x) var_y = Variable(train_y) # 前向传播 out = net(var_x) loss = criterion(out, var_y) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if (e + 1) % 100 == 0: #每 100 次输出结果 print('Epoch: {}, Loss: {:.5f}'.format(e + 1, loss.data)) net = net.eval() # 转换成测试模式 data_X = data_X.reshape(-1, 1, 2) data_X = torch.from_numpy(data_X).to(device) var_data = Variable(data_X) pred_test = net(var_data) #测试集的预测结果 if device!="cpu": pred_test=pred_test.cpu() # 改变输出的格式 pred_test = pred_test.view(-1).data.numpy() #print_pred_test=net(torch.tensor([[0.8,0.6]]).unsqueeze(0).float()) #print(print_pred_test) #print(print_pred_test*(max_value-min_value)) #print(olddataset*(max_value-min_value)) # 画出实际结果和预测的结果 plt.plot(pred_test, 'r', label='预测值') plt.plot(dataset, 'b', label='真实值') plt.legend(loc='best') plt.show() # 计算R2指标 r2 = R2Score() realData=[] for sublist in dataset: # 再次使用for循环遍历子列表中的元素 for element in sublist: realData.append(element) r2(torch.as_tensor(np.array(pred_test)), torch.as_tensor(np.array(realData[:-2]))) # 输出R2指标值 print("R2 score:", r2.compute()) mae=MeanAbsoluteError() mae(torch.as_tensor(np.array(pred_test)), torch.as_tensor(np.array(realData[:-2]))) print("mae score:", mae.compute()) mse=MeanSquaredError() mse(torch.as_tensor(np.array(pred_test)), torch.as_tensor(np.array(realData[:-2]))) print("mse score:", mse.compute()) print("rmse score:", torch.sqrt(mse.compute())) ''' mae = F.l1_loss( torch.tensor(pred_test), torch.tensor(dataset)) print(mae) mae = F.Rl1_loss( torch.tensor(pred_test), torch.tensor(dataset)) '''
数据
rnn_spider_9.csv
id,fileNum,totalTime/fileNum 683,1409,0.5862 682,865,0.6069 681,474,0.6371 678,482,0.5830 677,1385,0.7487 676,142,0.6127 675,795,0.6226 674,156,0.5577 669,61,0.5738 668,1139,0.6084 667,1346,0.6293 666,365,0.5890 653,56,0.6964 652,567,0.6155 651,624,0.6122 650,468,0.5876 649,670,0.6463 648,538,0.6041 645,257,0.5642 644,928,0.5636 643,169,0.6509 642,1078,0.5946 641,289,0.5952 640,167,0.5988 638,427,0.5738 637,669,0.5800 636,488,0.6434 635,1099,0.6497 634,930,0.5634 633,1401,0.6089 632,319,0.6395 631,406,0.7266 630,91,0.6264 629,23,0.7391 628,86,0.7209 625,23,1.4348 624,934,0.9550 623,180,0.6333 622,642,0.6495 621,717,0.6388 620,246,0.5854 619,182,0.6538 618,66,0.6212 613,149,0.5973 612,417,0.5779 611,1228,0.6466 610,499,0.6453 609,712,0.6236 598,881,1.1975 597,1123,0.6438 596,932,0.6148 591,600,0.6233 590,1062,0.6507 589,714,0.6148 588,667,1.2309 583,182,1.4505 582,799,0.6671 581,499,0.6814 580,530,0.6377 579,848,0.7028 572,480,0.6667 571,1288,0.6382 570,808,0.6238 569,183,0.6393 568,838,0.6337 564,66,0.6061 563,669,0.6368 562,371,0.6792 561,68,0.5735 560,363,0.6556 559,968,0.6312 558,470,0.6447 557,672,0.6429 554,326,0.6472 553,178,0.6404 552,916,0.6354 551,155,0.7548 550,216,0.6250 549,632,0.6392 548,518,0.6622 541,33,0.6364 540,766,0.6397 539,99,0.6162 538,1413,0.6369 537,665,0.6436 527,661,0.6399 526,1496,0.6310 525,428,0.6285 524,403,0.6551 519,57,0.6667 518,1397,0.6464 517,259,0.6873 516,136,0.6838 515,35,0.6286 514,597,0.6683 513,27,0.7037 509,703,0.6358 508,112,0.6071 507,812,0.5924 506,1362,0.5675 505,544,0.5993 504,1347,0.6288 503,989,0.6643 499,886,0.6716 498,162,0.7716 497,1120,0.6625 493,613,0.7015 487,383,0.6240 486,1390,0.6417 485,119,0.6303 484,879,0.6667 483,298,0.6477 482,385,0.6494 480,92,0.6522 479,1193,0.6438 478,1075,0.6335 477,775,0.6284 476,489,0.6380 463,206,0.6359 462,1316,0.6071 461,247,0.5870 460,36,0.5556 459,342,0.6374 458,56,0.6786 457,497,0.6318 456,329,0.6383 446,847,0.6198 445,1184,0.6453 444,304,0.6875 443,327,0.5688 442,311,0.5145 436,158,0.5443 435,493,0.5659 434,1462,0.6005 433,944,0.6356 432,687,0.6317 431,427,0.6745 430,579,0.6269 429,32,0.6250 428,479,0.7035 422,632,0.6598 421,1384,0.6575 420,94,0.6489 419,777,0.6499 414,605,0.7025 413,1148,0.6516 412,435,0.6437 411,674,0.6899 407,732,0.6325 406,1456,0.6518 405,221,0.6968 404,545,0.6312 402,381,0.6220 401,1430,0.6245 400,338,0.6243 399,455,0.6813 398,355,0.6620 393,978,0.6196 392,561,0.6132 391,163,0.5951 390,893,0.5980 389,375,0.6453 388,90,0.5889 383,220,0.6682 382,689,0.6488 381,970,0.6227 380,1058,0.6172 376,788,0.6345 375,333,0.5796 374,52,0.6731 373,741,0.6478 372,357,0.6022 371,670,0.6373 365,851,0.6428 364,976,0.8145 363,852,0.6631 361,387,0.6098 360,274,0.6460 355,878,0.6378 354,825,0.6255 353,897,0.6410 352,407,0.6536 351,1193,0.6203 350,581,0.6368 349,433,0.6005 348,229,0.5764 344,261,0.6322 343,1076,0.6747 342,11,8.8182 341,1262,0.6458 340,892,0.6670 339,444,0.6486 338,810,0.6802 337,544,0.6507 336,393,0.6616 335,98,0.7347 330,1140,0.6623 329,1089,0.6804 328,266,0.6579 327,720,0.6306 326,320,0.6688 325,1291,0.6414 324,268,0.7127 323,485,0.6165 322,273,0.6374 316,193,0.5389 315,493,0.6085 314,918,0.6133 313,1224,0.6242 312,116,0.6293 310,823,0.7400 309,60,0.6500 308,167,0.7425 307,190,3.7895 306,141,0.7021 305,362,0.7182 304,93,0.6774 302,98,0.7755 301,333,0.7327 300,412,0.7112 299,382,0.7487 298,636,0.7327 297,933,0.7256 296,149,0.7517 292,1118,0.7504 291,1183,1.5097 290,344,0.7587 289,217,0.8203 279,620,0.7613 278,450,0.7156 277,1253,0.7007 276,722,0.6856 271,820,0.7085 270,967,0.7787 269,1120,0.7429 268,127,0.7480 261,388,0.7964 260,997,0.7593 259,960,0.8063 258,573,0.7243 256,179,0.6927 255,1472,0.7310 254,757,0.7596 253,503,0.7614 249,805,0.7565 248,739,0.7442 247,1323,0.7385 238,1121,0.7101 237,1066,0.7242 236,137,0.6788 235,625,0.7728 229,291,0.7320 228,109,0.7615 227,824,0.6917 226,1115,0.7031 225,650,0.7215 224,521,0.7179 223,1131,0.6932 222,562,0.6726 217,860,0.6756 216,1386,0.6710 215,800,0.6738 208,451,0.7095 207,651,0.6390 206,330,0.6394 205,391,0.6726 204,320,1.5750 203,800,0.6475 197,277,0.6354 196,305,0.6230 195,337,0.6202 194,558,1.1219 193,1008,0.6250 192,125,0.7200 188,566,0.6484 187,481,0.6258 186,335,0.6388 185,1231,0.6255 184,1027,0.6475 183,505,0.6277 182,754,0.6737 181,820,0.6134 180,325,0.6431 179,643,0.6174 173,293,0.6348 162,612,1.1225 161,504,1.1905 160,1554,1.1782 159,1661,1.0108 158,266,0.8609 157,1155,0.8970 156,337,0.9644 155,1681,0.9590 154,1827,0.8752 153,1257,0.8815 152,1452,0.8333 151,936,0.8665 150,1159,0.9094 149,1616,0.9462 148,581,0.8090 147,965,0.8435 146,671,1.0313 145,1079,1.0074 144,67,1.0597 142,116,1.0862 141,1084,1.3782 140,332,1.0693 139,1406,1.0462 138,1866,1.0263 137,555,1.0198 136,1551,1.0393 135,1395,1.0975 134,500,0.9920 133,1143,1.0647 132,232,1.1422 131,764,1.0681 130,618,1.1165 129,755,1.0477 128,258,1.1395 127,1319,1.0447 126,464,1.2823 124,254,1.2362 123,1296,1.1566 122,1077,1.2228 121,291,1.3299 119,1031,1.1736 117,1849,1.4462 116,1690,1.1438 115,1254,1.0654 114,1418,1.0755 113,1052,1.2376 112,702,1.1339
相关阅读
评论:
↓ 广告开始-头部带绿为生活 ↓
↑ 广告结束-尾部支持多点击 ↑