From bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6 Mon Sep 17 00:00:00 2001 From: Harold-Ran <56714856+Harold-Ran@users.noreply.github.com> Date: Sat, 4 Dec 2021 20:40:24 +0800 Subject: [PATCH] Add files via upload --- Task5/Task5 模型建立之SA-ConvLSTM.ipynb | 1818 +++++++++++++++++++++++ 1 file changed, 1818 insertions(+) create mode 100644 Task5/Task5 模型建立之SA-ConvLSTM.ipynb diff --git a/Task5/Task5 模型建立之SA-ConvLSTM.ipynb b/Task5/Task5 模型建立之SA-ConvLSTM.ipynb new file mode 100644 index 0000000..9ac6f98 --- /dev/null +++ b/Task5/Task5 模型建立之SA-ConvLSTM.ipynb @@ -0,0 +1,1818 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Datawhale 气象海洋预测-Task5 模型建立之 SA-ConvLSTM\n", + "\n", + "本次任务我们将学习来自TOP选手“吴先生的队伍”的建模方案,该方案中采用的模型是SA-ConvLSTM。\n", + "\n", + "前两个TOP方案中选择将赛题看作一个多输出的任务,通过构建神经网络直接输出24个nino3.4预测值,这种思路的问题在于,序列问题往往是时序依赖的,当我们采用多输出的方法时其实把这24个nino3.4预测值看作是完全独立的,但是实际上它们之间是存在序列依赖的,即每个预测值往往受上一个时间步的预测值的影响。因此,在这次的TOP方案中,采用Seq2Seq结构来考虑输出预测值的序列依赖性。\n", + "\n", + "Seq2Seq结构包括Encoder(编码器)和Decoder(解码器)两部分,Encoder部分将输入序列编码成一个向量,Decoder部分对向量进行解码,输出一个预测序列。要将Seq2Seq结构应用于不同的序列问题,关键在于每一个时间步所使用的Cell。我们之前说到,挖掘空间信息通常会采用CNN,挖掘时间信息通常会采用RNN或LSTM,将二者结合在一起就得到了时空序列领域的经典模型——ConvLSTM,我们本次要学习的SA-ConvLSTM模型是对ConvLSTM模型的改进,在其基础上引入了自注意力机制来提高模型对于长期空间依赖关系的挖掘能力。\n", + "\n", + "另外与前两个TOP方案所不同的一点是,该TOP方案没有直接预测Nino3.4指数,而是通过预测sst来间接求得Nino3.4指数序列。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 学习目标\n", + "1. 学习TOP方案的模型构建方法" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 内容介绍\n", + "1. 数据处理\n", + " - 数据扁平化\n", + " - 空值填充\n", + " - 构造数据集\n", + "2. 模型构建\n", + " - 构造评估函数\n", + " - 模型构造\n", + " - 模型训练\n", + " - 模型评估\n", + "3. 总结" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 代码示例" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 数据处理\n", + "该TOP方案的数据处理主要包括三部分:\n", + "1. 数据扁平化。\n", + "2. 空值填充。\n", + "3. 构造数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:04:34.698663Z", + "iopub.status.busy": "2021-11-29T03:04:34.697133Z", + "iopub.status.idle": "2021-11-29T03:04:37.035400Z", + "shell.execute_reply": "2021-11-29T03:04:37.034767Z", + "shell.execute_reply.started": "2021-11-29T01:02:51.883602Z" + }, + "papermill": { + "duration": 2.370278, + "end_time": "2021-11-29T03:04:37.035673", + "exception": false, + "start_time": "2021-11-29T03:04:34.665395", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import netCDF4 as nc\n", + "import random\n", + "import os\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "import numpy as np\n", + "import math\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "import torch\n", + "from torch import nn, optim\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", + "\n", + "from sklearn.metrics import mean_squared_error" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:04:37.102995Z", + "iopub.status.busy": "2021-11-29T03:04:37.102144Z", + "iopub.status.idle": "2021-11-29T03:04:37.107646Z", + "shell.execute_reply": "2021-11-29T03:04:37.107161Z", + "shell.execute_reply.started": "2021-11-29T01:02:54.06493Z" + }, + "papermill": { + "duration": 0.040737, + "end_time": "2021-11-29T03:04:37.107761", + "exception": false, + "start_time": "2021-11-29T03:04:37.067024", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 固定随机种子\n", + "SEED = 22\n", + "\n", + "def seed_everything(seed=42):\n", + " random.seed(seed)\n", + " os.environ['PYTHONHASHSEED'] = str(seed)\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.backends.cudnn.deterministic = True\n", + " \n", + "seed_everything(SEED)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:04:37.222525Z", + "iopub.status.busy": "2021-11-29T03:04:37.221100Z", + "iopub.status.idle": "2021-11-29T03:04:37.225844Z", + "shell.execute_reply": "2021-11-29T03:04:37.226442Z", + "shell.execute_reply.started": "2021-11-29T01:02:54.074875Z" + }, + "papermill": { + "duration": 0.090198, + "end_time": "2021-11-29T03:04:37.226602", + "exception": false, + "start_time": "2021-11-29T03:04:37.136404", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CUDA is available! Training on GPU ...\n" + ] + } + ], + "source": [ + "# 查看CUDA是否可用\n", + "train_on_gpu = torch.cuda.is_available()\n", + "\n", + "if not train_on_gpu:\n", + " print('CUDA is not available. Training on CPU ...')\n", + "else:\n", + " print('CUDA is available! Training on GPU ...')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:04:37.353082Z", + "iopub.status.busy": "2021-11-29T03:04:37.352143Z", + "iopub.status.idle": "2021-11-29T03:04:37.432852Z", + "shell.execute_reply": "2021-11-29T03:04:37.434332Z", + "shell.execute_reply.started": "2021-11-28T10:13:13.644947Z" + }, + "papermill": { + "duration": 0.179146, + "end_time": "2021-11-29T03:04:37.434792", + "exception": false, + "start_time": "2021-11-29T03:04:37.255646", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 读取数据\n", + "\n", + "# 存放数据的路径\n", + "path = '/kaggle/input/ninoprediction/'\n", + "soda_train = nc.Dataset(path + 'SODA_train.nc')\n", + "soda_label = nc.Dataset(path + 'SODA_label.nc')\n", + "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n", + "cmip_label = nc.Dataset(path + 'CMIP_label.nc')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 数据扁平化\n", + "采用滑窗构造数据集。该方案中只使用了sst特征,且只使用了lon值在[90, 330]范围内的数据,可能是为了节约计算资源。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:04:37.548239Z", + "iopub.status.busy": "2021-11-29T03:04:37.546737Z", + "iopub.status.idle": "2021-11-29T03:04:37.551951Z", + "shell.execute_reply": "2021-11-29T03:04:37.553081Z", + "shell.execute_reply.started": "2021-11-27T13:38:32.620904Z" + }, + "papermill": { + "duration": 0.065069, + "end_time": "2021-11-29T03:04:37.553274", + "exception": false, + "start_time": "2021-11-29T03:04:37.488205", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def make_flatted(train_ds, label_ds, info, start_idx=0):\n", + " # 只使用sst特征\n", + " keys = ['sst']\n", + " label_key = 'nino'\n", + " # 年数\n", + " years = info[1]\n", + " # 模式数\n", + " models = info[2]\n", + " \n", + " train_list = []\n", + " label_list = []\n", + " \n", + " # 将同种模式下的数据拼接起来\n", + " for model_i in range(models):\n", + " blocks = []\n", + " \n", + " # 对每个特征,取每条数据的前12个月进行拼接,只使用lon值在[90, 330]范围内的数据\n", + " for key in keys:\n", + " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12, :, 19: 67].reshape(-1, 24, 48, 1).data\n", + " blocks.append(block)\n", + " \n", + " # 将所有特征在最后一个维度上拼接起来\n", + " train_flatted = np.concatenate(blocks, axis=-1)\n", + " \n", + " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n", + " label_flatted = np.concatenate([\n", + " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n", + " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n", + " ], axis=0)\n", + " \n", + " train_list.append(train_flatted)\n", + " label_list.append(label_flatted)\n", + " \n", + " return train_list, label_list" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:04:37.661954Z", + "iopub.status.busy": "2021-11-29T03:04:37.660977Z", + "iopub.status.idle": "2021-11-29T03:05:11.515409Z", + "shell.execute_reply": "2021-11-29T03:05:11.515853Z", + "shell.execute_reply.started": "2021-11-27T13:38:33.844185Z" + }, + "papermill": { + "duration": 33.912013, + "end_time": "2021-11-29T03:05:11.516001", + "exception": false, + "start_time": "2021-11-29T03:04:37.603988", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((1, 1200, 24, 48, 1), (15, 1812, 24, 48, 1), (17, 1680, 24, 48, 1))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "soda_info = ('soda', 100, 1)\n", + "cmip6_info = ('cmip6', 151, 15)\n", + "cmip5_info = ('cmip5', 140, 17)\n", + "\n", + "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n", + "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n", + "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n", + "\n", + "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n", + "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 空值填充\n", + "将空值填充为0。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:11.638562Z", + "iopub.status.busy": "2021-11-29T03:05:11.637553Z", + "iopub.status.idle": "2021-11-29T03:05:11.644302Z", + "shell.execute_reply": "2021-11-29T03:05:11.644742Z", + "shell.execute_reply.started": "2021-11-27T13:39:22.665855Z" + }, + "papermill": { + "duration": 0.040786, + "end_time": "2021-11-29T03:05:11.644893", + "exception": false, + "start_time": "2021-11-29T03:05:11.604107", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of null in soda_trains after fillna: 0\n" + ] + } + ], + "source": [ + "# 填充SODA数据中的空值\n", + "soda_trains = np.array(soda_trains)\n", + "soda_trains_nan = np.isnan(soda_trains)\n", + "soda_trains[soda_trains_nan] = 0\n", + "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:11.709054Z", + "iopub.status.busy": "2021-11-29T03:05:11.707767Z", + "iopub.status.idle": "2021-11-29T03:05:11.862744Z", + "shell.execute_reply": "2021-11-29T03:05:11.863294Z", + "shell.execute_reply.started": "2021-11-27T13:39:24.110039Z" + }, + "papermill": { + "duration": 0.18937, + "end_time": "2021-11-29T03:05:11.863480", + "exception": false, + "start_time": "2021-11-29T03:05:11.674110", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of null in cmip6_trains after fillna: 0\n" + ] + } + ], + "source": [ + "# 填充CMIP6数据中的空值\n", + "cmip6_trains = np.array(cmip6_trains)\n", + "cmip6_trains_nan = np.isnan(cmip6_trains)\n", + "cmip6_trains[cmip6_trains_nan] = 0\n", + "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:11.927752Z", + "iopub.status.busy": "2021-11-29T03:05:11.925353Z", + "iopub.status.idle": "2021-11-29T03:05:12.091117Z", + "shell.execute_reply": "2021-11-29T03:05:12.091855Z", + "shell.execute_reply.started": "2021-11-27T13:39:24.520724Z" + }, + "papermill": { + "duration": 0.197975, + "end_time": "2021-11-29T03:05:12.092014", + "exception": false, + "start_time": "2021-11-29T03:05:11.894039", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of null in cmip6_trains after fillna: 0\n" + ] + } + ], + "source": [ + "# 填充CMIP5数据中的空值\n", + "cmip5_trains = np.array(cmip5_trains)\n", + "cmip5_trains_nan = np.isnan(cmip5_trains)\n", + "cmip5_trains[cmip5_trains_nan] = 0\n", + "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 构造数据集\n", + "构造训练和验证集。注意这里取每条输入数据的序列长度是38,这是因为输入sst序列长度是12,输出sst序列长度是26,在训练中采用teacher forcing策略(这个策略会在之后的模型构造时详细说明),因此这里在构造输入数据时包含了输出sst序列的实际值。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:12.165242Z", + "iopub.status.busy": "2021-11-29T03:05:12.164045Z", + "iopub.status.idle": "2021-11-29T03:05:12.480257Z", + "shell.execute_reply": "2021-11-29T03:05:12.479767Z", + "shell.execute_reply.started": "2021-11-27T13:39:25.418945Z" + }, + "papermill": { + "duration": 0.361254, + "end_time": "2021-11-29T03:05:12.480405", + "exception": false, + "start_time": "2021-11-29T03:05:12.119151", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 构造训练集\n", + "\n", + "X_train = []\n", + "y_train = []\n", + "# 从CMIP5的17种模式中各抽取100条数据\n", + "for model_i in range(17):\n", + " samples = np.random.choice(cmip5_trains.shape[1]-38, size=100)\n", + " for ind in samples:\n", + " X_train.append(cmip5_trains[model_i, ind: ind+38])\n", + " y_train.append(cmip5_labels[model_i][ind: ind+24])\n", + "# 从CMIP6的15种模式种各抽取100条数据\n", + "for model_i in range(15):\n", + " samples = np.random.choice(cmip6_trains.shape[1]-38, size=100)\n", + " for ind in samples:\n", + " X_train.append(cmip6_trains[model_i, ind: ind+38])\n", + " y_train.append(cmip6_labels[model_i][ind: ind+24])\n", + "X_train = np.array(X_train)\n", + "y_train = np.array(y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:12.541232Z", + "iopub.status.busy": "2021-11-29T03:05:12.540676Z", + "iopub.status.idle": "2021-11-29T03:05:12.548103Z", + "shell.execute_reply": "2021-11-29T03:05:12.547520Z", + "shell.execute_reply.started": "2021-11-27T13:39:26.341849Z" + }, + "papermill": { + "duration": 0.040262, + "end_time": "2021-11-29T03:05:12.548224", + "exception": false, + "start_time": "2021-11-29T03:05:12.507962", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 构造测试集\n", + "\n", + "X_valid = []\n", + "y_valid = []\n", + "samples = np.random.choice(soda_trains.shape[1]-38, size=100)\n", + "for ind in samples:\n", + " X_valid.append(soda_trains[0, ind: ind+38])\n", + " y_valid.append(soda_labels[0][ind: ind+24])\n", + "X_valid = np.array(X_valid)\n", + "y_valid = np.array(y_valid)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:12.606407Z", + "iopub.status.busy": "2021-11-29T03:05:12.605555Z", + "iopub.status.idle": "2021-11-29T03:05:12.611580Z", + "shell.execute_reply": "2021-11-29T03:05:12.611152Z", + "shell.execute_reply.started": "2021-11-27T13:39:27.247585Z" + }, + "papermill": { + "duration": 0.036214, + "end_time": "2021-11-29T03:05:12.611721", + "exception": false, + "start_time": "2021-11-29T03:05:12.575507", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 查看数据集维度\n", + "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:12.737322Z", + "iopub.status.busy": "2021-11-29T03:05:12.736558Z", + "iopub.status.idle": "2021-11-29T03:05:13.516712Z", + "shell.execute_reply": "2021-11-29T03:05:13.517217Z", + "shell.execute_reply.started": "2021-11-27T13:39:38.421657Z" + }, + "papermill": { + "duration": 0.812187, + "end_time": "2021-11-29T03:05:13.517368", + "exception": false, + "start_time": "2021-11-29T03:05:12.705181", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 保存数据集\n", + "np.save('X_train_sample.npy', X_train)\n", + "np.save('y_train_sample.npy', y_train)\n", + "np.save('X_valid_sample.npy', X_valid)\n", + "np.save('y_valid_sample.npy', y_valid)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 模型构建" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:13.577516Z", + "iopub.status.busy": "2021-11-29T03:05:13.576992Z", + "iopub.status.idle": "2021-11-29T03:05:21.917657Z", + "shell.execute_reply": "2021-11-29T03:05:21.918265Z", + "shell.execute_reply.started": "2021-11-29T01:03:01.505192Z" + }, + "papermill": { + "duration": 8.372964, + "end_time": "2021-11-29T03:05:21.918443", + "exception": false, + "start_time": "2021-11-29T03:05:13.545479", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 读取数据集\n", + "X_train = np.load('../input/ai-earth-task05-samples/X_train_sample.npy')\n", + "y_train = np.load('../input/ai-earth-task05-samples/y_train_sample.npy')\n", + "X_valid = np.load('../input/ai-earth-task05-samples/X_valid_sample.npy')\n", + "y_valid = np.load('../input/ai-earth-task05-samples/y_valid_sample.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:21.983898Z", + "iopub.status.busy": "2021-11-29T03:05:21.982953Z", + "iopub.status.idle": "2021-11-29T03:05:21.986939Z", + "shell.execute_reply": "2021-11-29T03:05:21.986453Z", + "shell.execute_reply.started": "2021-11-29T01:03:11.548945Z" + }, + "papermill": { + "duration": 0.039398, + "end_time": "2021-11-29T03:05:21.987066", + "exception": false, + "start_time": "2021-11-29T03:05:21.947668", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:22.341929Z", + "iopub.status.busy": "2021-11-29T03:05:22.340932Z", + "iopub.status.idle": "2021-11-29T03:05:22.346140Z", + "shell.execute_reply": "2021-11-29T03:05:22.346878Z", + "shell.execute_reply.started": "2021-11-29T01:03:11.560457Z" + }, + "papermill": { + "duration": 0.143838, + "end_time": "2021-11-29T03:05:22.347113", + "exception": false, + "start_time": "2021-11-29T03:05:22.203275", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 构造数据管道\n", + "class AIEarthDataset(Dataset):\n", + " def __init__(self, data, label):\n", + " self.data = torch.tensor(data, dtype=torch.float32)\n", + " self.label = torch.tensor(label, dtype=torch.float32)\n", + "\n", + " def __len__(self):\n", + " return len(self.label)\n", + " \n", + " def __getitem__(self, idx):\n", + " return self.data[idx], self.label[idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:22.583350Z", + "iopub.status.busy": "2021-11-29T03:05:22.582298Z", + "iopub.status.idle": "2021-11-29T03:05:23.243100Z", + "shell.execute_reply": "2021-11-29T03:05:23.243851Z", + "shell.execute_reply.started": "2021-11-29T01:03:23.691846Z" + }, + "papermill": { + "duration": 0.825537, + "end_time": "2021-11-29T03:05:23.244098", + "exception": false, + "start_time": "2021-11-29T03:05:22.418561", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "batch_size = 2\n", + "\n", + "trainset = AIEarthDataset(X_train, y_train)\n", + "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n", + "\n", + "validset = AIEarthDataset(X_valid, y_valid)\n", + "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 构造评估函数" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.655820Z", + "iopub.status.busy": "2021-11-29T03:05:23.655241Z", + "iopub.status.idle": "2021-11-29T03:05:23.658416Z", + "shell.execute_reply": "2021-11-29T03:05:23.658859Z", + "shell.execute_reply.started": "2021-11-29T01:03:26.481561Z" + }, + "papermill": { + "duration": 0.040887, + "end_time": "2021-11-29T03:05:23.658990", + "exception": false, + "start_time": "2021-11-29T03:05:23.618103", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def rmse(y_true, y_preds):\n", + " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n", + "\n", + "# 评估函数\n", + "def score(y_true, y_preds):\n", + " # 相关性技巧评分\n", + " accskill_score = 0\n", + " # RMSE\n", + " rmse_scores = 0\n", + " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n", + " y_true_mean = np.mean(y_true, axis=0)\n", + " y_pred_mean = np.mean(y_preds, axis=0)\n", + " for i in range(24):\n", + " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n", + " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n", + " cor_i = fenzi / fenmu\n", + " accskill_score += a[i] * np.log(i+1) * cor_i\n", + " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n", + " rmse_scores += rmse_score\n", + " return 2/3.0 * accskill_score - rmse_scores" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": 0.028556, + "end_time": "2021-11-29T03:05:23.310560", + "exception": false, + "start_time": "2021-11-29T03:05:23.282004", + "status": "completed" + }, + "tags": [] + }, + "source": [ + "#### 模型构造" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "不同于前两个TOP方案所构建的多输出神经网络,该TOP方案采用的是Seq2Seq结构,以本赛题为例,输入的序列长度是12,输出的序列长度是26,方案中构建了四个隐藏层,那么一个基础的Seq2Seq结构就如下图所示:\n", + "\n", + "\n", + "\n", + "要将Seq2Seq结构应用于不同的问题,重点在于使用怎样的Cell(神经元)。在该TOP方案中使用的Cell是清华大学提出的SA-ConvLSTM(Self-Attention ConvLSTM),论文原文可参考https://ojs.aaai.org//index.php/AAAI/article/view/6819\n", + "\n", + "SA-ConvLSTM是施行健博士提出的时空序列领域经典模型ConvLSTM的改进模型,为了捕捉空间信息的时序依赖关系,它在ConvLSTM的基础上增加了SAM模块,用来记忆空间的聚合特征。ConvLSTM的论文原文可参考https://arxiv.org/pdf/1506.04214.pdf\n", + "\n", + "1. ConvLSTM模型\n", + "\n", + "LSTM模型是非常经典的时序模型,三个门的结构使得它在挖掘长期的时间依赖任务中有不俗的表现,并且相较于RNN,LSTM能够有效地避免梯度消失问题。对于单个输入样本,在每个时间步上,LSTM的每个门实际是对输入向量做了一个全连接,那么对应到我们这个赛题上,输入X的形状是(N,T,H,W,C),则单个输入样本在每个时间步上输入LSTM的就是形状为(H,W,C)的空间信息。我们知道,全连接网络对于这种空间信息的提取能力并不强,转换成卷积操作后能够在大大减少参数量的同时通过堆叠多层网络逐步提取出更复杂的特征,到这里就可以很自然地想到,把LSTM中的全连接操作转换为卷积操作,就能够适用于时空序列问题。ConvLSTM模型就是这么做的,实践也表明这样的作法是非常有效的。\n", + "\n", + "\n", + "\n", + "2. SAM模块\n", + "\n", + "然而,ConvLSTM模型存在两个问题:\n", + "\n", + "一是卷积层的感受野受限于卷积核的大小,需要通过堆叠多个卷积层来扩大感受野,发掘全局的特征。举例来说,假设第一个卷积层的卷积核大小是3×3,那么这一层的每个节点就只能感知这3×3的空间范围内的输入信息,此时再增加一个3×3的卷积层,那么每个节点所能感知的就是3×3个第一层的节点内的信息,在第一层步长为1的情况下,就是4×4范围内的输入信息,于是相比于第一个卷积层,第二层所能感知的输入信息的空间范围就增大了,而这样做所带来的后果就是参数量增加。对于单纯的CNN模型来说增加一层只是增加了一个卷积核大小的参数量,但是对于ConvLSTM来说就有些不堪重负,参数量的增加增大了过拟合的风险,与此同时模型的收效却并不高。\n", + "\n", + "二是卷积操作只针对当前时间步输入的空间信息,而忽视了过去的空间信息,因此难以挖掘空间信息在时间上的依赖关系。\n", + "\n", + "因此,为了同时挖掘全局和本地的空间依赖,提升模型在大空间范围和长时间的时空序列预测任务中的预测效果,SA-ConvLSTM模型在ConvLSTM模型的基础上引入了SAM(self-attention memory)模块。\n", + "\n", + "\n", + "\n", + "SAM模块引入了一个新的记忆单元M,用来记忆包含时序依赖关系的空间信息。SAM模块以当前时间步通过ConvLSTM所获得的隐藏层状态$H_t$和上一个时间步的记忆$M_{t-1}$作为输入,首先将$H_t$通过自注意力机制得到特征$Z_h$,自注意力机制能够增加$H_t$中与其他部分更相关的部分的权重,同时$H_t$也作为Query与$M_{t-1}$共同通过注意力机制得到特征$Z_m$,用以增强对$M_{t-1}$中与$H_t$有更强依赖关系的部分的权重,将$Z_h$和$Z_m$拼接起来就得到了二者的聚合特征$Z$。此时,聚合特征$Z$中既包含了当前时间步的信息,又包含了全局的时空记忆信息,接下来借鉴LSTM中的门控结构用聚合特征$Z$对隐藏层状态和记忆单元进行更新,就得到了更新后的隐藏层状态$\\hat{H_t}$和当前时间步的记忆$M_t$。SAM模块的公式如下:\n", + "\n", + "$$\n", + "\\begin{aligned}\n", + "& i'_t = \\sigma (W_{m;zi} \\ast Z + W_{m;hi} \\ast H_t + b_{m;i}) \\\\\n", + "& g'_t = tanh (W_{m;zg} \\ast Z + W_{m;hg} \\ast H_t + b_{m;g}) \\\\\n", + "& M_t = (1 - i'_t) \\circ M_{t-1} + i'_t \\circ g'_t \\\\\n", + "& o'_t = \\sigma (W_{m;zo} \\ast Z + W_{m;ho} \\ast H_t + b_{m;o}) \\\\\n", + "& \\hat{H_t} = o'_t \\circ M_t\n", + "\\end{aligned}\n", + "$$\n", + "\n", + "关于注意力机制和自注意力机制可以参考以下链接:\n", + "\n", + " - 深度学习中的注意力机制:https://blog.csdn.net/malefactor/article/details/78767781\n", + " - 目前主流的Attention方法:https://www.zhihu.com/question/68482809\n", + "\n", + "3. SA-ConvLSTM模型\n", + "\n", + "将以上二者结合起来,就得到了SA-ConvLSTM模型:\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.372772Z", + "iopub.status.busy": "2021-11-29T03:05:23.371873Z", + "iopub.status.idle": "2021-11-29T03:05:23.373700Z", + "shell.execute_reply": "2021-11-29T03:05:23.374122Z", + "shell.execute_reply.started": "2021-11-29T01:03:24.585147Z" + }, + "papermill": { + "duration": 0.035787, + "end_time": "2021-11-29T03:05:23.374254", + "exception": false, + "start_time": "2021-11-29T03:05:23.338467", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Attention机制\n", + "def attn(query, key, value):\n", + " # query、key、value的形状都是(N, C, H*W),令S=H*W\n", + " # 采用缩放点积模型计算得分,scores(i)=key(i)^T query/根号C\n", + " scores = torch.matmul(query.transpose(1, 2), key / math.sqrt(query.size(1))) # (N, S, S)\n", + " # 计算注意力得分\n", + " attn = F.softmax(scores, dim=-1)\n", + " output = torch.matmul(attn, value.transpose(1, 2)) # (N, S, C)\n", + " return output.transpose(1, 2) # (N, C, S)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.440765Z", + "iopub.status.busy": "2021-11-29T03:05:23.440042Z", + "iopub.status.idle": "2021-11-29T03:05:23.442191Z", + "shell.execute_reply": "2021-11-29T03:05:23.442569Z", + "shell.execute_reply.started": "2021-11-29T01:03:25.147999Z" + }, + "papermill": { + "duration": 0.041095, + "end_time": "2021-11-29T03:05:23.442725", + "exception": false, + "start_time": "2021-11-29T03:05:23.401630", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# SAM模块\n", + "class SAAttnMem(nn.Module):\n", + " def __init__(self, input_dim, d_model, kernel_size):\n", + " super().__init__()\n", + " pad = kernel_size[0] // 2, kernel_size[1] // 2\n", + " self.d_model = d_model\n", + " self.input_dim = input_dim\n", + " # 用1*1卷积实现全连接操作WhHt\n", + " self.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)\n", + " # 用1*1卷积实现全连接操作WmMt-1\n", + " self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)\n", + " # 用1*1卷积实现全连接操作Wz[Zh,Zm]\n", + " self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)\n", + " # 注意输出维度和输入维度要保持一致,都是input_dim\n", + " self.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)\n", + " \n", + " def forward(self, h, m):\n", + " # self.conv_h(h)得到WhHt,将其在dim=1上划分成大小为self.d_model的块,每一块的形状就是(N, d_model, H, W),所得到的三块就是Qh、Kh、Vh\n", + " hq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)\n", + " # 同样的方法得到Km和Vm\n", + " mk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)\n", + " N, C, H, W = hq.size()\n", + " # 通过自注意力机制得到Zh\n", + " Zh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1)) # (N, C, S), C=d_model\n", + " # 通过注意力机制得到Zm\n", + " Zm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1)) # (N, C, S), C=d_model\n", + " # 将Zh和Zm拼接起来,并进行全连接操作得到聚合特征Z\n", + " Z = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1)) # (N, C, H, W), C=d_model\n", + " # 计算i't、g't、o't\n", + " i, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1) # (N, C, H, W), C=input_dim\n", + " i = torch.sigmoid(i)\n", + " g = torch.tanh(g)\n", + " # 得到更新后的记忆单元Mt\n", + " m_next = i * g + (1 - i) * m\n", + " # 得到更新后的隐藏状态Ht\n", + " h_next = torch.sigmoid(o) * m_next\n", + " return h_next, m_next" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.509738Z", + "iopub.status.busy": "2021-11-29T03:05:23.509080Z", + "iopub.status.idle": "2021-11-29T03:05:23.512667Z", + "shell.execute_reply": "2021-11-29T03:05:23.512182Z", + "shell.execute_reply.started": "2021-11-29T01:03:25.667808Z" + }, + "papermill": { + "duration": 0.042616, + "end_time": "2021-11-29T03:05:23.512781", + "exception": false, + "start_time": "2021-11-29T03:05:23.470165", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# SA-ConvLSTM Cell\n", + "class SAConvLSTMCell(nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n", + " super().__init__()\n", + " self.input_dim = input_dim\n", + " self.hidden_dim = hidden_dim\n", + " pad = kernel_size[0] // 2, kernel_size[1] // 2\n", + " # 卷积操作Wx*Xt+Wh*Ht-1\n", + " self.conv = nn.Conv2d(in_channels=input_dim+hidden_dim, out_channels=4*hidden_dim, kernel_size=kernel_size, padding=pad)\n", + " self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)\n", + " \n", + " def initialize(self, inputs):\n", + " device = inputs.device\n", + " N, _, H, W = inputs.size()\n", + " # 初始化隐藏层状态Ht\n", + " self.hidden_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", + " # 初始化记忆细胞状态ct\n", + " self.cell_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", + " # 初始化记忆单元状态Mt\n", + " self.memory_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n", + " \n", + " def forward(self, inputs, first_step=False):\n", + " # 如果当前是第一个时间步,初始化Ht、ct、Mt\n", + " if first_step:\n", + " self.initialize(inputs)\n", + " \n", + " # ConvLSTM部分\n", + " # 拼接Xt和Ht\n", + " combined = torch.cat([inputs, self.hidden_state], dim=1) # (N, C, H, W), C=input_dim+hidden_dim\n", + " # 进行卷积操作\n", + " combined_conv = self.conv(combined) \n", + " # 得到四个门控单元it、ft、ot、gt\n", + " cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)\n", + " i = torch.sigmoid(cc_i)\n", + " f = torch.sigmoid(cc_f)\n", + " o = torch.sigmoid(cc_o)\n", + " g = torch.tanh(cc_g)\n", + " # 得到当前时间步的记忆细胞状态ct=ft·ct-1+it·gt\n", + " self.cell_state = f * self.cell_state + i * g\n", + " # 得到当前时间步的隐藏层状态Ht=ot·tanh(ct)\n", + " self.hidden_state = o * torch.tanh(self.cell_state)\n", + " \n", + " # SAM部分,更新Ht和Mt\n", + " self.hidden_state, self.memory_state = self.sa(self.hidden_state, self.memory_state)\n", + " \n", + " return self.hidden_state" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在Seq2Seq模型的训练中,有两种训练模式。一是Free running,也就是传统的训练方式,以上一个时间步的输出$\\hat{y_{t-1}}$作为下一个时间步的输入,但是这种做法存在的问题是在训练的初期所得到的$\\hat{y_{t-1}}$与实际标签$y_{t-1}$相差甚远,以此作为输入会导致后续的输出越来越偏离我们期望的预测标签。于是就产生了第二种训练模式——Teacher forcing。\n", + "\n", + "Teacher forcing就是直接使用实际标签$y_{t-1}$作为下一个时间步的输入,由老师(ground truth)带领着防止模型越走越偏。但是老师不能总是手把手领着学生走,要逐渐放手让学生自主学习,于是我们使用Scheduled Sampling来控制使用实际标签的概率。我们用ratio来表示Scheduled Sampling的比例,在训练初期,ratio=1,模型完全由老师带领着,随着训练论述的增加,ratio以一定的方式衰减(该方案中使用线性衰减,ratio每次减小一个衰减率decay_rate),每个时间步以ratio的概率从伯努利分布中提取二进制随机数0或1,为1时输入就是实际标签$y_{t-1}$,否则输入为$\\hat{y_{t-1}}$。" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.587567Z", + "iopub.status.busy": "2021-11-29T03:05:23.586781Z", + "iopub.status.idle": "2021-11-29T03:05:23.588776Z", + "shell.execute_reply": "2021-11-29T03:05:23.589156Z", + "shell.execute_reply.started": "2021-11-29T01:03:26.065997Z" + }, + "papermill": { + "duration": 0.047514, + "end_time": "2021-11-29T03:05:23.589277", + "exception": false, + "start_time": "2021-11-29T03:05:23.541763", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 构建SA-ConvLSTM模型\n", + "class SAConvLSTM(nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n", + " super().__init__()\n", + " self.input_dim = input_dim\n", + " self.hidden_dim = hidden_dim\n", + " self.num_layers = len(hidden_dim)\n", + " \n", + " layers = []\n", + " for i in range(self.num_layers):\n", + " cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]\n", + " layers.append(SAConvLSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], d_attn = d_attn, kernel_size=kernel_size)) \n", + " self.layers = nn.ModuleList(layers)\n", + " \n", + " self.conv_output = nn.Conv2d(self.hidden_dim[-1], 1, kernel_size=1)\n", + " \n", + " def forward(self, input_x, device=torch.device('cuda:0'), input_frames=12, future_frames=26, output_frames=37, teacher_forcing=False, scheduled_sampling_ratio=0, train=True):\n", + " # 将输入样本X的形状(N, T, H, W, C)转换为(N, T, C, H, W)\n", + " input_x = input_x.permute(0, 1, 4, 2, 3).contiguous()\n", + " \n", + " # 仅在训练时使用teacher forcing\n", + " if train:\n", + " if teacher_forcing and scheduled_sampling_ratio > 1e-6:\n", + " teacher_forcing_mask = torch.bernoulli(scheduled_sampling_ratio * torch.ones(input_x.size(0), future_frames-1, 1, 1, 1))\n", + " else:\n", + " teacher_forcing = False\n", + " else:\n", + " teacher_forcing = False\n", + " \n", + " total_steps = input_frames + future_frames - 1\n", + " outputs = [None] * total_steps\n", + " \n", + " # 对于每一个时间步\n", + " for t in range(total_steps):\n", + " # 在前12个月,使用每个月的输入样本Xt\n", + " if t < input_frames:\n", + " input_ = input_x[:, t].to(device)\n", + " # 若不使用teacher forcing,则以上一个时间步的预测标签作为当前时间步的输入\n", + " elif not teacher_forcing:\n", + " input_ = outputs[t-1]\n", + " # 若使用teacher forcing,则以ratio的概率使用上一个时间步的实际标签作为当前时间步的输入\n", + " else:\n", + " mask = teacher_forcing_mask[:, t-input_frames].float().to(device)\n", + " input_ = input_x[:, t].to(device) * mask + outputs[t-1] * (1-mask)\n", + " first_step = (t==0)\n", + " input_ = input_.float()\n", + " \n", + " # 将当前时间步的输入通过隐藏层\n", + " for layer_idx in range(self.num_layers):\n", + " input_ = self.layers[layer_idx](input_, first_step=first_step)\n", + " \n", + " # 记录每个时间步的输出\n", + " if train or (t >= (input_frames - 1)):\n", + " outputs[t] = self.conv_output(input_)\n", + " \n", + " outputs = [x for x in outputs if x is not None]\n", + " \n", + " # 确认输出序列的长度\n", + " if train:\n", + " assert len(outputs) == output_frames\n", + " else:\n", + " assert len(outputs) == future_frames\n", + " \n", + " # 得到sst的预测序列\n", + " outputs = torch.stack(outputs, dim=1)[:, :, 0] # (N, 37, H, W)\n", + " # 对sst的预测序列在nino3.4区域取三个月的平均值就得到nino3.4指数的预测序列\n", + " nino_pred = outputs[:, -future_frames:, 10:13, 19:30].mean(dim=[2, 3]) # (N, 26)\n", + " nino_pred = nino_pred.unfold(dimension=1, size=3, step=1).mean(dim=2) # (N, 24)\n", + " \n", + " return nino_pred" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.726291Z", + "iopub.status.busy": "2021-11-29T03:05:23.725688Z", + "iopub.status.idle": "2021-11-29T03:05:23.753509Z", + "shell.execute_reply": "2021-11-29T03:05:23.753976Z", + "shell.execute_reply.started": "2021-11-29T01:03:29.448921Z" + }, + "papermill": { + "duration": 0.066105, + "end_time": "2021-11-29T03:05:23.754109", + "exception": false, + "start_time": "2021-11-29T03:05:23.688004", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SAConvLSTM(\n", + " (layers): ModuleList(\n", + " (0): SAConvLSTMCell(\n", + " (conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (sa): SAAttnMem(\n", + " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (1): SAConvLSTMCell(\n", + " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (sa): SAAttnMem(\n", + " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (2): SAConvLSTMCell(\n", + " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (sa): SAAttnMem(\n", + " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " (3): SAConvLSTMCell(\n", + " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " (sa): SAAttnMem(\n", + " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n", + " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (conv_output): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n", + ")\n" + ] + } + ], + "source": [ + "# 输入特征数\n", + "input_dim = 1\n", + "# 隐藏层节点数\n", + "hidden_dim = (64, 64, 64, 64)\n", + "# 注意力机制节点数\n", + "d_attn = 32\n", + "# 卷积核大小\n", + "kernel_size = (3, 3)\n", + "\n", + "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 模型训练" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.816671Z", + "iopub.status.busy": "2021-11-29T03:05:23.815927Z", + "iopub.status.idle": "2021-11-29T03:05:23.818479Z", + "shell.execute_reply": "2021-11-29T03:05:23.818058Z", + "shell.execute_reply.started": "2021-11-29T01:03:31.476806Z" + }, + "papermill": { + "duration": 0.035723, + "end_time": "2021-11-29T03:05:23.818579", + "exception": false, + "start_time": "2021-11-29T03:05:23.782856", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 采用RMSE作为损失函数\n", + "def RMSELoss(y_pred,y_true):\n", + " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T03:05:23.893469Z", + "iopub.status.busy": "2021-11-29T03:05:23.892684Z", + "iopub.status.idle": "2021-11-29T04:55:28.956056Z", + "shell.execute_reply": "2021-11-29T04:55:28.956434Z" + }, + "papermill": { + "duration": 6605.109145, + "end_time": "2021-11-29T04:55:28.956614", + "exception": false, + "start_time": "2021-11-29T03:05:23.847469", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 1/5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training Loss: 3.289\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "50it [00:11, 4.47it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 44.009\n", + "Score: -43.458\n", + "Epoch: 2/5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training Loss: 3.084\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "50it [00:11, 4.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 25.011\n", + "Score: -19.966\n", + "Epoch: 3/5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1600/1600 [21:46<00:00, 1.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training Loss: 13.461\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "50it [00:12, 4.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 15.438\n", + "Score: -14.139\n", + "Epoch: 4/5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1600/1600 [21:54<00:00, 1.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training Loss: 17.627\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "50it [00:12, 3.99it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 15.389\n", + "Score: -22.500\n", + "Epoch: 5/5\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1600/1600 [21:55<00:00, 1.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training Loss: 17.592\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "50it [00:11, 4.48it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 15.252\n", + "Score: -14.459\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "model_weights = './task05_model_weights.pth'\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size).to(device)\n", + "criterion = RMSELoss\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n", + "lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=0, verbose=True, min_lr=0.0001)\n", + "epochs = 5\n", + "ratio, decay_rate = 1, 8e-5\n", + "train_losses, valid_losses = [], []\n", + "scores = []\n", + "best_score = float('-inf')\n", + "preds = np.zeros((len(y_valid),24))\n", + "\n", + "for epoch in range(epochs):\n", + " print('Epoch: {}/{}'.format(epoch+1, epochs))\n", + " \n", + " # 模型训练\n", + " model.train()\n", + " losses = 0\n", + " for data, labels in tqdm(trainloader):\n", + " data = data.to(device)\n", + " labels = labels.to(device)\n", + " optimizer.zero_grad()\n", + " # ratio线性衰减\n", + " ratio = max(ratio-decay_rate, 0)\n", + " pred = model(data, teacher_forcing=True, scheduled_sampling_ratio=ratio, train=True)\n", + " loss = criterion(pred, labels)\n", + " losses += loss.cpu().detach().numpy()\n", + " loss.backward()\n", + " optimizer.step()\n", + " train_loss = losses / len(trainloader)\n", + " train_losses.append(train_loss)\n", + " print('Training Loss: {:.3f}'.format(train_loss))\n", + " \n", + " # 模型验证\n", + " model.eval()\n", + " losses = 0\n", + " with torch.no_grad():\n", + " for i, data in tqdm(enumerate(validloader)):\n", + " data, labels = data\n", + " data = data.to(device)\n", + " labels = labels.to(device)\n", + " pred = model(data, train=False)\n", + " loss = criterion(pred, labels)\n", + " losses += loss.cpu().detach().numpy()\n", + " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", + " valid_loss = losses / len(validloader)\n", + " valid_losses.append(valid_loss)\n", + " print('Validation Loss: {:.3f}'.format(valid_loss))\n", + " s = score(y_valid, preds)\n", + " scores.append(s)\n", + " print('Score: {:.3f}'.format(s))\n", + " \n", + " # 保存最佳模型权重\n", + " if s > best_score:\n", + " best_score = s\n", + " checkpoint = {'best_score': s,\n", + " 'state_dict': model.state_dict()}\n", + " torch.save(checkpoint, model_weights)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T04:55:33.957872Z", + "iopub.status.busy": "2021-11-29T04:55:33.957066Z", + "iopub.status.idle": "2021-11-29T04:55:33.960119Z", + "shell.execute_reply": "2021-11-29T04:55:33.959684Z", + "shell.execute_reply.started": "2021-11-28T14:00:36.33194Z" + }, + "papermill": { + "duration": 2.38263, + "end_time": "2021-11-29T04:55:33.960247", + "exception": false, + "start_time": "2021-11-29T04:55:31.577617", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 绘制训练/验证曲线\n", + "def training_vis(train_losses, valid_losses):\n", + " # 绘制损失函数曲线\n", + " fig = plt.figure(figsize=(8,4))\n", + " # subplot loss\n", + " ax1 = fig.add_subplot(121)\n", + " ax1.plot(train_losses, label='train_loss')\n", + " ax1.plot(valid_losses,label='val_loss')\n", + " ax1.set_xlabel('Epochs')\n", + " ax1.set_ylabel('Loss')\n", + " ax1.set_title('Loss on Training and Validation Data')\n", + " ax1.legend()\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T04:55:38.227343Z", + "iopub.status.busy": "2021-11-29T04:55:38.226636Z", + "iopub.status.idle": "2021-11-29T04:55:38.470256Z", + "shell.execute_reply": "2021-11-29T04:55:38.469252Z", + "shell.execute_reply.started": "2021-11-28T14:00:43.42651Z" + }, + "papermill": { + "duration": 2.378943, + "end_time": "2021-11-29T04:55:38.470387", + "exception": false, + "start_time": "2021-11-29T04:55:36.091444", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsYUlEQVR4nO3dd3wUdf7H8dcnhSRApIYSeu8QMKKIngoWRATPE1AsgIVT0cMu3umpWE899TwVT08BFQsWDsQu4A89UQiYUKRzlBBCDwQlkPL5/TETWCCBhOxkspvP8/HYR2ZnZmfeu5l8MvPd78yIqmKMMaEmwu8AxhhzIqx4GWNCkhUvY0xIsuJljAlJVryMMSHJipcxJiRZ8aqkRORzERke7Hn9JCLrRORcD5b7rYhc7w5fKSJflWTeE1hPUxHZKyKRJ5q1MqlUxcurjbu8uBt24aNARPYFPL+yNMtS1QtVdVKw562IRGSsiMwpYnxdETkgIp1LuixVnayq5wcp12Hbo6puUNXqqpofjOUfsS4VkV/dbWWHiMwUkaGleP3ZIpIe7FxlUamKV6hzN+zqqlod2ABcHDBucuF8IhLlX8oK6W3gdBFpccT4y4HFqrrEh0x+6OZuO+2AicCLIvKgv5FOnBUvQERiROR5EclwH8+LSIw7ra6IzBCRLBHZKSLfiUiEO+1eEdkkItkiskJE+haz/Boi8qaIbBOR9SJyf8AyRojI9yLyjIjsEpH/iciFpcx/toiku3kygQkiUsvNvc1d7gwRaRzwmsBDoWNmKOW8LURkjvuZfCMiL4nI28XkLknGR0Tkv+7yvhKRugHTr3Y/zx0i8pfiPh9VTQdmAVcfMeka4M3j5Tgi8wgR+T7g+XkislxEdovIi4AETGslIrPcfNtFZLKI1HSnvQU0BT5x94buEZHm7h5SlDtPoohMd7e71SJyQ8CyHxKRKe52lS0iS0UkubjP4IjPY7uqvgXcBNwnInXcZY4UkWXu8taKyB/d8dWAz4FEObSnnygiPUVkrvu3sVlEXhSRKiXJEAxWvBx/AU4DkoBuQE/gfnfanUA6kADUB/4MqIi0A24BTlHVeOACYF0xy/8nUANoCZyF80czMmD6qcAKoC7wFPC6iMiRCzmOBkBtoBkwCud3O8F93hTYB7x4jNeXJsOx5n0HmAfUAR7i6IIRqCQZh+F8VvWAKsBdACLSERjvLj/RXV+RBcc1KTCL+/tLcvOW9rMqXEZd4GOcbaUusAboHTgL8ISbrwPQBOczQVWv5vC956eKWMV7ONteInAZ8LiI9AmYPtCdpyYwvSSZjzANiMLZ3gG2AgOAk3A+8+dEpIeq/gpcCGQE7OlnAPnA7e577wX0BW4uZYYTp6qV5oFTXM4tYvwaoH/A8wuAde7wOJxfcusjXtMa55d9LhB9jHVGAgeAjgHj/gh86w6PAFYHTKsKKNCgpO8FONtdR+wx5k8CdgU8/xa4viQZSjovzh9+HlA1YPrbwNsl/P0UlfH+gOc3A1+4w38F3guYVs39DI76/Qbk3AOc7j5/DJh2gp/V9+7wNcCPAfMJTrG5vpjlXgL8XNz2CDR3P8sonEKXD8QHTH8CmOgOPwR8EzCtI7DvGJ+tcsQ27I7PBK4s5jX/AcYEbGPpx/n93QZMLcnvOhgP2/NyJALrA56vd8cBPA2sBr5yd6XHAqjqapxf1kPAVhF5T0QSOVpdILqI5TcKeJ5ZOKCqv7mD1Uv5Hrapak7hExGpKiL/cg+r9gBzgJpS/DdZpclQ3LyJwM6AcQAbiwtcwoyZAcO/BWRKDFy2OnsHO4pbl5vpA+Aady/xSuDNUuQoypEZNPC5iNR3t4tN7nLfxtkeSqLws8wOGFfsdoPz2cRKKdo7RSQa54hip/v8QhH50T1MzQL6HyuviLR1D7Ez3ff3+LHmDzYrXo4MnEOGQk3dcahqtqreqaotcXbT7xC3bUtV31HVM9zXKvC3Ipa9HcgtYvmbgvwejrw8yJ04DbOnqupJwO/c8aU9HC2NzUBtEakaMK7JMeYvS8bNgct211nnOK+ZBAwBzgPigU/KmOPIDMLh7/dxnN9LF3e5Vx2xzGNd0iUD57OMDxgX7O1mEM6e8jxx2ng/Ap4B6qtqTeCzgLxFZR0PLAfauO/vz3i7fR2mMhavaBGJDXhEAe8C94tIgtuO8Vec/5KIyAARae1umLtxduULRKSdiPRxf+k5OO0kBUeuTJ2vvacAj4lIvIg0A+4oXL6H4t1MWSJSG/D8WyVVXQ+kAA+JSBUR6QVc7FHGD4EBInKG20g8juNvz98BWcCrOIecB8qY41Ogk4hc6m5Hf8I5fC4UD+wFdotII+DuI16/Bacd9CiquhH4AXjC3U67AtcRhO1GRGqL07XmJeBvqroDpz0xBtgG5InzJUxgl5AtQB0RqXHE+9sD7BWR9jhfAJSbyli8PsPZUAsfDwGP4vzRLQIWAwvdcQBtgG9wNsK5wMuqOhvnF/0kzp5VJk6D8n3FrPNW4FdgLfA9TiPxG8F9W0d5Hohz8/0IfOHx+gpdidN4uwPnM3wf2F/MvM9zghlVdSkwGuez3AzswmlvOtZrFOdQsZn7s0w5VHU7MBhnO9iBs638N2CWh4EeOP/0PsVp3A/0BM4/zSwRuauIVVyB0w6WAUwFHlTVb0qSrRhpIrIXpxnkeuB2Vf2r+16ycYrvFJzPchjOlwCF73U5zj/5tW7eRJwvT4YB2cBrOL/rciNuQ5sxnhCR94Hlqhqy/YlMxVQZ97yMh0TkFHH6N0WISD+cdpX/+BzLhCHriW2CrQHO4VEdnMO4m1T1Z38jmXBkh43GmJBkh43GmJAUEoeNdevW1ebNm/sdwxhTzhYsWLBdVROKmhYSxat58+akpKT4HcMYU85EZH1x0+yw0RgTkqx4GWNCkhUvY0xICok2L2MqotzcXNLT08nJyTn+zOaYYmNjady4MdHR0SV+jRUvY05Qeno68fHxNG/enNJfO9IUUlV27NhBeno6LVoceaXu4tlhozEnKCcnhzp16ljhKiMRoU6dOqXeg7XiZUwZWOEKjhP5HMOreOXsgR/Hg53yZEzYC6/itXwGfDEWUicff15jTEgLr+LV9XJo2gu+uh9+3e53GmM8lZWVxcsvv1zq1/Xv35+srKxSv27EiBF8+OGHpX6dV8KreEVEwIDnYf9ep4AZE8aKK155eXnHfN1nn31GzZo1PUpVfsKvq0S99tB7DHz3DHS7Alqe5XciUwk8/MlSfsnYE9Rldkw8iQcv7lTs9LFjx7JmzRqSkpKIjo4mNjaWWrVqsXz5clauXMkll1zCxo0bycnJYcyYMYwaNQo4dK7w3r17ufDCCznjjDP44YcfaNSoEdOmTSMuLu642WbOnMldd91FXl4ep5xyCuPHjycmJoaxY8cyffp0oqKiOP/883nmmWf44IMPePjhh4mMjKRGjRrMmTMnKJ9PeO15FfrdXVC7Jcy4HXKtA6EJT08++SStWrUiNTWVp59+moULF/KPf/yDlStXAvDGG2+wYMECUlJSeOGFF9ix4+g7w61atYrRo0ezdOlSatasyUcffXTc9ebk5DBixAjef/99Fi9eTF5eHuPHj2fHjh1MnTqVpUuXsmjRIu6/3zn6GTduHF9++SVpaWlMnz79OEsvufDb8wKIjoOLnoW3LoHv/g59ir0TvDFBcaw9pPLSs2fPwzp5vvDCC0ydOhWAjRs3smrVKurUOfzucC1atCApKQmAk08+mXXr1h13PStWrKBFixa0bdsWgOHDh/PSSy9xyy23EBsby3XXXceAAQMYMGAAAL1792bEiBEMGTKESy+9NAjv1BGee14Arc6BrkPh++dg2wq/0xjjuWrVqh0c/vbbb/nmm2+YO3cuaWlpdO/evchOoDExMQeHIyMjj9tedixRUVHMmzePyy67jBkzZtCvXz8AXnnlFR599FE2btzIySefXOQe4IkI3+IFcP5jUKUafHIbFBx1S0VjQlp8fDzZ2dlFTtu9eze1atWiatWqLF++nB9//DFo623Xrh3r1q1j9erVALz11lucddZZ7N27l927d9O/f3+ee+450tLSAFizZg2nnnoq48aNIyEhgY0bi72JeqmE52FjoeoJcP4jMP1Wp+9Xj6v9TmRM0NSpU4fevXvTuXNn4uLiqF+//sFp/fr145VXXqFDhw60a9eO0047LWjrjY2NZcKECQwePPhgg/2NN97Izp07GTRoEDk5Oagqzz77LAB33303q1atQlXp27cv3bp1C0qOkLgBR3Jysp7wlVRVYeJFsGUp3JLiFDRjgmDZsmV06NDB7xhho6jPU0QWqGpyUfOH92EjgAgMeA4O/ApfWcO9MeEi/IsXQEI7OON2WPQ+rJntdxpjKrTRo0eTlJR02GPChAl+xzpKeLd5BTrzTljyIXx6B9z0g9OdwhhzlJdeesnvCCVSOfa8AKJjncPHnWthzjN+pzHGlFHlKV4ALc92Thn67z9g6zK/0xhjyqByFS+A8x+FmOrOqUPW98uYkFX5ile1uk4B2zAXfn7L7zTGmBPkefESkUgR+VlEZrjPW4jITyKyWkTeF5EqXmc4StKV0OwM+PoB2Lu13FdvjB+qV69e7LR169bRuXPnckxTduWx5zUGCGxg+hvwnKq2BnYB15VDhsMV9v3K3Qdf/rncV2+MKTtPu0qISGPgIuAx4A5xrrLfBxjmzjIJeAgY72WOIiW0hTPugP970mnEb9233COYMPL5WMhcHNxlNugCFz5Z7OSxY8fSpEkTRo8eDcBDDz1EVFQUs2fPZteuXeTm5vLoo48yaNCgUq02JyeHm266iZSUFKKionj22Wc555xzWLp0KSNHjuTAgQMUFBTw0UcfkZiYyJAhQ0hPTyc/P58HHniAoUOHlultl5TXe17PA/cAhS3jdYAsVS08dT0daFTUC0VklIikiEjKtm3bvEl3xu1Qp7XT9yt3nzfrMMYjQ4cOZcqUKQefT5kyheHDhzN16lQWLlzI7NmzufPOOyntKYAvvfQSIsLixYt59913GT58ODk5ObzyyiuMGTOG1NRUUlJSaNy4MV988QWJiYmkpaWxZMmSg1eSKA+e7XmJyABgq6ouEJGzS/t6VX0VeBWccxuDm85V2Pdr0sUw52no+1dPVmMqgWPsIXmle/fubN26lYyMDLZt20atWrVo0KABt99+O3PmzCEiIoJNmzaxZcsWGjRoUOLlfv/999x6660AtG/fnmbNmrFy5Up69erFY489Rnp6Opdeeilt2rShS5cu3Hnnndx7770MGDCAM88806u3exQv97x6AwNFZB3wHs7h4j+AmiJSWDQbA5s8zHB8LX4H3YY5fb+2/OJrFGNKa/DgwXz44Ye8//77DB06lMmTJ7Nt2zYWLFhAamoq9evXL/XNXIszbNgwpk+fTlxcHP3792fWrFm0bduWhQsX0qVLF+6//37GjRsXlHWVhGfFS1XvU9XGqtocuByYpapXArOBy9zZhgPTvMpQYuc/CjEnwYzbrO+XCSlDhw7lvffe48MPP2Tw4MHs3r2bevXqER0dzezZs1m/fn2pl3nmmWcyebJz+8CVK1eyYcMG2rVrx9q1a2nZsiV/+tOfGDRoEIsWLSIjI4OqVaty1VVXcffdd7Nw4cJgv8Vi+dHP616cxvvVOG1gr/uQ4XDV6sAFj8HGn2DhJL/TGFNinTp1Ijs7m0aNGtGwYUOuvPJKUlJS6NKlC2+++Sbt27cv9TJvvvlmCgoK6NKlC0OHDmXixInExMQwZcoUOnfuTFJSEkuWLOGaa65h8eLF9OzZk6SkJB5++OGD160vD+F/Pa+SUnXavjIXwej5EF//+K8xlZpdzyu47HpeJ+qwvl/3+Z3GGHMcleeSOCVRtw2ceRd8+7jTiN/mXL8TGRNUixcv5uqrD78cekxMDD/99JNPiU6cFa8jnXEbLP7A6ft1849QparfiUwFpqo4fa9DQ5cuXUhNTfU7xlFOpPnKDhuPFBUDFz8PWethzlN+pzEVWGxsLDt27DihPzxziKqyY8cOYmNjS/U62/MqSvMzIOkq+OGf0GUw1Pf/hqKm4mncuDHp6el4dgZIJRIbG0vjxo1L9RorXsU5/xFY+blzz8drv4QI20k1h4uOjj7sDtWmfNlfZHGq1oYLHof0ebCg4t18wJjKzorXsXQd6pw+9M3DkJ3pdxpjTAArXsciAhc9B3k58IX1/TKmIrHidTx1W8Pv7oKlH8Oqr/1OY4xxWfEqid5joG47mHGHc+dtY4zvrHiVRFSMc+rQ7g3wf3/zO40xBiteJde8N3S/Gn54ETKX+J3GmErPildpnDcO4mrBJ2OgIN/vNMZUala8SqOw79emFEh5w+80xlRqVrxKq+sQaHk2zBwHezb7ncaYSsuKV2mJwEXPQt5++GKs32mMqbSseJ2IOq3grLvhl//Ayi/9TmNMpWTF60SdPgYS2sOnd1nfL2N8YMXrREVVgQHPO32/vn3C7zTGVDpWvMqiWS/oMRzmvgybF/mdxphKxYpXWZ37kNOFYsZt1vfLmHJkxausqtaGC56ATQtgvv+3oDSmsrDiFQxdLoNWfdy+Xxl+pzGmUrDiFQwicNHfoSAXPr/X7zTGVApWvIKldks46x5YNh1WfO53GmPCnhWvYOp1KyR0gM/uhv17/U5jTFiz4hVMUVWcez7u3mh9v4zxmBWvYGt6Gpw8En58GTan+Z3GmLBlxcsL5z4IVevadb+M8ZAVLy/E1YJ+T0DGzzDvNb/TGBOWrHh5pfMfoFVfmPUI7N7kdxpjwo4VL6+IwIBnncPGz+/xO40xYceKl5dqNYez74XlM2D5p36nMSasWPHyWq9boF5Ht+9Xtt9pjAkbVry8FhkNF/8D9myC2Y/7ncaYsGHFqzw06QnJ18JPrzjfQBpjysyKV3np+yBUS3D6fuXn+Z3GmJBnxau8xNWEfk86ve7nW98vY8rKs+IlIrEiMk9E0kRkqYg87I5vISI/ichqEXlfRKp4laHC6fR7aH0ezHoUdqf7ncaYkOblntd+oI+qdgOSgH4ichrwN+A5VW0N7AKu8zBDxSICFz3j9P36zPp+GVMWnhUvdRReFybafSjQB/jQHT8JuMSrDBVSreZwzn2w4lNYNsPvNMaELE/bvEQkUkRSga3A18AaIEtVC1us04FGXmaokE67Gep3tr5fxpSBp8VLVfNVNQloDPQE2pf0tSIySkRSRCRl27ZtXkX0R2S0c8/H7M1O+5cxptTK5dtGVc0CZgO9gJoiEuVOagwUedayqr6qqsmqmpyQkFAeMctXk1PglOvgp385dx4yxpSKl982JohITXc4DjgPWIZTxC5zZxsOTPMqQ4XX969QvT58cpv1/TKmlLzc82oIzBaRRcB84GtVnQHcC9whIquBOkDlvdlhbA248G+QuQjm/cvvNMaElKjjz3JiVHUR0L2I8Wtx2r8MQMdB0OYCmPUYdBgINZv4nciYkGA97P0mAv2fBtT59lHV70TGhAQrXhVBrWZw9n2w8nNY9onfaYwJCVa8KorTbob6XZyrrubs8TuNMRWeFa+KIjLKue5Xdqb1/TKmBKx4VSSNT4aeN8C8VyHd+n4ZcyxWvCqaPvdDfAO77pcxx2HFq6Ip7Pu1ZTH8NN7vNMZUWFa8KqIOA6Hthc4177M2+J3GmArJildFdLDvl8Cnd1nfL2OKYMWroqrZBM75M6z6En6pvKd/GlMcK14V2ak3QoOu8Pm9kLPb7zTGVChWvCqywr5fv26FmY/4ncaYCsWKV0XXqAf0HAXz/w3pKX6nMabCsOIVCs75C8Q3dPt+5fqdxpgKwYpXKIg9Cfo/BVuWwI8v+53GmArBileoaD8A2vWH2U/ArvV+pzHGd1a8QkVh3y+JgM+s75cxVrxCSY3GzrmPq76CpVP9TmOMr0pUvESkmohEuMNtRWSgiER7G80UqecoaNjN6fu1fZXfaYzxTUn3vOYAsSLSCPgKuBqY6FUocwyRUTDoZdACeK0PrPzS70TG+KKkxUtU9TfgUuBlVR0MdPIuljmmBp1h1LdQqzm8MxTmPGNtYKbSKXHxEpFewJXAp+64SG8imRKp2QSu/RK6DIZZj8AHw2H/Xr9TGVNuSlq8bgPuA6aq6lIRaYlz81jjpypV4dJX4fxHnRt3vH4e7FzrdypjyoVoKQ833Ib76qpabneJSE5O1pQUOzXmmNbMgg9GOsODJ0CrPv7mMSYIRGSBqiYXNa2k3za+IyIniUg1YAnwi4jcHcyQpoxa9XHawU5qBG//Af77grWDmbBW0sPGju6e1iXA50ALnG8cTUVSuwVc9xV0uBi+fgA+vgEO/OZ3KmM8UdLiFe3267oEmK6quYD9W6+IYqrD4EnQ96+w+EN44wK7lLQJSyUtXv8C1gHVgDki0gywO6NWVCJw5p0wbIpzHuSrZ8P/vvM7lTFBVaLipaovqGojVe2vjvXAOR5nM2XV9ny4YRZUrQtvDoKf/mXtYCZslLTBvoaIPCsiKe7j7zh7Yaaiq9sarv8G2l4An98D00ZDbo7fqYwps5IeNr4BZAND3MceYIJXoUyQxZ4EQyfDWWMhdTJMuBB2b/I7lTFlUtLi1UpVH1TVte7jYaCll8FMkEVEwDn3OUVs+0qnHWzDj36nMuaElbR47RORMwqfiEhvYJ83kYynOgyA62c630pOHAApb/idyJgTElXC+W4E3hSRGu7zXcBwbyIZz9VrDzfMho+uhxm3w+Y0uPBpiKridzJjSqyk3zamqWo3oCvQVVW7A3b+SSiLqwnD3ocz7oAFE2HSAMjO9DuVMSVWqiupquqegHMa7/AgjylPEZFw7oNw2QTIXOy0g6Uv8DuVMSVSlstAS9BSGH91vhSu+xoiq8CEfvDz234nMua4ylK8rLdjOCm8wGHTXk5fsM/usXtEmgrtmA32IpJN0UVKgDhPEhn/VK0NV30M3zwIc1+ELUthyCSoVtfvZMYc5Zh7Xqoar6onFfGIV9WSflNpQklkFFzwGPz+VdiU4rSDZaT6ncqYo3h26zMRaSIis0XkFxFZKiJj3PG1ReRrEVnl/qzlVQZTBt2GwrVfOOdCvnEBLJridyJjDuPlfRvzgDtVtSNwGjBaRDoCY4GZqtoGmOk+NxVRYnenHazRyc61wb78C+Tn+Z3KGMDD4qWqm1V1oTucDSwDGgGDgEnubJNwrhFmKqrqCXDNNOd+kXNfhMl/gN92+p3KmPK5Y7aINAe6Az8B9VV1szspE6hfzGtGFV7FYtu2beUR0xQnMhr6Pw0DX4T1PzjtYJlL/E5lKjnPi5eIVAc+Am478qYd6tz9o8guF6r6qqomq2pyQkKC1zFNSfS4GkZ+DvkHnDsVLf2P34lMJeZp8XIvHf0RMFlVP3ZHbxGRhu70hsBWLzOYIGuc7LSD1e/s3Cty5jgoyPc7lamEvPy2UYDXgWWq+mzApOkcOql7ODDNqwzGI/ENYMQM6DEcvvs7vHs57MvyO5WpZLzc8+qNc4ehPiKS6j76A08C54nIKuBc97kJNVExMPAFGPCcc8/I1/rAthV+pzKViGcdTVX1e4o//7GvV+s15Sz5WkjoAFOugdf6wqX/gvYX+Z3KVALl8m2jCXPNejntYHVbw3vD4NsnoaDA71QmzFnxMsFRo5HzTWS3K+DbJ+D9qyDH7o5nvGPFywRPdBxcMh76PQkrv4B/nwvbV/udyoQpK14muETgtJvg6qnw6zanIX/lV36nMmHIipfxRsuznHawmk3hnSHw3bN2w1sTVHZZG+OdWs3guq9g+i0w82HnRh+XvAxV7H7FAKrKwg27mJaawc8bslAUQRD3O3oBEDn4lb2IM07cGeTguEMvKBznPHWWdeRzDi5Pjpj/0DgOjpcjph9aX+Gyj17/kfkOrS86Unjqsm4n/qEFsOJlvFWlKvzhdWjYDb55CLavgssnQ+0WfifzzfLMPUxLzeCTtAzSd+0jJiqCU5rXJjrS+TNXnJ3Uwv1UdfdYC3dcFT00rEc8B7TAGYc7XgOXcdhynCeH1lPUsg+9NnCewOmFGfXgQg+97shlR0cG72DPipfxngj0HgP1O8GH18Jr5zg3/Wh1jt/Jys3Gnb8xPS2D6akZrNiSTWSE0Lt1XW4/ty3nd6pPfGy03xFDjmgItEMkJydrSkqK3zFMMOxYA+9dCdtXwHmPQK/RHHYsE0a2Ze/n00UZTE/LYOGGLACSm9ViYFIi/bs0pG71GH8DhgARWaCqyUVNsz0vU77qtILrv4b/3ARf/cVpBxv4gtPNIgzsycnlyyWZTE/L4L+rt1Og0L5BPPf0a8fFXRNpUruq3xHDhhUvU/5i4mHwm85J3bMfc/bChk6Gmk38TnZCcnLzmb18K9PTMpi5fCsH8gpoUjuOm85uxcBujWjXIN7viGHJipfxR0QEnHW3c8u1j0c5FzgcMgman+F3shLJyy9g7todTEvN4MslmWTvz6Nu9SoM69mUgUmJdG9S8+C3bsYbVryMv9pdCNfPdM6JfHMQXPAE9LyhQraDqSo/b8xiemoGMxZlsH3vAeJjorigcwMGJSXSq2UdooL4bZo5Nitexn8JbeGGmc4e2Od3Q2Ya9P87RMf6nQyAlVuymZa6ielpGWzcuY8qURH0bV+PQUmJnN2uHrHRkX5HrJSseJmKIbYGXP6uc1L3nKdg63IY+haclOhLnI07f+OTRU7XhuWZ2UQI9G5dlzF9na4NJ1nXBt9Z8TIVR0QE9PkLNOgCU2902sF+/wrUbonTfVtK/7MU827/9QCfL8nkk7RMUjZkoUD3JrV4eGAn+ndpSEK8dW2oSKyfl6mYtvzitIPt+p/fSQIUVfQiSlAYi3ttYJGNAIl0CrhEQkRkwDh3OCIy4PmxxkcUMV9x4wNef9iySrOM0oyPgianlPwTt35eJuTU7+ic2L3qayjIdc9p0cN/asHR46DoeQN+5ubns2brXpZl7GbN1mzyC5SacZF0aBhPhwbVSagec9xlHPPnCb2mwHkU5IPmuz8L3OGCgHEB0wryQQ8UMz5w/qJeX8z4om/mFTyRVeCB4NzK0IqXqbjiakLXwUFZVH6BMnfNDqanbeLzJZlk5+RRp1oVBiQ3ZGBSIj2a1rKuDXCokB5V6IorjKUsmEFkxcuELVUldWMW09MymLFoM9uy91M9JooLOjVgYFIivVtZ14ajiBw6zKvgrHiZsLN6azbTUjOYlprBhp2/USUygj7t6zEwKZE+7a1rQ7iw4mXCwqasfXyS5hSsZZv3ECFwequ63NKnNRd0akCNOOvaEG6seJmQtfPXA3y6eDPTUzcxf90uAJKa1OTBiztyUdeG1IuvGJ1cjTeseJmQsnd/Hl//ksm01Ay+X7WdvAKlTb3q3HV+Wy7ulkizOnaV1srCipep8Pbn5fN/K7YxPS2Db5ZtISe3gEY147j+zJYMSkqkfYN4+6awErLiZSqs9F2/8eKs1Xy2eDN7cvKoXa0Kg09uwiC3a0NEhBWsysyKl6mQFqVnce3EFH7dn8eFnd2uDa3rBvUa6Ca0WfEyFc7MZVu45Z2fqVO9Cu+N6k3renYxP3M0K16mQnlr7joenL6UTok1eH1Esn1jaIplxctUCAUFyt++WM6/5qzl3A71eOGK7lStYpunKZ5tHcZ3Obn53PlBGp8u2szVpzXjoYGdiLTGeHMcVryMr3b9eoAb3kwhZf0u/ty/PTec2dK6PZgSseJlfLN+x6+MnDCf9Kx9vDSsBxd1beh3JBNCrHgZX/y8YRfXT0ohX5V3rj+V5Oa1/Y5kQowVL1PuvlyayZj3fqZefCwTR55Cy4TqfkcyIciKlylXb3z/Px759Be6Na7J68OTqWO3vDcnyIqXKRf5Bcpjny7jjf/+jws61ef5od2Jq2LX1TInzoqX8VxObj63vZfKF0szGdm7Ofdf1NG6Qpgys+JlPLVj736ufzOF1I1Z/HVAR649o4XfkUyYsOJlPLN2215GTpxP5u4cxl95Mv06N/A7kgkjnp2iLyJviMhWEVkSMK62iHwtIqvcn7W8Wr/x14L1O/nD+B/Izsnj3VGnWeEyQefl9UUmAv2OGDcWmKmqbYCZ7nMTZj5dtJkrXvuJmlWrMPXm0+nR1P5HmeDzrHip6hxg5xGjBwGT3OFJwCVerd+UP1XltTlrGf3OQro2qsFHN51ul2U2ninvNq/6qrrZHc4E6hc3o4iMAkYBNG3atByimbLIL1DGfbKUSXPXc1GXhvx9SDe7xZjxlG+XpVQtvNd5sdNfVdVkVU1OSEgox2SmtH47kMcf31rApLnrGfW7lvzziu5WuIznynvPa4uINFTVzSLSENhazus3QbYtez/XTZrPkk27eWRQJ67u1dzvSKaSKO89r+nAcHd4ODCtnNdvgmj11r38/uX/smrLXl69OtkKlylXnu15ici7wNlAXRFJBx4EngSmiMh1wHpgiFfrN976ae0ORr21gOhI4f0/nkbXxjX9jmQqGc+Kl6peUcykvl6t05SPaambuPuDRTSpHcfEkT1pUruq35FMJWQ97E2JqSrj/28NT32xgp4tavPa1cnUqBrtdyxTSVnxMiWSl1/AA9OW8u68DQzslsjTg7sSE2XfKBr/WPEyx/Xr/jxueWchs1ds4+azW3HX+e3sbtXGd1a8zDFt2ZPDtRPnszwzm8d/34Vhp1qHYVMxWPEyxVq5JZsRb8wja18u/x6ezDnt6vkdyZiDrHiZIv2wejt/fHsBcdGRTPljLzo3quF3JGMOY8XLHOXjhenc+9EiWtStxoSRPWlUM87vSMYcxYqXOUhV+ees1Tz79UpOb1WH8VedTI046wphKiYrXgaA3PwC/jJ1MVNS0rm0RyOevLQrVaJ8O2/fmOOy4mXIzsnl5skL+W7Vdv7Utw23n9sGEesKYSo2K16V3Obd+xg5YT6rt+7lqT90ZcgpTfyOZEyJWPGqxJZt3sPICfPZuz+PCSNP4cw2dt00EzqseFVSc1Zu4+bJC6keE8UHN/aiQ8OT/I5kTKlY8aqEpqRs5M8fL6Z1vepMGHkKDWtYVwgTeqx4VSKqynPfrOKFmas4s01dXr6yB/Gx1hXChCYrXpXEgbwCxn68iI8XbmJIcmMe+30XoiOtK4QJXVa8KoHd+3K56e0F/LBmB3ee15Zb+rS2rhAm5FnxCnObsvYxcsI8/rf9V54d0o1LezT2O5IxQWHFK4wt2bSbayfOZ19uPpNG9uT01nX9jmRM0FjxClOzV2xl9OSF1KpahbevP5W29eP9jmRMUFnxCkPv/LSBB6YtoX2DeCaMOIV6J8X6HcmYoLPiFUYKCpRnvlrBy9+u4Zx2Cbw4rAfVYuxXbMKTbdlhYn9ePnd/sIjpaRkMO7Up4wZ2Isq6QpgwZsUrDGT9doBRby1g3v92cm+/9tx4VkvrCmHCnhWvELdx52+MmDCPjTv38Y/LkxiU1MjvSMaUCyteIWxRehbXTpxPbr7y1nU9ObVlHb8jGVNuwqp4fbk0k3s/WkRUhBAVEUFkhBAdKe5P53lUhBAVGTgtgugIZ56oSOd1Ue5wZETEwdcXvq5w2c68RS876uCy3BwBw1FFDB+2jMgipkVEHHWfxG9+2cKt7/5MnepVeG9UT1rXq+7Tp26MP8KqeDWqGcegbonkFij5+UpuQQH5BUpevpLnDufmq/uzgP25BeQW5JNfUODOc2jaoXkLyHOXkV/gLFO1/N+bCERHHCpwe/fn0bVRDf49/BQS4mPKP5AxPgur4tW5UY1yuUVXQUFAYQwojnn5hxe/w6YdMZx/1HglL7/gqAJ61OvcaTXiohn1u5ZUrRJWv0JjSsy2/BMQESHERET6HcOYSs06AhljQpIVL2NMSLLiZYwJSVa8jDEhyYqXMSYkWfEyxoQkK17GmJBkxcsYE5JE/TjXpZREZBuwvoSz1wW2exinorD3GV7sfRatmaomFDUhJIpXaYhIiqom+53Da/Y+w4u9z9Kzw0ZjTEiy4mWMCUnhWLxe9TtAObH3GV7sfZZS2LV5GWMqh3Dc8zLGVAJWvIwxISmsipeI9BORFSKyWkTG+p3HCyLyhohsFZElfmfxkog0EZHZIvKLiCwVkTF+Z/KCiMSKyDwRSXPf58N+Z/KSiESKyM8iMqOsywqb4iUikcBLwIVAR+AKEenobypPTAT6+R2iHOQBd6pqR+A0YHSY/j73A31UtRuQBPQTkdP8jeSpMcCyYCwobIoX0BNYraprVfUA8B4wyOdMQaeqc4CdfufwmqpuVtWF7nA2zgYfdjelVMde92m0+wjLb9FEpDFwEfDvYCwvnIpXI2BjwPN0wnBjr4xEpDnQHfjJ5yiecA+lUoGtwNeqGpbvE3geuAcoCMbCwql4mTAkItWBj4DbVHWP33m8oKr5qpoENAZ6ikhnnyMFnYgMALaq6oJgLTOcitcmoEnA88buOBOiRCQap3BNVtWP/c7jNVXNAmYTnm2avYGBIrIOp0mnj4i8XZYFhlPxmg+0EZEWIlIFuByY7nMmc4JERIDXgWWq+qzfebwiIgkiUtMdjgPOA5b7GsoDqnqfqjZW1eY4f5uzVPWqsiwzbIqXquYBtwBf4jTuTlHVpf6mCj4ReReYC7QTkXQRuc7vTB7pDVyN8x861X309zuUBxoCs0VkEc4/4K9VtczdCCoDOz3IGBOSwmbPyxhTuVjxMsaEJCtexpiQZMXLGBOSrHgZY0KSFS/jKRHJD+jqkBrMq32ISPNwv7qGKV6U3wFM2NvnnvpiTFDZnpfxhYisE5GnRGSxez2r1u745iIyS0QWichMEWnqjq8vIlPd616licjp7qIiReQ191pYX7m91BGRP7nXAlskIu/59DaNh6x4Ga/FHXHYODRg2m5V7QK8iHPFAYB/ApNUtSswGXjBHf8C8H/uda96AIVnT7QBXlLVTkAW8Ad3/Figu7ucG715a8ZP1sPeeEpE9qpq9SLGr8O5CN9a9wTsTFWtIyLbgYaqmuuO36yqdd27pjdW1f0By2iOczpNG/f5vUC0qj4qIl8Ae4H/AP8JuGaWCRO252X8pMUMl8b+gOF8DrXjXoRzZd0ewHwRsfbdMGPFy/hpaMDPue7wDzhXHQC4EvjOHZ4J3AQHL95Xo7iFikgE0ERVZwP3AjWAo/b+TGiz/0bGa3HuVUILfaGqhd0larlXU9gPXOGOuxWYICJ3A9uAke74McCr7lU08nEK2eZi1hkJvO0WOAFecK+VZcKItXkZX7htXsmqut3vLCY02WGjMSYk2Z6XMSYk2Z6XMSYkWfEyxoQkK17GmJBkxcsYE5KseBljQtL/A/w8USoPr020AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "training_vis(train_losses, valid_losses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 模型评估\n", + "\n", + "在测试集上评估模型效果。" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T04:55:47.340416Z", + "iopub.status.busy": "2021-11-29T04:55:47.339537Z", + "iopub.status.idle": "2021-11-29T04:55:47.606447Z", + "shell.execute_reply": "2021-11-29T04:55:47.607038Z", + "shell.execute_reply.started": "2021-11-28T14:01:44.127754Z" + }, + "papermill": { + "duration": 2.453872, + "end_time": "2021-11-29T04:55:47.607210", + "exception": false, + "start_time": "2021-11-29T04:55:45.153338", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 加载得分最高的模型\n", + "checkpoint = torch.load('../input/ai-earth-model-weights/task05_model_weights.pth')\n", + "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n", + "model.load_state_dict(checkpoint['state_dict'])" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T04:55:51.849996Z", + "iopub.status.busy": "2021-11-29T04:55:51.849073Z", + "iopub.status.idle": "2021-11-29T04:55:51.851492Z", + "shell.execute_reply": "2021-11-29T04:55:51.850969Z", + "shell.execute_reply.started": "2021-11-28T14:06:59.931318Z" + }, + "papermill": { + "duration": 2.125413, + "end_time": "2021-11-29T04:55:51.851629", + "exception": false, + "start_time": "2021-11-29T04:55:49.726216", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 测试集路径\n", + "test_path = '../input/ai-earth-tests/'\n", + "# 测试集标签路径\n", + "test_label_path = '../input/ai-earth-tests-labels/'" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T04:55:56.429364Z", + "iopub.status.busy": "2021-11-29T04:55:56.428800Z", + "iopub.status.idle": "2021-11-29T04:55:58.115486Z", + "shell.execute_reply": "2021-11-29T04:55:58.115007Z", + "shell.execute_reply.started": "2021-11-28T14:07:13.415385Z" + }, + "papermill": { + "duration": 4.135325, + "end_time": "2021-11-29T04:55:58.115667", + "exception": false, + "start_time": "2021-11-29T04:55:53.980342", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# 读取测试数据和测试数据的标签\n", + "files = os.listdir(test_path)\n", + "X_test = []\n", + "y_test = []\n", + "for file in files:\n", + " X_test.append(np.load(test_path + file))\n", + " y_test.append(np.load(test_label_path + file))" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T04:56:02.560786Z", + "iopub.status.busy": "2021-11-29T04:56:02.559461Z", + "iopub.status.idle": "2021-11-29T04:56:02.587431Z", + "shell.execute_reply": "2021-11-29T04:56:02.588024Z", + "shell.execute_reply.started": "2021-11-28T14:07:17.046359Z" + }, + "papermill": { + "duration": 2.329175, + "end_time": "2021-11-29T04:56:02.588201", + "exception": false, + "start_time": "2021-11-29T04:56:00.259026", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "((103, 12, 24, 48, 1), (103, 24))" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_test = np.array(X_test)[:, :, :, 19: 67, :1]\n", + "y_test = np.array(y_test)\n", + "X_test.shape, y_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "execution": { + "iopub.execute_input": "2021-11-29T04:56:07.675344Z", + "iopub.status.busy": "2021-11-29T04:56:07.674488Z", + "iopub.status.idle": "2021-11-29T04:56:07.682481Z", + "shell.execute_reply": "2021-11-29T04:56:07.682895Z", + "shell.execute_reply.started": "2021-11-28T14:07:31.503452Z" + }, + "papermill": { + "duration": 2.455352, + "end_time": "2021-11-29T04:56:07.683041", + "exception": false, + "start_time": "2021-11-29T04:56:05.227689", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "testset = AIEarthDataset(X_test, y_test)\n", + "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 在测试集上评估模型效果\n", + "model.eval()\n", + "model.to(device)\n", + "preds = np.zeros((len(y_test),24))\n", + "for i, data in tqdm(enumerate(testloader)):\n", + " data, labels = data\n", + " data = data.to(device)\n", + " labels = labels.to(device)\n", + " pred = model(data, train=False)\n", + " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n", + "s = score(y_test, preds)\n", + "print('Score: {:.3f}'.format(s))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "source": [ + "## 总结\n", + "\n", + "这一次的TOP方案没有自己设计模型,而是使用了目前时空序列预测领域现有的模型,另一组TOP选手“ailab”也使用了现有的模型PredRNN++,关于时空序列预测领域的一些比较经典的模型可以参考https://www.zhihu.com/column/c_1208033701705162752" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 作业\n", + "\n", + "该TOP方案中以sst作为预测目标,间接计算nino3.4指数,学有余力的同学可以尝试用SA-ConvLSTM模型直接预测nino3.4指数。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 参考文献\n", + "\n", + "1. 吴先生的队伍方案分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d5330dF9lX1&postId=231465\n", + "2. ailab团队思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d5330dF9lX1&postId=210734" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "papermill": { + "default_parameters": {}, + "duration": 6708.571081, + "end_time": "2021-11-29T04:56:15.789285", + "environment_variables": {}, + "exception": true, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2021-11-29T03:04:27.218204", + "version": "2.3.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}