From 1e1b6d3291aaf7b63be478bc50a441dff66f7a88 Mon Sep 17 00:00:00 2001 From: Harold-Ran <56714856+Harold-Ran@users.noreply.github.com> Date: Sat, 4 Dec 2021 20:39:47 +0800 Subject: [PATCH] =?UTF-8?q?Delete=20Task5=20=E6=A8=A1=E5=9E=8B=E5=BB=BA?= =?UTF-8?q?=E7=AB=8B=E4=B9=8BSA-ConvLSTM.ipynb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Task5/Task5 模型建立之SA-ConvLSTM.ipynb | 1818 ----------------------- 1 file changed, 1818 deletions(-) delete mode 100644 Task5/Task5 模型建立之SA-ConvLSTM.ipynb diff --git a/Task5/Task5 模型建立之SA-ConvLSTM.ipynb b/Task5/Task5 模型建立之SA-ConvLSTM.ipynb deleted file mode 100644 index 0117872..0000000 --- a/Task5/Task5 模型建立之SA-ConvLSTM.ipynb +++ /dev/null @@ -1,1818 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Datawhale 气象海洋预测-Task5 模型建立之 SA-ConvLSTM\n", - "\n", - "本次任务我们将学习来自TOP选手“吴先生的队伍”的建模方案,该方案中采用的模型是SA-ConvLSTM。\n", - "\n", - "前两个TOP方案中选择把赛题当作一个多输出的任务,通过构建神经网络直接输出24个nino3.4预测值,这种方案的问题在于,序列问题往往是时序依赖的,当我们采用多输出的方法时其实把这24个nino3.4预测值看作是完全独立的,但是实际上它们之间是存在序列依赖的,即每个预测值往往受上一个时间步的预测值的影响。因此,在这次的TOP方案中,采用Seq2Seq结构来考虑输出预测值的序列依赖性。\n", - "\n", - "Seq2Seq结构包括Encoder(编码器)和Decoder(解码器)两部分,Encoder部分将输入序列编码成一个向量,Decoder部分对向量进行解码,输出一个预测序列。要将Seq2Seq结构应用于不同的序列问题,关键在于每一个时间步所使用的Cell。我们之前说到,挖掘空间信息通常会采用CNN,挖掘时间信息通常会采用RNN或LSTM,将二者结合在一起构成一个Cell就得到了时空序列领域的经典模型——ConvLSTM,我们本次要学习的TOP方案中使用的SA-ConvLSTM模型是对ConvLSTM模型的改进,在其基础上增加了自注意力机制来提高模型对于长期空间依赖关系的挖掘能力。\n", - "\n", - "另外与前两个TOP方案所不同的一点是,该TOP方案并不直接预测未来24个月的nino3.4指数,而是预测未来26个月的sst值,nino3.4指数由当前时刻起连续三个月的sst值求平均得到。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 学习目标\n", - "1. 学习TOP方案的模型构建方法" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 内容介绍\n", - "1. 数据处理\n", - " - 数据扁平化\n", - " - 空值填充\n", - " - 构造数据集\n", - "2. 模型构建\n", - " - 构造评估函数\n", - " - 模型构造\n", - " - 模型训练\n", - " - 模型评估\n", - "3. 总结" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 代码示例" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 数据处理\n", - "该TOP方案的数据处理主要包括三部分:\n", - "1. 数据扁平化。\n", - "2. 空值填充。\n", - "3. 构造数据集" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:04:34.698663Z", - "iopub.status.busy": "2021-11-29T03:04:34.697133Z", - "iopub.status.idle": "2021-11-29T03:04:37.035400Z", - "shell.execute_reply": "2021-11-29T03:04:37.034767Z", - "shell.execute_reply.started": "2021-11-29T01:02:51.883602Z" - }, - "papermill": { - "duration": 2.370278, - "end_time": "2021-11-29T03:04:37.035673", - "exception": false, - "start_time": "2021-11-29T03:04:34.665395", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import netCDF4 as nc\n", - "import random\n", - "import os\n", - "from tqdm import tqdm\n", - "import pandas as pd\n", - "import numpy as np\n", - "import math\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n", - "\n", - "import torch\n", - "from torch import nn, optim\n", - "import torch.nn.functional as F\n", - "from torch.utils.data import Dataset, DataLoader\n", - "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", - "\n", - "from sklearn.metrics import mean_squared_error" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:04:37.102995Z", - "iopub.status.busy": "2021-11-29T03:04:37.102144Z", - "iopub.status.idle": "2021-11-29T03:04:37.107646Z", - "shell.execute_reply": "2021-11-29T03:04:37.107161Z", - "shell.execute_reply.started": "2021-11-29T01:02:54.06493Z" - }, - "papermill": { - "duration": 0.040737, - "end_time": "2021-11-29T03:04:37.107761", - "exception": false, - "start_time": "2021-11-29T03:04:37.067024", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 固定随机种子\n", - "SEED = 22\n", - "\n", - "def seed_everything(seed=42):\n", - " random.seed(seed)\n", - " os.environ['PYTHONHASHSEED'] = str(seed)\n", - " np.random.seed(seed)\n", - " torch.manual_seed(seed)\n", - " torch.cuda.manual_seed(seed)\n", - " torch.backends.cudnn.deterministic = True\n", - " \n", - "seed_everything(SEED)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:04:37.222525Z", - "iopub.status.busy": "2021-11-29T03:04:37.221100Z", - "iopub.status.idle": "2021-11-29T03:04:37.225844Z", - "shell.execute_reply": "2021-11-29T03:04:37.226442Z", - "shell.execute_reply.started": "2021-11-29T01:02:54.074875Z" - }, - "papermill": { - "duration": 0.090198, - "end_time": "2021-11-29T03:04:37.226602", - "exception": false, - "start_time": "2021-11-29T03:04:37.136404", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CUDA is available! Training on GPU ...\n" - ] - } - ], - "source": [ - "# 查看CUDA是否可用\n", - "train_on_gpu = torch.cuda.is_available()\n", - "\n", - "if not train_on_gpu:\n", - " print('CUDA is not available. Training on CPU ...')\n", - "else:\n", - " print('CUDA is available! Training on GPU ...')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:04:37.353082Z", - "iopub.status.busy": "2021-11-29T03:04:37.352143Z", - "iopub.status.idle": "2021-11-29T03:04:37.432852Z", - "shell.execute_reply": "2021-11-29T03:04:37.434332Z", - "shell.execute_reply.started": "2021-11-28T10:13:13.644947Z" - }, - "papermill": { - "duration": 0.179146, - "end_time": "2021-11-29T03:04:37.434792", - "exception": false, - "start_time": "2021-11-29T03:04:37.255646", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 读取数据\n", - "\n", - "# 存放数据的路径\n", - "path = '/kaggle/input/ninoprediction/'\n", - "soda_train = nc.Dataset(path + 'SODA_train.nc')\n", - "soda_label = nc.Dataset(path + 'SODA_label.nc')\n", - "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n", - "cmip_label = nc.Dataset(path + 'CMIP_label.nc')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 数据扁平化\n", - "采用滑窗构造数据集。该方案中只使用了sst特征,且只使用了lon值在[90, 330]范围内的数据,可能是为了节约计算资源。" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:04:37.548239Z", - "iopub.status.busy": "2021-11-29T03:04:37.546737Z", - "iopub.status.idle": "2021-11-29T03:04:37.551951Z", - "shell.execute_reply": "2021-11-29T03:04:37.553081Z", - "shell.execute_reply.started": "2021-11-27T13:38:32.620904Z" - }, - "papermill": { - "duration": 0.065069, - "end_time": "2021-11-29T03:04:37.553274", - "exception": false, - "start_time": "2021-11-29T03:04:37.488205", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "def make_flatted(train_ds, label_ds, info, start_idx=0):\n", - " # 只使用sst特征\n", - " keys = ['sst']\n", - " label_key = 'nino'\n", - " # 年数\n", - " years = info[1]\n", - " # 模式数\n", - " models = info[2]\n", - " \n", - " train_list = []\n", - " label_list = []\n", - " \n", - " # 将同种模式下的数据拼接起来\n", - " for model_i in range(models):\n", - " blocks = []\n", - " \n", - " # 对每个特征,取每条数据的前12个月进行拼接,只使用lon值在[90, 330]范围内的数据\n", - " for key in keys:\n", - " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12, :, 19: 67].reshape(-1, 24, 48, 1).data\n", - " blocks.append(block)\n", - " \n", - " # 将所有特征在最后一个维度上拼接起来\n", - " train_flatted = np.concatenate(blocks, axis=-1)\n", - " \n", - " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n", - " label_flatted = np.concatenate([\n", - " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n", - " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n", - " ], axis=0)\n", - " \n", - " train_list.append(train_flatted)\n", - " label_list.append(label_flatted)\n", - " \n", - " return train_list, label_list" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:04:37.661954Z", - "iopub.status.busy": "2021-11-29T03:04:37.660977Z", - "iopub.status.idle": "2021-11-29T03:05:11.515409Z", - "shell.execute_reply": "2021-11-29T03:05:11.515853Z", - "shell.execute_reply.started": "2021-11-27T13:38:33.844185Z" - }, - "papermill": { - "duration": 33.912013, - "end_time": "2021-11-29T03:05:11.516001", - "exception": false, - "start_time": "2021-11-29T03:04:37.603988", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((1, 1200, 24, 48, 1), (15, 1812, 24, 48, 1), (17, 1680, 24, 48, 1))" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "soda_info = ('soda', 100, 1)\n", - "cmip6_info = ('cmip6', 151, 15)\n", - "cmip5_info = ('cmip5', 140, 17)\n", - "\n", - "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n", - "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n", - "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n", - "\n", - "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n", - "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 空值填充\n", - "将空值填充为0。" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:11.638562Z", - "iopub.status.busy": "2021-11-29T03:05:11.637553Z", - "iopub.status.idle": "2021-11-29T03:05:11.644302Z", - "shell.execute_reply": "2021-11-29T03:05:11.644742Z", - "shell.execute_reply.started": "2021-11-27T13:39:22.665855Z" - }, - "papermill": { - "duration": 0.040786, - "end_time": "2021-11-29T03:05:11.644893", - "exception": false, - "start_time": "2021-11-29T03:05:11.604107", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of null in soda_trains after fillna: 0\n" - ] - } - ], - "source": [ - "# 填充SODA数据中的空值\n", - "soda_trains = np.array(soda_trains)\n", - "soda_trains_nan = np.isnan(soda_trains)\n", - "soda_trains[soda_trains_nan] = 0\n", - "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:11.709054Z", - "iopub.status.busy": "2021-11-29T03:05:11.707767Z", - "iopub.status.idle": "2021-11-29T03:05:11.862744Z", - "shell.execute_reply": "2021-11-29T03:05:11.863294Z", - "shell.execute_reply.started": "2021-11-27T13:39:24.110039Z" - }, - "papermill": { - "duration": 0.18937, - "end_time": "2021-11-29T03:05:11.863480", - "exception": false, - "start_time": "2021-11-29T03:05:11.674110", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of null in cmip6_trains after fillna: 0\n" - ] - } - ], - "source": [ - "# 填充CMIP6数据中的空值\n", - "cmip6_trains = np.array(cmip6_trains)\n", - "cmip6_trains_nan = np.isnan(cmip6_trains)\n", - "cmip6_trains[cmip6_trains_nan] = 0\n", - "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:11.927752Z", - "iopub.status.busy": "2021-11-29T03:05:11.925353Z", - "iopub.status.idle": "2021-11-29T03:05:12.091117Z", - "shell.execute_reply": "2021-11-29T03:05:12.091855Z", - "shell.execute_reply.started": "2021-11-27T13:39:24.520724Z" - }, - "papermill": { - "duration": 0.197975, - "end_time": "2021-11-29T03:05:12.092014", - "exception": false, - "start_time": "2021-11-29T03:05:11.894039", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of null in cmip6_trains after fillna: 0\n" - ] - } - ], - "source": [ - "# 填充CMIP5数据中的空值\n", - "cmip5_trains = np.array(cmip5_trains)\n", - "cmip5_trains_nan = np.isnan(cmip5_trains)\n", - "cmip5_trains[cmip5_trains_nan] = 0\n", - "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 构造数据集\n", - "构造训练和验证集。注意这里取每条输入数据的序列长度是38,这是因为输入sst序列长度是12,输出sst序列长度是26,在训练中采用teacher forcing策略(这个策略会在之后的模型构造时详细说明),因此这里在构造输入数据时包含了输出sst序列的实际值。" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:12.165242Z", - "iopub.status.busy": "2021-11-29T03:05:12.164045Z", - "iopub.status.idle": "2021-11-29T03:05:12.480257Z", - "shell.execute_reply": "2021-11-29T03:05:12.479767Z", - "shell.execute_reply.started": "2021-11-27T13:39:25.418945Z" - }, - "papermill": { - "duration": 0.361254, - "end_time": "2021-11-29T03:05:12.480405", - "exception": false, - "start_time": "2021-11-29T03:05:12.119151", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 构造训练集\n", - "\n", - "X_train = []\n", - "y_train = []\n", - "# 从CMIP5的17种模式中各抽取100条数据\n", - "for model_i in range(17):\n", - " samples = np.random.choice(cmip5_trains.shape[1]-38, size=100)\n", - " for ind in samples:\n", - " X_train.append(cmip5_trains[model_i, ind: ind+38])\n", - " y_train.append(cmip5_labels[model_i][ind: ind+24])\n", - "# 从CMIP6的15种模式种各抽取100条数据\n", - "for model_i in range(15):\n", - " samples = np.random.choice(cmip6_trains.shape[1]-38, size=100)\n", - " for ind in samples:\n", - " X_train.append(cmip6_trains[model_i, ind: ind+38])\n", - " y_train.append(cmip6_labels[model_i][ind: ind+24])\n", - "X_train = np.array(X_train)\n", - "y_train = np.array(y_train)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:12.541232Z", - "iopub.status.busy": "2021-11-29T03:05:12.540676Z", - "iopub.status.idle": "2021-11-29T03:05:12.548103Z", - "shell.execute_reply": "2021-11-29T03:05:12.547520Z", - "shell.execute_reply.started": "2021-11-27T13:39:26.341849Z" - }, - "papermill": { - "duration": 0.040262, - "end_time": "2021-11-29T03:05:12.548224", - "exception": false, - "start_time": "2021-11-29T03:05:12.507962", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 构造测试集\n", - "\n", - "X_valid = []\n", - "y_valid = []\n", - "samples = np.random.choice(soda_trains.shape[1]-38, size=100)\n", - "for ind in samples:\n", - " X_valid.append(soda_trains[0, ind: ind+38])\n", - " y_valid.append(soda_labels[0][ind: ind+24])\n", - "X_valid = np.array(X_valid)\n", - "y_valid = np.array(y_valid)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:12.606407Z", - "iopub.status.busy": "2021-11-29T03:05:12.605555Z", - "iopub.status.idle": "2021-11-29T03:05:12.611580Z", - "shell.execute_reply": "2021-11-29T03:05:12.611152Z", - "shell.execute_reply.started": "2021-11-27T13:39:27.247585Z" - }, - "papermill": { - "duration": 0.036214, - "end_time": "2021-11-29T03:05:12.611721", - "exception": false, - "start_time": "2021-11-29T03:05:12.575507", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 查看数据集维度\n", - "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:12.737322Z", - "iopub.status.busy": "2021-11-29T03:05:12.736558Z", - "iopub.status.idle": "2021-11-29T03:05:13.516712Z", - "shell.execute_reply": "2021-11-29T03:05:13.517217Z", - "shell.execute_reply.started": "2021-11-27T13:39:38.421657Z" - }, - "papermill": { - "duration": 0.812187, - "end_time": "2021-11-29T03:05:13.517368", - "exception": false, - "start_time": "2021-11-29T03:05:12.705181", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 保存数据集\n", - "np.save('X_train_sample.npy', X_train)\n", - "np.save('y_train_sample.npy', y_train)\n", - "np.save('X_valid_sample.npy', X_valid)\n", - "np.save('y_valid_sample.npy', y_valid)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 模型构建" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:13.577516Z", - "iopub.status.busy": "2021-11-29T03:05:13.576992Z", - "iopub.status.idle": "2021-11-29T03:05:21.917657Z", - "shell.execute_reply": "2021-11-29T03:05:21.918265Z", - "shell.execute_reply.started": "2021-11-29T01:03:01.505192Z" - }, - "papermill": { - "duration": 8.372964, - "end_time": "2021-11-29T03:05:21.918443", - "exception": false, - "start_time": "2021-11-29T03:05:13.545479", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 读取数据集\n", - "X_train = np.load('../input/ai-earth-task05-samples/X_train_sample.npy')\n", - "y_train = np.load('../input/ai-earth-task05-samples/y_train_sample.npy')\n", - "X_valid = np.load('../input/ai-earth-task05-samples/X_valid_sample.npy')\n", - "y_valid = np.load('../input/ai-earth-task05-samples/y_valid_sample.npy')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:21.983898Z", - "iopub.status.busy": "2021-11-29T03:05:21.982953Z", - "iopub.status.idle": "2021-11-29T03:05:21.986939Z", - "shell.execute_reply": "2021-11-29T03:05:21.986453Z", - "shell.execute_reply.started": "2021-11-29T01:03:11.548945Z" - }, - "papermill": { - "duration": 0.039398, - "end_time": "2021-11-29T03:05:21.987066", - "exception": false, - "start_time": "2021-11-29T03:05:21.947668", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:22.341929Z", - "iopub.status.busy": "2021-11-29T03:05:22.340932Z", - "iopub.status.idle": "2021-11-29T03:05:22.346140Z", - "shell.execute_reply": "2021-11-29T03:05:22.346878Z", - "shell.execute_reply.started": "2021-11-29T01:03:11.560457Z" - }, - "papermill": { - "duration": 0.143838, - "end_time": "2021-11-29T03:05:22.347113", - "exception": false, - "start_time": "2021-11-29T03:05:22.203275", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 构造数据管道\n", - "class AIEarthDataset(Dataset):\n", - " def __init__(self, data, label):\n", - " self.data = torch.tensor(data, dtype=torch.float32)\n", - " self.label = torch.tensor(label, dtype=torch.float32)\n", - "\n", - " def __len__(self):\n", - " return len(self.label)\n", - " \n", - " def __getitem__(self, idx):\n", - " return self.data[idx], self.label[idx]" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:22.583350Z", - "iopub.status.busy": "2021-11-29T03:05:22.582298Z", - "iopub.status.idle": "2021-11-29T03:05:23.243100Z", - "shell.execute_reply": "2021-11-29T03:05:23.243851Z", - "shell.execute_reply.started": "2021-11-29T01:03:23.691846Z" - }, - "papermill": { - "duration": 0.825537, - "end_time": "2021-11-29T03:05:23.244098", - "exception": false, - "start_time": "2021-11-29T03:05:22.418561", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "batch_size = 2\n", - "\n", - "trainset = AIEarthDataset(X_train, y_train)\n", - "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", - "\n", - "validset = AIEarthDataset(X_valid, y_valid)\n", - "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 构造评估函数" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.655820Z", - "iopub.status.busy": "2021-11-29T03:05:23.655241Z", - "iopub.status.idle": "2021-11-29T03:05:23.658416Z", - "shell.execute_reply": "2021-11-29T03:05:23.658859Z", - "shell.execute_reply.started": "2021-11-29T01:03:26.481561Z" - }, - "papermill": { - "duration": 0.040887, - "end_time": "2021-11-29T03:05:23.658990", - "exception": false, - "start_time": "2021-11-29T03:05:23.618103", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "def rmse(y_true, y_preds):\n", - " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n", - "\n", - "# 评估函数\n", - "def score(y_true, y_preds):\n", - " # 相关性技巧评分\n", - " accskill_score = 0\n", - " # RMSE\n", - " rmse_scores = 0\n", - " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n", - " y_true_mean = np.mean(y_true, axis=0)\n", - " y_pred_mean = np.mean(y_preds, axis=0)\n", - " for i in range(24):\n", - " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n", - " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n", - " cor_i = fenzi / fenmu\n", - " accskill_score += a[i] * np.log(i+1) * cor_i\n", - " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n", - " rmse_scores += rmse_score\n", - " return 2/3.0 * accskill_score - rmse_scores" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.028556, - "end_time": "2021-11-29T03:05:23.310560", - "exception": false, - "start_time": "2021-11-29T03:05:23.282004", - "status": "completed" - }, - "tags": [] - }, - "source": [ - "#### 模型构造" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "不同于前两个TOP方案所构建的多输出神经网络,该TOP方案采用的是Seq2Seq结构,以本赛题为例,输入的序列长度是12,输出的序列长度是26,方案中构建了四个隐藏层,那么一个基础的Seq2Seq结构就如下图所示:\n", - "\n", - "\n", - "\n", - "要将Seq2Seq结构应用于不同的问题,重点在于使用怎样的Cell(神经元)。在该TOP方案中使用的Cell是清华大学提出的SA-ConvLSTM(Self-Attention ConvLSTM),论文原文可参考https://ojs.aaai.org//index.php/AAAI/article/view/6819\n", - "\n", - "SA-ConvLSTM是施行健博士提出的时空序列领域经典模型ConvLSTM的改进模型,为了捕捉空间信息的时序依赖关系,它在ConvLSTM的基础上增加了SAM模块,用来记忆空间的聚合特征。ConvLSTM的论文原文可参考https://arxiv.org/pdf/1506.04214.pdf\n", - "\n", - "1. ConvLSTM模型\n", - "\n", - "LSTM模型是非常经典的时序模型,三个门的结构使得它在挖掘长期的时间依赖任务中有不俗的表现,并且相较于RNN,LSTM能够有效地避免梯度消失问题。对于单个输入样本,在每个时间步上,LSTM的每个门实际是对输入向量做了一个全连接,那么对应到我们这个赛题上,输入X的形状是(N,T,H,W,C),则单个输入样本在每个时间步上输入LSTM的就是形状为(H,W,C)的空间信息。我们知道,全连接网络对于这种空间信息的提取能力并不强,转换成卷积操作后能够在大大减少参数量的同时通过堆叠多层网络逐步提取出更复杂的特征,到这里就可以很自然地想到,把LSTM中的全连接操作转换为卷积操作,就能够适用于时空序列问题。ConvLSTM模型就是这么做的,实践也表明这样的作法是非常有效的。\n", - "\n", - "\n", - "\n", - "2. SAM模块\n", - "\n", - "然而,ConvLSTM模型存在两个问题:\n", - "\n", - "一是卷积层的感受野受限于卷积核的大小,需要通过堆叠多个卷积层来扩大感受野,发掘全局的特征。举例来说,假设第一个卷积层的卷积核大小是3×3,那么这一层的每个节点就只能感知这3×3的空间范围内的输入信息,此时再增加一个3×3的卷积层,那么每个节点所能感知的就是3×3个第一层的节点内的信息,在第一层步长为1的情况下,就是4×4范围内的输入信息,于是相比于第一个卷积层,第二层所能感知的输入信息的空间范围就增大了,而这样做所带来的后果就是参数量增加。对于单纯的CNN模型来说增加一层只是增加了一个卷积核大小的参数量,但是对于ConvLSTM来说就有些不堪重负,参数量的增加增大了过拟合的风险,与此同时模型的收效却并不高。\n", - "\n", - "二是卷积操作只针对当前时间步输入的空间信息,而忽视了过去的空间信息,因此难以挖掘空间信息在时间上的依赖关系。\n", - "\n", - "因此,为了同时挖掘全局和本地的空间依赖,提升模型在大空间范围和长时间的时空序列预测任务中的预测效果,SA-ConvLSTM模型在ConvLSTM模型的基础上引入了SAM(self-attention memory)模块。\n", - "\n", - "\n", - "\n", - "SAM模块引入了一个新的记忆单元M,用来记忆包含时序依赖关系的空间信息。SAM模块以当前时间步通过ConvLSTM所获得的隐藏层状态$H_t$和上一个时间步的记忆$M_{t-1}$作为输入,首先将$H_t$通过自注意力机制得到特征$Z_h$,自注意力机制能够增加$H_t$中与其他部分更相关的部分的权重,同时$H_t$也作为Query与$M_{t-1}$共同通过注意力机制得到特征$Z_m$,用以增强对$M_{t-1}$中与$H_t$有更强依赖关系的部分的权重,将$Z_h$和$Z_m$拼接起来就得到了二者的聚合特征$Z$。此时,聚合特征$Z$中既包含了当前时间步的信息,又包含了全局的时空记忆信息,接下来借鉴LSTM中的门控结构用聚合特征$Z$对隐藏层状态和记忆单元进行更新,就得到了更新后的隐藏层状态$\\hat{H_t}$和当前时间步的记忆$M_t$。SAM模块的公式如下:\n", - "\n", - "$$\n", - "\\begin{aligned}\n", - "& i'_t = \\sigma (W_{m;zi} \\ast Z + W_{m;hi} \\ast H_t + b_{m;i}) \\\\\n", - "& g'_t = tanh (W_{m;zg} \\ast Z + W_{m;hg} \\ast H_t + b_{m;g}) \\\\\n", - "& M_t = (1 - i'_t) \\circ M_{t-1} + i'_t \\circ g'_t \\\\\n", - "& o'_t = \\sigma (W_{m;zo} \\ast Z + W_{m;ho} \\ast H_t + b_{m;o}) \\\\\n", - "& \\hat{H_t} = o'_t \\circ M_t\n", - "\\end{aligned}\n", - "$$\n", - "\n", - "关于注意力机制和自注意力机制可以参考以下链接:\n", - "\n", - " - 深度学习中的注意力机制:https://blog.csdn.net/malefactor/article/details/78767781\n", - " - 目前主流的Attention方法:https://www.zhihu.com/question/68482809\n", - "\n", - "3. SA-ConvLSTM模型\n", - "\n", - "将以上二者结合起来,就得到了SA-ConvLSTM模型:\n", - "\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.372772Z", - "iopub.status.busy": "2021-11-29T03:05:23.371873Z", - "iopub.status.idle": "2021-11-29T03:05:23.373700Z", - "shell.execute_reply": "2021-11-29T03:05:23.374122Z", - "shell.execute_reply.started": "2021-11-29T01:03:24.585147Z" - }, - "papermill": { - "duration": 0.035787, - "end_time": "2021-11-29T03:05:23.374254", - "exception": false, - "start_time": "2021-11-29T03:05:23.338467", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# Attention机制\n", - "def attn(query, key, value):\n", - " # query、key、value的形状都是(N, C, H*W),令S=H*W\n", - " # 采用缩放点积模型计算得分,scores(i)=key(i)^T query/根号C\n", - " scores = torch.matmul(query.transpose(1, 2), key / math.sqrt(query.size(1))) # (N, S, S)\n", - " # 计算注意力得分\n", - " attn = F.softmax(scores, dim=-1)\n", - " output = torch.matmul(attn, value.transpose(1, 2)) # (N, S, C)\n", - " return output.transpose(1, 2) # (N, C, S)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.440765Z", - "iopub.status.busy": "2021-11-29T03:05:23.440042Z", - "iopub.status.idle": "2021-11-29T03:05:23.442191Z", - "shell.execute_reply": "2021-11-29T03:05:23.442569Z", - "shell.execute_reply.started": "2021-11-29T01:03:25.147999Z" - }, - "papermill": { - "duration": 0.041095, - "end_time": "2021-11-29T03:05:23.442725", - "exception": false, - "start_time": "2021-11-29T03:05:23.401630", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# SAM模块\n", - "class SAAttnMem(nn.Module):\n", - " def __init__(self, input_dim, d_model, kernel_size):\n", - " super().__init__()\n", - " pad = kernel_size[0] // 2, kernel_size[1] // 2\n", - " self.d_model = d_model\n", - " self.input_dim = input_dim\n", - " # 用1*1卷积实现全连接操作WhHt\n", - " self.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)\n", - " # 用1*1卷积实现全连接操作WmMt-1\n", - " self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)\n", - " # 用1*1卷积实现全连接操作Wz[Zh,Zm]\n", - " self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)\n", - " # 注意输出维度和输入维度要保持一致,都是input_dim\n", - " self.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)\n", - " \n", - " def forward(self, h, m):\n", - " # self.conv_h(h)得到WhHt,将其在dim=1上划分成大小为self.d_model的块,每一块的形状就是(N, d_model, H, W),所得到的三块就是Qh、Kh、Vh\n", - " hq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)\n", - " # 同样的方法得到Km和Vm\n", - " mk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)\n", - " N, C, H, W = hq.size()\n", - " # 通过自注意力机制得到Zh\n", - " Zh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1)) # (N, C, S), C=d_model\n", - " # 通过注意力机制得到Zm\n", - " Zm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1)) # (N, C, S), C=d_model\n", - " # 将Zh和Zm拼接起来,并进行全连接操作得到聚合特征Z\n", - " Z = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1)) # (N, C, H, W), C=d_model\n", - " # 计算i't、g't、o't\n", - " i, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1) # (N, C, H, W), C=input_dim\n", - " i = torch.sigmoid(i)\n", - " g = torch.tanh(g)\n", - " # 得到更新后的记忆单元Mt\n", - " m_next = i * g + (1 - i) * m\n", - " # 得到更新后的隐藏状态Ht\n", - " h_next = torch.sigmoid(o) * m_next\n", - " return h_next, m_next" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.509738Z", - "iopub.status.busy": "2021-11-29T03:05:23.509080Z", - "iopub.status.idle": "2021-11-29T03:05:23.512667Z", - "shell.execute_reply": "2021-11-29T03:05:23.512182Z", - "shell.execute_reply.started": "2021-11-29T01:03:25.667808Z" - }, - "papermill": { - "duration": 0.042616, - "end_time": "2021-11-29T03:05:23.512781", - "exception": false, - "start_time": "2021-11-29T03:05:23.470165", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# SA-ConvLSTM Cell\n", - "class SAConvLSTMCell(nn.Module):\n", - " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n", - " super().__init__()\n", - " self.input_dim = input_dim\n", - " self.hidden_dim = hidden_dim\n", - " pad = kernel_size[0] // 2, kernel_size[1] // 2\n", - " # 卷积操作Wx*Xt+Wh*Ht-1\n", - " self.conv = nn.Conv2d(in_channels=input_dim+hidden_dim, out_channels=4*hidden_dim, kernel_size=kernel_size, padding=pad)\n", - " self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)\n", - " \n", - " def initialize(self, inputs):\n", - " device = inputs.device\n", - " N, _, H, W = inputs.size()\n", - " # 初始化隐藏层状态Ht\n", - " self.hidden_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", - " # 初始化记忆细胞状态ct\n", - " self.cell_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", - " # 初始化记忆单元状态Mt\n", - " self.memory_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", - " \n", - " def forward(self, inputs, first_step=False):\n", - " # 如果当前是第一个时间步,初始化Ht、ct、Mt\n", - " if first_step:\n", - " self.initialize(inputs)\n", - " \n", - " # ConvLSTM部分\n", - " # 拼接Xt和Ht\n", - " combined = torch.cat([inputs, self.hidden_state], dim=1) # (N, C, H, W), C=input_dim+hidden_dim\n", - " # 进行卷积操作\n", - " combined_conv = self.conv(combined) \n", - " # 得到四个门控单元it、ft、ot、gt\n", - " cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)\n", - " i = torch.sigmoid(cc_i)\n", - " f = torch.sigmoid(cc_f)\n", - " o = torch.sigmoid(cc_o)\n", - " g = torch.tanh(cc_g)\n", - " # 得到当前时间步的记忆细胞状态ct=ft·ct-1+it·gt\n", - " self.cell_state = f * self.cell_state + i * g\n", - " # 得到当前时间步的隐藏层状态Ht=ot·tanh(ct)\n", - " self.hidden_state = o * torch.tanh(self.cell_state)\n", - " \n", - " # SAM部分,更新Ht和Mt\n", - " self.hidden_state, self.memory_state = self.sa(self.hidden_state, self.memory_state)\n", - " \n", - " return self.hidden_state" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "在Seq2Seq模型的训练中,有两种训练模式。一是Free running,也就是传统的训练方式,以上一个时间步的输出$\\hat{y_{t-1}}$作为下一个时间步的输入,但是这种做法存在的问题是在训练的初期所得到的$\\hat{y_{t-1}}$与实际标签$y_{t-1}$相差甚远,以此作为输入会导致后续的输出越来越偏离我们期望的预测标签。于是就产生了第二种训练模式——Teacher forcing。\n", - "\n", - "Teacher forcing就是直接使用实际标签$y_{t-1}$作为下一个时间步的输入,由老师(ground truth)带领着防止模型越走越偏。但是老师不能总是手把手领着学生走,要逐渐放手让学生自主学习,于是我们使用Scheduled Sampling来控制使用实际标签的概率。我们用ratio来表示Scheduled Sampling的比例,在训练初期,ratio=1,模型完全由老师带领着,随着训练论述的增加,ratio以一定的方式衰减(该方案中使用线性衰减,ratio每次减小一个衰减率decay_rate),每个时间步以ratio的概率从伯努利分布中提取二进制随机数0或1,为1时输入就是实际标签$y_{t-1}$,否则输入为$\\hat{y_{t-1}}$。" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.587567Z", - "iopub.status.busy": "2021-11-29T03:05:23.586781Z", - "iopub.status.idle": "2021-11-29T03:05:23.588776Z", - "shell.execute_reply": "2021-11-29T03:05:23.589156Z", - "shell.execute_reply.started": "2021-11-29T01:03:26.065997Z" - }, - "papermill": { - "duration": 0.047514, - "end_time": "2021-11-29T03:05:23.589277", - "exception": false, - "start_time": "2021-11-29T03:05:23.541763", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 构建SA-ConvLSTM模型\n", - "class SAConvLSTM(nn.Module):\n", - " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n", - " super().__init__()\n", - " self.input_dim = input_dim\n", - " self.hidden_dim = hidden_dim\n", - " self.num_layers = len(hidden_dim)\n", - " \n", - " layers = []\n", - " for i in range(self.num_layers):\n", - " cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]\n", - " layers.append(SAConvLSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], d_attn = d_attn, kernel_size=kernel_size)) \n", - " self.layers = nn.ModuleList(layers)\n", - " \n", - " self.conv_output = nn.Conv2d(self.hidden_dim[-1], 1, kernel_size=1)\n", - " \n", - " def forward(self, input_x, device=torch.device('cuda:0'), input_frames=12, future_frames=26, output_frames=37, teacher_forcing=False, scheduled_sampling_ratio=0, train=True):\n", - " # 将输入样本X的形状(N, T, H, W, C)转换为(N, T, C, H, W)\n", - " input_x = input_x.permute(0, 1, 4, 2, 3).contiguous()\n", - " \n", - " # 仅在训练时使用teacher forcing\n", - " if train:\n", - " if teacher_forcing and scheduled_sampling_ratio > 1e-6:\n", - " teacher_forcing_mask = torch.bernoulli(scheduled_sampling_ratio * torch.ones(input_x.size(0), future_frames-1, 1, 1, 1))\n", - " else:\n", - " teacher_forcing = False\n", - " else:\n", - " teacher_forcing = False\n", - " \n", - " total_steps = input_frames + future_frames - 1\n", - " outputs = [None] * total_steps\n", - " \n", - " # 对于每一个时间步\n", - " for t in range(total_steps):\n", - " # 在前12个月,使用每个月的输入样本Xt\n", - " if t < input_frames:\n", - " input_ = input_x[:, t].to(device)\n", - " # 若不使用teacher forcing,则以上一个时间步的预测标签作为当前时间步的输入\n", - " elif not teacher_forcing:\n", - " input_ = outputs[t-1]\n", - " # 若使用teacher forcing,则以ratio的概率使用上一个时间步的实际标签作为当前时间步的输入\n", - " else:\n", - " mask = teacher_forcing_mask[:, t-input_frames].float().to(device)\n", - " input_ = input_x[:, t].to(device) * mask + outputs[t-1] * (1-mask)\n", - " first_step = (t==0)\n", - " input_ = input_.float()\n", - " \n", - " # 将当前时间步的输入通过隐藏层\n", - " for layer_idx in range(self.num_layers):\n", - " input_ = self.layers[layer_idx](input_, first_step=first_step)\n", - " \n", - " # 记录每个时间步的输出\n", - " if train or (t >= (input_frames - 1)):\n", - " outputs[t] = self.conv_output(input_)\n", - " \n", - " outputs = [x for x in outputs if x is not None]\n", - " \n", - " # 确认输出序列的长度\n", - " if train:\n", - " assert len(outputs) == output_frames\n", - " else:\n", - " assert len(outputs) == future_frames\n", - " \n", - " # 得到sst的预测序列\n", - " outputs = torch.stack(outputs, dim=1)[:, :, 0] # (N, 37, H, W)\n", - " # 对sst的预测序列在nino3.4区域取三个月的平均值就得到nino3.4指数的预测序列\n", - " nino_pred = outputs[:, -future_frames:, 10:13, 19:30].mean(dim=[2, 3]) # (N, 26)\n", - " nino_pred = nino_pred.unfold(dimension=1, size=3, step=1).mean(dim=2) # (N, 24)\n", - " \n", - " return nino_pred" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.726291Z", - "iopub.status.busy": "2021-11-29T03:05:23.725688Z", - "iopub.status.idle": "2021-11-29T03:05:23.753509Z", - "shell.execute_reply": "2021-11-29T03:05:23.753976Z", - "shell.execute_reply.started": "2021-11-29T01:03:29.448921Z" - }, - "papermill": { - "duration": 0.066105, - "end_time": "2021-11-29T03:05:23.754109", - "exception": false, - "start_time": "2021-11-29T03:05:23.688004", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SAConvLSTM(\n", - " (layers): ModuleList(\n", - " (0): SAConvLSTMCell(\n", - " (conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (sa): SAAttnMem(\n", - " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " )\n", - " (1): SAConvLSTMCell(\n", - " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (sa): SAAttnMem(\n", - " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " )\n", - " (2): SAConvLSTMCell(\n", - " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (sa): SAAttnMem(\n", - " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " )\n", - " (3): SAConvLSTMCell(\n", - " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (sa): SAAttnMem(\n", - " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", - " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " )\n", - " )\n", - " )\n", - " (conv_output): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n", - ")\n" - ] - } - ], - "source": [ - "# 输入特征数\n", - "input_dim = 1\n", - "# 隐藏层节点数\n", - "hidden_dim = (64, 64, 64, 64)\n", - "# 注意力机制节点数\n", - "d_attn = 32\n", - "# 卷积核大小\n", - "kernel_size = (3, 3)\n", - "\n", - "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n", - "print(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 模型训练" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.816671Z", - "iopub.status.busy": "2021-11-29T03:05:23.815927Z", - "iopub.status.idle": "2021-11-29T03:05:23.818479Z", - "shell.execute_reply": "2021-11-29T03:05:23.818058Z", - "shell.execute_reply.started": "2021-11-29T01:03:31.476806Z" - }, - "papermill": { - "duration": 0.035723, - "end_time": "2021-11-29T03:05:23.818579", - "exception": false, - "start_time": "2021-11-29T03:05:23.782856", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 采用RMSE作为损失函数\n", - "def RMSELoss(y_pred,y_true):\n", - " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n", - " return loss" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T03:05:23.893469Z", - "iopub.status.busy": "2021-11-29T03:05:23.892684Z", - "iopub.status.idle": "2021-11-29T04:55:28.956056Z", - "shell.execute_reply": "2021-11-29T04:55:28.956434Z" - }, - "papermill": { - "duration": 6605.109145, - "end_time": "2021-11-29T04:55:28.956614", - "exception": false, - "start_time": "2021-11-29T03:05:23.847469", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 1/5\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Loss: 3.289\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "50it [00:11, 4.47it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validation Loss: 44.009\n", - "Score: -43.458\n", - "Epoch: 2/5\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Loss: 3.084\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "50it [00:11, 4.33it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validation Loss: 25.011\n", - "Score: -19.966\n", - "Epoch: 3/5\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1600/1600 [21:46<00:00, 1.22it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Loss: 13.461\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "50it [00:12, 4.16it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validation Loss: 15.438\n", - "Score: -14.139\n", - "Epoch: 4/5\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1600/1600 [21:54<00:00, 1.22it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Loss: 17.627\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "50it [00:12, 3.99it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validation Loss: 15.389\n", - "Score: -22.500\n", - "Epoch: 5/5\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1600/1600 [21:55<00:00, 1.22it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Loss: 17.592\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "50it [00:11, 4.48it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validation Loss: 15.252\n", - "Score: -14.459\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], - "source": [ - "model_weights = './task05_model_weights.pth'\n", - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", - "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size).to(device)\n", - "criterion = RMSELoss\n", - "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", - "lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=0, verbose=True, min_lr=0.0001)\n", - "epochs = 5\n", - "ratio, decay_rate = 1, 8e-5\n", - "train_losses, valid_losses = [], []\n", - "scores = []\n", - "best_score = float('-inf')\n", - "preds = np.zeros((len(y_valid),24))\n", - "\n", - "for epoch in range(epochs):\n", - " print('Epoch: {}/{}'.format(epoch+1, epochs))\n", - " \n", - " # 模型训练\n", - " model.train()\n", - " losses = 0\n", - " for data, labels in tqdm(trainloader):\n", - " data = data.to(device)\n", - " labels = labels.to(device)\n", - " optimizer.zero_grad()\n", - " # ratio线性衰减\n", - " ratio = max(ratio-decay_rate, 0)\n", - " pred = model(data, teacher_forcing=True, scheduled_sampling_ratio=ratio, train=True)\n", - " loss = criterion(pred, labels)\n", - " losses += loss.cpu().detach().numpy()\n", - " loss.backward()\n", - " optimizer.step()\n", - " train_loss = losses / len(trainloader)\n", - " train_losses.append(train_loss)\n", - " print('Training Loss: {:.3f}'.format(train_loss))\n", - " \n", - " # 模型验证\n", - " model.eval()\n", - " losses = 0\n", - " with torch.no_grad():\n", - " for i, data in tqdm(enumerate(validloader)):\n", - " data, labels = data\n", - " data = data.to(device)\n", - " labels = labels.to(device)\n", - " pred = model(data, train=False)\n", - " loss = criterion(pred, labels)\n", - " losses += loss.cpu().detach().numpy()\n", - " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", - " valid_loss = losses / len(validloader)\n", - " valid_losses.append(valid_loss)\n", - " print('Validation Loss: {:.3f}'.format(valid_loss))\n", - " s = score(y_valid, preds)\n", - " scores.append(s)\n", - " print('Score: {:.3f}'.format(s))\n", - " \n", - " # 保存最佳模型权重\n", - " if s > best_score:\n", - " best_score = s\n", - " checkpoint = {'best_score': s,\n", - " 'state_dict': model.state_dict()}\n", - " torch.save(checkpoint, model_weights)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T04:55:33.957872Z", - "iopub.status.busy": "2021-11-29T04:55:33.957066Z", - "iopub.status.idle": "2021-11-29T04:55:33.960119Z", - "shell.execute_reply": "2021-11-29T04:55:33.959684Z", - "shell.execute_reply.started": "2021-11-28T14:00:36.33194Z" - }, - "papermill": { - "duration": 2.38263, - "end_time": "2021-11-29T04:55:33.960247", - "exception": false, - "start_time": "2021-11-29T04:55:31.577617", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 绘制训练/验证曲线\n", - "def training_vis(train_losses, valid_losses):\n", - " # 绘制损失函数曲线\n", - " fig = plt.figure(figsize=(8,4))\n", - " # subplot loss\n", - " ax1 = fig.add_subplot(121)\n", - " ax1.plot(train_losses, label='train_loss')\n", - " ax1.plot(valid_losses,label='val_loss')\n", - " ax1.set_xlabel('Epochs')\n", - " ax1.set_ylabel('Loss')\n", - " ax1.set_title('Loss on Training and Validation Data')\n", - " ax1.legend()\n", - " plt.tight_layout()" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T04:55:38.227343Z", - "iopub.status.busy": "2021-11-29T04:55:38.226636Z", - "iopub.status.idle": "2021-11-29T04:55:38.470256Z", - "shell.execute_reply": "2021-11-29T04:55:38.469252Z", - "shell.execute_reply.started": "2021-11-28T14:00:43.42651Z" - }, - "papermill": { - "duration": 2.378943, - "end_time": "2021-11-29T04:55:38.470387", - "exception": false, - "start_time": "2021-11-29T04:55:36.091444", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsYUlEQVR4nO3dd3wUdf7H8dcnhSRApIYSeu8QMKKIngoWRATPE1AsgIVT0cMu3umpWE899TwVT08BFQsWDsQu4A89UQiYUKRzlBBCDwQlkPL5/TETWCCBhOxkspvP8/HYR2ZnZmfeu5l8MvPd78yIqmKMMaEmwu8AxhhzIqx4GWNCkhUvY0xIsuJljAlJVryMMSHJipcxJiRZ8aqkRORzERke7Hn9JCLrRORcD5b7rYhc7w5fKSJflWTeE1hPUxHZKyKRJ5q1MqlUxcurjbu8uBt24aNARPYFPL+yNMtS1QtVdVKw562IRGSsiMwpYnxdETkgIp1LuixVnayq5wcp12Hbo6puUNXqqpofjOUfsS4VkV/dbWWHiMwUkaGleP3ZIpIe7FxlUamKV6hzN+zqqlod2ABcHDBucuF8IhLlX8oK6W3gdBFpccT4y4HFqrrEh0x+6OZuO+2AicCLIvKgv5FOnBUvQERiROR5EclwH8+LSIw7ra6IzBCRLBHZKSLfiUiEO+1eEdkkItkiskJE+haz/Boi8qaIbBOR9SJyf8AyRojI9yLyjIjsEpH/iciFpcx/toiku3kygQkiUsvNvc1d7gwRaRzwmsBDoWNmKOW8LURkjvuZfCMiL4nI28XkLknGR0Tkv+7yvhKRugHTr3Y/zx0i8pfiPh9VTQdmAVcfMeka4M3j5Tgi8wgR+T7g+XkislxEdovIi4AETGslIrPcfNtFZLKI1HSnvQU0BT5x94buEZHm7h5SlDtPoohMd7e71SJyQ8CyHxKRKe52lS0iS0UkubjP4IjPY7uqvgXcBNwnInXcZY4UkWXu8taKyB/d8dWAz4FEObSnnygiPUVkrvu3sVlEXhSRKiXJEAxWvBx/AU4DkoBuQE/gfnfanUA6kADUB/4MqIi0A24BTlHVeOACYF0xy/8nUANoCZyF80czMmD6qcAKoC7wFPC6iMiRCzmOBkBtoBkwCud3O8F93hTYB7x4jNeXJsOx5n0HmAfUAR7i6IIRqCQZh+F8VvWAKsBdACLSERjvLj/RXV+RBcc1KTCL+/tLcvOW9rMqXEZd4GOcbaUusAboHTgL8ISbrwPQBOczQVWv5vC956eKWMV7ONteInAZ8LiI9AmYPtCdpyYwvSSZjzANiMLZ3gG2AgOAk3A+8+dEpIeq/gpcCGQE7OlnAPnA7e577wX0BW4uZYYTp6qV5oFTXM4tYvwaoH/A8wuAde7wOJxfcusjXtMa55d9LhB9jHVGAgeAjgHj/gh86w6PAFYHTKsKKNCgpO8FONtdR+wx5k8CdgU8/xa4viQZSjovzh9+HlA1YPrbwNsl/P0UlfH+gOc3A1+4w38F3guYVs39DI76/Qbk3AOc7j5/DJh2gp/V9+7wNcCPAfMJTrG5vpjlXgL8XNz2CDR3P8sonEKXD8QHTH8CmOgOPwR8EzCtI7DvGJ+tcsQ27I7PBK4s5jX/AcYEbGPpx/n93QZMLcnvOhgP2/NyJALrA56vd8cBPA2sBr5yd6XHAqjqapxf1kPAVhF5T0QSOVpdILqI5TcKeJ5ZOKCqv7mD1Uv5Hrapak7hExGpKiL/cg+r9gBzgJpS/DdZpclQ3LyJwM6AcQAbiwtcwoyZAcO/BWRKDFy2OnsHO4pbl5vpA+Aady/xSuDNUuQoypEZNPC5iNR3t4tN7nLfxtkeSqLws8wOGFfsdoPz2cRKKdo7RSQa54hip/v8QhH50T1MzQL6HyuviLR1D7Ez3ff3+LHmDzYrXo4MnEOGQk3dcahqtqreqaotcXbT7xC3bUtV31HVM9zXKvC3Ipa9HcgtYvmbgvwejrw8yJ04DbOnqupJwO/c8aU9HC2NzUBtEakaMK7JMeYvS8bNgct211nnOK+ZBAwBzgPigU/KmOPIDMLh7/dxnN9LF3e5Vx2xzGNd0iUD57OMDxgX7O1mEM6e8jxx2ng/Ap4B6qtqTeCzgLxFZR0PLAfauO/vz3i7fR2mMhavaBGJDXhEAe8C94tIgtuO8Vec/5KIyAARae1umLtxduULRKSdiPRxf+k5OO0kBUeuTJ2vvacAj4lIvIg0A+4oXL6H4t1MWSJSG/D8WyVVXQ+kAA+JSBUR6QVc7FHGD4EBInKG20g8juNvz98BWcCrOIecB8qY41Ogk4hc6m5Hf8I5fC4UD+wFdotII+DuI16/Bacd9CiquhH4AXjC3U67AtcRhO1GRGqL07XmJeBvqroDpz0xBtgG5InzJUxgl5AtQB0RqXHE+9sD7BWR9jhfAJSbyli8PsPZUAsfDwGP4vzRLQIWAwvdcQBtgG9wNsK5wMuqOhvnF/0kzp5VJk6D8n3FrPNW4FdgLfA9TiPxG8F9W0d5Hohz8/0IfOHx+gpdidN4uwPnM3wf2F/MvM9zghlVdSkwGuez3AzswmlvOtZrFOdQsZn7s0w5VHU7MBhnO9iBs638N2CWh4EeOP/0PsVp3A/0BM4/zSwRuauIVVyB0w6WAUwFHlTVb0qSrRhpIrIXpxnkeuB2Vf2r+16ycYrvFJzPchjOlwCF73U5zj/5tW7eRJwvT4YB2cBrOL/rciNuQ5sxnhCR94Hlqhqy/YlMxVQZ97yMh0TkFHH6N0WISD+cdpX/+BzLhCHriW2CrQHO4VEdnMO4m1T1Z38jmXBkh43GmJBkh43GmJAUEoeNdevW1ebNm/sdwxhTzhYsWLBdVROKmhYSxat58+akpKT4HcMYU85EZH1x0+yw0RgTkqx4GWNCkhUvY0xICok2L2MqotzcXNLT08nJyTn+zOaYYmNjady4MdHR0SV+jRUvY05Qeno68fHxNG/enNJfO9IUUlV27NhBeno6LVoceaXu4tlhozEnKCcnhzp16ljhKiMRoU6dOqXeg7XiZUwZWOEKjhP5HMOreOXsgR/Hg53yZEzYC6/itXwGfDEWUicff15jTEgLr+LV9XJo2gu+uh9+3e53GmM8lZWVxcsvv1zq1/Xv35+srKxSv27EiBF8+OGHpX6dV8KreEVEwIDnYf9ep4AZE8aKK155eXnHfN1nn31GzZo1PUpVfsKvq0S99tB7DHz3DHS7Alqe5XciUwk8/MlSfsnYE9Rldkw8iQcv7lTs9LFjx7JmzRqSkpKIjo4mNjaWWrVqsXz5clauXMkll1zCxo0bycnJYcyYMYwaNQo4dK7w3r17ufDCCznjjDP44YcfaNSoEdOmTSMuLu642WbOnMldd91FXl4ep5xyCuPHjycmJoaxY8cyffp0oqKiOP/883nmmWf44IMPePjhh4mMjKRGjRrMmTMnKJ9PeO15FfrdXVC7Jcy4HXKtA6EJT08++SStWrUiNTWVp59+moULF/KPf/yDlStXAvDGG2+wYMECUlJSeOGFF9ix4+g7w61atYrRo0ezdOlSatasyUcffXTc9ebk5DBixAjef/99Fi9eTF5eHuPHj2fHjh1MnTqVpUuXsmjRIu6/3zn6GTduHF9++SVpaWlMnz79OEsvufDb8wKIjoOLnoW3LoHv/g59ir0TvDFBcaw9pPLSs2fPwzp5vvDCC0ydOhWAjRs3smrVKurUOfzucC1atCApKQmAk08+mXXr1h13PStWrKBFixa0bdsWgOHDh/PSSy9xyy23EBsby3XXXceAAQMYMGAAAL1792bEiBEMGTKESy+9NAjv1BGee14Arc6BrkPh++dg2wq/0xjjuWrVqh0c/vbbb/nmm2+YO3cuaWlpdO/evchOoDExMQeHIyMjj9tedixRUVHMmzePyy67jBkzZtCvXz8AXnnlFR599FE2btzIySefXOQe4IkI3+IFcP5jUKUafHIbFBx1S0VjQlp8fDzZ2dlFTtu9eze1atWiatWqLF++nB9//DFo623Xrh3r1q1j9erVALz11lucddZZ7N27l927d9O/f3+ee+450tLSAFizZg2nnnoq48aNIyEhgY0bi72JeqmE52FjoeoJcP4jMP1Wp+9Xj6v9TmRM0NSpU4fevXvTuXNn4uLiqF+//sFp/fr145VXXqFDhw60a9eO0047LWjrjY2NZcKECQwePPhgg/2NN97Izp07GTRoEDk5Oagqzz77LAB33303q1atQlXp27cv3bp1C0qOkLgBR3Jysp7wlVRVYeJFsGUp3JLiFDRjgmDZsmV06NDB7xhho6jPU0QWqGpyUfOH92EjgAgMeA4O/ApfWcO9MeEi/IsXQEI7OON2WPQ+rJntdxpjKrTRo0eTlJR02GPChAl+xzpKeLd5BTrzTljyIXx6B9z0g9OdwhhzlJdeesnvCCVSOfa8AKJjncPHnWthzjN+pzHGlFHlKV4ALc92Thn67z9g6zK/0xhjyqByFS+A8x+FmOrOqUPW98uYkFX5ile1uk4B2zAXfn7L7zTGmBPkefESkUgR+VlEZrjPW4jITyKyWkTeF5EqXmc4StKV0OwM+PoB2Lu13FdvjB+qV69e7LR169bRuXPnckxTduWx5zUGCGxg+hvwnKq2BnYB15VDhsMV9v3K3Qdf/rncV2+MKTtPu0qISGPgIuAx4A5xrrLfBxjmzjIJeAgY72WOIiW0hTPugP970mnEb9233COYMPL5WMhcHNxlNugCFz5Z7OSxY8fSpEkTRo8eDcBDDz1EVFQUs2fPZteuXeTm5vLoo48yaNCgUq02JyeHm266iZSUFKKionj22Wc555xzWLp0KSNHjuTAgQMUFBTw0UcfkZiYyJAhQ0hPTyc/P58HHniAoUOHlultl5TXe17PA/cAhS3jdYAsVS08dT0daFTUC0VklIikiEjKtm3bvEl3xu1Qp7XT9yt3nzfrMMYjQ4cOZcqUKQefT5kyheHDhzN16lQWLlzI7NmzufPOOyntKYAvvfQSIsLixYt59913GT58ODk5ObzyyiuMGTOG1NRUUlJSaNy4MV988QWJiYmkpaWxZMmSg1eSKA+e7XmJyABgq6ouEJGzS/t6VX0VeBWccxuDm85V2Pdr0sUw52no+1dPVmMqgWPsIXmle/fubN26lYyMDLZt20atWrVo0KABt99+O3PmzCEiIoJNmzaxZcsWGjRoUOLlfv/999x6660AtG/fnmbNmrFy5Up69erFY489Rnp6Opdeeilt2rShS5cu3Hnnndx7770MGDCAM88806u3exQv97x6AwNFZB3wHs7h4j+AmiJSWDQbA5s8zHB8LX4H3YY5fb+2/OJrFGNKa/DgwXz44Ye8//77DB06lMmTJ7Nt2zYWLFhAamoq9evXL/XNXIszbNgwpk+fTlxcHP3792fWrFm0bduWhQsX0qVLF+6//37GjRsXlHWVhGfFS1XvU9XGqtocuByYpapXArOBy9zZhgPTvMpQYuc/CjEnwYzbrO+XCSlDhw7lvffe48MPP2Tw4MHs3r2bevXqER0dzezZs1m/fn2pl3nmmWcyebJz+8CVK1eyYcMG2rVrx9q1a2nZsiV/+tOfGDRoEIsWLSIjI4OqVaty1VVXcffdd7Nw4cJgv8Vi+dHP616cxvvVOG1gr/uQ4XDV6sAFj8HGn2DhJL/TGFNinTp1Ijs7m0aNGtGwYUOuvPJKUlJS6NKlC2+++Sbt27cv9TJvvvlmCgoK6NKlC0OHDmXixInExMQwZcoUOnfuTFJSEkuWLOGaa65h8eLF9OzZk6SkJB5++OGD160vD+F/Pa+SUnXavjIXwej5EF//+K8xlZpdzyu47HpeJ+qwvl/3+Z3GGHMcleeSOCVRtw2ceRd8+7jTiN/mXL8TGRNUixcv5uqrD78cekxMDD/99JNPiU6cFa8jnXEbLP7A6ft1849QparfiUwFpqo4fa9DQ5cuXUhNTfU7xlFOpPnKDhuPFBUDFz8PWethzlN+pzEVWGxsLDt27DihPzxziKqyY8cOYmNjS/U62/MqSvMzIOkq+OGf0GUw1Pf/hqKm4mncuDHp6el4dgZIJRIbG0vjxo1L9RorXsU5/xFY+blzz8drv4QI20k1h4uOjj7sDtWmfNlfZHGq1oYLHof0ebCg4t18wJjKzorXsXQd6pw+9M3DkJ3pdxpjTAArXsciAhc9B3k58IX1/TKmIrHidTx1W8Pv7oKlH8Oqr/1OY4xxWfEqid5joG47mHGHc+dtY4zvrHiVRFSMc+rQ7g3wf3/zO40xBiteJde8N3S/Gn54ETKX+J3GmErPildpnDcO4mrBJ2OgIN/vNMZUala8SqOw79emFEh5w+80xlRqVrxKq+sQaHk2zBwHezb7ncaYSsuKV2mJwEXPQt5++GKs32mMqbSseJ2IOq3grLvhl//Ayi/9TmNMpWTF60SdPgYS2sOnd1nfL2N8YMXrREVVgQHPO32/vn3C7zTGVDpWvMqiWS/oMRzmvgybF/mdxphKxYpXWZ37kNOFYsZt1vfLmHJkxausqtaGC56ATQtgvv+3oDSmsrDiFQxdLoNWfdy+Xxl+pzGmUrDiFQwicNHfoSAXPr/X7zTGVApWvIKldks46x5YNh1WfO53GmPCnhWvYOp1KyR0gM/uhv17/U5jTFiz4hVMUVWcez7u3mh9v4zxmBWvYGt6Gpw8En58GTan+Z3GmLBlxcsL5z4IVevadb+M8ZAVLy/E1YJ+T0DGzzDvNb/TGBOWrHh5pfMfoFVfmPUI7N7kdxpjwo4VL6+IwIBnncPGz+/xO40xYceKl5dqNYez74XlM2D5p36nMSasWPHyWq9boF5Ht+9Xtt9pjAkbVry8FhkNF/8D9myC2Y/7ncaYsGHFqzw06QnJ18JPrzjfQBpjysyKV3np+yBUS3D6fuXn+Z3GmJBnxau8xNWEfk86ve7nW98vY8rKs+IlIrEiMk9E0kRkqYg87I5vISI/ichqEXlfRKp4laHC6fR7aH0ezHoUdqf7ncaYkOblntd+oI+qdgOSgH4ichrwN+A5VW0N7AKu8zBDxSICFz3j9P36zPp+GVMWnhUvdRReFybafSjQB/jQHT8JuMSrDBVSreZwzn2w4lNYNsPvNMaELE/bvEQkUkRSga3A18AaIEtVC1us04FGXmaokE67Gep3tr5fxpSBp8VLVfNVNQloDPQE2pf0tSIySkRSRCRl27ZtXkX0R2S0c8/H7M1O+5cxptTK5dtGVc0CZgO9gJoiEuVOagwUedayqr6qqsmqmpyQkFAeMctXk1PglOvgp385dx4yxpSKl982JohITXc4DjgPWIZTxC5zZxsOTPMqQ4XX969QvT58cpv1/TKmlLzc82oIzBaRRcB84GtVnQHcC9whIquBOkDlvdlhbA248G+QuQjm/cvvNMaElKjjz3JiVHUR0L2I8Wtx2r8MQMdB0OYCmPUYdBgINZv4nciYkGA97P0mAv2fBtT59lHV70TGhAQrXhVBrWZw9n2w8nNY9onfaYwJCVa8KorTbob6XZyrrubs8TuNMRWeFa+KIjLKue5Xdqb1/TKmBKx4VSSNT4aeN8C8VyHd+n4ZcyxWvCqaPvdDfAO77pcxx2HFq6Ip7Pu1ZTH8NN7vNMZUWFa8KqIOA6Hthc4177M2+J3GmArJildFdLDvl8Cnd1nfL2OKYMWroqrZBM75M6z6En6pvKd/GlMcK14V2ak3QoOu8Pm9kLPb7zTGVChWvCqywr5fv26FmY/4ncaYCsWKV0XXqAf0HAXz/w3pKX6nMabCsOIVCs75C8Q3dPt+5fqdxpgKwYpXKIg9Cfo/BVuWwI8v+53GmArBileoaD8A2vWH2U/ArvV+pzHGd1a8QkVh3y+JgM+s75cxVrxCSY3GzrmPq76CpVP9TmOMr0pUvESkmohEuMNtRWSgiER7G80UqecoaNjN6fu1fZXfaYzxTUn3vOYAsSLSCPgKuBqY6FUocwyRUTDoZdACeK0PrPzS70TG+KKkxUtU9TfgUuBlVR0MdPIuljmmBp1h1LdQqzm8MxTmPGNtYKbSKXHxEpFewJXAp+64SG8imRKp2QSu/RK6DIZZj8AHw2H/Xr9TGVNuSlq8bgPuA6aq6lIRaYlz81jjpypV4dJX4fxHnRt3vH4e7FzrdypjyoVoKQ833Ib76qpabneJSE5O1pQUOzXmmNbMgg9GOsODJ0CrPv7mMSYIRGSBqiYXNa2k3za+IyIniUg1YAnwi4jcHcyQpoxa9XHawU5qBG//Af77grWDmbBW0sPGju6e1iXA50ALnG8cTUVSuwVc9xV0uBi+fgA+vgEO/OZ3KmM8UdLiFe3267oEmK6quYD9W6+IYqrD4EnQ96+w+EN44wK7lLQJSyUtXv8C1gHVgDki0gywO6NWVCJw5p0wbIpzHuSrZ8P/vvM7lTFBVaLipaovqGojVe2vjvXAOR5nM2XV9ny4YRZUrQtvDoKf/mXtYCZslLTBvoaIPCsiKe7j7zh7Yaaiq9sarv8G2l4An98D00ZDbo7fqYwps5IeNr4BZAND3MceYIJXoUyQxZ4EQyfDWWMhdTJMuBB2b/I7lTFlUtLi1UpVH1TVte7jYaCll8FMkEVEwDn3OUVs+0qnHWzDj36nMuaElbR47RORMwqfiEhvYJ83kYynOgyA62c630pOHAApb/idyJgTElXC+W4E3hSRGu7zXcBwbyIZz9VrDzfMho+uhxm3w+Y0uPBpiKridzJjSqyk3zamqWo3oCvQVVW7A3b+SSiLqwnD3ocz7oAFE2HSAMjO9DuVMSVWqiupquqegHMa7/AgjylPEZFw7oNw2QTIXOy0g6Uv8DuVMSVSlstAS9BSGH91vhSu+xoiq8CEfvDz234nMua4ylK8rLdjOCm8wGHTXk5fsM/usXtEmgrtmA32IpJN0UVKgDhPEhn/VK0NV30M3zwIc1+ELUthyCSoVtfvZMYc5Zh7Xqoar6onFfGIV9WSflNpQklkFFzwGPz+VdiU4rSDZaT6ncqYo3h26zMRaSIis0XkFxFZKiJj3PG1ReRrEVnl/qzlVQZTBt2GwrVfOOdCvnEBLJridyJjDuPlfRvzgDtVtSNwGjBaRDoCY4GZqtoGmOk+NxVRYnenHazRyc61wb78C+Tn+Z3KGMDD4qWqm1V1oTucDSwDGgGDgEnubJNwrhFmKqrqCXDNNOd+kXNfhMl/gN92+p3KmPK5Y7aINAe6Az8B9VV1szspE6hfzGtGFV7FYtu2beUR0xQnMhr6Pw0DX4T1PzjtYJlL/E5lKjnPi5eIVAc+Am478qYd6tz9o8guF6r6qqomq2pyQkKC1zFNSfS4GkZ+DvkHnDsVLf2P34lMJeZp8XIvHf0RMFlVP3ZHbxGRhu70hsBWLzOYIGuc7LSD1e/s3Cty5jgoyPc7lamEvPy2UYDXgWWq+mzApOkcOql7ODDNqwzGI/ENYMQM6DEcvvs7vHs57MvyO5WpZLzc8+qNc4ehPiKS6j76A08C54nIKuBc97kJNVExMPAFGPCcc8/I1/rAthV+pzKViGcdTVX1e4o//7GvV+s15Sz5WkjoAFOugdf6wqX/gvYX+Z3KVALl8m2jCXPNejntYHVbw3vD4NsnoaDA71QmzFnxMsFRo5HzTWS3K+DbJ+D9qyDH7o5nvGPFywRPdBxcMh76PQkrv4B/nwvbV/udyoQpK14muETgtJvg6qnw6zanIX/lV36nMmHIipfxRsuznHawmk3hnSHw3bN2w1sTVHZZG+OdWs3guq9g+i0w82HnRh+XvAxV7H7FAKrKwg27mJaawc8bslAUQRD3O3oBEDn4lb2IM07cGeTguEMvKBznPHWWdeRzDi5Pjpj/0DgOjpcjph9aX+Gyj17/kfkOrS86Unjqsm4n/qEFsOJlvFWlKvzhdWjYDb55CLavgssnQ+0WfifzzfLMPUxLzeCTtAzSd+0jJiqCU5rXJjrS+TNXnJ3Uwv1UdfdYC3dcFT00rEc8B7TAGYc7XgOXcdhynCeH1lPUsg+9NnCewOmFGfXgQg+97shlR0cG72DPipfxngj0HgP1O8GH18Jr5zg3/Wh1jt/Jys3Gnb8xPS2D6akZrNiSTWSE0Lt1XW4/ty3nd6pPfGy03xFDjmgItEMkJydrSkqK3zFMMOxYA+9dCdtXwHmPQK/RHHYsE0a2Ze/n00UZTE/LYOGGLACSm9ViYFIi/bs0pG71GH8DhgARWaCqyUVNsz0vU77qtILrv4b/3ARf/cVpBxv4gtPNIgzsycnlyyWZTE/L4L+rt1Og0L5BPPf0a8fFXRNpUruq3xHDhhUvU/5i4mHwm85J3bMfc/bChk6Gmk38TnZCcnLzmb18K9PTMpi5fCsH8gpoUjuOm85uxcBujWjXIN7viGHJipfxR0QEnHW3c8u1j0c5FzgcMgman+F3shLJyy9g7todTEvN4MslmWTvz6Nu9SoM69mUgUmJdG9S8+C3bsYbVryMv9pdCNfPdM6JfHMQXPAE9LyhQraDqSo/b8xiemoGMxZlsH3vAeJjorigcwMGJSXSq2UdooL4bZo5Nitexn8JbeGGmc4e2Od3Q2Ya9P87RMf6nQyAlVuymZa6ielpGWzcuY8qURH0bV+PQUmJnN2uHrHRkX5HrJSseJmKIbYGXP6uc1L3nKdg63IY+haclOhLnI07f+OTRU7XhuWZ2UQI9G5dlzF9na4NJ1nXBt9Z8TIVR0QE9PkLNOgCU2902sF+/wrUbonTfVtK/7MU827/9QCfL8nkk7RMUjZkoUD3JrV4eGAn+ndpSEK8dW2oSKyfl6mYtvzitIPt+p/fSQIUVfQiSlAYi3ttYJGNAIl0CrhEQkRkwDh3OCIy4PmxxkcUMV9x4wNef9iySrOM0oyPgianlPwTt35eJuTU7+ic2L3qayjIdc9p0cN/asHR46DoeQN+5ubns2brXpZl7GbN1mzyC5SacZF0aBhPhwbVSagec9xlHPPnCb2mwHkU5IPmuz8L3OGCgHEB0wryQQ8UMz5w/qJeX8z4om/mFTyRVeCB4NzK0IqXqbjiakLXwUFZVH6BMnfNDqanbeLzJZlk5+RRp1oVBiQ3ZGBSIj2a1rKuDXCokB5V6IorjKUsmEFkxcuELVUldWMW09MymLFoM9uy91M9JooLOjVgYFIivVtZ14ajiBw6zKvgrHiZsLN6azbTUjOYlprBhp2/USUygj7t6zEwKZE+7a1rQ7iw4mXCwqasfXyS5hSsZZv3ECFwequ63NKnNRd0akCNOOvaEG6seJmQtfPXA3y6eDPTUzcxf90uAJKa1OTBiztyUdeG1IuvGJ1cjTeseJmQsnd/Hl//ksm01Ay+X7WdvAKlTb3q3HV+Wy7ulkizOnaV1srCipep8Pbn5fN/K7YxPS2Db5ZtISe3gEY147j+zJYMSkqkfYN4+6awErLiZSqs9F2/8eKs1Xy2eDN7cvKoXa0Kg09uwiC3a0NEhBWsysyKl6mQFqVnce3EFH7dn8eFnd2uDa3rBvUa6Ca0WfEyFc7MZVu45Z2fqVO9Cu+N6k3renYxP3M0K16mQnlr7joenL6UTok1eH1Esn1jaIplxctUCAUFyt++WM6/5qzl3A71eOGK7lStYpunKZ5tHcZ3Obn53PlBGp8u2szVpzXjoYGdiLTGeHMcVryMr3b9eoAb3kwhZf0u/ty/PTec2dK6PZgSseJlfLN+x6+MnDCf9Kx9vDSsBxd1beh3JBNCrHgZX/y8YRfXT0ohX5V3rj+V5Oa1/Y5kQowVL1PuvlyayZj3fqZefCwTR55Cy4TqfkcyIciKlylXb3z/Px759Be6Na7J68OTqWO3vDcnyIqXKRf5Bcpjny7jjf/+jws61ef5od2Jq2LX1TInzoqX8VxObj63vZfKF0szGdm7Ofdf1NG6Qpgys+JlPLVj736ufzOF1I1Z/HVAR649o4XfkUyYsOJlPLN2215GTpxP5u4cxl95Mv06N/A7kgkjnp2iLyJviMhWEVkSMK62iHwtIqvcn7W8Wr/x14L1O/nD+B/Izsnj3VGnWeEyQefl9UUmAv2OGDcWmKmqbYCZ7nMTZj5dtJkrXvuJmlWrMPXm0+nR1P5HmeDzrHip6hxg5xGjBwGT3OFJwCVerd+UP1XltTlrGf3OQro2qsFHN51ul2U2ninvNq/6qrrZHc4E6hc3o4iMAkYBNG3atByimbLIL1DGfbKUSXPXc1GXhvx9SDe7xZjxlG+XpVQtvNd5sdNfVdVkVU1OSEgox2SmtH47kMcf31rApLnrGfW7lvzziu5WuIznynvPa4uINFTVzSLSENhazus3QbYtez/XTZrPkk27eWRQJ67u1dzvSKaSKO89r+nAcHd4ODCtnNdvgmj11r38/uX/smrLXl69OtkKlylXnu15ici7wNlAXRFJBx4EngSmiMh1wHpgiFfrN976ae0ORr21gOhI4f0/nkbXxjX9jmQqGc+Kl6peUcykvl6t05SPaambuPuDRTSpHcfEkT1pUruq35FMJWQ97E2JqSrj/28NT32xgp4tavPa1cnUqBrtdyxTSVnxMiWSl1/AA9OW8u68DQzslsjTg7sSE2XfKBr/WPEyx/Xr/jxueWchs1ds4+azW3HX+e3sbtXGd1a8zDFt2ZPDtRPnszwzm8d/34Vhp1qHYVMxWPEyxVq5JZsRb8wja18u/x6ezDnt6vkdyZiDrHiZIv2wejt/fHsBcdGRTPljLzo3quF3JGMOY8XLHOXjhenc+9EiWtStxoSRPWlUM87vSMYcxYqXOUhV+ees1Tz79UpOb1WH8VedTI046wphKiYrXgaA3PwC/jJ1MVNS0rm0RyOevLQrVaJ8O2/fmOOy4mXIzsnl5skL+W7Vdv7Utw23n9sGEesKYSo2K16V3Obd+xg5YT6rt+7lqT90ZcgpTfyOZEyJWPGqxJZt3sPICfPZuz+PCSNP4cw2dt00EzqseFVSc1Zu4+bJC6keE8UHN/aiQ8OT/I5kTKlY8aqEpqRs5M8fL6Z1vepMGHkKDWtYVwgTeqx4VSKqynPfrOKFmas4s01dXr6yB/Gx1hXChCYrXpXEgbwCxn68iI8XbmJIcmMe+30XoiOtK4QJXVa8KoHd+3K56e0F/LBmB3ee15Zb+rS2rhAm5FnxCnObsvYxcsI8/rf9V54d0o1LezT2O5IxQWHFK4wt2bSbayfOZ19uPpNG9uT01nX9jmRM0FjxClOzV2xl9OSF1KpahbevP5W29eP9jmRMUFnxCkPv/LSBB6YtoX2DeCaMOIV6J8X6HcmYoLPiFUYKCpRnvlrBy9+u4Zx2Cbw4rAfVYuxXbMKTbdlhYn9ePnd/sIjpaRkMO7Up4wZ2Isq6QpgwZsUrDGT9doBRby1g3v92cm+/9tx4VkvrCmHCnhWvELdx52+MmDCPjTv38Y/LkxiU1MjvSMaUCyteIWxRehbXTpxPbr7y1nU9ObVlHb8jGVNuwqp4fbk0k3s/WkRUhBAVEUFkhBAdKe5P53lUhBAVGTgtgugIZ56oSOd1Ue5wZETEwdcXvq5w2c68RS876uCy3BwBw1FFDB+2jMgipkVEHHWfxG9+2cKt7/5MnepVeG9UT1rXq+7Tp26MP8KqeDWqGcegbonkFij5+UpuQQH5BUpevpLnDufmq/uzgP25BeQW5JNfUODOc2jaoXkLyHOXkV/gLFO1/N+bCERHHCpwe/fn0bVRDf49/BQS4mPKP5AxPgur4tW5UY1yuUVXQUFAYQwojnn5hxe/w6YdMZx/1HglL7/gqAJ61OvcaTXiohn1u5ZUrRJWv0JjSsy2/BMQESHERET6HcOYSs06AhljQpIVL2NMSLLiZYwJSVa8jDEhyYqXMSYkWfEyxoQkK17GmJBkxcsYE5JE/TjXpZREZBuwvoSz1wW2exinorD3GV7sfRatmaomFDUhJIpXaYhIiqom+53Da/Y+w4u9z9Kzw0ZjTEiy4mWMCUnhWLxe9TtAObH3GV7sfZZS2LV5GWMqh3Dc8zLGVAJWvIwxISmsipeI9BORFSKyWkTG+p3HCyLyhohsFZElfmfxkog0EZHZIvKLiCwVkTF+Z/KCiMSKyDwRSXPf58N+Z/KSiESKyM8iMqOsywqb4iUikcBLwIVAR+AKEenobypPTAT6+R2iHOQBd6pqR+A0YHSY/j73A31UtRuQBPQTkdP8jeSpMcCyYCwobIoX0BNYraprVfUA8B4wyOdMQaeqc4CdfufwmqpuVtWF7nA2zgYfdjelVMde92m0+wjLb9FEpDFwEfDvYCwvnIpXI2BjwPN0wnBjr4xEpDnQHfjJ5yiecA+lUoGtwNeqGpbvE3geuAcoCMbCwql4mTAkItWBj4DbVHWP33m8oKr5qpoENAZ6ikhnnyMFnYgMALaq6oJgLTOcitcmoEnA88buOBOiRCQap3BNVtWP/c7jNVXNAmYTnm2avYGBIrIOp0mnj4i8XZYFhlPxmg+0EZEWIlIFuByY7nMmc4JERIDXgWWq+qzfebwiIgkiUtMdjgPOA5b7GsoDqnqfqjZW1eY4f5uzVPWqsiwzbIqXquYBtwBf4jTuTlHVpf6mCj4ReReYC7QTkXQRuc7vTB7pDVyN8x861X309zuUBxoCs0VkEc4/4K9VtczdCCoDOz3IGBOSwmbPyxhTuVjxMsaEJCtexpiQZMXLGBOSrHgZY0KSFS/jKRHJD+jqkBrMq32ISPNwv7qGKV6U3wFM2NvnnvpiTFDZnpfxhYisE5GnRGSxez2r1u745iIyS0QWichMEWnqjq8vIlPd616licjp7qIiReQ191pYX7m91BGRP7nXAlskIu/59DaNh6x4Ga/FHXHYODRg2m5V7QK8iHPFAYB/ApNUtSswGXjBHf8C8H/uda96AIVnT7QBXlLVTkAW8Ad3/Figu7ucG715a8ZP1sPeeEpE9qpq9SLGr8O5CN9a9wTsTFWtIyLbgYaqmuuO36yqdd27pjdW1f0By2iOczpNG/f5vUC0qj4qIl8Ae4H/AP8JuGaWCRO252X8pMUMl8b+gOF8DrXjXoRzZd0ewHwRsfbdMGPFy/hpaMDPue7wDzhXHQC4EvjOHZ4J3AQHL95Xo7iFikgE0ERVZwP3AjWAo/b+TGiz/0bGa3HuVUILfaGqhd0larlXU9gPXOGOuxWYICJ3A9uAke74McCr7lU08nEK2eZi1hkJvO0WOAFecK+VZcKItXkZX7htXsmqut3vLCY02WGjMSYk2Z6XMSYk2Z6XMSYkWfEyxoQkK17GmJBkxcsYE5KseBljQtL/A/w8USoPr020AAAAAElFTkSuQmCC\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "training_vis(train_losses, valid_losses)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### 模型评估\n", - "\n", - "在测试集上评估模型效果。" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T04:55:47.340416Z", - "iopub.status.busy": "2021-11-29T04:55:47.339537Z", - "iopub.status.idle": "2021-11-29T04:55:47.606447Z", - "shell.execute_reply": "2021-11-29T04:55:47.607038Z", - "shell.execute_reply.started": "2021-11-28T14:01:44.127754Z" - }, - "papermill": { - "duration": 2.453872, - "end_time": "2021-11-29T04:55:47.607210", - "exception": false, - "start_time": "2021-11-29T04:55:45.153338", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# 加载得分最高的模型\n", - "checkpoint = torch.load('../input/ai-earth-model-weights/task05_model_weights.pth')\n", - "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n", - "model.load_state_dict(checkpoint['state_dict'])" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T04:55:51.849996Z", - "iopub.status.busy": "2021-11-29T04:55:51.849073Z", - "iopub.status.idle": "2021-11-29T04:55:51.851492Z", - "shell.execute_reply": "2021-11-29T04:55:51.850969Z", - "shell.execute_reply.started": "2021-11-28T14:06:59.931318Z" - }, - "papermill": { - "duration": 2.125413, - "end_time": "2021-11-29T04:55:51.851629", - "exception": false, - "start_time": "2021-11-29T04:55:49.726216", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "# 测试集路径\n", - "test_path = '../input/ai-earth-tests/'\n", - "# 测试集标签路径\n", - "test_label_path = '../input/ai-earth-tests-labels/'" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T04:55:56.429364Z", - "iopub.status.busy": "2021-11-29T04:55:56.428800Z", - "iopub.status.idle": "2021-11-29T04:55:58.115486Z", - "shell.execute_reply": "2021-11-29T04:55:58.115007Z", - "shell.execute_reply.started": "2021-11-28T14:07:13.415385Z" - }, - "papermill": { - "duration": 4.135325, - "end_time": "2021-11-29T04:55:58.115667", - "exception": false, - "start_time": "2021-11-29T04:55:53.980342", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import os\n", - "\n", - "# 读取测试数据和测试数据的标签\n", - "files = os.listdir(test_path)\n", - "X_test = []\n", - "y_test = []\n", - "for file in files:\n", - " X_test.append(np.load(test_path + file))\n", - " y_test.append(np.load(test_label_path + file))" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T04:56:02.560786Z", - "iopub.status.busy": "2021-11-29T04:56:02.559461Z", - "iopub.status.idle": "2021-11-29T04:56:02.587431Z", - "shell.execute_reply": "2021-11-29T04:56:02.588024Z", - "shell.execute_reply.started": "2021-11-28T14:07:17.046359Z" - }, - "papermill": { - "duration": 2.329175, - "end_time": "2021-11-29T04:56:02.588201", - "exception": false, - "start_time": "2021-11-29T04:56:00.259026", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "((103, 12, 24, 48, 1), (103, 24))" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X_test = np.array(X_test)[:, :, :, 19: 67, :1]\n", - "y_test = np.array(y_test)\n", - "X_test.shape, y_test.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": { - "execution": { - "iopub.execute_input": "2021-11-29T04:56:07.675344Z", - "iopub.status.busy": "2021-11-29T04:56:07.674488Z", - "iopub.status.idle": "2021-11-29T04:56:07.682481Z", - "shell.execute_reply": "2021-11-29T04:56:07.682895Z", - "shell.execute_reply.started": "2021-11-28T14:07:31.503452Z" - }, - "papermill": { - "duration": 2.455352, - "end_time": "2021-11-29T04:56:07.683041", - "exception": false, - "start_time": "2021-11-29T04:56:05.227689", - "status": "completed" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "testset = AIEarthDataset(X_test, y_test)\n", - "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# 在测试集上评估模型效果\n", - "model.eval()\n", - "model.to(device)\n", - "preds = np.zeros((len(y_test),24))\n", - "for i, data in tqdm(enumerate(testloader)):\n", - " data, labels = data\n", - " data = data.to(device)\n", - " labels = labels.to(device)\n", - " pred = model(data, train=False)\n", - " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", - "s = score(y_test, preds)\n", - "print('Score: {:.3f}'.format(s))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": null, - "end_time": null, - "exception": null, - "start_time": null, - "status": "pending" - }, - "tags": [] - }, - "source": [ - "## 总结\n", - "\n", - "这一次的TOP方案没有自己构造模型,而是使用了目前时空序列预测领域比较出色的模型,另一组TOP选手“ailab”也使用了这个领域经典的模型PredRNN++,大家在其他任务中也可以多去尝试相关领域的前沿模型,关于在时空序列预测领域的一些比较典型的模型可以参考https://www.zhihu.com/column/c_1208033701705162752" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 作业\n", - "\n", - "该TOP方案中以sst作为预测目标,然后折算成nino3.4指数,学有余力的同学可以尝试一下用SA-ConvLSTM模型直接预测nino3.4指数。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 参考文献\n", - "\n", - "1. 吴先生的队伍方案分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d5330dF9lX1&postId=231465\n", - "2. ailab团队思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d5330dF9lX1&postId=210734" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - }, - "papermill": { - "default_parameters": {}, - "duration": 6708.571081, - "end_time": "2021-11-29T04:56:15.789285", - "environment_variables": {}, - "exception": true, - "input_path": "__notebook__.ipynb", - "output_path": "__notebook__.ipynb", - "parameters": {}, - "start_time": "2021-11-29T03:04:27.218204", - "version": "2.3.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}