From bcc434537cd97e5f2b382f5a7a474df6b2c8d9f6 Mon Sep 17 00:00:00 2001
From: Harold-Ran <56714856+Harold-Ran@users.noreply.github.com>
Date: Sat, 4 Dec 2021 20:40:24 +0800
Subject: [PATCH] Add files via upload
---
Task5/Task5 模型建立之SA-ConvLSTM.ipynb | 1818 +++++++++++++++++++++++
1 file changed, 1818 insertions(+)
create mode 100644 Task5/Task5 模型建立之SA-ConvLSTM.ipynb
diff --git a/Task5/Task5 模型建立之SA-ConvLSTM.ipynb b/Task5/Task5 模型建立之SA-ConvLSTM.ipynb
new file mode 100644
index 0000000..9ac6f98
--- /dev/null
+++ b/Task5/Task5 模型建立之SA-ConvLSTM.ipynb
@@ -0,0 +1,1818 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Datawhale 气象海洋预测-Task5 模型建立之 SA-ConvLSTM\n",
+ "\n",
+ "本次任务我们将学习来自TOP选手“吴先生的队伍”的建模方案,该方案中采用的模型是SA-ConvLSTM。\n",
+ "\n",
+ "前两个TOP方案中选择将赛题看作一个多输出的任务,通过构建神经网络直接输出24个nino3.4预测值,这种思路的问题在于,序列问题往往是时序依赖的,当我们采用多输出的方法时其实把这24个nino3.4预测值看作是完全独立的,但是实际上它们之间是存在序列依赖的,即每个预测值往往受上一个时间步的预测值的影响。因此,在这次的TOP方案中,采用Seq2Seq结构来考虑输出预测值的序列依赖性。\n",
+ "\n",
+ "Seq2Seq结构包括Encoder(编码器)和Decoder(解码器)两部分,Encoder部分将输入序列编码成一个向量,Decoder部分对向量进行解码,输出一个预测序列。要将Seq2Seq结构应用于不同的序列问题,关键在于每一个时间步所使用的Cell。我们之前说到,挖掘空间信息通常会采用CNN,挖掘时间信息通常会采用RNN或LSTM,将二者结合在一起就得到了时空序列领域的经典模型——ConvLSTM,我们本次要学习的SA-ConvLSTM模型是对ConvLSTM模型的改进,在其基础上引入了自注意力机制来提高模型对于长期空间依赖关系的挖掘能力。\n",
+ "\n",
+ "另外与前两个TOP方案所不同的一点是,该TOP方案没有直接预测Nino3.4指数,而是通过预测sst来间接求得Nino3.4指数序列。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 学习目标\n",
+ "1. 学习TOP方案的模型构建方法"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 内容介绍\n",
+ "1. 数据处理\n",
+ " - 数据扁平化\n",
+ " - 空值填充\n",
+ " - 构造数据集\n",
+ "2. 模型构建\n",
+ " - 构造评估函数\n",
+ " - 模型构造\n",
+ " - 模型训练\n",
+ " - 模型评估\n",
+ "3. 总结"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 代码示例"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 数据处理\n",
+ "该TOP方案的数据处理主要包括三部分:\n",
+ "1. 数据扁平化。\n",
+ "2. 空值填充。\n",
+ "3. 构造数据集"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:04:34.698663Z",
+ "iopub.status.busy": "2021-11-29T03:04:34.697133Z",
+ "iopub.status.idle": "2021-11-29T03:04:37.035400Z",
+ "shell.execute_reply": "2021-11-29T03:04:37.034767Z",
+ "shell.execute_reply.started": "2021-11-29T01:02:51.883602Z"
+ },
+ "papermill": {
+ "duration": 2.370278,
+ "end_time": "2021-11-29T03:04:37.035673",
+ "exception": false,
+ "start_time": "2021-11-29T03:04:34.665395",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import netCDF4 as nc\n",
+ "import random\n",
+ "import os\n",
+ "from tqdm import tqdm\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import math\n",
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline\n",
+ "\n",
+ "import torch\n",
+ "from torch import nn, optim\n",
+ "import torch.nn.functional as F\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
+ "\n",
+ "from sklearn.metrics import mean_squared_error"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:04:37.102995Z",
+ "iopub.status.busy": "2021-11-29T03:04:37.102144Z",
+ "iopub.status.idle": "2021-11-29T03:04:37.107646Z",
+ "shell.execute_reply": "2021-11-29T03:04:37.107161Z",
+ "shell.execute_reply.started": "2021-11-29T01:02:54.06493Z"
+ },
+ "papermill": {
+ "duration": 0.040737,
+ "end_time": "2021-11-29T03:04:37.107761",
+ "exception": false,
+ "start_time": "2021-11-29T03:04:37.067024",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 固定随机种子\n",
+ "SEED = 22\n",
+ "\n",
+ "def seed_everything(seed=42):\n",
+ " random.seed(seed)\n",
+ " os.environ['PYTHONHASHSEED'] = str(seed)\n",
+ " np.random.seed(seed)\n",
+ " torch.manual_seed(seed)\n",
+ " torch.cuda.manual_seed(seed)\n",
+ " torch.backends.cudnn.deterministic = True\n",
+ " \n",
+ "seed_everything(SEED)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:04:37.222525Z",
+ "iopub.status.busy": "2021-11-29T03:04:37.221100Z",
+ "iopub.status.idle": "2021-11-29T03:04:37.225844Z",
+ "shell.execute_reply": "2021-11-29T03:04:37.226442Z",
+ "shell.execute_reply.started": "2021-11-29T01:02:54.074875Z"
+ },
+ "papermill": {
+ "duration": 0.090198,
+ "end_time": "2021-11-29T03:04:37.226602",
+ "exception": false,
+ "start_time": "2021-11-29T03:04:37.136404",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA is available! Training on GPU ...\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 查看CUDA是否可用\n",
+ "train_on_gpu = torch.cuda.is_available()\n",
+ "\n",
+ "if not train_on_gpu:\n",
+ " print('CUDA is not available. Training on CPU ...')\n",
+ "else:\n",
+ " print('CUDA is available! Training on GPU ...')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:04:37.353082Z",
+ "iopub.status.busy": "2021-11-29T03:04:37.352143Z",
+ "iopub.status.idle": "2021-11-29T03:04:37.432852Z",
+ "shell.execute_reply": "2021-11-29T03:04:37.434332Z",
+ "shell.execute_reply.started": "2021-11-28T10:13:13.644947Z"
+ },
+ "papermill": {
+ "duration": 0.179146,
+ "end_time": "2021-11-29T03:04:37.434792",
+ "exception": false,
+ "start_time": "2021-11-29T03:04:37.255646",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 读取数据\n",
+ "\n",
+ "# 存放数据的路径\n",
+ "path = '/kaggle/input/ninoprediction/'\n",
+ "soda_train = nc.Dataset(path + 'SODA_train.nc')\n",
+ "soda_label = nc.Dataset(path + 'SODA_label.nc')\n",
+ "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n",
+ "cmip_label = nc.Dataset(path + 'CMIP_label.nc')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 数据扁平化\n",
+ "采用滑窗构造数据集。该方案中只使用了sst特征,且只使用了lon值在[90, 330]范围内的数据,可能是为了节约计算资源。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:04:37.548239Z",
+ "iopub.status.busy": "2021-11-29T03:04:37.546737Z",
+ "iopub.status.idle": "2021-11-29T03:04:37.551951Z",
+ "shell.execute_reply": "2021-11-29T03:04:37.553081Z",
+ "shell.execute_reply.started": "2021-11-27T13:38:32.620904Z"
+ },
+ "papermill": {
+ "duration": 0.065069,
+ "end_time": "2021-11-29T03:04:37.553274",
+ "exception": false,
+ "start_time": "2021-11-29T03:04:37.488205",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "def make_flatted(train_ds, label_ds, info, start_idx=0):\n",
+ " # 只使用sst特征\n",
+ " keys = ['sst']\n",
+ " label_key = 'nino'\n",
+ " # 年数\n",
+ " years = info[1]\n",
+ " # 模式数\n",
+ " models = info[2]\n",
+ " \n",
+ " train_list = []\n",
+ " label_list = []\n",
+ " \n",
+ " # 将同种模式下的数据拼接起来\n",
+ " for model_i in range(models):\n",
+ " blocks = []\n",
+ " \n",
+ " # 对每个特征,取每条数据的前12个月进行拼接,只使用lon值在[90, 330]范围内的数据\n",
+ " for key in keys:\n",
+ " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12, :, 19: 67].reshape(-1, 24, 48, 1).data\n",
+ " blocks.append(block)\n",
+ " \n",
+ " # 将所有特征在最后一个维度上拼接起来\n",
+ " train_flatted = np.concatenate(blocks, axis=-1)\n",
+ " \n",
+ " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n",
+ " label_flatted = np.concatenate([\n",
+ " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n",
+ " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n",
+ " ], axis=0)\n",
+ " \n",
+ " train_list.append(train_flatted)\n",
+ " label_list.append(label_flatted)\n",
+ " \n",
+ " return train_list, label_list"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:04:37.661954Z",
+ "iopub.status.busy": "2021-11-29T03:04:37.660977Z",
+ "iopub.status.idle": "2021-11-29T03:05:11.515409Z",
+ "shell.execute_reply": "2021-11-29T03:05:11.515853Z",
+ "shell.execute_reply.started": "2021-11-27T13:38:33.844185Z"
+ },
+ "papermill": {
+ "duration": 33.912013,
+ "end_time": "2021-11-29T03:05:11.516001",
+ "exception": false,
+ "start_time": "2021-11-29T03:04:37.603988",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((1, 1200, 24, 48, 1), (15, 1812, 24, 48, 1), (17, 1680, 24, 48, 1))"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "soda_info = ('soda', 100, 1)\n",
+ "cmip6_info = ('cmip6', 151, 15)\n",
+ "cmip5_info = ('cmip5', 140, 17)\n",
+ "\n",
+ "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n",
+ "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n",
+ "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n",
+ "\n",
+ "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n",
+ "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 空值填充\n",
+ "将空值填充为0。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:11.638562Z",
+ "iopub.status.busy": "2021-11-29T03:05:11.637553Z",
+ "iopub.status.idle": "2021-11-29T03:05:11.644302Z",
+ "shell.execute_reply": "2021-11-29T03:05:11.644742Z",
+ "shell.execute_reply.started": "2021-11-27T13:39:22.665855Z"
+ },
+ "papermill": {
+ "duration": 0.040786,
+ "end_time": "2021-11-29T03:05:11.644893",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:11.604107",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of null in soda_trains after fillna: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 填充SODA数据中的空值\n",
+ "soda_trains = np.array(soda_trains)\n",
+ "soda_trains_nan = np.isnan(soda_trains)\n",
+ "soda_trains[soda_trains_nan] = 0\n",
+ "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:11.709054Z",
+ "iopub.status.busy": "2021-11-29T03:05:11.707767Z",
+ "iopub.status.idle": "2021-11-29T03:05:11.862744Z",
+ "shell.execute_reply": "2021-11-29T03:05:11.863294Z",
+ "shell.execute_reply.started": "2021-11-27T13:39:24.110039Z"
+ },
+ "papermill": {
+ "duration": 0.18937,
+ "end_time": "2021-11-29T03:05:11.863480",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:11.674110",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of null in cmip6_trains after fillna: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 填充CMIP6数据中的空值\n",
+ "cmip6_trains = np.array(cmip6_trains)\n",
+ "cmip6_trains_nan = np.isnan(cmip6_trains)\n",
+ "cmip6_trains[cmip6_trains_nan] = 0\n",
+ "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:11.927752Z",
+ "iopub.status.busy": "2021-11-29T03:05:11.925353Z",
+ "iopub.status.idle": "2021-11-29T03:05:12.091117Z",
+ "shell.execute_reply": "2021-11-29T03:05:12.091855Z",
+ "shell.execute_reply.started": "2021-11-27T13:39:24.520724Z"
+ },
+ "papermill": {
+ "duration": 0.197975,
+ "end_time": "2021-11-29T03:05:12.092014",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:11.894039",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of null in cmip6_trains after fillna: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 填充CMIP5数据中的空值\n",
+ "cmip5_trains = np.array(cmip5_trains)\n",
+ "cmip5_trains_nan = np.isnan(cmip5_trains)\n",
+ "cmip5_trains[cmip5_trains_nan] = 0\n",
+ "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 构造数据集\n",
+ "构造训练和验证集。注意这里取每条输入数据的序列长度是38,这是因为输入sst序列长度是12,输出sst序列长度是26,在训练中采用teacher forcing策略(这个策略会在之后的模型构造时详细说明),因此这里在构造输入数据时包含了输出sst序列的实际值。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:12.165242Z",
+ "iopub.status.busy": "2021-11-29T03:05:12.164045Z",
+ "iopub.status.idle": "2021-11-29T03:05:12.480257Z",
+ "shell.execute_reply": "2021-11-29T03:05:12.479767Z",
+ "shell.execute_reply.started": "2021-11-27T13:39:25.418945Z"
+ },
+ "papermill": {
+ "duration": 0.361254,
+ "end_time": "2021-11-29T03:05:12.480405",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:12.119151",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构造训练集\n",
+ "\n",
+ "X_train = []\n",
+ "y_train = []\n",
+ "# 从CMIP5的17种模式中各抽取100条数据\n",
+ "for model_i in range(17):\n",
+ " samples = np.random.choice(cmip5_trains.shape[1]-38, size=100)\n",
+ " for ind in samples:\n",
+ " X_train.append(cmip5_trains[model_i, ind: ind+38])\n",
+ " y_train.append(cmip5_labels[model_i][ind: ind+24])\n",
+ "# 从CMIP6的15种模式种各抽取100条数据\n",
+ "for model_i in range(15):\n",
+ " samples = np.random.choice(cmip6_trains.shape[1]-38, size=100)\n",
+ " for ind in samples:\n",
+ " X_train.append(cmip6_trains[model_i, ind: ind+38])\n",
+ " y_train.append(cmip6_labels[model_i][ind: ind+24])\n",
+ "X_train = np.array(X_train)\n",
+ "y_train = np.array(y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:12.541232Z",
+ "iopub.status.busy": "2021-11-29T03:05:12.540676Z",
+ "iopub.status.idle": "2021-11-29T03:05:12.548103Z",
+ "shell.execute_reply": "2021-11-29T03:05:12.547520Z",
+ "shell.execute_reply.started": "2021-11-27T13:39:26.341849Z"
+ },
+ "papermill": {
+ "duration": 0.040262,
+ "end_time": "2021-11-29T03:05:12.548224",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:12.507962",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构造测试集\n",
+ "\n",
+ "X_valid = []\n",
+ "y_valid = []\n",
+ "samples = np.random.choice(soda_trains.shape[1]-38, size=100)\n",
+ "for ind in samples:\n",
+ " X_valid.append(soda_trains[0, ind: ind+38])\n",
+ " y_valid.append(soda_labels[0][ind: ind+24])\n",
+ "X_valid = np.array(X_valid)\n",
+ "y_valid = np.array(y_valid)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:12.606407Z",
+ "iopub.status.busy": "2021-11-29T03:05:12.605555Z",
+ "iopub.status.idle": "2021-11-29T03:05:12.611580Z",
+ "shell.execute_reply": "2021-11-29T03:05:12.611152Z",
+ "shell.execute_reply.started": "2021-11-27T13:39:27.247585Z"
+ },
+ "papermill": {
+ "duration": 0.036214,
+ "end_time": "2021-11-29T03:05:12.611721",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:12.575507",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# 查看数据集维度\n",
+ "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:12.737322Z",
+ "iopub.status.busy": "2021-11-29T03:05:12.736558Z",
+ "iopub.status.idle": "2021-11-29T03:05:13.516712Z",
+ "shell.execute_reply": "2021-11-29T03:05:13.517217Z",
+ "shell.execute_reply.started": "2021-11-27T13:39:38.421657Z"
+ },
+ "papermill": {
+ "duration": 0.812187,
+ "end_time": "2021-11-29T03:05:13.517368",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:12.705181",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 保存数据集\n",
+ "np.save('X_train_sample.npy', X_train)\n",
+ "np.save('y_train_sample.npy', y_train)\n",
+ "np.save('X_valid_sample.npy', X_valid)\n",
+ "np.save('y_valid_sample.npy', y_valid)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 模型构建"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:13.577516Z",
+ "iopub.status.busy": "2021-11-29T03:05:13.576992Z",
+ "iopub.status.idle": "2021-11-29T03:05:21.917657Z",
+ "shell.execute_reply": "2021-11-29T03:05:21.918265Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:01.505192Z"
+ },
+ "papermill": {
+ "duration": 8.372964,
+ "end_time": "2021-11-29T03:05:21.918443",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:13.545479",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 读取数据集\n",
+ "X_train = np.load('../input/ai-earth-task05-samples/X_train_sample.npy')\n",
+ "y_train = np.load('../input/ai-earth-task05-samples/y_train_sample.npy')\n",
+ "X_valid = np.load('../input/ai-earth-task05-samples/X_valid_sample.npy')\n",
+ "y_valid = np.load('../input/ai-earth-task05-samples/y_valid_sample.npy')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:21.983898Z",
+ "iopub.status.busy": "2021-11-29T03:05:21.982953Z",
+ "iopub.status.idle": "2021-11-29T03:05:21.986939Z",
+ "shell.execute_reply": "2021-11-29T03:05:21.986453Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:11.548945Z"
+ },
+ "papermill": {
+ "duration": 0.039398,
+ "end_time": "2021-11-29T03:05:21.987066",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:21.947668",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((3200, 38, 24, 48, 1), (3200, 24), (100, 38, 24, 48, 1), (100, 24))"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:22.341929Z",
+ "iopub.status.busy": "2021-11-29T03:05:22.340932Z",
+ "iopub.status.idle": "2021-11-29T03:05:22.346140Z",
+ "shell.execute_reply": "2021-11-29T03:05:22.346878Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:11.560457Z"
+ },
+ "papermill": {
+ "duration": 0.143838,
+ "end_time": "2021-11-29T03:05:22.347113",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:22.203275",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构造数据管道\n",
+ "class AIEarthDataset(Dataset):\n",
+ " def __init__(self, data, label):\n",
+ " self.data = torch.tensor(data, dtype=torch.float32)\n",
+ " self.label = torch.tensor(label, dtype=torch.float32)\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.label)\n",
+ " \n",
+ " def __getitem__(self, idx):\n",
+ " return self.data[idx], self.label[idx]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:22.583350Z",
+ "iopub.status.busy": "2021-11-29T03:05:22.582298Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.243100Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.243851Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:23.691846Z"
+ },
+ "papermill": {
+ "duration": 0.825537,
+ "end_time": "2021-11-29T03:05:23.244098",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:22.418561",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "batch_size = 2\n",
+ "\n",
+ "trainset = AIEarthDataset(X_train, y_train)\n",
+ "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
+ "\n",
+ "validset = AIEarthDataset(X_valid, y_valid)\n",
+ "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 构造评估函数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.655820Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.655241Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.658416Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.658859Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:26.481561Z"
+ },
+ "papermill": {
+ "duration": 0.040887,
+ "end_time": "2021-11-29T03:05:23.658990",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.618103",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "def rmse(y_true, y_preds):\n",
+ " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n",
+ "\n",
+ "# 评估函数\n",
+ "def score(y_true, y_preds):\n",
+ " # 相关性技巧评分\n",
+ " accskill_score = 0\n",
+ " # RMSE\n",
+ " rmse_scores = 0\n",
+ " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n",
+ " y_true_mean = np.mean(y_true, axis=0)\n",
+ " y_pred_mean = np.mean(y_preds, axis=0)\n",
+ " for i in range(24):\n",
+ " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n",
+ " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n",
+ " cor_i = fenzi / fenmu\n",
+ " accskill_score += a[i] * np.log(i+1) * cor_i\n",
+ " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n",
+ " rmse_scores += rmse_score\n",
+ " return 2/3.0 * accskill_score - rmse_scores"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "papermill": {
+ "duration": 0.028556,
+ "end_time": "2021-11-29T03:05:23.310560",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.282004",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "#### 模型构造"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "不同于前两个TOP方案所构建的多输出神经网络,该TOP方案采用的是Seq2Seq结构,以本赛题为例,输入的序列长度是12,输出的序列长度是26,方案中构建了四个隐藏层,那么一个基础的Seq2Seq结构就如下图所示:\n",
+ "\n",
+ "
\n",
+ "\n",
+ "要将Seq2Seq结构应用于不同的问题,重点在于使用怎样的Cell(神经元)。在该TOP方案中使用的Cell是清华大学提出的SA-ConvLSTM(Self-Attention ConvLSTM),论文原文可参考https://ojs.aaai.org//index.php/AAAI/article/view/6819\n",
+ "\n",
+ "SA-ConvLSTM是施行健博士提出的时空序列领域经典模型ConvLSTM的改进模型,为了捕捉空间信息的时序依赖关系,它在ConvLSTM的基础上增加了SAM模块,用来记忆空间的聚合特征。ConvLSTM的论文原文可参考https://arxiv.org/pdf/1506.04214.pdf\n",
+ "\n",
+ "1. ConvLSTM模型\n",
+ "\n",
+ "LSTM模型是非常经典的时序模型,三个门的结构使得它在挖掘长期的时间依赖任务中有不俗的表现,并且相较于RNN,LSTM能够有效地避免梯度消失问题。对于单个输入样本,在每个时间步上,LSTM的每个门实际是对输入向量做了一个全连接,那么对应到我们这个赛题上,输入X的形状是(N,T,H,W,C),则单个输入样本在每个时间步上输入LSTM的就是形状为(H,W,C)的空间信息。我们知道,全连接网络对于这种空间信息的提取能力并不强,转换成卷积操作后能够在大大减少参数量的同时通过堆叠多层网络逐步提取出更复杂的特征,到这里就可以很自然地想到,把LSTM中的全连接操作转换为卷积操作,就能够适用于时空序列问题。ConvLSTM模型就是这么做的,实践也表明这样的作法是非常有效的。\n",
+ "\n",
+ "
\n",
+ "\n",
+ "2. SAM模块\n",
+ "\n",
+ "然而,ConvLSTM模型存在两个问题:\n",
+ "\n",
+ "一是卷积层的感受野受限于卷积核的大小,需要通过堆叠多个卷积层来扩大感受野,发掘全局的特征。举例来说,假设第一个卷积层的卷积核大小是3×3,那么这一层的每个节点就只能感知这3×3的空间范围内的输入信息,此时再增加一个3×3的卷积层,那么每个节点所能感知的就是3×3个第一层的节点内的信息,在第一层步长为1的情况下,就是4×4范围内的输入信息,于是相比于第一个卷积层,第二层所能感知的输入信息的空间范围就增大了,而这样做所带来的后果就是参数量增加。对于单纯的CNN模型来说增加一层只是增加了一个卷积核大小的参数量,但是对于ConvLSTM来说就有些不堪重负,参数量的增加增大了过拟合的风险,与此同时模型的收效却并不高。\n",
+ "\n",
+ "二是卷积操作只针对当前时间步输入的空间信息,而忽视了过去的空间信息,因此难以挖掘空间信息在时间上的依赖关系。\n",
+ "\n",
+ "因此,为了同时挖掘全局和本地的空间依赖,提升模型在大空间范围和长时间的时空序列预测任务中的预测效果,SA-ConvLSTM模型在ConvLSTM模型的基础上引入了SAM(self-attention memory)模块。\n",
+ "\n",
+ "
\n",
+ "\n",
+ "SAM模块引入了一个新的记忆单元M,用来记忆包含时序依赖关系的空间信息。SAM模块以当前时间步通过ConvLSTM所获得的隐藏层状态$H_t$和上一个时间步的记忆$M_{t-1}$作为输入,首先将$H_t$通过自注意力机制得到特征$Z_h$,自注意力机制能够增加$H_t$中与其他部分更相关的部分的权重,同时$H_t$也作为Query与$M_{t-1}$共同通过注意力机制得到特征$Z_m$,用以增强对$M_{t-1}$中与$H_t$有更强依赖关系的部分的权重,将$Z_h$和$Z_m$拼接起来就得到了二者的聚合特征$Z$。此时,聚合特征$Z$中既包含了当前时间步的信息,又包含了全局的时空记忆信息,接下来借鉴LSTM中的门控结构用聚合特征$Z$对隐藏层状态和记忆单元进行更新,就得到了更新后的隐藏层状态$\\hat{H_t}$和当前时间步的记忆$M_t$。SAM模块的公式如下:\n",
+ "\n",
+ "$$\n",
+ "\\begin{aligned}\n",
+ "& i'_t = \\sigma (W_{m;zi} \\ast Z + W_{m;hi} \\ast H_t + b_{m;i}) \\\\\n",
+ "& g'_t = tanh (W_{m;zg} \\ast Z + W_{m;hg} \\ast H_t + b_{m;g}) \\\\\n",
+ "& M_t = (1 - i'_t) \\circ M_{t-1} + i'_t \\circ g'_t \\\\\n",
+ "& o'_t = \\sigma (W_{m;zo} \\ast Z + W_{m;ho} \\ast H_t + b_{m;o}) \\\\\n",
+ "& \\hat{H_t} = o'_t \\circ M_t\n",
+ "\\end{aligned}\n",
+ "$$\n",
+ "\n",
+ "关于注意力机制和自注意力机制可以参考以下链接:\n",
+ "\n",
+ " - 深度学习中的注意力机制:https://blog.csdn.net/malefactor/article/details/78767781\n",
+ " - 目前主流的Attention方法:https://www.zhihu.com/question/68482809\n",
+ "\n",
+ "3. SA-ConvLSTM模型\n",
+ "\n",
+ "将以上二者结合起来,就得到了SA-ConvLSTM模型:\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.372772Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.371873Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.373700Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.374122Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:24.585147Z"
+ },
+ "papermill": {
+ "duration": 0.035787,
+ "end_time": "2021-11-29T03:05:23.374254",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.338467",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Attention机制\n",
+ "def attn(query, key, value):\n",
+ " # query、key、value的形状都是(N, C, H*W),令S=H*W\n",
+ " # 采用缩放点积模型计算得分,scores(i)=key(i)^T query/根号C\n",
+ " scores = torch.matmul(query.transpose(1, 2), key / math.sqrt(query.size(1))) # (N, S, S)\n",
+ " # 计算注意力得分\n",
+ " attn = F.softmax(scores, dim=-1)\n",
+ " output = torch.matmul(attn, value.transpose(1, 2)) # (N, S, C)\n",
+ " return output.transpose(1, 2) # (N, C, S)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.440765Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.440042Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.442191Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.442569Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:25.147999Z"
+ },
+ "papermill": {
+ "duration": 0.041095,
+ "end_time": "2021-11-29T03:05:23.442725",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.401630",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# SAM模块\n",
+ "class SAAttnMem(nn.Module):\n",
+ " def __init__(self, input_dim, d_model, kernel_size):\n",
+ " super().__init__()\n",
+ " pad = kernel_size[0] // 2, kernel_size[1] // 2\n",
+ " self.d_model = d_model\n",
+ " self.input_dim = input_dim\n",
+ " # 用1*1卷积实现全连接操作WhHt\n",
+ " self.conv_h = nn.Conv2d(input_dim, d_model*3, kernel_size=1)\n",
+ " # 用1*1卷积实现全连接操作WmMt-1\n",
+ " self.conv_m = nn.Conv2d(input_dim, d_model*2, kernel_size=1)\n",
+ " # 用1*1卷积实现全连接操作Wz[Zh,Zm]\n",
+ " self.conv_z = nn.Conv2d(d_model*2, d_model, kernel_size=1)\n",
+ " # 注意输出维度和输入维度要保持一致,都是input_dim\n",
+ " self.conv_output = nn.Conv2d(input_dim+d_model, input_dim*3, kernel_size=kernel_size, padding=pad)\n",
+ " \n",
+ " def forward(self, h, m):\n",
+ " # self.conv_h(h)得到WhHt,将其在dim=1上划分成大小为self.d_model的块,每一块的形状就是(N, d_model, H, W),所得到的三块就是Qh、Kh、Vh\n",
+ " hq, hk, hv = torch.split(self.conv_h(h), self.d_model, dim=1)\n",
+ " # 同样的方法得到Km和Vm\n",
+ " mk, mv = torch.split(self.conv_m(m), self.d_model, dim=1)\n",
+ " N, C, H, W = hq.size()\n",
+ " # 通过自注意力机制得到Zh\n",
+ " Zh = attn(hq.view(N, C, -1), hk.view(N, C, -1), hv.view(N, C, -1)) # (N, C, S), C=d_model\n",
+ " # 通过注意力机制得到Zm\n",
+ " Zm = attn(hq.view(N, C, -1), mk.view(N, C, -1), mv.view(N, C, -1)) # (N, C, S), C=d_model\n",
+ " # 将Zh和Zm拼接起来,并进行全连接操作得到聚合特征Z\n",
+ " Z = self.conv_z(torch.cat([Zh.view(N, C, H, W), Zm.view(N, C, H, W)], dim=1)) # (N, C, H, W), C=d_model\n",
+ " # 计算i't、g't、o't\n",
+ " i, g, o = torch.split(self.conv_output(torch.cat([Z, h], dim=1)), self.input_dim, dim=1) # (N, C, H, W), C=input_dim\n",
+ " i = torch.sigmoid(i)\n",
+ " g = torch.tanh(g)\n",
+ " # 得到更新后的记忆单元Mt\n",
+ " m_next = i * g + (1 - i) * m\n",
+ " # 得到更新后的隐藏状态Ht\n",
+ " h_next = torch.sigmoid(o) * m_next\n",
+ " return h_next, m_next"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.509738Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.509080Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.512667Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.512182Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:25.667808Z"
+ },
+ "papermill": {
+ "duration": 0.042616,
+ "end_time": "2021-11-29T03:05:23.512781",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.470165",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# SA-ConvLSTM Cell\n",
+ "class SAConvLSTMCell(nn.Module):\n",
+ " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n",
+ " super().__init__()\n",
+ " self.input_dim = input_dim\n",
+ " self.hidden_dim = hidden_dim\n",
+ " pad = kernel_size[0] // 2, kernel_size[1] // 2\n",
+ " # 卷积操作Wx*Xt+Wh*Ht-1\n",
+ " self.conv = nn.Conv2d(in_channels=input_dim+hidden_dim, out_channels=4*hidden_dim, kernel_size=kernel_size, padding=pad)\n",
+ " self.sa = SAAttnMem(input_dim=hidden_dim, d_model=d_attn, kernel_size=kernel_size)\n",
+ " \n",
+ " def initialize(self, inputs):\n",
+ " device = inputs.device\n",
+ " N, _, H, W = inputs.size()\n",
+ " # 初始化隐藏层状态Ht\n",
+ " self.hidden_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
+ " # 初始化记忆细胞状态ct\n",
+ " self.cell_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
+ " # 初始化记忆单元状态Mt\n",
+ " self.memory_state = torch.zeros(N, self.hidden_dim, H, W, device=device)\n",
+ " \n",
+ " def forward(self, inputs, first_step=False):\n",
+ " # 如果当前是第一个时间步,初始化Ht、ct、Mt\n",
+ " if first_step:\n",
+ " self.initialize(inputs)\n",
+ " \n",
+ " # ConvLSTM部分\n",
+ " # 拼接Xt和Ht\n",
+ " combined = torch.cat([inputs, self.hidden_state], dim=1) # (N, C, H, W), C=input_dim+hidden_dim\n",
+ " # 进行卷积操作\n",
+ " combined_conv = self.conv(combined) \n",
+ " # 得到四个门控单元it、ft、ot、gt\n",
+ " cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)\n",
+ " i = torch.sigmoid(cc_i)\n",
+ " f = torch.sigmoid(cc_f)\n",
+ " o = torch.sigmoid(cc_o)\n",
+ " g = torch.tanh(cc_g)\n",
+ " # 得到当前时间步的记忆细胞状态ct=ft·ct-1+it·gt\n",
+ " self.cell_state = f * self.cell_state + i * g\n",
+ " # 得到当前时间步的隐藏层状态Ht=ot·tanh(ct)\n",
+ " self.hidden_state = o * torch.tanh(self.cell_state)\n",
+ " \n",
+ " # SAM部分,更新Ht和Mt\n",
+ " self.hidden_state, self.memory_state = self.sa(self.hidden_state, self.memory_state)\n",
+ " \n",
+ " return self.hidden_state"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "在Seq2Seq模型的训练中,有两种训练模式。一是Free running,也就是传统的训练方式,以上一个时间步的输出$\\hat{y_{t-1}}$作为下一个时间步的输入,但是这种做法存在的问题是在训练的初期所得到的$\\hat{y_{t-1}}$与实际标签$y_{t-1}$相差甚远,以此作为输入会导致后续的输出越来越偏离我们期望的预测标签。于是就产生了第二种训练模式——Teacher forcing。\n",
+ "\n",
+ "Teacher forcing就是直接使用实际标签$y_{t-1}$作为下一个时间步的输入,由老师(ground truth)带领着防止模型越走越偏。但是老师不能总是手把手领着学生走,要逐渐放手让学生自主学习,于是我们使用Scheduled Sampling来控制使用实际标签的概率。我们用ratio来表示Scheduled Sampling的比例,在训练初期,ratio=1,模型完全由老师带领着,随着训练论述的增加,ratio以一定的方式衰减(该方案中使用线性衰减,ratio每次减小一个衰减率decay_rate),每个时间步以ratio的概率从伯努利分布中提取二进制随机数0或1,为1时输入就是实际标签$y_{t-1}$,否则输入为$\\hat{y_{t-1}}$。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.587567Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.586781Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.588776Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.589156Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:26.065997Z"
+ },
+ "papermill": {
+ "duration": 0.047514,
+ "end_time": "2021-11-29T03:05:23.589277",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.541763",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构建SA-ConvLSTM模型\n",
+ "class SAConvLSTM(nn.Module):\n",
+ " def __init__(self, input_dim, hidden_dim, d_attn, kernel_size):\n",
+ " super().__init__()\n",
+ " self.input_dim = input_dim\n",
+ " self.hidden_dim = hidden_dim\n",
+ " self.num_layers = len(hidden_dim)\n",
+ " \n",
+ " layers = []\n",
+ " for i in range(self.num_layers):\n",
+ " cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]\n",
+ " layers.append(SAConvLSTMCell(input_dim=cur_input_dim, hidden_dim=self.hidden_dim[i], d_attn = d_attn, kernel_size=kernel_size)) \n",
+ " self.layers = nn.ModuleList(layers)\n",
+ " \n",
+ " self.conv_output = nn.Conv2d(self.hidden_dim[-1], 1, kernel_size=1)\n",
+ " \n",
+ " def forward(self, input_x, device=torch.device('cuda:0'), input_frames=12, future_frames=26, output_frames=37, teacher_forcing=False, scheduled_sampling_ratio=0, train=True):\n",
+ " # 将输入样本X的形状(N, T, H, W, C)转换为(N, T, C, H, W)\n",
+ " input_x = input_x.permute(0, 1, 4, 2, 3).contiguous()\n",
+ " \n",
+ " # 仅在训练时使用teacher forcing\n",
+ " if train:\n",
+ " if teacher_forcing and scheduled_sampling_ratio > 1e-6:\n",
+ " teacher_forcing_mask = torch.bernoulli(scheduled_sampling_ratio * torch.ones(input_x.size(0), future_frames-1, 1, 1, 1))\n",
+ " else:\n",
+ " teacher_forcing = False\n",
+ " else:\n",
+ " teacher_forcing = False\n",
+ " \n",
+ " total_steps = input_frames + future_frames - 1\n",
+ " outputs = [None] * total_steps\n",
+ " \n",
+ " # 对于每一个时间步\n",
+ " for t in range(total_steps):\n",
+ " # 在前12个月,使用每个月的输入样本Xt\n",
+ " if t < input_frames:\n",
+ " input_ = input_x[:, t].to(device)\n",
+ " # 若不使用teacher forcing,则以上一个时间步的预测标签作为当前时间步的输入\n",
+ " elif not teacher_forcing:\n",
+ " input_ = outputs[t-1]\n",
+ " # 若使用teacher forcing,则以ratio的概率使用上一个时间步的实际标签作为当前时间步的输入\n",
+ " else:\n",
+ " mask = teacher_forcing_mask[:, t-input_frames].float().to(device)\n",
+ " input_ = input_x[:, t].to(device) * mask + outputs[t-1] * (1-mask)\n",
+ " first_step = (t==0)\n",
+ " input_ = input_.float()\n",
+ " \n",
+ " # 将当前时间步的输入通过隐藏层\n",
+ " for layer_idx in range(self.num_layers):\n",
+ " input_ = self.layers[layer_idx](input_, first_step=first_step)\n",
+ " \n",
+ " # 记录每个时间步的输出\n",
+ " if train or (t >= (input_frames - 1)):\n",
+ " outputs[t] = self.conv_output(input_)\n",
+ " \n",
+ " outputs = [x for x in outputs if x is not None]\n",
+ " \n",
+ " # 确认输出序列的长度\n",
+ " if train:\n",
+ " assert len(outputs) == output_frames\n",
+ " else:\n",
+ " assert len(outputs) == future_frames\n",
+ " \n",
+ " # 得到sst的预测序列\n",
+ " outputs = torch.stack(outputs, dim=1)[:, :, 0] # (N, 37, H, W)\n",
+ " # 对sst的预测序列在nino3.4区域取三个月的平均值就得到nino3.4指数的预测序列\n",
+ " nino_pred = outputs[:, -future_frames:, 10:13, 19:30].mean(dim=[2, 3]) # (N, 26)\n",
+ " nino_pred = nino_pred.unfold(dimension=1, size=3, step=1).mean(dim=2) # (N, 24)\n",
+ " \n",
+ " return nino_pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.726291Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.725688Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.753509Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.753976Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:29.448921Z"
+ },
+ "papermill": {
+ "duration": 0.066105,
+ "end_time": "2021-11-29T03:05:23.754109",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.688004",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "SAConvLSTM(\n",
+ " (layers): ModuleList(\n",
+ " (0): SAConvLSTMCell(\n",
+ " (conv): Conv2d(65, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (sa): SAAttnMem(\n",
+ " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " )\n",
+ " )\n",
+ " (1): SAConvLSTMCell(\n",
+ " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (sa): SAAttnMem(\n",
+ " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " )\n",
+ " )\n",
+ " (2): SAConvLSTMCell(\n",
+ " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (sa): SAAttnMem(\n",
+ " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " )\n",
+ " )\n",
+ " (3): SAConvLSTMCell(\n",
+ " (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (sa): SAAttnMem(\n",
+ " (conv_h): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_m): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_z): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (conv_output): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (conv_output): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 输入特征数\n",
+ "input_dim = 1\n",
+ "# 隐藏层节点数\n",
+ "hidden_dim = (64, 64, 64, 64)\n",
+ "# 注意力机制节点数\n",
+ "d_attn = 32\n",
+ "# 卷积核大小\n",
+ "kernel_size = (3, 3)\n",
+ "\n",
+ "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n",
+ "print(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 模型训练"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.816671Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.815927Z",
+ "iopub.status.idle": "2021-11-29T03:05:23.818479Z",
+ "shell.execute_reply": "2021-11-29T03:05:23.818058Z",
+ "shell.execute_reply.started": "2021-11-29T01:03:31.476806Z"
+ },
+ "papermill": {
+ "duration": 0.035723,
+ "end_time": "2021-11-29T03:05:23.818579",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.782856",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 采用RMSE作为损失函数\n",
+ "def RMSELoss(y_pred,y_true):\n",
+ " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n",
+ " return loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T03:05:23.893469Z",
+ "iopub.status.busy": "2021-11-29T03:05:23.892684Z",
+ "iopub.status.idle": "2021-11-29T04:55:28.956056Z",
+ "shell.execute_reply": "2021-11-29T04:55:28.956434Z"
+ },
+ "papermill": {
+ "duration": 6605.109145,
+ "end_time": "2021-11-29T04:55:28.956614",
+ "exception": false,
+ "start_time": "2021-11-29T03:05:23.847469",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch: 1/5\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 3.289\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "50it [00:11, 4.47it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 44.009\n",
+ "Score: -43.458\n",
+ "Epoch: 2/5\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 1600/1600 [21:43<00:00, 1.23it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 3.084\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "50it [00:11, 4.33it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 25.011\n",
+ "Score: -19.966\n",
+ "Epoch: 3/5\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 1600/1600 [21:46<00:00, 1.22it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 13.461\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "50it [00:12, 4.16it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 15.438\n",
+ "Score: -14.139\n",
+ "Epoch: 4/5\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 1600/1600 [21:54<00:00, 1.22it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 17.627\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "50it [00:12, 3.99it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 15.389\n",
+ "Score: -22.500\n",
+ "Epoch: 5/5\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 1600/1600 [21:55<00:00, 1.22it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 17.592\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "50it [00:11, 4.48it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 15.252\n",
+ "Score: -14.459\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_weights = './task05_model_weights.pth'\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size).to(device)\n",
+ "criterion = RMSELoss\n",
+ "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
+ "lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=0, verbose=True, min_lr=0.0001)\n",
+ "epochs = 5\n",
+ "ratio, decay_rate = 1, 8e-5\n",
+ "train_losses, valid_losses = [], []\n",
+ "scores = []\n",
+ "best_score = float('-inf')\n",
+ "preds = np.zeros((len(y_valid),24))\n",
+ "\n",
+ "for epoch in range(epochs):\n",
+ " print('Epoch: {}/{}'.format(epoch+1, epochs))\n",
+ " \n",
+ " # 模型训练\n",
+ " model.train()\n",
+ " losses = 0\n",
+ " for data, labels in tqdm(trainloader):\n",
+ " data = data.to(device)\n",
+ " labels = labels.to(device)\n",
+ " optimizer.zero_grad()\n",
+ " # ratio线性衰减\n",
+ " ratio = max(ratio-decay_rate, 0)\n",
+ " pred = model(data, teacher_forcing=True, scheduled_sampling_ratio=ratio, train=True)\n",
+ " loss = criterion(pred, labels)\n",
+ " losses += loss.cpu().detach().numpy()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " train_loss = losses / len(trainloader)\n",
+ " train_losses.append(train_loss)\n",
+ " print('Training Loss: {:.3f}'.format(train_loss))\n",
+ " \n",
+ " # 模型验证\n",
+ " model.eval()\n",
+ " losses = 0\n",
+ " with torch.no_grad():\n",
+ " for i, data in tqdm(enumerate(validloader)):\n",
+ " data, labels = data\n",
+ " data = data.to(device)\n",
+ " labels = labels.to(device)\n",
+ " pred = model(data, train=False)\n",
+ " loss = criterion(pred, labels)\n",
+ " losses += loss.cpu().detach().numpy()\n",
+ " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
+ " valid_loss = losses / len(validloader)\n",
+ " valid_losses.append(valid_loss)\n",
+ " print('Validation Loss: {:.3f}'.format(valid_loss))\n",
+ " s = score(y_valid, preds)\n",
+ " scores.append(s)\n",
+ " print('Score: {:.3f}'.format(s))\n",
+ " \n",
+ " # 保存最佳模型权重\n",
+ " if s > best_score:\n",
+ " best_score = s\n",
+ " checkpoint = {'best_score': s,\n",
+ " 'state_dict': model.state_dict()}\n",
+ " torch.save(checkpoint, model_weights)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T04:55:33.957872Z",
+ "iopub.status.busy": "2021-11-29T04:55:33.957066Z",
+ "iopub.status.idle": "2021-11-29T04:55:33.960119Z",
+ "shell.execute_reply": "2021-11-29T04:55:33.959684Z",
+ "shell.execute_reply.started": "2021-11-28T14:00:36.33194Z"
+ },
+ "papermill": {
+ "duration": 2.38263,
+ "end_time": "2021-11-29T04:55:33.960247",
+ "exception": false,
+ "start_time": "2021-11-29T04:55:31.577617",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 绘制训练/验证曲线\n",
+ "def training_vis(train_losses, valid_losses):\n",
+ " # 绘制损失函数曲线\n",
+ " fig = plt.figure(figsize=(8,4))\n",
+ " # subplot loss\n",
+ " ax1 = fig.add_subplot(121)\n",
+ " ax1.plot(train_losses, label='train_loss')\n",
+ " ax1.plot(valid_losses,label='val_loss')\n",
+ " ax1.set_xlabel('Epochs')\n",
+ " ax1.set_ylabel('Loss')\n",
+ " ax1.set_title('Loss on Training and Validation Data')\n",
+ " ax1.legend()\n",
+ " plt.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T04:55:38.227343Z",
+ "iopub.status.busy": "2021-11-29T04:55:38.226636Z",
+ "iopub.status.idle": "2021-11-29T04:55:38.470256Z",
+ "shell.execute_reply": "2021-11-29T04:55:38.469252Z",
+ "shell.execute_reply.started": "2021-11-28T14:00:43.42651Z"
+ },
+ "papermill": {
+ "duration": 2.378943,
+ "end_time": "2021-11-29T04:55:38.470387",
+ "exception": false,
+ "start_time": "2021-11-29T04:55:36.091444",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsYUlEQVR4nO3dd3wUdf7H8dcnhSRApIYSeu8QMKKIngoWRATPE1AsgIVT0cMu3umpWE899TwVT08BFQsWDsQu4A89UQiYUKRzlBBCDwQlkPL5/TETWCCBhOxkspvP8/HYR2ZnZmfeu5l8MvPd78yIqmKMMaEmwu8AxhhzIqx4GWNCkhUvY0xIsuJljAlJVryMMSHJipcxJiRZ8aqkRORzERke7Hn9JCLrRORcD5b7rYhc7w5fKSJflWTeE1hPUxHZKyKRJ5q1MqlUxcurjbu8uBt24aNARPYFPL+yNMtS1QtVdVKw562IRGSsiMwpYnxdETkgIp1LuixVnayq5wcp12Hbo6puUNXqqpofjOUfsS4VkV/dbWWHiMwUkaGleP3ZIpIe7FxlUamKV6hzN+zqqlod2ABcHDBucuF8IhLlX8oK6W3gdBFpccT4y4HFqrrEh0x+6OZuO+2AicCLIvKgv5FOnBUvQERiROR5EclwH8+LSIw7ra6IzBCRLBHZKSLfiUiEO+1eEdkkItkiskJE+haz/Boi8qaIbBOR9SJyf8AyRojI9yLyjIjsEpH/iciFpcx/toiku3kygQkiUsvNvc1d7gwRaRzwmsBDoWNmKOW8LURkjvuZfCMiL4nI28XkLknGR0Tkv+7yvhKRugHTr3Y/zx0i8pfiPh9VTQdmAVcfMeka4M3j5Tgi8wgR+T7g+XkislxEdovIi4AETGslIrPcfNtFZLKI1HSnvQU0BT5x94buEZHm7h5SlDtPoohMd7e71SJyQ8CyHxKRKe52lS0iS0UkubjP4IjPY7uqvgXcBNwnInXcZY4UkWXu8taKyB/d8dWAz4FEObSnnygiPUVkrvu3sVlEXhSRKiXJEAxWvBx/AU4DkoBuQE/gfnfanUA6kADUB/4MqIi0A24BTlHVeOACYF0xy/8nUANoCZyF80czMmD6qcAKoC7wFPC6iMiRCzmOBkBtoBkwCud3O8F93hTYB7x4jNeXJsOx5n0HmAfUAR7i6IIRqCQZh+F8VvWAKsBdACLSERjvLj/RXV+RBcc1KTCL+/tLcvOW9rMqXEZd4GOcbaUusAboHTgL8ISbrwPQBOczQVWv5vC956eKWMV7ONteInAZ8LiI9AmYPtCdpyYwvSSZjzANiMLZ3gG2AgOAk3A+8+dEpIeq/gpcCGQE7OlnAPnA7e577wX0BW4uZYYTp6qV5oFTXM4tYvwaoH/A8wuAde7wOJxfcusjXtMa55d9LhB9jHVGAgeAjgHj/gh86w6PAFYHTKsKKNCgpO8FONtdR+wx5k8CdgU8/xa4viQZSjovzh9+HlA1YPrbwNsl/P0UlfH+gOc3A1+4w38F3guYVs39DI76/Qbk3AOc7j5/DJh2gp/V9+7wNcCPAfMJTrG5vpjlXgL8XNz2CDR3P8sonEKXD8QHTH8CmOgOPwR8EzCtI7DvGJ+tcsQ27I7PBK4s5jX/AcYEbGPpx/n93QZMLcnvOhgP2/NyJALrA56vd8cBPA2sBr5yd6XHAqjqapxf1kPAVhF5T0QSOVpdILqI5TcKeJ5ZOKCqv7mD1Uv5Hrapak7hExGpKiL/cg+r9gBzgJpS/DdZpclQ3LyJwM6AcQAbiwtcwoyZAcO/BWRKDFy2OnsHO4pbl5vpA+Aady/xSuDNUuQoypEZNPC5iNR3t4tN7nLfxtkeSqLws8wOGFfsdoPz2cRKKdo7RSQa54hip/v8QhH50T1MzQL6HyuviLR1D7Ez3ff3+LHmDzYrXo4MnEOGQk3dcahqtqreqaotcXbT7xC3bUtV31HVM9zXKvC3Ipa9HcgtYvmbgvwejrw8yJ04DbOnqupJwO/c8aU9HC2NzUBtEakaMK7JMeYvS8bNgct211nnOK+ZBAwBzgPigU/KmOPIDMLh7/dxnN9LF3e5Vx2xzGNd0iUD57OMDxgX7O1mEM6e8jxx2ng/Ap4B6qtqTeCzgLxFZR0PLAfauO/vz3i7fR2mMhavaBGJDXhEAe8C94tIgtuO8Vec/5KIyAARae1umLtxduULRKSdiPRxf+k5OO0kBUeuTJ2vvacAj4lIvIg0A+4oXL6H4t1MWSJSG/D8WyVVXQ+kAA+JSBUR6QVc7FHGD4EBInKG20g8juNvz98BWcCrOIecB8qY41Ogk4hc6m5Hf8I5fC4UD+wFdotII+DuI16/Bacd9CiquhH4AXjC3U67AtcRhO1GRGqL07XmJeBvqroDpz0xBtgG5InzJUxgl5AtQB0RqXHE+9sD7BWR9jhfAJSbyli8PsPZUAsfDwGP4vzRLQIWAwvdcQBtgG9wNsK5wMuqOhvnF/0kzp5VJk6D8n3FrPNW4FdgLfA9TiPxG8F9W0d5Hohz8/0IfOHx+gpdidN4uwPnM3wf2F/MvM9zghlVdSkwGuez3AzswmlvOtZrFOdQsZn7s0w5VHU7MBhnO9iBs638N2CWh4EeOP/0PsVp3A/0BM4/zSwRuauIVVyB0w6WAUwFHlTVb0qSrRhpIrIXpxnkeuB2Vf2r+16ycYrvFJzPchjOlwCF73U5zj/5tW7eRJwvT4YB2cBrOL/rciNuQ5sxnhCR94Hlqhqy/YlMxVQZ97yMh0TkFHH6N0WISD+cdpX/+BzLhCHriW2CrQHO4VEdnMO4m1T1Z38jmXBkh43GmJBkh43GmJAUEoeNdevW1ebNm/sdwxhTzhYsWLBdVROKmhYSxat58+akpKT4HcMYU85EZH1x0+yw0RgTkqx4GWNCkhUvY0xICok2L2MqotzcXNLT08nJyTn+zOaYYmNjady4MdHR0SV+jRUvY05Qeno68fHxNG/enNJfO9IUUlV27NhBeno6LVoceaXu4tlhozEnKCcnhzp16ljhKiMRoU6dOqXeg7XiZUwZWOEKjhP5HMOreOXsgR/Hg53yZEzYC6/itXwGfDEWUicff15jTEgLr+LV9XJo2gu+uh9+3e53GmM8lZWVxcsvv1zq1/Xv35+srKxSv27EiBF8+OGHpX6dV8KreEVEwIDnYf9ep4AZE8aKK155eXnHfN1nn31GzZo1PUpVfsKvq0S99tB7DHz3DHS7Alqe5XciUwk8/MlSfsnYE9Rldkw8iQcv7lTs9LFjx7JmzRqSkpKIjo4mNjaWWrVqsXz5clauXMkll1zCxo0bycnJYcyYMYwaNQo4dK7w3r17ufDCCznjjDP44YcfaNSoEdOmTSMuLu642WbOnMldd91FXl4ep5xyCuPHjycmJoaxY8cyffp0oqKiOP/883nmmWf44IMPePjhh4mMjKRGjRrMmTMnKJ9PeO15FfrdXVC7Jcy4HXKtA6EJT08++SStWrUiNTWVp59+moULF/KPf/yDlStXAvDGG2+wYMECUlJSeOGFF9ix4+g7w61atYrRo0ezdOlSatasyUcffXTc9ebk5DBixAjef/99Fi9eTF5eHuPHj2fHjh1MnTqVpUuXsmjRIu6/3zn6GTduHF9++SVpaWlMnz79OEsvufDb8wKIjoOLnoW3LoHv/g59ir0TvDFBcaw9pPLSs2fPwzp5vvDCC0ydOhWAjRs3smrVKurUOfzucC1atCApKQmAk08+mXXr1h13PStWrKBFixa0bdsWgOHDh/PSSy9xyy23EBsby3XXXceAAQMYMGAAAL1792bEiBEMGTKESy+9NAjv1BGee14Arc6BrkPh++dg2wq/0xjjuWrVqh0c/vbbb/nmm2+YO3cuaWlpdO/evchOoDExMQeHIyMjj9tedixRUVHMmzePyy67jBkzZtCvXz8AXnnlFR599FE2btzIySefXOQe4IkI3+IFcP5jUKUafHIbFBx1S0VjQlp8fDzZ2dlFTtu9eze1atWiatWqLF++nB9//DFo623Xrh3r1q1j9erVALz11lucddZZ7N27l927d9O/f3+ee+450tLSAFizZg2nnnoq48aNIyEhgY0bi72JeqmE52FjoeoJcP4jMP1Wp+9Xj6v9TmRM0NSpU4fevXvTuXNn4uLiqF+//sFp/fr145VXXqFDhw60a9eO0047LWjrjY2NZcKECQwePPhgg/2NN97Izp07GTRoEDk5Oagqzz77LAB33303q1atQlXp27cv3bp1C0qOkLgBR3Jysp7wlVRVYeJFsGUp3JLiFDRjgmDZsmV06NDB7xhho6jPU0QWqGpyUfOH92EjgAgMeA4O/ApfWcO9MeEi/IsXQEI7OON2WPQ+rJntdxpjKrTRo0eTlJR02GPChAl+xzpKeLd5BTrzTljyIXx6B9z0g9OdwhhzlJdeesnvCCVSOfa8AKJjncPHnWthzjN+pzHGlFHlKV4ALc92Thn67z9g6zK/0xhjyqByFS+A8x+FmOrOqUPW98uYkFX5ile1uk4B2zAXfn7L7zTGmBPkefESkUgR+VlEZrjPW4jITyKyWkTeF5EqXmc4StKV0OwM+PoB2Lu13FdvjB+qV69e7LR169bRuXPnckxTduWx5zUGCGxg+hvwnKq2BnYB15VDhsMV9v3K3Qdf/rncV2+MKTtPu0qISGPgIuAx4A5xrrLfBxjmzjIJeAgY72WOIiW0hTPugP970mnEb9233COYMPL5WMhcHNxlNugCFz5Z7OSxY8fSpEkTRo8eDcBDDz1EVFQUs2fPZteuXeTm5vLoo48yaNCgUq02JyeHm266iZSUFKKionj22Wc555xzWLp0KSNHjuTAgQMUFBTw0UcfkZiYyJAhQ0hPTyc/P58HHniAoUOHlultl5TXe17PA/cAhS3jdYAsVS08dT0daFTUC0VklIikiEjKtm3bvEl3xu1Qp7XT9yt3nzfrMMYjQ4cOZcqUKQefT5kyheHDhzN16lQWLlzI7NmzufPOOyntKYAvvfQSIsLixYt59913GT58ODk5ObzyyiuMGTOG1NRUUlJSaNy4MV988QWJiYmkpaWxZMmSg1eSKA+e7XmJyABgq6ouEJGzS/t6VX0VeBWccxuDm85V2Pdr0sUw52no+1dPVmMqgWPsIXmle/fubN26lYyMDLZt20atWrVo0KABt99+O3PmzCEiIoJNmzaxZcsWGjRoUOLlfv/999x6660AtG/fnmbNmrFy5Up69erFY489Rnp6Opdeeilt2rShS5cu3Hnnndx7770MGDCAM88806u3exQv97x6AwNFZB3wHs7h4j+AmiJSWDQbA5s8zHB8LX4H3YY5fb+2/OJrFGNKa/DgwXz44Ye8//77DB06lMmTJ7Nt2zYWLFhAamoq9evXL/XNXIszbNgwpk+fTlxcHP3792fWrFm0bduWhQsX0qVLF+6//37GjRsXlHWVhGfFS1XvU9XGqtocuByYpapXArOBy9zZhgPTvMpQYuc/CjEnwYzbrO+XCSlDhw7lvffe48MPP2Tw4MHs3r2bevXqER0dzezZs1m/fn2pl3nmmWcyebJz+8CVK1eyYcMG2rVrx9q1a2nZsiV/+tOfGDRoEIsWLSIjI4OqVaty1VVXcffdd7Nw4cJgv8Vi+dHP616cxvvVOG1gr/uQ4XDV6sAFj8HGn2DhJL/TGFNinTp1Ijs7m0aNGtGwYUOuvPJKUlJS6NKlC2+++Sbt27cv9TJvvvlmCgoK6NKlC0OHDmXixInExMQwZcoUOnfuTFJSEkuWLOGaa65h8eLF9OzZk6SkJB5++OGD160vD+F/Pa+SUnXavjIXwej5EF//+K8xlZpdzyu47HpeJ+qwvl/3+Z3GGHMcleeSOCVRtw2ceRd8+7jTiN/mXL8TGRNUixcv5uqrD78cekxMDD/99JNPiU6cFa8jnXEbLP7A6ft1849QparfiUwFpqo4fa9DQ5cuXUhNTfU7xlFOpPnKDhuPFBUDFz8PWethzlN+pzEVWGxsLDt27DihPzxziKqyY8cOYmNjS/U62/MqSvMzIOkq+OGf0GUw1Pf/hqKm4mncuDHp6el4dgZIJRIbG0vjxo1L9RorXsU5/xFY+blzz8drv4QI20k1h4uOjj7sDtWmfNlfZHGq1oYLHof0ebCg4t18wJjKzorXsXQd6pw+9M3DkJ3pdxpjTAArXsciAhc9B3k58IX1/TKmIrHidTx1W8Pv7oKlH8Oqr/1OY4xxWfEqid5joG47mHGHc+dtY4zvrHiVRFSMc+rQ7g3wf3/zO40xBiteJde8N3S/Gn54ETKX+J3GmErPildpnDcO4mrBJ2OgIN/vNMZUala8SqOw79emFEh5w+80xlRqVrxKq+sQaHk2zBwHezb7ncaYSsuKV2mJwEXPQt5++GKs32mMqbSseJ2IOq3grLvhl//Ayi/9TmNMpWTF60SdPgYS2sOnd1nfL2N8YMXrREVVgQHPO32/vn3C7zTGVDpWvMqiWS/oMRzmvgybF/mdxphKxYpXWZ37kNOFYsZt1vfLmHJkxausqtaGC56ATQtgvv+3oDSmsrDiFQxdLoNWfdy+Xxl+pzGmUrDiFQwicNHfoSAXPr/X7zTGVApWvIKldks46x5YNh1WfO53GmPCnhWvYOp1KyR0gM/uhv17/U5jTFiz4hVMUVWcez7u3mh9v4zxmBWvYGt6Gpw8En58GTan+Z3GmLBlxcsL5z4IVevadb+M8ZAVLy/E1YJ+T0DGzzDvNb/TGBOWrHh5pfMfoFVfmPUI7N7kdxpjwo4VL6+IwIBnncPGz+/xO40xYceKl5dqNYez74XlM2D5p36nMSasWPHyWq9boF5Ht+9Xtt9pjAkbVry8FhkNF/8D9myC2Y/7ncaYsGHFqzw06QnJ18JPrzjfQBpjysyKV3np+yBUS3D6fuXn+Z3GmJBnxau8xNWEfk86ve7nW98vY8rKs+IlIrEiMk9E0kRkqYg87I5vISI/ichqEXlfRKp4laHC6fR7aH0ezHoUdqf7ncaYkOblntd+oI+qdgOSgH4ichrwN+A5VW0N7AKu8zBDxSICFz3j9P36zPp+GVMWnhUvdRReFybafSjQB/jQHT8JuMSrDBVSreZwzn2w4lNYNsPvNMaELE/bvEQkUkRSga3A18AaIEtVC1us04FGXmaokE67Gep3tr5fxpSBp8VLVfNVNQloDPQE2pf0tSIySkRSRCRl27ZtXkX0R2S0c8/H7M1O+5cxptTK5dtGVc0CZgO9gJoiEuVOagwUedayqr6qqsmqmpyQkFAeMctXk1PglOvgp385dx4yxpSKl982JohITXc4DjgPWIZTxC5zZxsOTPMqQ4XX969QvT58cpv1/TKmlLzc82oIzBaRRcB84GtVnQHcC9whIquBOkDlvdlhbA248G+QuQjm/cvvNMaElKjjz3JiVHUR0L2I8Wtx2r8MQMdB0OYCmPUYdBgINZv4nciYkGA97P0mAv2fBtT59lHV70TGhAQrXhVBrWZw9n2w8nNY9onfaYwJCVa8KorTbob6XZyrrubs8TuNMRWeFa+KIjLKue5Xdqb1/TKmBKx4VSSNT4aeN8C8VyHd+n4ZcyxWvCqaPvdDfAO77pcxx2HFq6Ip7Pu1ZTH8NN7vNMZUWFa8KqIOA6Hthc4177M2+J3GmArJildFdLDvl8Cnd1nfL2OKYMWroqrZBM75M6z6En6pvKd/GlMcK14V2ak3QoOu8Pm9kLPb7zTGVChWvCqywr5fv26FmY/4ncaYCsWKV0XXqAf0HAXz/w3pKX6nMabCsOIVCs75C8Q3dPt+5fqdxpgKwYpXKIg9Cfo/BVuWwI8v+53GmArBileoaD8A2vWH2U/ArvV+pzHGd1a8QkVh3y+JgM+s75cxVrxCSY3GzrmPq76CpVP9TmOMr0pUvESkmohEuMNtRWSgiER7G80UqecoaNjN6fu1fZXfaYzxTUn3vOYAsSLSCPgKuBqY6FUocwyRUTDoZdACeK0PrPzS70TG+KKkxUtU9TfgUuBlVR0MdPIuljmmBp1h1LdQqzm8MxTmPGNtYKbSKXHxEpFewJXAp+64SG8imRKp2QSu/RK6DIZZj8AHw2H/Xr9TGVNuSlq8bgPuA6aq6lIRaYlz81jjpypV4dJX4fxHnRt3vH4e7FzrdypjyoVoKQ833Ib76qpabneJSE5O1pQUOzXmmNbMgg9GOsODJ0CrPv7mMSYIRGSBqiYXNa2k3za+IyIniUg1YAnwi4jcHcyQpoxa9XHawU5qBG//Af77grWDmbBW0sPGju6e1iXA50ALnG8cTUVSuwVc9xV0uBi+fgA+vgEO/OZ3KmM8UdLiFe3267oEmK6quYD9W6+IYqrD4EnQ96+w+EN44wK7lLQJSyUtXv8C1gHVgDki0gywO6NWVCJw5p0wbIpzHuSrZ8P/vvM7lTFBVaLipaovqGojVe2vjvXAOR5nM2XV9ny4YRZUrQtvDoKf/mXtYCZslLTBvoaIPCsiKe7j7zh7Yaaiq9sarv8G2l4An98D00ZDbo7fqYwps5IeNr4BZAND3MceYIJXoUyQxZ4EQyfDWWMhdTJMuBB2b/I7lTFlUtLi1UpVH1TVte7jYaCll8FMkEVEwDn3OUVs+0qnHWzDj36nMuaElbR47RORMwqfiEhvYJ83kYynOgyA62c630pOHAApb/idyJgTElXC+W4E3hSRGu7zXcBwbyIZz9VrDzfMho+uhxm3w+Y0uPBpiKridzJjSqyk3zamqWo3oCvQVVW7A3b+SSiLqwnD3ocz7oAFE2HSAMjO9DuVMSVWqiupquqegHMa7/AgjylPEZFw7oNw2QTIXOy0g6Uv8DuVMSVSlstAS9BSGH91vhSu+xoiq8CEfvDz234nMua4ylK8rLdjOCm8wGHTXk5fsM/usXtEmgrtmA32IpJN0UVKgDhPEhn/VK0NV30M3zwIc1+ELUthyCSoVtfvZMYc5Zh7Xqoar6onFfGIV9WSflNpQklkFFzwGPz+VdiU4rSDZaT6ncqYo3h26zMRaSIis0XkFxFZKiJj3PG1ReRrEVnl/qzlVQZTBt2GwrVfOOdCvnEBLJridyJjDuPlfRvzgDtVtSNwGjBaRDoCY4GZqtoGmOk+NxVRYnenHazRyc61wb78C+Tn+Z3KGMDD4qWqm1V1oTucDSwDGgGDgEnubJNwrhFmKqrqCXDNNOd+kXNfhMl/gN92+p3KmPK5Y7aINAe6Az8B9VV1szspE6hfzGtGFV7FYtu2beUR0xQnMhr6Pw0DX4T1PzjtYJlL/E5lKjnPi5eIVAc+Am478qYd6tz9o8guF6r6qqomq2pyQkKC1zFNSfS4GkZ+DvkHnDsVLf2P34lMJeZp8XIvHf0RMFlVP3ZHbxGRhu70hsBWLzOYIGuc7LSD1e/s3Cty5jgoyPc7lamEvPy2UYDXgWWq+mzApOkcOql7ODDNqwzGI/ENYMQM6DEcvvs7vHs57MvyO5WpZLzc8+qNc4ehPiKS6j76A08C54nIKuBc97kJNVExMPAFGPCcc8/I1/rAthV+pzKViGcdTVX1e4o//7GvV+s15Sz5WkjoAFOugdf6wqX/gvYX+Z3KVALl8m2jCXPNejntYHVbw3vD4NsnoaDA71QmzFnxMsFRo5HzTWS3K+DbJ+D9qyDH7o5nvGPFywRPdBxcMh76PQkrv4B/nwvbV/udyoQpK14muETgtJvg6qnw6zanIX/lV36nMmHIipfxRsuznHawmk3hnSHw3bN2w1sTVHZZG+OdWs3guq9g+i0w82HnRh+XvAxV7H7FAKrKwg27mJaawc8bslAUQRD3O3oBEDn4lb2IM07cGeTguEMvKBznPHWWdeRzDi5Pjpj/0DgOjpcjph9aX+Gyj17/kfkOrS86Unjqsm4n/qEFsOJlvFWlKvzhdWjYDb55CLavgssnQ+0WfifzzfLMPUxLzeCTtAzSd+0jJiqCU5rXJjrS+TNXnJ3Uwv1UdfdYC3dcFT00rEc8B7TAGYc7XgOXcdhynCeH1lPUsg+9NnCewOmFGfXgQg+97shlR0cG72DPipfxngj0HgP1O8GH18Jr5zg3/Wh1jt/Jys3Gnb8xPS2D6akZrNiSTWSE0Lt1XW4/ty3nd6pPfGy03xFDjmgItEMkJydrSkqK3zFMMOxYA+9dCdtXwHmPQK/RHHYsE0a2Ze/n00UZTE/LYOGGLACSm9ViYFIi/bs0pG71GH8DhgARWaCqyUVNsz0vU77qtILrv4b/3ARf/cVpBxv4gtPNIgzsycnlyyWZTE/L4L+rt1Og0L5BPPf0a8fFXRNpUruq3xHDhhUvU/5i4mHwm85J3bMfc/bChk6Gmk38TnZCcnLzmb18K9PTMpi5fCsH8gpoUjuOm85uxcBujWjXIN7viGHJipfxR0QEnHW3c8u1j0c5FzgcMgman+F3shLJyy9g7todTEvN4MslmWTvz6Nu9SoM69mUgUmJdG9S8+C3bsYbVryMv9pdCNfPdM6JfHMQXPAE9LyhQraDqSo/b8xiemoGMxZlsH3vAeJjorigcwMGJSXSq2UdooL4bZo5Nitexn8JbeGGmc4e2Od3Q2Ya9P87RMf6nQyAlVuymZa6ielpGWzcuY8qURH0bV+PQUmJnN2uHrHRkX5HrJSseJmKIbYGXP6uc1L3nKdg63IY+haclOhLnI07f+OTRU7XhuWZ2UQI9G5dlzF9na4NJ1nXBt9Z8TIVR0QE9PkLNOgCU2902sF+/wrUbonTfVtK/7MU827/9QCfL8nkk7RMUjZkoUD3JrV4eGAn+ndpSEK8dW2oSKyfl6mYtvzitIPt+p/fSQIUVfQiSlAYi3ttYJGNAIl0CrhEQkRkwDh3OCIy4PmxxkcUMV9x4wNef9iySrOM0oyPgianlPwTt35eJuTU7+ic2L3qayjIdc9p0cN/asHR46DoeQN+5ubns2brXpZl7GbN1mzyC5SacZF0aBhPhwbVSagec9xlHPPnCb2mwHkU5IPmuz8L3OGCgHEB0wryQQ8UMz5w/qJeX8z4om/mFTyRVeCB4NzK0IqXqbjiakLXwUFZVH6BMnfNDqanbeLzJZlk5+RRp1oVBiQ3ZGBSIj2a1rKuDXCokB5V6IorjKUsmEFkxcuELVUldWMW09MymLFoM9uy91M9JooLOjVgYFIivVtZ14ajiBw6zKvgrHiZsLN6azbTUjOYlprBhp2/USUygj7t6zEwKZE+7a1rQ7iw4mXCwqasfXyS5hSsZZv3ECFwequ63NKnNRd0akCNOOvaEG6seJmQtfPXA3y6eDPTUzcxf90uAJKa1OTBiztyUdeG1IuvGJ1cjTeseJmQsnd/Hl//ksm01Ay+X7WdvAKlTb3q3HV+Wy7ulkizOnaV1srCipep8Pbn5fN/K7YxPS2Db5ZtISe3gEY147j+zJYMSkqkfYN4+6awErLiZSqs9F2/8eKs1Xy2eDN7cvKoXa0Kg09uwiC3a0NEhBWsysyKl6mQFqVnce3EFH7dn8eFnd2uDa3rBvUa6Ca0WfEyFc7MZVu45Z2fqVO9Cu+N6k3renYxP3M0K16mQnlr7joenL6UTok1eH1Esn1jaIplxctUCAUFyt++WM6/5qzl3A71eOGK7lStYpunKZ5tHcZ3Obn53PlBGp8u2szVpzXjoYGdiLTGeHMcVryMr3b9eoAb3kwhZf0u/ty/PTec2dK6PZgSseJlfLN+x6+MnDCf9Kx9vDSsBxd1beh3JBNCrHgZX/y8YRfXT0ohX5V3rj+V5Oa1/Y5kQowVL1PuvlyayZj3fqZefCwTR55Cy4TqfkcyIciKlylXb3z/Px759Be6Na7J68OTqWO3vDcnyIqXKRf5Bcpjny7jjf/+jws61ef5od2Jq2LX1TInzoqX8VxObj63vZfKF0szGdm7Ofdf1NG6Qpgys+JlPLVj736ufzOF1I1Z/HVAR649o4XfkUyYsOJlPLN2215GTpxP5u4cxl95Mv06N/A7kgkjnp2iLyJviMhWEVkSMK62iHwtIqvcn7W8Wr/x14L1O/nD+B/Izsnj3VGnWeEyQefl9UUmAv2OGDcWmKmqbYCZ7nMTZj5dtJkrXvuJmlWrMPXm0+nR1P5HmeDzrHip6hxg5xGjBwGT3OFJwCVerd+UP1XltTlrGf3OQro2qsFHN51ul2U2ninvNq/6qrrZHc4E6hc3o4iMAkYBNG3atByimbLIL1DGfbKUSXPXc1GXhvx9SDe7xZjxlG+XpVQtvNd5sdNfVdVkVU1OSEgox2SmtH47kMcf31rApLnrGfW7lvzziu5WuIznynvPa4uINFTVzSLSENhazus3QbYtez/XTZrPkk27eWRQJ67u1dzvSKaSKO89r+nAcHd4ODCtnNdvgmj11r38/uX/smrLXl69OtkKlylXnu15ici7wNlAXRFJBx4EngSmiMh1wHpgiFfrN976ae0ORr21gOhI4f0/nkbXxjX9jmQqGc+Kl6peUcykvl6t05SPaambuPuDRTSpHcfEkT1pUruq35FMJWQ97E2JqSrj/28NT32xgp4tavPa1cnUqBrtdyxTSVnxMiWSl1/AA9OW8u68DQzslsjTg7sSE2XfKBr/WPEyx/Xr/jxueWchs1ds4+azW3HX+e3sbtXGd1a8zDFt2ZPDtRPnszwzm8d/34Vhp1qHYVMxWPEyxVq5JZsRb8wja18u/x6ezDnt6vkdyZiDrHiZIv2wejt/fHsBcdGRTPljLzo3quF3JGMOY8XLHOXjhenc+9EiWtStxoSRPWlUM87vSMYcxYqXOUhV+ees1Tz79UpOb1WH8VedTI046wphKiYrXgaA3PwC/jJ1MVNS0rm0RyOevLQrVaJ8O2/fmOOy4mXIzsnl5skL+W7Vdv7Utw23n9sGEesKYSo2K16V3Obd+xg5YT6rt+7lqT90ZcgpTfyOZEyJWPGqxJZt3sPICfPZuz+PCSNP4cw2dt00EzqseFVSc1Zu4+bJC6keE8UHN/aiQ8OT/I5kTKlY8aqEpqRs5M8fL6Z1vepMGHkKDWtYVwgTeqx4VSKqynPfrOKFmas4s01dXr6yB/Gx1hXChCYrXpXEgbwCxn68iI8XbmJIcmMe+30XoiOtK4QJXVa8KoHd+3K56e0F/LBmB3ee15Zb+rS2rhAm5FnxCnObsvYxcsI8/rf9V54d0o1LezT2O5IxQWHFK4wt2bSbayfOZ19uPpNG9uT01nX9jmRM0FjxClOzV2xl9OSF1KpahbevP5W29eP9jmRMUFnxCkPv/LSBB6YtoX2DeCaMOIV6J8X6HcmYoLPiFUYKCpRnvlrBy9+u4Zx2Cbw4rAfVYuxXbMKTbdlhYn9ePnd/sIjpaRkMO7Up4wZ2Isq6QpgwZsUrDGT9doBRby1g3v92cm+/9tx4VkvrCmHCnhWvELdx52+MmDCPjTv38Y/LkxiU1MjvSMaUCyteIWxRehbXTpxPbr7y1nU9ObVlHb8jGVNuwqp4fbk0k3s/WkRUhBAVEUFkhBAdKe5P53lUhBAVGTgtgugIZ56oSOd1Ue5wZETEwdcXvq5w2c68RS876uCy3BwBw1FFDB+2jMgipkVEHHWfxG9+2cKt7/5MnepVeG9UT1rXq+7Tp26MP8KqeDWqGcegbonkFij5+UpuQQH5BUpevpLnDufmq/uzgP25BeQW5JNfUODOc2jaoXkLyHOXkV/gLFO1/N+bCERHHCpwe/fn0bVRDf49/BQS4mPKP5AxPgur4tW5UY1yuUVXQUFAYQwojnn5hxe/w6YdMZx/1HglL7/gqAJ61OvcaTXiohn1u5ZUrRJWv0JjSsy2/BMQESHERET6HcOYSs06AhljQpIVL2NMSLLiZYwJSVa8jDEhyYqXMSYkWfEyxoQkK17GmJBkxcsYE5JE/TjXpZREZBuwvoSz1wW2exinorD3GV7sfRatmaomFDUhJIpXaYhIiqom+53Da/Y+w4u9z9Kzw0ZjTEiy4mWMCUnhWLxe9TtAObH3GV7sfZZS2LV5GWMqh3Dc8zLGVAJWvIwxISmsipeI9BORFSKyWkTG+p3HCyLyhohsFZElfmfxkog0EZHZIvKLiCwVkTF+Z/KCiMSKyDwRSXPf58N+Z/KSiESKyM8iMqOsywqb4iUikcBLwIVAR+AKEenobypPTAT6+R2iHOQBd6pqR+A0YHSY/j73A31UtRuQBPQTkdP8jeSpMcCyYCwobIoX0BNYraprVfUA8B4wyOdMQaeqc4CdfufwmqpuVtWF7nA2zgYfdjelVMde92m0+wjLb9FEpDFwEfDvYCwvnIpXI2BjwPN0wnBjr4xEpDnQHfjJ5yiecA+lUoGtwNeqGpbvE3geuAcoCMbCwql4mTAkItWBj4DbVHWP33m8oKr5qpoENAZ6ikhnnyMFnYgMALaq6oJgLTOcitcmoEnA88buOBOiRCQap3BNVtWP/c7jNVXNAmYTnm2avYGBIrIOp0mnj4i8XZYFhlPxmg+0EZEWIlIFuByY7nMmc4JERIDXgWWq+qzfebwiIgkiUtMdjgPOA5b7GsoDqnqfqjZW1eY4f5uzVPWqsiwzbIqXquYBtwBf4jTuTlHVpf6mCj4ReReYC7QTkXQRuc7vTB7pDVyN8x861X309zuUBxoCs0VkEc4/4K9VtczdCCoDOz3IGBOSwmbPyxhTuVjxMsaEJCtexpiQZMXLGBOSrHgZY0KSFS/jKRHJD+jqkBrMq32ISPNwv7qGKV6U3wFM2NvnnvpiTFDZnpfxhYisE5GnRGSxez2r1u745iIyS0QWichMEWnqjq8vIlPd616licjp7qIiReQ191pYX7m91BGRP7nXAlskIu/59DaNh6x4Ga/FHXHYODRg2m5V7QK8iHPFAYB/ApNUtSswGXjBHf8C8H/uda96AIVnT7QBXlLVTkAW8Ad3/Figu7ucG715a8ZP1sPeeEpE9qpq9SLGr8O5CN9a9wTsTFWtIyLbgYaqmuuO36yqdd27pjdW1f0By2iOczpNG/f5vUC0qj4qIl8Ae4H/AP8JuGaWCRO252X8pMUMl8b+gOF8DrXjXoRzZd0ewHwRsfbdMGPFy/hpaMDPue7wDzhXHQC4EvjOHZ4J3AQHL95Xo7iFikgE0ERVZwP3AjWAo/b+TGiz/0bGa3HuVUILfaGqhd0larlXU9gPXOGOuxWYICJ3A9uAke74McCr7lU08nEK2eZi1hkJvO0WOAFecK+VZcKItXkZX7htXsmqut3vLCY02WGjMSYk2Z6XMSYk2Z6XMSYkWfEyxoQkK17GmJBkxcsYE5KseBljQtL/A/w8USoPr020AAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "training_vis(train_losses, valid_losses)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 模型评估\n",
+ "\n",
+ "在测试集上评估模型效果。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T04:55:47.340416Z",
+ "iopub.status.busy": "2021-11-29T04:55:47.339537Z",
+ "iopub.status.idle": "2021-11-29T04:55:47.606447Z",
+ "shell.execute_reply": "2021-11-29T04:55:47.607038Z",
+ "shell.execute_reply.started": "2021-11-28T14:01:44.127754Z"
+ },
+ "papermill": {
+ "duration": 2.453872,
+ "end_time": "2021-11-29T04:55:47.607210",
+ "exception": false,
+ "start_time": "2021-11-29T04:55:45.153338",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# 加载得分最高的模型\n",
+ "checkpoint = torch.load('../input/ai-earth-model-weights/task05_model_weights.pth')\n",
+ "model = SAConvLSTM(input_dim, hidden_dim, d_attn, kernel_size)\n",
+ "model.load_state_dict(checkpoint['state_dict'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T04:55:51.849996Z",
+ "iopub.status.busy": "2021-11-29T04:55:51.849073Z",
+ "iopub.status.idle": "2021-11-29T04:55:51.851492Z",
+ "shell.execute_reply": "2021-11-29T04:55:51.850969Z",
+ "shell.execute_reply.started": "2021-11-28T14:06:59.931318Z"
+ },
+ "papermill": {
+ "duration": 2.125413,
+ "end_time": "2021-11-29T04:55:51.851629",
+ "exception": false,
+ "start_time": "2021-11-29T04:55:49.726216",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 测试集路径\n",
+ "test_path = '../input/ai-earth-tests/'\n",
+ "# 测试集标签路径\n",
+ "test_label_path = '../input/ai-earth-tests-labels/'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T04:55:56.429364Z",
+ "iopub.status.busy": "2021-11-29T04:55:56.428800Z",
+ "iopub.status.idle": "2021-11-29T04:55:58.115486Z",
+ "shell.execute_reply": "2021-11-29T04:55:58.115007Z",
+ "shell.execute_reply.started": "2021-11-28T14:07:13.415385Z"
+ },
+ "papermill": {
+ "duration": 4.135325,
+ "end_time": "2021-11-29T04:55:58.115667",
+ "exception": false,
+ "start_time": "2021-11-29T04:55:53.980342",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "# 读取测试数据和测试数据的标签\n",
+ "files = os.listdir(test_path)\n",
+ "X_test = []\n",
+ "y_test = []\n",
+ "for file in files:\n",
+ " X_test.append(np.load(test_path + file))\n",
+ " y_test.append(np.load(test_label_path + file))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T04:56:02.560786Z",
+ "iopub.status.busy": "2021-11-29T04:56:02.559461Z",
+ "iopub.status.idle": "2021-11-29T04:56:02.587431Z",
+ "shell.execute_reply": "2021-11-29T04:56:02.588024Z",
+ "shell.execute_reply.started": "2021-11-28T14:07:17.046359Z"
+ },
+ "papermill": {
+ "duration": 2.329175,
+ "end_time": "2021-11-29T04:56:02.588201",
+ "exception": false,
+ "start_time": "2021-11-29T04:56:00.259026",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((103, 12, 24, 48, 1), (103, 24))"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_test = np.array(X_test)[:, :, :, 19: 67, :1]\n",
+ "y_test = np.array(y_test)\n",
+ "X_test.shape, y_test.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-29T04:56:07.675344Z",
+ "iopub.status.busy": "2021-11-29T04:56:07.674488Z",
+ "iopub.status.idle": "2021-11-29T04:56:07.682481Z",
+ "shell.execute_reply": "2021-11-29T04:56:07.682895Z",
+ "shell.execute_reply.started": "2021-11-28T14:07:31.503452Z"
+ },
+ "papermill": {
+ "duration": 2.455352,
+ "end_time": "2021-11-29T04:56:07.683041",
+ "exception": false,
+ "start_time": "2021-11-29T04:56:05.227689",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "testset = AIEarthDataset(X_test, y_test)\n",
+ "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 在测试集上评估模型效果\n",
+ "model.eval()\n",
+ "model.to(device)\n",
+ "preds = np.zeros((len(y_test),24))\n",
+ "for i, data in tqdm(enumerate(testloader)):\n",
+ " data, labels = data\n",
+ " data = data.to(device)\n",
+ " labels = labels.to(device)\n",
+ " pred = model(data, train=False)\n",
+ " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
+ "s = score(y_test, preds)\n",
+ "print('Score: {:.3f}'.format(s))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "papermill": {
+ "duration": null,
+ "end_time": null,
+ "exception": null,
+ "start_time": null,
+ "status": "pending"
+ },
+ "tags": []
+ },
+ "source": [
+ "## 总结\n",
+ "\n",
+ "这一次的TOP方案没有自己设计模型,而是使用了目前时空序列预测领域现有的模型,另一组TOP选手“ailab”也使用了现有的模型PredRNN++,关于时空序列预测领域的一些比较经典的模型可以参考https://www.zhihu.com/column/c_1208033701705162752"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 作业\n",
+ "\n",
+ "该TOP方案中以sst作为预测目标,间接计算nino3.4指数,学有余力的同学可以尝试用SA-ConvLSTM模型直接预测nino3.4指数。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 参考文献\n",
+ "\n",
+ "1. 吴先生的队伍方案分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.9.561d5330dF9lX1&postId=231465\n",
+ "2. ailab团队思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.15.561d5330dF9lX1&postId=210734"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.3"
+ },
+ "papermill": {
+ "default_parameters": {},
+ "duration": 6708.571081,
+ "end_time": "2021-11-29T04:56:15.789285",
+ "environment_variables": {},
+ "exception": true,
+ "input_path": "__notebook__.ipynb",
+ "output_path": "__notebook__.ipynb",
+ "parameters": {},
+ "start_time": "2021-11-29T03:04:27.218204",
+ "version": "2.3.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}