diff --git a/Task4/Task4 模型建立之TCNN+RNN.ipynb b/Task4/Task4 模型建立之TCNN+RNN.ipynb
new file mode 100644
index 0000000..2e575a0
--- /dev/null
+++ b/Task4/Task4 模型建立之TCNN+RNN.ipynb
@@ -0,0 +1,2009 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Datawhale 气象海洋预测-Task4 模型建立之 TCNN+RNN\n",
+ "本次任务我们将学习来自TOP选手“swg-lhl”的冠军建模方案,该方案中采用的模型是TCNN+RNN\n",
+ "\n",
+ "在Task3中我们学习了CNN+LSTM模型,但是LSTM有四个门,构建一个LSTM层的参数量非常大,这就带来以下问题:一是参数量大在数据量小的情况下模型容易过拟合;二是为了尽量避免过拟合,在有限的数据集下我们无法构建更深层的模型,难以挖掘到更复杂的信息,这一点在可用的特征数非常少的情况下是很不利的。相较于LSTM,CNN的参数量只与过滤器的大小有关,在各类任务中往往都有不错的表现,因此我们可以考虑同样用卷积操作来挖掘时间信息。但是如果用三维卷积来同时挖掘时间和空间信息,假设使用的过滤器大小为(T_f, H_f, W_f),那么一层的参数量就是T_f×H_f×W_f,这样的参数量仍然是比较大的。为了进一步降低每一层的参数,增加模型深度,我们本次学习的这个TOP方案对时间和空间分别进行卷积操作,即采用TCN单元挖掘时间信息,然后输入CNN单元中挖掘空间信息,将TCN单元+CNN单元的串行结构称为TCNN层,通过堆叠多层的TCNN层就可以很好地挖掘到复杂的时空信息。同时,考虑到不同时间尺度下的时空信息对预测结果的影响可能是不同的,该方案采用了三个RNN层来抽取三种时间尺度下的特征,将三者拼接起来得到最终的预测结果。\n",
+ "\n",
+ "可以看出,该方案是基于给定的数据集情况、基于问题背景深刻理解后得出来的,希望同学们在学习的过程中能够深入思考,将这种建模方法灵活迁移到其他任务中。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 学习目标\n",
+ "1. 学习TOP方案的模型构建方法"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 内容介绍\n",
+ "1. 数据处理\n",
+ " - 数据扁平化\n",
+ " - 空值填充\n",
+ " - 构造数据集\n",
+ "2. 模型构建\n",
+ " - 构造评估函数\n",
+ " - 模型构造\n",
+ " - 模型训练\n",
+ " - 模型评估\n",
+ "3. 总结"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 代码示例"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 数据处理\n",
+ "该TOP方案的数据处理主要包括三部分:\n",
+ "1. 数据扁平化。\n",
+ "2. 空值填充。\n",
+ "3. 构造数据集\n",
+ "在该方案中除了没有构造新的特征外,其他数据处理方法都与Task3基本相同,因此不多做赘述。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:32:41.439634Z",
+ "iopub.status.busy": "2021-11-07T11:32:41.431440Z",
+ "iopub.status.idle": "2021-11-07T11:32:43.645979Z",
+ "shell.execute_reply": "2021-11-07T11:32:43.645292Z",
+ "shell.execute_reply.started": "2021-11-07T11:22:11.718250Z"
+ },
+ "papermill": {
+ "duration": 2.245092,
+ "end_time": "2021-11-07T11:32:43.646157",
+ "exception": false,
+ "start_time": "2021-11-07T11:32:41.401065",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import netCDF4 as nc\n",
+ "import random\n",
+ "import os\n",
+ "from tqdm import tqdm\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline\n",
+ "\n",
+ "import torch\n",
+ "from torch import nn, optim\n",
+ "import torch.nn.functional as F\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "\n",
+ "from sklearn.metrics import mean_squared_error"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:32:43.703083Z",
+ "iopub.status.busy": "2021-11-07T11:32:43.702422Z",
+ "iopub.status.idle": "2021-11-07T11:32:43.706718Z",
+ "shell.execute_reply": "2021-11-07T11:32:43.707099Z",
+ "shell.execute_reply.started": "2021-11-07T11:22:17.072537Z"
+ },
+ "papermill": {
+ "duration": 0.03493,
+ "end_time": "2021-11-07T11:32:43.707228",
+ "exception": false,
+ "start_time": "2021-11-07T11:32:43.672298",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 固定随机种子\n",
+ "SEED = 22\n",
+ "\n",
+ "def seed_everything(seed=42):\n",
+ " random.seed(seed)\n",
+ " os.environ['PYTHONHASHSEED'] = str(seed)\n",
+ " np.random.seed(seed)\n",
+ " torch.manual_seed(seed)\n",
+ " torch.cuda.manual_seed(seed)\n",
+ " torch.backends.cudnn.deterministic = True\n",
+ " \n",
+ "seed_everything(SEED)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:32:43.806316Z",
+ "iopub.status.busy": "2021-11-07T11:32:43.805429Z",
+ "iopub.status.idle": "2021-11-07T11:32:43.809106Z",
+ "shell.execute_reply": "2021-11-07T11:32:43.809544Z",
+ "shell.execute_reply.started": "2021-11-07T11:22:34.845477Z"
+ },
+ "papermill": {
+ "duration": 0.077409,
+ "end_time": "2021-11-07T11:32:43.809691",
+ "exception": false,
+ "start_time": "2021-11-07T11:32:43.732282",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA is available! Training on GPU ...\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 查看GPU是否可用\n",
+ "train_on_gpu = torch.cuda.is_available()\n",
+ "\n",
+ "if not train_on_gpu:\n",
+ " print('CUDA is not available. Training on CPU ...')\n",
+ "else:\n",
+ " print('CUDA is available! Training on GPU ...')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:32:43.938252Z",
+ "iopub.status.busy": "2021-11-07T11:32:43.937335Z",
+ "iopub.status.idle": "2021-11-07T11:32:43.991840Z",
+ "shell.execute_reply": "2021-11-07T11:32:43.991334Z",
+ "shell.execute_reply.started": "2021-11-07T11:22:38.372763Z"
+ },
+ "papermill": {
+ "duration": 0.156188,
+ "end_time": "2021-11-07T11:32:43.991963",
+ "exception": false,
+ "start_time": "2021-11-07T11:32:43.835775",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 读取数据\n",
+ "\n",
+ "# 存放数据的路径\n",
+ "path = '/kaggle/input/ninoprediction/'\n",
+ "soda_train = nc.Dataset(path + 'SODA_train.nc')\n",
+ "soda_label = nc.Dataset(path + 'SODA_label.nc')\n",
+ "cmip_train = nc.Dataset(path + 'CMIP_train.nc')\n",
+ "cmip_label = nc.Dataset(path + 'CMIP_label.nc')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 数据扁平化\n",
+ "采用滑窗构造数据集。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:32:44.053052Z",
+ "iopub.status.busy": "2021-11-07T11:32:44.052369Z",
+ "iopub.status.idle": "2021-11-07T11:32:44.055801Z",
+ "shell.execute_reply": "2021-11-07T11:32:44.055258Z",
+ "shell.execute_reply.started": "2021-11-07T11:22:41.386272Z"
+ },
+ "papermill": {
+ "duration": 0.037368,
+ "end_time": "2021-11-07T11:32:44.055938",
+ "exception": false,
+ "start_time": "2021-11-07T11:32:44.018570",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "def make_flatted(train_ds, label_ds, info, start_idx=0):\n",
+ " keys = ['sst', 't300', 'ua', 'va']\n",
+ " label_key = 'nino'\n",
+ " # 年数\n",
+ " years = info[1]\n",
+ " # 模式数\n",
+ " models = info[2]\n",
+ " \n",
+ " train_list = []\n",
+ " label_list = []\n",
+ " \n",
+ " # 将同种模式下的数据拼接起来\n",
+ " for model_i in range(models):\n",
+ " blocks = []\n",
+ " \n",
+ " # 对每个特征,取每条数据的前12个月进行拼接\n",
+ " for key in keys:\n",
+ " block = train_ds[key][start_idx + model_i * years: start_idx + (model_i + 1) * years, :12].reshape(-1, 24, 72, 1).data\n",
+ " blocks.append(block)\n",
+ " \n",
+ " # 将所有特征在最后一个维度上拼接起来\n",
+ " train_flatted = np.concatenate(blocks, axis=-1)\n",
+ " \n",
+ " # 取12-23月的标签进行拼接,注意加上最后一年的最后12个月的标签(与最后一年12-23月的标签共同构成最后一年前12个月的预测目标)\n",
+ " label_flatted = np.concatenate([\n",
+ " label_ds[label_key][start_idx + model_i * years: start_idx + (model_i + 1) * years, 12: 24].reshape(-1).data,\n",
+ " label_ds[label_key][start_idx + (model_i + 1) * years - 1, 24: 36].reshape(-1).data\n",
+ " ], axis=0)\n",
+ " \n",
+ " train_list.append(train_flatted)\n",
+ " label_list.append(label_flatted)\n",
+ " \n",
+ " return train_list, label_list"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:32:44.114429Z",
+ "iopub.status.busy": "2021-11-07T11:32:44.113857Z",
+ "iopub.status.idle": "2021-11-07T11:33:35.423369Z",
+ "shell.execute_reply": "2021-11-07T11:33:35.423920Z",
+ "shell.execute_reply.started": "2021-11-07T11:22:43.054065Z"
+ },
+ "papermill": {
+ "duration": 51.342264,
+ "end_time": "2021-11-07T11:33:35.424098",
+ "exception": false,
+ "start_time": "2021-11-07T11:32:44.081834",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((1, 1200, 24, 72, 4), (15, 1812, 24, 72, 4), (17, 1680, 24, 72, 4))"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "soda_info = ('soda', 100, 1)\n",
+ "cmip6_info = ('cmip6', 151, 15)\n",
+ "cmip5_info = ('cmip5', 140, 17)\n",
+ "\n",
+ "soda_trains, soda_labels = make_flatted(soda_train, soda_label, soda_info)\n",
+ "cmip6_trains, cmip6_labels = make_flatted(cmip_train, cmip_label, cmip6_info)\n",
+ "cmip5_trains, cmip5_labels = make_flatted(cmip_train, cmip_label, cmip5_info, cmip6_info[1]*cmip6_info[2])\n",
+ "\n",
+ "# 得到扁平化后的数据维度为(模式数×序列长度×纬度×经度×特征数),其中序列长度=年数×12\n",
+ "np.shape(soda_trains), np.shape(cmip6_trains), np.shape(cmip5_trains)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 空值填充\n",
+ "将空值填充为0。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:35.538420Z",
+ "iopub.status.busy": "2021-11-07T11:33:35.537032Z",
+ "iopub.status.idle": "2021-11-07T11:33:35.589717Z",
+ "shell.execute_reply": "2021-11-07T11:33:35.590116Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:30.781276Z"
+ },
+ "papermill": {
+ "duration": 0.083633,
+ "end_time": "2021-11-07T11:33:35.590263",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:35.506630",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of null in soda_trains after fillna: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 填充SODA数据中的空值\n",
+ "soda_trains = np.array(soda_trains)\n",
+ "soda_trains_nan = np.isnan(soda_trains)\n",
+ "soda_trains[soda_trains_nan] = 0\n",
+ "print('Number of null in soda_trains after fillna:', np.sum(np.isnan(soda_trains)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:35.649057Z",
+ "iopub.status.busy": "2021-11-07T11:33:35.647940Z",
+ "iopub.status.idle": "2021-11-07T11:33:36.785018Z",
+ "shell.execute_reply": "2021-11-07T11:33:36.784016Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:30.842000Z"
+ },
+ "papermill": {
+ "duration": 1.1683,
+ "end_time": "2021-11-07T11:33:36.785204",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:35.616904",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of null in cmip6_trains after fillna: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 填充CMIP6数据中的空值\n",
+ "cmip6_trains = np.array(cmip6_trains)\n",
+ "cmip6_trains_nan = np.isnan(cmip6_trains)\n",
+ "cmip6_trains[cmip6_trains_nan] = 0\n",
+ "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip6_trains)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:36.844494Z",
+ "iopub.status.busy": "2021-11-07T11:33:36.843382Z",
+ "iopub.status.idle": "2021-11-07T11:33:37.982369Z",
+ "shell.execute_reply": "2021-11-07T11:33:37.983434Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:31.982897Z"
+ },
+ "papermill": {
+ "duration": 1.170648,
+ "end_time": "2021-11-07T11:33:37.983683",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:36.813035",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of null in cmip6_trains after fillna: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 填充CMIP5数据中的空值\n",
+ "cmip5_trains = np.array(cmip5_trains)\n",
+ "cmip5_trains_nan = np.isnan(cmip5_trains)\n",
+ "cmip5_trains[cmip5_trains_nan] = 0\n",
+ "print('Number of null in cmip6_trains after fillna:', np.sum(np.isnan(cmip5_trains)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 构造数据集\n",
+ "构造训练和验证集。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:38.091882Z",
+ "iopub.status.busy": "2021-11-07T11:33:38.091030Z",
+ "iopub.status.idle": "2021-11-07T11:33:38.757089Z",
+ "shell.execute_reply": "2021-11-07T11:33:38.756486Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:33.105534Z"
+ },
+ "papermill": {
+ "duration": 0.72117,
+ "end_time": "2021-11-07T11:33:38.757230",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:38.036060",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构造训练集\n",
+ "\n",
+ "X_train = []\n",
+ "y_train = []\n",
+ "# 从CMIP5的17种模式中各抽取100条数据\n",
+ "for model_i in range(17):\n",
+ " samples = np.random.choice(cmip5_trains.shape[1]-12, size=100)\n",
+ " for ind in samples:\n",
+ " X_train.append(cmip5_trains[model_i, ind: ind+12])\n",
+ " y_train.append(cmip5_labels[model_i][ind: ind+24])\n",
+ "# 从CMIP6的15种模式种各抽取100条数据\n",
+ "for model_i in range(15):\n",
+ " samples = np.random.choice(cmip6_trains.shape[1]-12, size=100)\n",
+ " for ind in samples:\n",
+ " X_train.append(cmip6_trains[model_i, ind: ind+12])\n",
+ " y_train.append(cmip6_labels[model_i][ind: ind+24])\n",
+ "X_train = np.array(X_train)\n",
+ "y_train = np.array(y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:38.819204Z",
+ "iopub.status.busy": "2021-11-07T11:33:38.818020Z",
+ "iopub.status.idle": "2021-11-07T11:33:38.840229Z",
+ "shell.execute_reply": "2021-11-07T11:33:38.839737Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:33.801270Z"
+ },
+ "papermill": {
+ "duration": 0.055138,
+ "end_time": "2021-11-07T11:33:38.840360",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:38.785222",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构造验证集\n",
+ "\n",
+ "X_valid = []\n",
+ "y_valid = []\n",
+ "samples = np.random.choice(soda_trains.shape[1]-12, size=100)\n",
+ "for ind in samples:\n",
+ " X_valid.append(soda_trains[0, ind: ind+12])\n",
+ " y_valid.append(soda_labels[0][ind: ind+24])\n",
+ "X_valid = np.array(X_valid)\n",
+ "y_valid = np.array(y_valid)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:38.898632Z",
+ "iopub.status.busy": "2021-11-07T11:33:38.897899Z",
+ "iopub.status.idle": "2021-11-07T11:33:38.900564Z",
+ "shell.execute_reply": "2021-11-07T11:33:38.901077Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:33.839695Z"
+ },
+ "papermill": {
+ "duration": 0.034011,
+ "end_time": "2021-11-07T11:33:38.901204",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:38.867193",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((3200, 12, 24, 72, 4), (3200, 24), (100, 12, 24, 72, 4), (100, 24))"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# 查看数据集维度\n",
+ "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:39.033073Z",
+ "iopub.status.busy": "2021-11-07T11:33:39.032319Z",
+ "iopub.status.idle": "2021-11-07T11:33:40.963743Z",
+ "shell.execute_reply": "2021-11-07T11:33:40.962736Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:33.879524Z"
+ },
+ "papermill": {
+ "duration": 1.963702,
+ "end_time": "2021-11-07T11:33:40.963922",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:39.000220",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 保存数据集\n",
+ "np.save('X_train_sample.npy', X_train)\n",
+ "np.save('y_train_sample.npy', y_train)\n",
+ "np.save('X_valid_sample.npy', X_valid)\n",
+ "np.save('y_valid_sample.npy', y_valid)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "papermill": {
+ "duration": 0.442674,
+ "end_time": "2021-11-07T11:33:43.003352",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:42.560678",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "### 模型构建"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "这一部分我们来重点学习一下该方案的模型结构。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:43.228569Z",
+ "iopub.status.busy": "2021-11-07T11:33:43.223113Z",
+ "iopub.status.idle": "2021-11-07T11:33:58.100748Z",
+ "shell.execute_reply": "2021-11-07T11:33:58.101185Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:42.357343Z"
+ },
+ "papermill": {
+ "duration": 15.039399,
+ "end_time": "2021-11-07T11:33:58.101351",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:43.061952",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 读取数据集\n",
+ "X_train = np.load('../input/ai-earth-task04-samples/X_train_sample.npy')\n",
+ "y_train = np.load('../input/ai-earth-task04-samples/y_train_sample.npy')\n",
+ "X_valid = np.load('../input/ai-earth-task04-samples/X_valid_sample.npy')\n",
+ "y_valid = np.load('../input/ai-earth-task04-samples/y_valid_sample.npy')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:58.162900Z",
+ "iopub.status.busy": "2021-11-07T11:33:58.161742Z",
+ "iopub.status.idle": "2021-11-07T11:33:58.164817Z",
+ "shell.execute_reply": "2021-11-07T11:33:58.165196Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:56.303716Z"
+ },
+ "papermill": {
+ "duration": 0.036534,
+ "end_time": "2021-11-07T11:33:58.165327",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:58.128793",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((3200, 12, 24, 72, 4), (3200, 24), (100, 12, 24, 72, 4), (100, 24))"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_train.shape, y_train.shape, X_valid.shape, y_valid.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:58.225845Z",
+ "iopub.status.busy": "2021-11-07T11:33:58.224285Z",
+ "iopub.status.idle": "2021-11-07T11:33:58.226437Z",
+ "shell.execute_reply": "2021-11-07T11:33:58.226877Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:56.313927Z"
+ },
+ "papermill": {
+ "duration": 0.03485,
+ "end_time": "2021-11-07T11:33:58.227001",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:58.192151",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构造数据管道\n",
+ "class AIEarthDataset(Dataset):\n",
+ " def __init__(self, data, label):\n",
+ " self.data = torch.tensor(data, dtype=torch.float32)\n",
+ " self.label = torch.tensor(label, dtype=torch.float32)\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.label)\n",
+ " \n",
+ " def __getitem__(self, idx):\n",
+ " return self.data[idx], self.label[idx]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:58.289091Z",
+ "iopub.status.busy": "2021-11-07T11:33:58.288549Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.029333Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.028812Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:56.324788Z"
+ },
+ "papermill": {
+ "duration": 0.775212,
+ "end_time": "2021-11-07T11:33:59.029494",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:58.254282",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "batch_size = 32\n",
+ "\n",
+ "trainset = AIEarthDataset(X_train, y_train)\n",
+ "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
+ "\n",
+ "validset = AIEarthDataset(X_valid, y_valid)\n",
+ "validloader = DataLoader(validset, batch_size=batch_size, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 构造评估函数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.402757Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.401182Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.404701Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.405088Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:57.219014Z"
+ },
+ "papermill": {
+ "duration": 0.037919,
+ "end_time": "2021-11-07T11:33:59.405211",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.367292",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "def rmse(y_true, y_preds):\n",
+ " return np.sqrt(mean_squared_error(y_pred = y_preds, y_true = y_true))\n",
+ "\n",
+ "# 评估函数\n",
+ "def score(y_true, y_preds):\n",
+ " # 相关性技巧评分\n",
+ " accskill_score = 0\n",
+ " # RMSE\n",
+ " rmse_scores = 0\n",
+ " a = [1.5] * 4 + [2] * 7 + [3] * 7 + [4] * 6\n",
+ " y_true_mean = np.mean(y_true, axis=0)\n",
+ " y_pred_mean = np.mean(y_preds, axis=0)\n",
+ " for i in range(24):\n",
+ " fenzi = np.sum((y_true[:, i] - y_true_mean[i]) * (y_preds[:, i] - y_pred_mean[i]))\n",
+ " fenmu = np.sqrt(np.sum((y_true[:, i] - y_true_mean[i])**2) * np.sum((y_preds[:, i] - y_pred_mean[i])**2))\n",
+ " cor_i = fenzi / fenmu\n",
+ " accskill_score += a[i] * np.log(i+1) * cor_i\n",
+ " rmse_score = rmse(y_true[:, i], y_preds[:, i])\n",
+ " rmse_scores += rmse_score\n",
+ " return 2/3.0 * accskill_score - rmse_scores"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 模型构造\n",
+ "\n",
+ "该TOP方案采用TCN单元+CNN单元串行组成TCNN层,通过堆叠多层的TCNN层来交替地提取时间和空间信息,并将提取到的时空信息用RNN来抽取出三种不同时间尺度的特征表达。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **TCN单元**\n",
+ "\n",
+ "TCN模型全称时间卷积网络(Temporal Convolutional Network),与RNN一样是时序模型。TCN以CNN为基础,为了适应序列问题,它从以下三方面做出了改进:\n",
+ "\n",
+ "1. 因果卷积\n",
+ "\n",
+ "TCN处理输入与输出等长的序列问题,它的每一个隐藏层节点数与输入步长是相同的,并且隐藏层t时刻节点的值只依赖于前一层t时刻及之前节点的值。也就是说TCN通过追溯前因(t时刻及之前的值)来获得当前结果,称为因果卷积。\n",
+ "\n",
+ "2. 扩张卷积\n",
+ "\n",
+ "传统CNN的感受野受限于卷积核的大小,需要通过增加池化层来获得更大的感受野,但是池化的操作会带来信息的损失。为了解决这个问题,TCN采用扩张卷积来增大感受野,获取更长时间的信息。扩张卷积对输入进行间隔采样,采样间隔由扩张因子d控制,公式定义如下:\n",
+ "$$\n",
+ "F(s) = (X * df)(s) = \\sum_{i=0}^{k-1} f(i) \\times X_{s-di}\n",
+ "$$\n",
+ "其中X为当前层的输入,k为当前层的卷积核大小,s为当前节点的时刻。也就是说,对于扩张因子为d、卷积核为k的隐藏层,对前一层的输入每d个点采样一次,共采样k个点作为当前时刻s的输入。这样TCN的感受野就由卷积核的大小k和扩张因子d共同决定,可以获取更长时间的依赖信息。\n",
+ "\n",
+ "
\n",
+ "\n",
+ "3. 残差连接\n",
+ "\n",
+ "网络的层数越多,所能提取到的特征就越丰富,但这也会带来梯度消失或爆炸的问题,目前解决这个问题的一个有效方法就是残差连接。TCN的残差模块包含两层卷积操作,并且采用了WeightNorm和Dropout进行正则化,如下图所示。\n",
+ "\n",
+ "
\n",
+ "\n",
+ "总的来说,TCN是卷积操作在序列问题上的改进,具有CNN参数量少的优点,可以搭建更深层的网络,相比于RNN不容易存在梯度消失和爆炸的问题,同时TCN具有灵活的感受野,能够适应不同的任务,在许多数据集上的比较表明TCN比RNN、LSTM、GRU等序列模型有更好的表现。\n",
+ "\n",
+ "想要更深入地了解TCN可以参考以下链接:\n",
+ " \n",
+ " - 论文原文:https://arxiv.org/pdf/1803.01271.pdf\n",
+ " - GitHub:https://github.com/locuslab/tcn\n",
+ " \n",
+ "该方案中所构建的TCN单元并不是标准的TCN层,它的结构如下图所示,可以看到,这里的TCN单元只是用了一个卷积层,并且在卷积层前后都采用了BatchNormalization来提高模型的泛化能力。需要注意的是,这里的卷积操作是对时间维度进行操作,因此需要对输入的形状进行转换,并且为了便于匹配之后的网络层,需要将输出的形状转换回输入时的(N,T,C,H,W)的形式。\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.094709Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.094119Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.097687Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.097168Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:57.153300Z"
+ },
+ "papermill": {
+ "duration": 0.039993,
+ "end_time": "2021-11-07T11:33:59.097803",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.057810",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构建TCN单元\n",
+ "class TCNBlock(nn.Module):\n",
+ " def __init__(self, in_channels, out_channels, kernel_size, stride, padding):\n",
+ " super().__init__()\n",
+ " self.bn1 = nn.BatchNorm1d(in_channels)\n",
+ " self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)\n",
+ " self.bn2 = nn.BatchNorm1d(out_channels)\n",
+ " \n",
+ " if in_channels == out_channels and stride == 1:\n",
+ " self.res = lambda x: x\n",
+ " else:\n",
+ " self.res = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
+ " \n",
+ " def forward(self, x):\n",
+ " # 转换输入形状\n",
+ " N, T, C, H, W = x.shape\n",
+ " x = x.permute(0, 3, 4, 2, 1).contiguous()\n",
+ " x = x.view(N*H*W, C, T)\n",
+ " \n",
+ " # 残差\n",
+ " res = self.res(x) \n",
+ " res = self.bn2(res)\n",
+ "\n",
+ " x = F.relu(self.bn1(x))\n",
+ " x = self.conv(x)\n",
+ " x = self.bn2(x)\n",
+ " \n",
+ " x = x + res\n",
+ " \n",
+ " # 将输出转换回(N,T,C,H,W)的形式\n",
+ " _, C_new, T_new = x.shape\n",
+ " x = x.view(N, H, W, C_new, T_new)\n",
+ " x = x.permute(0, 4, 3, 1, 2).contiguous()\n",
+ " \n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **CNN单元**\n",
+ "\n",
+ "CNN单元结构与TCN单元相似,都只有一个卷积层,并且使用BatchNormalization来提高模型泛化能力。同时,类似TCN单元,CNN单元中也加入了残差连接。结构如下图所示:\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.160481Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.158973Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.163197Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.163610Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:57.170593Z"
+ },
+ "papermill": {
+ "duration": 0.038796,
+ "end_time": "2021-11-07T11:33:59.163736",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.124940",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构建CNN单元\n",
+ "class CNNBlock(nn.Module):\n",
+ " def __init__(self, in_channels, out_channels, kernel_size, stride, padding):\n",
+ " super().__init__()\n",
+ " self.bn1 = nn.BatchNorm2d(in_channels)\n",
+ " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n",
+ " self.bn2 = nn.BatchNorm2d(out_channels)\n",
+ " \n",
+ " if (in_channels == out_channels) and (stride == 1):\n",
+ " self.res = lambda x: x\n",
+ " else:\n",
+ " self.res = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)\n",
+ " \n",
+ " def forward(self, x):\n",
+ " # 转换输入形状\n",
+ " N, T, C, H, W = x.shape\n",
+ " x = x.view(N*T, C, H, W)\n",
+ " \n",
+ " # 残差\n",
+ " res = self.res(x)\n",
+ " res = self.bn2(res)\n",
+ "\n",
+ " x = F.relu(self.bn1(x))\n",
+ " x = self.conv(x)\n",
+ " x = self.bn2(x)\n",
+ " \n",
+ " x = x + res\n",
+ " \n",
+ " # 将输出转换回(N,T,C,H,W)的形式\n",
+ " _, C_new, H_new, W_new = x.shape\n",
+ " x = x.view(N, T, C_new, H_new, W_new)\n",
+ " \n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **TCNN层**\n",
+ "\n",
+ "将TCN单元和CNN单元串行连接,就构成了一个TCNN层。\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.223231Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.222402Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.224980Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.224529Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:57.182192Z"
+ },
+ "papermill": {
+ "duration": 0.034509,
+ "end_time": "2021-11-07T11:33:59.225082",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.190573",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "class TCNNBlock(nn.Module):\n",
+ " def __init__(self, in_channels, out_channels, kernel_size, stride_tcn, stride_cnn, padding):\n",
+ " super().__init__()\n",
+ " self.tcn = TCNBlock(in_channels, out_channels, kernel_size, stride_tcn, padding)\n",
+ " self.cnn = CNNBlock(out_channels, out_channels, kernel_size, stride_cnn, padding)\n",
+ " \n",
+ " def forward(self, x):\n",
+ " x = self.tcn(x)\n",
+ " x = self.cnn(x)\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **TCNN+RNN模型**\n",
+ "\n",
+ "整体的模型结构如下图所示:\n",
+ "\n",
+ "
\n",
+ "\n",
+ "1. TCNN部分\n",
+ "\n",
+ "TCNN部分的模型结构类似传统CNN的结构,非常规整,通过逐渐增加通道数来提取更丰富的特征表达。需要注意的是输入数据的格式是(N,T,H,W,C),为了匹配卷积层的输入格式,需要将数据格式转换为(N,T,C,H,W)。\n",
+ "\n",
+ "2. GAP层\n",
+ "\n",
+ "GAP全称为全局平均池化(Global Average Pooling)层,它的作用是把每个通道上的特征图取全局平均,假设经过TCNN部分得到的输出格式为(N,T,C,H,W),那么GAP层就会把每个通道上形状为H×W的特征图上的所有值求平均,最终得到的输出格式就变成(N,T,C)。GAP层最早出现在论文《Network in Network》(论文原文:https://arxiv.org/pdf/1312.4400.pdf )中用于代替传统CNN中的全连接层,之后的许多实验证明GAP层确实可以提高CNN的效果。\n",
+ "\n",
+ "那么GAP层为什么可以代替全连接层呢?在传统CNN中,经过多层卷积和池化的操作后,会由Flatten层将特征图拉伸成一列,然后经过全连接层,那么对于形状为(C,H,W)的一条数据,经Flatten层拉伸后的长度为C×H×W,此时假设全连接层节点数为U,全连接层的参数量就是C×H×W×U,这么大的参数量很容易使得模型过拟合。相比之下,GAP层不引入新的参数,因此可以有效减少过拟合问题,并且模型参数少也能加快训练速度。另一方面,全连接层是一个黑箱子,我们很难解释多分类的信息是怎样传回卷积层的,而GAP层就很容易理解,每个通道的值就代表了经过多层卷积操作后所提取出来的特征。更详细的理解可以参考https://www.zhihu.com/question/373188099\n",
+ "\n",
+ "在Pytorch中没有内置的GAP层,因此可以用adaptive_avg_pool2d来替代,这个函数可以将特征图压缩成给定的输出形状,将output_size参数设置为(1,1),就等同于GAP操作,函数的详细使用方法可以参考https://pytorch.org/docs/stable/generated/torch.nn.functional.adaptive_avg_pool2d.html?highlight=adaptive_avg_pool2d#torch.nn.functional.adaptive_avg_pool2d\n",
+ "\n",
+ "3. RNN部分\n",
+ "\n",
+ "至此为止我们所使用的都是长度为12的时间序列,每个时间步代表一个月的信息。不同尺度的时间序列所携带的信息是不尽相同的,比如用长度为6的时间序列来表达一年的SST值,那么每个时间步所代表的就是两个月的SST信息,这种时间尺度下的SST序列与长度为12的SST序列所反映的一年中SST变化趋势等信息就不完全相同。所以,为了尽可能全面地挖掘更多信息,该TOP方案中用MaxPool层来获得三种不同时间尺度的序列,同时,用RNN层来抽取序列的特征表达。RNN非常适合用于线性序列的自动特征提取,例如对于形状为(T,C1)的一条输入数据,R经过节点数为C2的RNN层就能抽取出长度为C2的向量,由于RNN由前往后进行信息线性传递的网络结构,抽取出的向量能够很好地表达序列中的依赖关系。\n",
+ "\n",
+ "此时三种不同时间尺度的序列都抽取出了一个向量来表示,将向量拼接起来再经过一个全连接层就得到了24个月的预测序列。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.339242Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.337656Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.339842Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.340257Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:57.199137Z"
+ },
+ "papermill": {
+ "duration": 0.088367,
+ "end_time": "2021-11-07T11:33:59.340377",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.252010",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 构造模型\n",
+ "class Model(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.conv = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3)\n",
+ " self.tcnn1 = TCNNBlock(64, 64, 3, 1, 1, 1)\n",
+ " self.tcnn2 = TCNNBlock(64, 128, 3, 1, 2, 1)\n",
+ " self.tcnn3 = TCNNBlock(128, 128, 3, 1, 1, 1)\n",
+ " self.tcnn4 = TCNNBlock(128, 256, 3, 1, 2, 1)\n",
+ " self.tcnn5 = TCNNBlock(256, 256, 3, 1, 1, 1)\n",
+ " self.rnn = nn.RNN(256, 256, batch_first=True)\n",
+ " self.maxpool = nn.MaxPool1d(2)\n",
+ " self.fc = nn.Linear(256*3, 24)\n",
+ " \n",
+ " def forward(self, x):\n",
+ " # 转换输入形状\n",
+ " N, T, H, W, C = x.shape\n",
+ " x = x.permute(0, 1, 4, 2, 3).contiguous()\n",
+ " x = x.view(N*T, C, H, W)\n",
+ " \n",
+ " # 经过一个卷积层\n",
+ " x = self.conv(x)\n",
+ " _, C_new, H_new, W_new = x.shape\n",
+ " x = x.view(N, T, C_new, H_new, W_new)\n",
+ " \n",
+ " # TCNN部分\n",
+ " for i in range(3):\n",
+ " x = self.tcnn1(x)\n",
+ " x = self.tcnn2(x)\n",
+ " for i in range(2):\n",
+ " x = self.tcnn3(x)\n",
+ " x = self.tcnn4(x)\n",
+ " for i in range(2):\n",
+ " x = self.tcnn5(x)\n",
+ " \n",
+ " # 全局平均池化\n",
+ " x = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()\n",
+ " \n",
+ " # RNN部分,分别得到长度为T、T/2、T/4三种时间尺度的特征表达,注意转换RNN层输出的格式\n",
+ " hidden_state = []\n",
+ " for i in range(3):\n",
+ " x, h = self.rnn(x)\n",
+ " h = h.squeeze()\n",
+ " hidden_state.append(h)\n",
+ " x = self.maxpool(x.transpose(1, 2)).transpose(1, 2)\n",
+ " \n",
+ " x = torch.cat(hidden_state, dim=1)\n",
+ " x = self.fc(x)\n",
+ " \n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.467079Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.466474Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.507567Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.507973Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:57.232949Z"
+ },
+ "papermill": {
+ "duration": 0.076245,
+ "end_time": "2021-11-07T11:33:59.508101",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.431856",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model(\n",
+ " (conv): Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))\n",
+ " (tcnn1): TCNNBlock(\n",
+ " (tcn): TCNBlock(\n",
+ " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (cnn): CNNBlock(\n",
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (tcnn2): TCNNBlock(\n",
+ " (tcn): TCNBlock(\n",
+ " (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (res): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " (cnn): CNNBlock(\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (res): Conv2d(128, 128, kernel_size=(1, 1), stride=(2, 2))\n",
+ " )\n",
+ " )\n",
+ " (tcnn3): TCNNBlock(\n",
+ " (tcn): TCNBlock(\n",
+ " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (cnn): CNNBlock(\n",
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (tcnn4): TCNNBlock(\n",
+ " (tcn): TCNBlock(\n",
+ " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (res): Conv1d(128, 256, kernel_size=(1,), stride=(1,))\n",
+ " )\n",
+ " (cnn): CNNBlock(\n",
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (res): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2))\n",
+ " )\n",
+ " )\n",
+ " (tcnn5): TCNNBlock(\n",
+ " (tcn): TCNBlock(\n",
+ " (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (cnn): CNNBlock(\n",
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (rnn): RNN(256, 256, batch_first=True)\n",
+ " (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
+ " (fc): Linear(in_features=768, out_features=24, bias=True)\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = Model()\n",
+ "print(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 模型训练"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.569324Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.568443Z",
+ "iopub.status.idle": "2021-11-07T11:33:59.570995Z",
+ "shell.execute_reply": "2021-11-07T11:33:59.570559Z",
+ "shell.execute_reply.started": "2021-11-07T11:23:59.196182Z"
+ },
+ "papermill": {
+ "duration": 0.035076,
+ "end_time": "2021-11-07T11:33:59.571104",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.536028",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 采用RMSE作为损失函数\n",
+ "def RMSELoss(y_pred,y_true):\n",
+ " loss = torch.sqrt(torch.mean((y_pred-y_true)**2, dim=0)).sum()\n",
+ " return loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:33:59.635947Z",
+ "iopub.status.busy": "2021-11-07T11:33:59.635133Z",
+ "iopub.status.idle": "2021-11-07T11:41:54.987872Z",
+ "shell.execute_reply": "2021-11-07T11:41:54.987159Z",
+ "shell.execute_reply.started": "2021-11-07T11:24:00.777235Z"
+ },
+ "papermill": {
+ "duration": 475.387685,
+ "end_time": "2021-11-07T11:41:54.988052",
+ "exception": false,
+ "start_time": "2021-11-07T11:33:59.600367",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch: 1/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:51<00:00, 1.95it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 18.099\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.19it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 16.756\n",
+ "Score: -4.320\n",
+ "Epoch: 2/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 16.955\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.41it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 17.657\n",
+ "Score: -32.332\n",
+ "Epoch: 3/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:45<00:00, 2.18it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 16.639\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.29it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 19.156\n",
+ "Score: -25.483\n",
+ "Epoch: 4/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 16.173\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.29it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 18.130\n",
+ "Score: -15.470\n",
+ "Epoch: 5/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:45<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 15.818\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.28it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 17.367\n",
+ "Score: -14.745\n",
+ "Epoch: 6/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 15.464\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.28it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 18.289\n",
+ "Score: -4.441\n",
+ "Epoch: 7/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 15.175\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.26it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 18.604\n",
+ "Score: -21.144\n",
+ "Epoch: 8/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 15.004\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.27it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 18.593\n",
+ "Score: -27.508\n",
+ "Epoch: 9/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 14.578\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.28it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 18.264\n",
+ "Score: -19.113\n",
+ "Epoch: 10/10\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:46<00:00, 2.17it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Loss: 14.330\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 6.27it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Validation Loss: 17.739\n",
+ "Score: -18.628\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_weights = './task04_model_weights.pth'\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+ "model = Model().to(device)\n",
+ "criterion = RMSELoss\n",
+ "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
+ "epochs = 10\n",
+ "train_losses, valid_losses = [], []\n",
+ "scores = []\n",
+ "best_score = float('-inf')\n",
+ "preds = np.zeros((len(y_valid),24))\n",
+ "\n",
+ "for epoch in range(epochs):\n",
+ " print('Epoch: {}/{}'.format(epoch+1, epochs))\n",
+ " \n",
+ " # 模型训练\n",
+ " model.train()\n",
+ " losses = 0\n",
+ " for data, labels in tqdm(trainloader):\n",
+ " data = data.to(device)\n",
+ " labels = labels.to(device)\n",
+ " optimizer.zero_grad()\n",
+ " pred = model(data)\n",
+ " loss = criterion(pred, labels)\n",
+ " losses += loss.cpu().detach().numpy()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " train_loss = losses / len(trainloader)\n",
+ " train_losses.append(train_loss)\n",
+ " print('Training Loss: {:.3f}'.format(train_loss))\n",
+ " \n",
+ " # 模型验证\n",
+ " model.eval()\n",
+ " losses = 0\n",
+ " with torch.no_grad():\n",
+ " for i, data in tqdm(enumerate(validloader)):\n",
+ " data, labels = data\n",
+ " data = data.to(device)\n",
+ " labels = labels.to(device)\n",
+ " pred = model(data)\n",
+ " loss = criterion(pred, labels)\n",
+ " losses += loss.cpu().detach().numpy()\n",
+ " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
+ " valid_loss = losses / len(validloader)\n",
+ " valid_losses.append(valid_loss)\n",
+ " print('Validation Loss: {:.3f}'.format(valid_loss))\n",
+ " s = score(y_valid, preds)\n",
+ " scores.append(s)\n",
+ " print('Score: {:.3f}'.format(s))\n",
+ " \n",
+ " # 保存最佳模型权重\n",
+ " if s > best_score:\n",
+ " best_score = s\n",
+ " checkpoint = {'best_score': s,\n",
+ " 'state_dict': model.state_dict()}\n",
+ " torch.save(checkpoint, model_weights)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:41:55.621271Z",
+ "iopub.status.busy": "2021-11-07T11:41:55.620191Z",
+ "iopub.status.idle": "2021-11-07T11:41:55.623008Z",
+ "shell.execute_reply": "2021-11-07T11:41:55.623387Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:56.277287Z"
+ },
+ "papermill": {
+ "duration": 0.31815,
+ "end_time": "2021-11-07T11:41:55.623547",
+ "exception": false,
+ "start_time": "2021-11-07T11:41:55.305397",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 绘制训练/验证曲线\n",
+ "def training_vis(train_losses, valid_losses):\n",
+ " # 绘制损失函数曲线\n",
+ " fig = plt.figure(figsize=(8,4))\n",
+ " # subplot loss\n",
+ " ax1 = fig.add_subplot(121)\n",
+ " ax1.plot(train_losses, label='train_loss')\n",
+ " ax1.plot(valid_losses,label='val_loss')\n",
+ " ax1.set_xlabel('Epochs')\n",
+ " ax1.set_ylabel('Loss')\n",
+ " ax1.set_title('Loss on Training and Validation Data')\n",
+ " ax1.legend()\n",
+ " plt.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:41:56.271045Z",
+ "iopub.status.busy": "2021-11-07T11:41:56.270383Z",
+ "iopub.status.idle": "2021-11-07T11:41:56.549178Z",
+ "shell.execute_reply": "2021-11-07T11:41:56.549615Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:56.286373Z"
+ },
+ "papermill": {
+ "duration": 0.611301,
+ "end_time": "2021-11-07T11:41:56.549765",
+ "exception": false,
+ "start_time": "2021-11-07T11:41:55.938464",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS8AAAEYCAYAAAANoXDNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA51ElEQVR4nO3dd3hU1dbA4d9KIyGEHnoJNbQQSkBAelGaYENERVRsoGLlil4/9d7rtVcsSBELIIrYBRWpAQUh9E7oRFoSeifJ/v7YgzdAAikzc2aS9T5PHmbOzJy9zsywZp99dhFjDEop5W8CnA5AKaXyQpOXUsovafJSSvklTV5KKb+kyUsp5Zc0eSml/JImr0JKRH4WkUHufq6TRGS7iHT1wH7nisjdrtu3isiMnDw3D+VUE5FjIhKY11gLk0KVvDz15fYW1xf73F+GiJzMdP/W3OzLGNPDGPOpu5/ri0RkhIjEZ7G9rIicEZFGOd2XMWaSMeYqN8V13vfRGLPTGFPMGJPujv1fUJYRkeOu70qqiMwSkf65eH1HEUlyd1z5UaiSl79zfbGLGWOKATuBazJtm3TueSIS5FyUPmki0EZEalyw/WZgtTFmjQMxOSHW9d2JBj4B3hOR55wNKe80eQEiUkRE3haR3a6/t0WkiOuxsiLyk4gcEpEDIjJfRAJcjz0pIn+JyFER2SgiXbLZfwkR+UxEkkVkh4g8k2kfd4jIAhF5XUQOisg2EemRy/g7ikiSK569wMciUsoVd7Jrvz+JSJVMr8l8KnTJGHL53BoiEu96T2aKyPsiMjGbuHMS439E5HfX/maISNlMjw90vZ+pIvLP7N4fY0wSMBsYeMFDtwOfXS6OC2K+Q0QWZLrfTUQ2iMhhEXkPkEyP1RKR2a74UkRkkoiUdD02AagG/OiqDf1DRKJcNaQg13MqicgPru/dZhG5J9O+nxeRKa7v1VERWSsicdm9Bxe8HynGmAnAEOApESnj2uedIrLetb+tInKfa3s48DNQSf5X068kIi1FZKHr/8YeEXlPREJyEoM7aPKy/gm0ApoAsUBL4BnXY48DSUAkUB54GjAiEg08CLQwxkQAVwPbs9n/u0AJoCbQAfuf5s5Mj18BbATKAq8CH4mIXLiTy6gAlAaqA/diP9uPXferASeB9y7x+tzEcKnnfg4sBsoAz3NxwsgsJzHegn2vygEhwBMAItIAGOXafyVXeVkmHJdPM8fi+vyauOLN7Xt1bh9lgW+w35WywBbgysxPAV5yxVcfqIp9TzDGDOT82vOrWRTxBfa7Vwm4EXhRRDpneryP6zklgR9yEvMFvgeCsN93gP1Ab6A49j1/S0SaGWOOAz2A3Zlq+ruBdOBR17G3BroAQ3MZQ94ZYwrNHza5dM1i+xagZ6b7VwPbXbf/jf2Qa1/wmtrYD7srEHyJMgOBM0CDTNvuA+a6bt8BbM70WFHAABVyeixAR1cZoZd4fhPgYKb7c4G7cxJDTp+L/Y+fBhTN9PhEYGIOP5+sYnwm0/2hwC+u288CX2R6LNz1Hlz0+WaK8wjQxnX/v8D3eXyvFrhu3w4syvQ8wSabu7PZ77XA8uy+j0CU670Mwia6dCAi0+MvAZ+4bj8PzMz0WAPg5CXeW8MF32HX9r3Ardm85jvg4UzfsaTLfH6PAN/m5LN2x5/WvKxKwI5M93e4tgG8BmwGZriq0iMAjDGbsR/W88B+EflCRCpxsbJAcBb7r5zp/t5zN4wxJ1w3i+XyGJKNMafO3RGRoiIy2nVadQSIB0pK9leychNDds+tBBzItA1gV3YB5zDGvZlun8gUU6XM+za2dpCaXVmumL4CbnfVEm8FPstFHFm5MAaT+b6IlHd9L/5y7Xci9vuQE+fey6OZtmX7vcG+N6GSi/ZOEQnGnlEccN3vISKLXKeph4Cel4pXROq6TrH3uo7vxUs93900eVm7sacM51RzbcMYc9QY87gxpia2mv6YuNq2jDGfG2Paul5rgFey2HcKcDaL/f/l5mO4cHqQx7ENs1cYY4oD7V3bc3s6mht7gNIiUjTTtqqXeH5+YtyTed+uMstc5jWfAjcB3YAI4Md8xnFhDML5x/si9nOJce33tgv2eakpXXZj38uITNvc/b3pi60pLxbbxvs18DpQ3hhTEpieKd6sYh0FbADquI7vaTz7/TpPYUxewSISmukvCJgMPCMika52jGexv5KISG8Rqe36Yh7GVuUzRCRaRDq7PvRT2HaSjAsLM/ay9xTgvyISISLVgcfO7d+DIlwxHRKR0oDHryoZY3YACcDzIhIiIq2BazwU41Sgt4i0dTUS/5vLf5/nA4eAMdhTzjP5jGMa0FBErnd9j4ZhT5/PiQCOAYdFpDIw/ILX78O2g17EGLML+AN4yfU9bQwMxg3fGxEpLbZrzfvAK8aYVGx7YhEgGUgTexEmc5eQfUAZESlxwfEdAY6JSD3sBQCvKYzJazr2i3ru73ngBex/ulXAamCZaxtAHWAm9ku4EPjAGDMH+0G/jK1Z7cU2KD+VTZkPAceBrcACbCPxePce1kXeBsJc8S0CfvFweefcim28TcW+h18Cp7N57tvkMUZjzFrgAex7uQc4iG1vutRrDPZUsbrr33zFYYxJAfphvwep2O/K75me8i+gGfZHbxq2cT+zl7A/modE5IksihiAbQfbDXwLPGeMmZmT2LKxUkSOYZtB7gYeNcY86zqWo9jkOwX7Xt6CvQhw7lg3YH/kt7rirYS9eHILcBQYi/2svUZcDW1KeYSIfAlsMMb4bX8i5ZsKY81LeZCItBDbvylARLpj21W+czgsVQBpT2zlbhWwp0dlsKdxQ4wxy50NSRVEetqolPJLetqolPJLfnHaWLZsWRMVFeV0GEopL1u6dGmKMSYyq8f8InlFRUWRkJDgdBhKKS8TkR3ZPaanjUopv6TJSynllzR5KaX8kiYvpZRf0uSllPJLmryUUn5Jk5dSyi9p8vJlB7bC/g1OR6GUT9Lk5auMgcm3wMc94MQBp6NRyudo8vJVO36H5PVw8gDM+pfT0SjlczR5+aol4yC0JMQNhqWfQpIOj1IqM01evujoXlj/IzS9Dbr9CyIqwE+PQobbV4FXym95LHmJyHgR2S8iazJti3WtsLtaRH4UkeKeKt+vLf0UMtIg7i4oEgFXvwh7V8GSj5yOTCmf4cma1ydA9wu2jQNGGGNisAsKXLiaiko/C0s/hlpdoEwtu63hdVCzE8x+AY7uczY+pXyEx5KXMSYe12KWmdTFLugJ8Btwg6fK91sbp8PRPdDi7v9tE4Ger0PaSfjt/5yLTSkf4u02r7XYBRnALhmV7YKkInKviCSISEJycrJXgvMJS8ZBiapQ9+rzt5etDVc+DKu+hO0LnIlNKR/i7eR1FzBURJZiF6w8k90TjTFjjDFxxpi4yMgsJ1IseJI3wrZ4iLsTArJYab7tY1CyGkx73J5eKlWIeTV5GWM2GGOuMsY0xy5gucWb5fu8JR9BYAg0vT3rx0OKQo9XIXkDLPrAu7Ep5WO8mrxEpJzr3wDgGeBDb5bv004fg5WTocG1UOwSNc3oHhDdE+a+DIcvuUC0UgWaJ7tKTAYWAtEikiQig4EBIrIJ2IBdwvxjT5Xvd1ZPgdNHzm+oz073l+3woV+e8nxcSvkojy3AYYwZkM1D73iqTL9ljD1lrBADVVte/vmlqkP7J2D2fyBxJtTp6vkYlfIx2sPeF+z6E/atsbUukZy9ps1DUKY2TH8Czp7ybHxK+SBNXr5g8VgoUgJi+uX8NUFFbN+vg9vg97c9FppSvkqTl9OO7Yd130OTWyAkPHevrdUJGl4P89+0c38pVYho8nLass8g4yy0GJy311/9XwgMhp+ftG1nShUSmryclJ4GCR9DjQ5Qtk7e9lG8EnR6GhJnwIaf3BufUj5Mk5eTEn+FI0nQ8p787aflfVCuIfw8As4cd09sSvk4TV5OWjwWileGuj3yt5/AIOj1hk2E8151T2xK+ThNXk5J2Qxb50DzO23yya/qraHJrbDwPV20QxUKmryckjAeAoKgWTbjGPOi278hpJjt+6WN96qA81gPe3UJZ07AiolQvw9ElHfffsPLQpdnYdpjsHoqNM5FvzGVc+lnYeUXsHe1vdIbEAgBwfbHKDAo0+0LH8vqftDFtzPfDwyG4lUgQOsZF9Lk5YQ1U+HU4fw31Gel+R2wfCL8+jTUvQpCS7i/jMLqXNKKfxUO7YSQCDAZdsrujLP2tidE1oPO/wf1euV8BEYhoMnL24yxDfXlGkC11u7ff0Cgbbwf2xnmvAg9XnF/GYVNepodOD/vFTi4HSo1hZ5vQJ1u5yeTjHOJzJXMMtJtwrvovmtbeubnZnP/1GFYPBq+vBUqx0HX56BGe8feCl+iycvbkhLsYhq93vDcr2jlZnbxjsVjbM/9irGeKaegy0i3p9/zXoEDW6BCYxjwBdTtnvVnFxAAASFAiHvjiLsLVn5up0H69Bqo1dk2D1Rq6t5y/IyeSHvbknH2dKNxf8+W0+X/IKy0nXU1w0OnMwXVuaT1/hXw7b0QXBT6T4L74u18at4+dQt0Xdh5aBlc9V/YvQLGdIQpt0NKondj8SGavLzpeCqs/QZib7ZLmnlSWCm46j+QtASWT/BsWQVFRgas+QY+aA1fD7aN5Td9ZpNW/d7OtzcFh0KbB+HhldDhSdg8yybY7x8slBNTavLypuWfQfqZvI9jzK3YAVCtDcx8ziZOlbWMDDs4/sMrYeqdNknd+DHc/zs06Ot7V/pCi9shYcNWQMt77aIsI5vBr/8sVJ+zj30qBVhGuu3bFdUOytX3Tpki0Ot1OHUEZj3vnTL9iTGw/icY3c6egqWfhRs+giF/QKPrfS9pXahYJPR4GR5aaqdTWvQBvBMLc1+B00edjs7jfPzTKUASf7OX171V6zqnfENoNcTOXrFriXfL9lXGwMafYXR7exXv7Em4bgw88CfE3Jj1yk2+rGQ1uPZ9GLIQanWEuS/CO01g0ShIO+10dB6jyctbloyDYhWgXm/vl91xBERUtJ1X09O8X76vMAY2zYCxnWDyzXbNgGtHwQOLIba//yWtC5WrB/0nwt2z7Y/WLyPg3eawfJKt+Rcwmry84cBW2DzTdiANDPZ++UUioPtLtotGwkfeL99pxtj3f1xX+LwfnEiFPu/Bgwm2K4k7xpb6kirNYdAPMPA7O+ri+6H2IsT6HwvUsDFNXt6QMB4kAJoPci6GBtdCzU4w+wU4us+5OLzJGNgyBz66CibeAMf2wTXvwINLodlAZ35IvKlWJ7hnDtw0ATDw5W0wrgtsned0ZG6hycvTzp60w3Xq97YTBzpFxM55n3YKZjzjXBzesi0ePu4BE66FI39BrzdtP6nmd0CQmzuR+jIRaNDHtof1fd/+cH3WBz7rC38tczq6fNHk5Wlrv4WTB3O2HqOnla0NVz5sh7psm+90NJ6RugU+6W17oh/cbhP2sOX2QklhSloXCgyCprfZK5NXv2QHlY/tBF8OhORNTkeXJ2L84Bw4Li7OJCQkOB1G3ozpZGc3feBP5zs5gp3R4oMrICgM7l9QsP5Dp5+1YzoP7YCOT9taVnCo01H5ptNHYeH78Md7cPa4bfvr9h8oWtrpyM4jIkuNMXFZPaY1L0/6aynsXpa79Rg9LaQo9HgNUjbafkEFyfw37EWJvu9Dq/s1cV1KkQh7FfrhldBqKKz80l7QSNnsdGQ5psnLk5aMh+Bwexnel0R3h+iedsDxoV1OR+Meu1dA/GsQcxPUv8bpaPxHeBm7AtWgH+HUIdugvy3e6ahyRJOXp5w4YOftanyTb86p1f1lezXu16ecjiT/0k7Dd0OgaFmdAiivqreGu2dBRAWYcB0s/dTpiC5Lk5enrJhkr+z5QkN9VkpVh/ZP2L4/ib85HU3+zH0Z9q+DPiN9rs3Gr5SuAYNn2KX4fhxmx0r6cOdWTV6ekJEBSz6ykw1WaOR0NNlr8xCUqWPnvD970ulo8iYpAX5/G5rcBnWvdjoa/xdaAm6ZYpfTW/gefHGLz46T1OTlCVtmw8FtvlvrOieoiB24fXC77bzqb86ehG/vh4hK0P1Fp6MpOAKDoOertptJ4m8wvrtPto0WqOSVuO8oA8Ys4q9DDtciloyD8Ei7wIavq9kR4gbbX9lVXzkdTe7MfgFSE6Hvu77ZrujvWt4Dt35lJxQY29nWcn1IgUpeRYsEsXTnQd6c4WCnu4M7YNMv0GyQ//Sh6v6ynffrhwf9p9f1jj9sP6W4u+y0yMozaneBwb/ZLjYf97QzzPoIjyUvERkvIvtFZE2mbU1EZJGIrBCRBBFp6c4yK5cM4842UXyzPIl1u4+4c9c5t/Rj26cr7k5nys+LoBA7Y2h4JHxxq++PfTxzHL4baqeC6fYfp6Mp+MrVszNVVG5uZ5id+7JPDPD2ZM3rE6D7BdteBf5ljGkCPOu671ZDO9ameGgwr/ziwKrRaaftvFnRPaFEFe+Xnx/FIuHmz21fny9v8+15oGY+b9sUr/0AihRzOprCIbwM3P4dxN4Cc1+Cr+92/CKPx5KXMSYeOHDhZqC463YJYLe7yy1RNJgHO9Vm3qZkft+c4u7dX9ra7+x0K96ecNBdKja2CSFpsZ37ywd+XS+ydZ5dFemKIRDV1uloCpegIvb70fV524fxk96O1tK93eb1CPCaiOwCXgey7SEpIve6Ti0TkpOTc1XIwNbVqVwyjJd+Xk9Ghhf/Ay4ZB2VqQ42O3ivT3RpeB+2H25kw/hztdDTnO3XELjZRupZd+kt5nwi0fdROerh/ne2Rv3fN5V/nAd5OXkOAR40xVYFHgWxnxjPGjDHGxBlj4iIjI3NVSGhwIE9cXZc1fx3hx1Vur9xlbc9KW2OJG+z7c59fTsenIbqXXXV761yno/mfGc/AkSQ7+2lIUaejKdzqXwN3/mwXxx1/NWz8xeshePt/2SDgG9ftrwC3Nthn1je2Mg0qFue1XzdyOs0LvYSXjLMzNTQZ4PmyPC0gAK4fDWXrwpRBdiZYpyXOhGWfQusHodoVTkejACo1gXtm27ONyTfbGSq82NTg7eS1G+jgut0Z8NiKmQEBwlM965F08CQTF+30VDHWyUO2j1Tjfna9xIKgSAQM+NzenuxwL+uTh+CHhyCyHnT6p3NxqIsVr2RrYPWvgRn/hB8ftlMTeYEnu0pMBhYC0SKSJCKDgXuAN0RkJfAicK+nygdoVyeSdnXK8u7sRA6f9OAbunIypJ30/R71uVW6JvT7BFI2wTf3Obfy9i8j7BTO147SaW58UUhR6PcptHvc1o4nXm8nJvAwT15tHGCMqWiMCTbGVDHGfGSMWWCMaW6MiTXGXGGMWeqp8s95sns9Dp04y4fztnimgIwMe8pYpSVUjPVMGU6q1clOmbJxmr1E7m0bptsfh3aPQeVm3i9f5UxAgL2Ict1o2LnIK3OD+XnL8uU1qlyC65pWZvyCbew57IF+KdvmQermglfryuyK++3A5/hX7bTW3nLigD0NKR8D7f/hvXJV3sXeDLf/4JW5wQp88gJ4rFtdjMEzw4aWjIOiZeyy8AWVCPR+09YuvxsKe1Z5p9zpT9j5/68b5T9DrZTX5gYrFMmraumiDGpTna+XJbFxrxsbng8nwcbp0Oz2gt8WE1TE9u0JLWmHEB33cAfgtd/Bmq+hw5NQIcazZSn388LcYIUieQE80Kk2xYoEuXfY0NJP7KXh5n40jjE/IsrDzZPg+H6YcjuknfFMOceSbQ//ik1sh0jlnzw8N1ihSV4li4YwtFNtZm/Yz8ItqfnfYdoZWx2ue7WdlbSwqNzMrja943f45Un3798Y+OkR+yW/7sOCt5p1YZPV3GBuGlJUaJIXwB1toqhUItQ9w4bW/2BrIC3ucU9w/qRxP7v+Y8J4O2OsO63+Cjb8ZPtzlavv3n0r55ybG6x4JQgr6ZZdFqrkFRocyGNXRbMq6TDTVu/J386WfASlogrvXFJdnoM6V8HP/4Dtv7tnn0f22Eb6Ki3tFNWqYKndxSawoCJu2V2hSl4A1zWtTL0KEbz260bOpOWx0+W+tbDzj4IxjjGvAgLhhnFQqgZMGWhn28wPY2zDbtoZ2xk1INA9caoCq9D9zwsMEEb0qMfOAyf4/M8dud9BRoYdsBwcbpdPL8xCS8CAyZCeZocQnTme930tnwiJM6Drc1C2tvtiVAVWoUteAB3qRtKmVhlGzt7M0VO5HDa0ZKydaeHqF3SZLYCydeDGj2DfGtsHLC8Dcw/tgl+egupt7ZUppXKgUCYvEeGpHvU5cPwMo+flYsaE5E3w27O2raewdI/IiTrdoNu/YN13EP967l5rjJ0732RA3/cK72m4yrVC+02JqVKCPrGVGLdgK3sPn7r8C9LPwrf3QnBR6POu7XWu/qfNMGjcH+a8ABum5fx1CeNtTfaq/9iOjUrlUKFNXgDDr44mPcPw9swcDBuKfx12L4feb9lhD+p8InDNO1CpKXxzL+xbd/nXHNgGM/4PanayqwAplQuFOnlVLV2Uga2imJKwi8R9l+j5m7QU4l+DxjdDw2u9Fp/fCQ6zi3iEhMMXAy49LUpGhp3SOSDQni5qTVblUqFOXgAPdq5NeMglhg2dOWFPFyMq2p7C6tKKV4L+k+DIbvjqDnslMiuLR8OOBXD1i/630pLyCYU+eZUOD+H+jrWYuX4/f27NYtjQb8/aKW+u/UBXZc6pqi2g99t2uqAZz1z8eMpmmPkve+GjsHc3UXlW6JMXwF1X1qBC8VBe+nkDJvOl/s2zbNeIVkOhZofsd6Au1vRW+779Ocr24TonIx2+G2J7WV8zUk8XVZ5p8gLCQgJ5rFtdVuw6xM9r9tqNJw7A9w9A2WhdZiuvuv3HNsb/9CjsWmy3/fGuXWWp52tQvKKz8Sm/psnL5YbmVahbvhiv/rKBs+kZdozd8WS4foxtiFa5FxgEN46H4pXtHGCbZ8Kc/0K93hDTz+nolJ/T5OVybtjQ9tQTLPx+tGsivBF2eSeVd0VL2yFEZ0/AxBvsqkS939bTRZVvmrwy6RRdju7VMmiy6j+kV4rTifDcpVx9O4g7JML2BSuWu0WElcqKzvSWiRjDq0GjCTJpfFp+BHfpRHjuE90Dntyukwsqt9GaV2ZLxlF893y+Lz+U1xLS2X8kB8OGVM5p4lJupMnrnJRE26erdjda93uCs+kZvD3LYwt6K6XySZMX2EHX39xrVwDq+x5RkcW4rVV1vlyyi837jzkdnVIqC5q8AOa/AbuXnTfo+qHOtQkLDuRVd642pJRyG01efy2Fea9CzE3Q8Lq/N5cpVoT7O9Rkxrp9JGy/xABjpZQjCnfyOnMCvrnP1rZ6vnbRw3e1rUG5iCK8OH39+cOGlFKOK9zJa+bzkJpoB11nsRxT0ZAgHu1Wl2U7D/HrWvesNaeUco/Cm7y2zLbTslwxBGp2zPZp/ZpXoXa5TMOGlFI+oXAmr5MH4TvXoOuuz13yqUGBATzZvR5bU47z5ZJdXgpQKXU5hTN5TXvCrnZ9/egcDbruWr8cLaJK8fbMRI6fzmZyPaWUV3kseYnIeBHZLyJrMm37UkRWuP62i8gKT5WfrTVfw5qp0OFJO996DogIT/WsT8qx04ydn4vVhpRSHuPJmtcnQPfMG4wx/Y0xTYwxTYCvgW88WP7FjuyGnx6DynHQ9rFcvbRZtVL0aFSBMfFbST562kMBKqVyymPJyxgTD2TZQUpEBLgJmOyp8rMIyC74kHYarhudp3F2w6+O5kxaBu/MysFqQ0opj3KqzasdsM8Yk+3gQRG5V0QSRCQhOTk5/yUuGQdbZtn1AfO4nHzNyGIMaFmNyYt3sTVZhw0p5SSnktcALlPrMsaMMcbEGWPiIiPzOf9TSqJdH7BWF2hxd752NaxLHUKDAnhx+gbStOuEUo7xevISkSDgeuBLrxSYngbf3mcXfOj7fr5n8IyMKMLQTrWZuX4fXd+cx7fLk0jP0N73SnlbjpKXiISLSIDrdl0R6SMiwXkssyuwwRiTlMfX5878N+z4xd5vuW3Bh6EdazH29jjCQoJ49MuVXPXWPH5YuVuTmFJelNOaVzwQKiKVgRnAQOzVxGyJyGRgIRAtIkkiMtj10M14q6H+r2Uw7xW72EOj6922WxGhW4PyTHuoLR/e1oyggACGTV5Oj3fimb56DxmaxJTyOMnJgGMRWWaMaSYiDwFhxphXRWSFq8uDx8XFxZmEhITcvejsSRjdHk4fg6F/QFgpzwQHZGQYpq/Zw9szE9m8/xj1KkTwaLe6XNWgPKILTSiVZyKy1BgTl9VjOa15iYi0Bm4Fprm2BbojOI+Z+TykbHINuvZc4gIICBB6N67Er4+0552bm3AmLYP7Jiyl97sLmLlun85IoZQH5DR5PQI8BXxrjFkrIjWBOR6LKr+2zIY/P4SW90GtTl4rNjBA6NukMjMebc8b/WI5eiqNuz9LoO/7vzNn435NYkq5UY5OG897gW24L2aMOeKZkC6Wq9PGkwfhgzYQEg73xUNIUc8Gdwln0zP4dtlfjJydSNLBkzStVpLHutWlbe2yejqpVA7k+7RRRD4XkeIiEg6sAdaJyHB3Buk204fDsX120LWDiQsgODCAm1pUZfbjHXnxuhj2HT7FwI8Wc9PohfyxJcXR2JTydzk9bWzgqmldC/wM1MBecfQtW2bD6q/soOvKzZ2O5m8hQQHcckU15gzvyH/6NmTngRPcMvZPbh6zkMXbdIpppfIip8kr2NWv61rgB2PMWcD3GnBqdIC+H0C7x52OJEtFggIZ2DqKecM78dw1DdiSfJybRi/ktnF/snSHJjGlciOnyWs0sB0IB+JFpDrgtTavHAsIhKa3+vzipqHBgdx5ZQ3ih3fimV71Wb/nCDeMWsig8YtZseuQ0+Ep5Rdy3WD/9wtFgowxXpmZL0/9vPzIiTNpfLZwB6PnbeHgibN0qVeOR7vVpVHlEk6HppSjLtVgn9NOqiWA54D2rk3zgH8bYw67LcpLKOjJ65xjp9P49I/tjInfyuGTZ+nWoDxPXBVNdIUIp0NTyhHu6KQ6HjiKnYPrJuwp48fuCU+dU6xIEA90qs38JzvxWLe6LNqayjXvLmDCwu3aR0ypC+S05nXRUCCfHx5UABw4fobHp6xgzsZkesZU4OUbGlM8NK/j4ZXyP+6oeZ0UkbaZdnglcNIdwanslQ4P4aNBLRjRox6/rt1H75ELWJV0yOmwlPIJOU1e9wPvuxbN2A68B9znsajU3wIChPs71GLKfa1IS8/ghlF/8Mnv2/Q0UhV6OUpexpiVxphYoDHQ2BjTFOjs0cjUeZpXL820Ye3oUDeS539cx5CJyzh88qzTYSnlmFzNpGqMOZJpTGPult9R+VYqPISxt8fxTK/6zFy/j14j52u/MFVo5WcaaB1Z7AAR4e52NZlyf2uMgX4f/sG4+Vv1NFIVOvlJXvq/xUHNqpVi+rB2dIwuxwvT1nPPZ0s5dOKM02Ep5TWXTF4iclREjmTxdxSo5KUYVTZKFA1mzMDmPNu7AfM27afXyAUs23nQ6bCU8opLJi9jTIQxpngWfxHGGN8eQFhIiAh3ta3B1PvbEBAAN324kDHxW3QefVXgObVuo3Kz2Kol+emhdnStX54Xp2/g7s8SOHhcTyNVwaXJqwApERbMqNua8a8+DVmQmELPkfN1qh1VYGnyKmBEhEFtovh6SBtCggK4afQiRs3V00hV8GjyKqBiqpTgx4fa0r1RBV75ZQN3fbqE1GOnnQ5LKbfR5FWAFQ8N5r0BTXnh2kb8sSWVXiMX6LTTqsDQ5FXAiQi3tarOt0PbEBYSyICxi3h/zmY9jVR+T5NXIdGwkj2N7BVTkdd+3cigjxeToqeRyo9p8ipEihUJ4p2bm/DS9TEs3naAnu/MZ+GWVKfDUipPNHkVMiLCgJbV+O6BKykWGsSt4xYxclYi6XoaqfyMJq9Cqn7F4vz4YFv6NqnMm79t4voPftc+YcqvaPIqxMKLBPHmTbG8c3MT9h05zQ2jFvLQ5OUkHTzhdGhKXZYmr0JOROjbpDKzn+jAsC51mLF2L13emMcbMzZy/LRXVrZTKk80eSkAioYE8Vi3usx+oiPdG1Xg3dmb6fT6XKYuTdJuFconafJS56lcMox3bm7K10PaULFkGE98tZJrP/idJdu1PUz5Fo8lLxEZLyL7RWTNBdsfEpENIrJWRF71VPkqf5pXL8W3Q9rwdv8m7D9ymn4fLuSBz5dpe5jyGZ6seX0CdM+8QUQ6AX2BWGNMQ+B1D5av8ikgQLi2qW0Pe7hLHWat30fnN+bx+q/aHqac57HkZYyJBy481xgCvGyMOe16zn5Pla/cp2hIEI92q8vsxzvSs1EF3puzmY6vz+WrhF3aHqYc4+02r7pAOxH5U0TmiUiL7J4oIveKSIKIJCQnJ3sxRJWdSiXDePvmpnwztA2VS4YxfOoq+ryvg72VM7ydvIKA0kArYDgwRUSyXIXIGDPGGBNnjImLjIz0ZozqMppVK8W3Q9vwzs1NSD12hptGL+SBScvYdUDbw5T3eDt5JQHfGGsxkAGU9XIMyg3+7h/2eEce7VqX2Rv20+XNebz6ywaOaXuY8gJvJ6/vgE4AIlIXCAFSvByDcqOwkEAe7lqH2U90oFdMRT6Yu4VOr89liraHKQ/zZFeJycBCIFpEkkRkMDAeqOnqPvEFMMjoaqkFQsUSYbzVvwnfDm1DlVJh/GPqKq55bwF/btVZK5RniD/kjri4OJOQkOB0GCqHjDH8sHI3r/y8gd2HT9EzpgJP9ahP1dJFnQ5N+RkRWWqMicvqMV17UbndufawqxpUYOz8rYyau4WZ6/YzuF0NhnasRURosNMhqgJAhwcpjwkLCWRYlzrMeaIjvWMrMsrVHjZ58U6dP0zlmyYv5XEVSoTy5k1N+P6BK4kqE85T36ym18j5/L5Zr9WovNPkpbwmtmpJvrq/Ne/d0pSjp9K4ddyf3P3pErYmH3M6NOWHNHkprxIRejeuxKzHO/CP7tEs2nqAq96K598/ruPwibNOh6f8iCYv5YjQ4ECGdqzNnCc60i+uCp/8sY0Or8/hk9+3cTY9w+nwlB/Q5KUcFRlRhJeub8y0Ye1oWKk4z/+4ju5vxzN7wz78oRuPco4mL+UT6lcszsTBVzD29jgyDNz1SQK3j1/Mxr1HnQ5N+ShNXspniAjdGpTn10fa83+9G7By1yF6vBPPP79dTaoukKsuoMlL+ZyQoAAGt63BvOGduL11FF8s2UXH1+Yyet4WTqelOx2e8hGavJTPKhUewvN9GvLrI+1pUaM0L/28gW5vxvPz6j3aHqY0eSnfV7tcMcbf0YLP7mpJaHAAQyYto/+YRaxOOux0aMpBmryU32hfN5Lpw9rxwrWN2LL/GH3eX8DjU1ay78gpp0NTDtDkpfxKUGAAt7WqzpzhHbm3XU1+XLmbjq/NZeSsRE6e0fawwkSTl/JLxUODeapnfX57rD0doyN587dNdH5jLt8t/0snQSwkNHkpv1a9TDijbmvOl/e2okyxEB75cgXXffA7v67dq0msgNPJCFWBkZFh+HpZEu/MSiTp4ElqRYZzX/ta9G1aiSJBgU6Hp/LgUpMRavJSBU5aegbTVu/hw3lbWb/nCOWLF2Fw2xoMaFlNJ0L0M5q8VKFkjCE+MYUP525h4dZUIkKDGNiqOndcGUW5iFCnw1M5oMlLFXordx1idPwWfl6zl+DAAG5oVoV729ekRtlwp0NTl6DJSymXbSnHGRO/la+XJXE2PYMejSpwf4daNK5S0unQVBY0eSl1gf1HT/HJ79uZsGgHR0+l0bpmGe7vWIv2dcqSzSLuygGavJTKxtFTZ5m8eCcfLdjGviOnaVCxOPd1qEmvmIoEBWpPIqdp8lLqMk6npfP98t2Mjt/CluTjVC0dxj3tatKveVXCQrSbhVM0eSmVQxkZhpnr9/HhvC0s23mI0uEhDGodxe2tq1MqPMTp8AodTV5K5ZIxhiXbD/LhvC3M3rCfsOBAbm5Zlbvb1aRyyTCnwys0dMVspXJJRGhZozQta5Rm496jjI7fwoSFO5iwcAd9YitxX4daRFeIcDrMQk1rXkrl0F+HTvLR/G18sWQnJ86kc0WN0nSIjqR9nUgaVCxOQIBepXQ3PW1Uyo0OnTjDhIU7mL5mL+v3HAGgTHgIbeuUpX2dSNrVKUu54tqD3x00eSnlIfuPnGLB5hTiNyUzPzGF1ONnAKhXIYL2dW2tLC6qFKHBesUyLzR5KeUFGRmGdXuOMD/RJrOEHQc4m24IDQ7gihplaFenLB3qRlK7XDHtCJtDmryUcsCJM2ks2ppK/KYU4hOT2Zp8HICKJUJpV6cs7epE0rZ2We2CcQmOJC8RGQ/0BvYbYxq5tj0P3AMku572tDFm+uX2pclLFQRJB0+wINEmsgWJKRw5lYYINK5cgnZ1ImlfN5Km1UoSrD37/+ZU8moPHAM+uyB5HTPGvJ6bfWWVvM6ePUtSUhKnTuniC/kVGhpKlSpVCA7Wua68JT3DsDLpEPNdtbIVuw6RnmEoViSI1rXK0L5OWdrXjaR6mcI964Uj/byMMfEiEuWp/SclJREREUFUVJS2H+SDMYbU1FSSkpKoUaOG0+EUGoEBQrNqpWhWrRQPd63D4ZNnWbglhXhXe9lv6/YBUK10UQa3rcHAVtW1K8YFnOik+qCI3A4kAI8bYw5m9SQRuRe4F6BatWoXPX7q1ClNXG4gIpQpU4bk5OTLP1l5TImwYLo3qkj3RhUxxrA99QTxm5KZtnoPz/2wll/X7uW1frHauz8Tb59cjwJqAU2APcAb2T3RGDPGGBNnjImLjIzM8jmauNxD30ffIiLUKBvOoDZRfHlvK166PoaVuw7R/a14vkrYpauFu3g1eRlj9hlj0o0xGcBYoKU3y1fK34gIA1pW45dH2lO/UnGGT13FPZ8tZf9Rbev1avISkYqZ7l4HrPFm+Ur5q6qli/LFPa14pld94hOTufqteKav3uN0WI7yWPISkcnAQiBaRJJEZDDwqoisFpFVQCfgUU+V72mHDh3igw8+yPXrevbsyaFDh3L9ujvuuIOpU6fm+nWq4AgIEO5uV5Ppw9pStXRRhk5axsNfLOfwibNOh+YIT15tHJDF5o88Uda/flzLut1H3LrPBpWK89w1DbN9/FzyGjp06Hnb09LSCArK/m2dPv2y3dqUuqTa5SL4ekgbRs3dwshZiSzamsorNzSmY3Q5p0PzKu0Nl0cjRoxgy5YtNGnShBYtWtCuXTv69OlDgwYNALj22mtp3rw5DRs2ZMyYMX+/LioqipSUFLZv3079+vW55557aNiwIVdddRUnT57MUdmzZs2iadOmxMTEcNddd3H69Om/Y2rQoAGNGzfmiSeeAOCrr76iUaNGxMbG0r59eze/C8opwYEBDOtSh+8euJLiocHc8fESnv52NcdPpzkdmvcYY3z+r3nz5uZC69atu2ibN23bts00bNjQGGPMnDlzTNGiRc3WrVv/fjw1NdUYY8yJEydMw4YNTUpKijHGmOrVq5vk5GSzbds2ExgYaJYvX26MMaZfv35mwoQJ2ZY3aNAg89VXX5mTJ0+aKlWqmI0bNxpjjBk4cKB56623TEpKiqlbt67JyMgwxhhz8OBBY4wxjRo1MklJSedty4rT76fKu5Nn0syL09aZqBE/mbavzDJ/bk11OiS3ARJMNnlBa15u0rJly/M6eY4cOZLY2FhatWrFrl27SExMvOg1NWrUoEmTJgA0b96c7du3X7acjRs3UqNGDerWrQvAoEGDiI+Pp0SJEoSGhjJ48GC++eYbihYtCsCVV17JHXfcwdixY0lPT8//gSqfExocyFM96zPlvtYIQv8xC/nvtHWcOluwP29NXm4SHv6/YRxz585l5syZLFy4kJUrV9K0adMshzEVKVLk79uBgYGkpeW9yh8UFMTixYu58cYb+emnn+jevTsAH374IS+88AK7du2iefPmpKam5rkM5dtaRJXm54fbcUvLaoydv43e7y5gVdIhp8PyGE1eeRQREcHRo0ezfOzw4cOUKlWKokWLsmHDBhYtWuS2cqOjo9m+fTubN28GYMKECXTo0IFjx45x+PBhevbsyVtvvcXKlSsB2LJlC1dccQX//ve/iYyMZNeuXW6LRfme8CJB/Pe6GD69qyXHTqVx3Qd/8NZvmzibnuF0aG6nc9jnUZkyZbjyyitp1KgRYWFhlC9f/u/Hunfvzocffkj9+vWJjo6mVatWbis3NDSUjz/+mH79+pGWlkaLFi24//77OXDgAH379uXUqVMYY3jzzTcBGD58OImJiRhj6NKlC7GxsW6LRfmuDnUj+fWR9jz/41remZXI7A37efOmWOqULzjz7vvtfF7r16+nfv36DkVU8Oj7WXD9smYPT3+7hmOn0xh+VTR3ta1BoJ8M8r7UrBJ62qhUAde9UUV+faQ9HepG8t/p6xkwZhE7U084HVa+afLyMQ888ABNmjQ57+/jjz92Oizl5yIjijBmYHNe7xfL+j1H6P5OPJP+3OHXg7y1zcvHvP/++06HoAooEeHG5lVoXasM/5i6kn9+u4YZa/fxyg2NqVDC/1Y70pqXUoVM5ZJhTLjrCv7dtyF/bkvl6rfj+X7FX35XC9PkpVQhFBAg3N46ip8fbk+tyHAe/mIFXd+cx7j5WzngWr7N12nyUqoQq1E2nK/ub8Pr/WIpERbMC9PW0+rFWQybvJyFW1J9ujambV5KFXKBAbYt7MbmVdiw9whfLN7FN8uS+GHlbmqUDWdAy6rc0KwKZYoVufzOvEhrXl5SrFixbB/bvn07jRo18mI0SmWtXoXiPN+nIYv/2ZU3b4qlbLEQXpy+gVYvzeKBz5fx++YUMjJ8ozZWMGpeP4+Avavdu88KMdDjZffuUyk/ERocyPXNqnB9syok7jvK5MW7+GZ5EtNW7aF6maLc3KIaNzavQmSEc7UxrXnl0YgRI87r1vD888/zwgsv0KVLF5o1a0ZMTAzff/99rvd76tQp7rzzTmJiYmjatClz5swBYO3atbRs2ZImTZrQuHFjEhMTOX78OL169SI2NpZGjRrx5Zdfuu34lDqnTvkInr2mAYue6sI7NzehQvFQXvllA61fmsWQiUuJ35TsTG0su7lyfOnPF+fzWrZsmWnfvv3f9+vXr2927txpDh8+bIwxJjk52dSqVevv+bXCw8Oz3VfmucFef/11c+eddxpjjFm/fr2pWrWqOXnypHnwwQfNxIkTjTHGnD592pw4ccJMnTrV3H333X/v59ChQ3k+HqffT+VfEvcdNS/8tNY0+devpvqTdh6x92Ynmn2HT7q1HHQ+L/dr2rQp+/fvZ/fu3axcuZJSpUpRoUIFnn76aRo3bkzXrl3566+/2LdvX672u2DBAm677TYA6tWrR/Xq1dm0aROtW7fmxRdf5JVXXmHHjh2EhYURExPDb7/9xpNPPsn8+fMpUaKEJw5VqYvULleMf/ZqwKKnuzByQFOqlirKa79upPXLs7n3swTmbNxPuodrYwWjzcsh/fr1Y+rUqezdu5f+/fszadIkkpOTWbp0KcHBwURFRWU5j1de3HLLLVxxxRVMmzaNnj17Mnr0aDp37syyZcuYPn06zzzzDF26dOHZZ591S3lK5USRoED6xFaiT2wltqUc54slO5makMSMdfuoXDKM/i2qclNcVY/04NfklQ/9+/fnnnvuISUlhXnz5jFlyhTKlStHcHAwc+bMYceOHbneZ7t27Zg0aRKdO3dm06ZN7Ny5k+joaLZu3UrNmjUZNmwYO3fuZNWqVdSrV4/SpUtz2223UbJkScaNG+eBo1QqZ2qUDeepHvV5vFs0v63bxxdLdvLmb5t4e+YmOtcrx4CW1egYXc5tM1po8sqHhg0bcvToUSpXrkzFihW59dZbueaaa4iJiSEuLo569erlep9Dhw5lyJAhxMTEEBQUxCeffEKRIkWYMmUKEyZMIDg4+O/T0yVLljB8+HACAgIIDg5m1KhRHjhKpXInJCiAXo0r0qtxRXakHufLJbuYkpDEzPUJVCwRyncPXEn54vmviel8XgrQ91N51tn0DGat30d8Ygr/vbYRIjmrfV1qPi+teSmlPC44MIDujSrSvVFFt+1Tk5cXrV69moEDB563rUiRIvz5558ORaSU//Lr5GWMyXH10xfExMSwYsUKp8O4iD80HSh1Ib/t5xUaGkpqqm+PevcHxhhSU1MJDfW/yehU4ea3Na8qVaqQlJREcnKy06H4vdDQUKpUqeJ0GErlit8mr+Dg4PNWqFZKFS5+e9qolCrcNHkppfySJi+llF/yix72IpIM5GagYFkgxUPheFtBOhYoWMdTkI4FfPN4qhtjIrN6wC+SV26JSEJ2Qwr8TUE6FihYx1OQjgX873j0tFEp5Zc0eSml/FJBTV5jnA7AjQrSsUDBOp6CdCzgZ8dTINu8lFIFX0GteSmlCjhNXkopv1SgkpeIdBeRjSKyWURGOB1PfohIVRGZIyLrRGStiDzsdEz5JSKBIrJcRH5yOpb8EpGSIjJVRDaIyHoRae10THklIo+6vmNrRGSyiPjFFCMFJnmJSCDwPtADaAAMEJEGzkaVL2nA48aYBkAr4AE/Px6Ah4H1TgfhJu8Avxhj6gGx+OlxiUhlYBgQZ4xpBAQCNzsbVc4UmOQFtAQ2G2O2GmPOAF8AfR2OKc+MMXuMMctct49i/3NUdjaqvBORKkAvwO+XOBKREkB74CMAY8wZY8whR4PKnyAgTESCgKLAbofjyZGClLwqA7sy3U/Cj/+zZyYiUUBTwJ/ni34b+AeQ4XAc7lADSAY+dp0GjxORcKeDygtjzF/A68BOYA9w2Bgzw9mocqYgJa8CSUSKAV8DjxhjjjgdT16ISG9gvzFmqdOxuEkQ0AwYZYxpChwH/LKNVURKYc9QagCVgHARuc3ZqHKmICWvv4Cqme5XcW3zWyISjE1ck4wx3zgdTz5cCfQRke3Y0/nOIjLR2ZDyJQlIMsacqwlPxSYzf9QV2GaMSTbGnAW+Ado4HFOOFKTktQSoIyI1RCQE2+j4g8Mx5ZnYlUU+AtYbY950Op78MMY8ZYypYoyJwn4us40xfvHrnhVjzF5gl4hEuzZ1AdY5GFJ+7ARaiUhR13euC35y8cFvp4G+kDEmTUQeBH7FXjEZb4xZ63BY+XElMBBYLSIrXNueNsZMdy4klclDwCTXD+VW4E6H48kTY8yfIjIVWIa9wr0cPxkmpMODlFJ+qSCdNiqlChFNXkopv6TJSynllzR5KaX8kiYvpZRf0uSlPEpE0kVkRaY/t/VEF5EoEVnjrv0p/1Jg+nkpn3XSGNPE6SBUwaM1L+UIEdkuIq+KyGoRWSwitV3bo0RktoisEpFZIlLNtb28iHwrIitdf+eGsASKyFjXfFQzRCTM9fxhrrnQVonIFw4dpvIgTV7K08IuOG3sn+mxw8aYGOA97KwTAO8CnxpjGgOTgJGu7SOBecaYWOw4wnOjJ+oA7xtjGgKHgBtc20cATV37ud8zh6acpD3slUeJyDFjTLEstm8HOhtjtroGoO81xpQRkRSgojHmrGv7HmNMWdeq6VWMMacz7SMK+M0YU8d1/0kg2Bjzgoj8AhwDvgO+M8Yc8/ChKi/Tmpdyksnmdm6cznQ7nf+14/bCzqzbDFjimmhPFSCavJST+mf6d6Hr9h/8bxriW4H5rtuzgCHw91z4JbLbqYgEAFWNMXOAJ4ESwEW1P+Xf9NdIeVpYplkxwM77fq67RCkRWYWtPQ1wbXsIO0PpcOxspedma3gYGCMig7E1rCHYmT+zEghMdCU4AUb6+TTNKgva5qUc4WrzijPGpDgdi/JPetqolPJLWvNSSvklrXkppfySJi+llF/S5KWU8kuavJRSfkmTl1LKL/0/rC1qDE44mlAAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "training_vis(train_losses, valid_losses)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### 模型评估\n",
+ "\n",
+ "在测试集上评估模型效果。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:41:57.180629Z",
+ "iopub.status.busy": "2021-11-07T11:41:57.180063Z",
+ "iopub.status.idle": "2021-11-07T11:41:57.544721Z",
+ "shell.execute_reply": "2021-11-07T11:41:57.544233Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:56.594261Z"
+ },
+ "papermill": {
+ "duration": 0.680346,
+ "end_time": "2021-11-07T11:41:57.544847",
+ "exception": false,
+ "start_time": "2021-11-07T11:41:56.864501",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# 加载最佳模型权重\n",
+ "checkpoint = torch.load('../input/ai-earth-model-weights/task04_model_weights.pth')\n",
+ "model = Model()\n",
+ "model.load_state_dict(checkpoint['state_dict'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:41:58.181883Z",
+ "iopub.status.busy": "2021-11-07T11:41:58.181002Z",
+ "iopub.status.idle": "2021-11-07T11:41:58.182836Z",
+ "shell.execute_reply": "2021-11-07T11:41:58.183296Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:57.001527Z"
+ },
+ "papermill": {
+ "duration": 0.323811,
+ "end_time": "2021-11-07T11:41:58.183429",
+ "exception": false,
+ "start_time": "2021-11-07T11:41:57.859618",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# 测试集路径\n",
+ "test_path = '../input/ai-earth-tests/'\n",
+ "# 测试集标签路径\n",
+ "test_label_path = '../input/ai-earth-tests-labels/'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:41:58.817618Z",
+ "iopub.status.busy": "2021-11-07T11:41:58.816950Z",
+ "iopub.status.idle": "2021-11-07T11:42:00.269201Z",
+ "shell.execute_reply": "2021-11-07T11:42:00.268643Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:57.010291Z"
+ },
+ "papermill": {
+ "duration": 1.770946,
+ "end_time": "2021-11-07T11:42:00.269350",
+ "exception": false,
+ "start_time": "2021-11-07T11:41:58.498404",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "# 读取测试数据和测试数据的标签\n",
+ "files = os.listdir(test_path)\n",
+ "X_test = []\n",
+ "y_test = []\n",
+ "for file in files:\n",
+ " X_test.append(np.load(test_path + file))\n",
+ " y_test.append(np.load(test_label_path + file))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:42:00.911046Z",
+ "iopub.status.busy": "2021-11-07T11:42:00.909973Z",
+ "iopub.status.idle": "2021-11-07T11:42:00.973831Z",
+ "shell.execute_reply": "2021-11-07T11:42:00.973361Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:58.325122Z"
+ },
+ "papermill": {
+ "duration": 0.395072,
+ "end_time": "2021-11-07T11:42:00.973969",
+ "exception": false,
+ "start_time": "2021-11-07T11:42:00.578897",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((103, 12, 24, 72, 4), (103, 24))"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X_test = np.array(X_test)\n",
+ "y_test = np.array(y_test)\n",
+ "X_test.shape, y_test.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:42:01.612810Z",
+ "iopub.status.busy": "2021-11-07T11:42:01.611682Z",
+ "iopub.status.idle": "2021-11-07T11:42:01.638035Z",
+ "shell.execute_reply": "2021-11-07T11:42:01.637526Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:58.416006Z"
+ },
+ "papermill": {
+ "duration": 0.3441,
+ "end_time": "2021-11-07T11:42:01.638176",
+ "exception": false,
+ "start_time": "2021-11-07T11:42:01.294076",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "testset = AIEarthDataset(X_test, y_test)\n",
+ "testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2021-11-07T11:42:02.286777Z",
+ "iopub.status.busy": "2021-11-07T11:42:02.285865Z",
+ "iopub.status.idle": "2021-11-07T11:42:02.621942Z",
+ "shell.execute_reply": "2021-11-07T11:42:02.622610Z",
+ "shell.execute_reply.started": "2021-11-07T11:31:58.447987Z"
+ },
+ "papermill": {
+ "duration": 0.666798,
+ "end_time": "2021-11-07T11:42:02.622817",
+ "exception": false,
+ "start_time": "2021-11-07T11:42:01.956019",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "4it [00:00, 12.75it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Score: 20.274\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 在测试集上评估模型效果\n",
+ "model.eval()\n",
+ "model.to(device)\n",
+ "preds = np.zeros((len(y_test),24))\n",
+ "for i, data in tqdm(enumerate(testloader)):\n",
+ " data, labels = data\n",
+ " data = data.to(device)\n",
+ " labels = labels.to(device)\n",
+ " pred = model(data)\n",
+ " preds[i*batch_size:(i+1)*batch_size] = pred.detach().cpu().numpy()\n",
+ "s = score(y_test, preds)\n",
+ "print('Score: {:.3f}'.format(s))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "papermill": {
+ "duration": 0.311337,
+ "end_time": "2021-11-07T11:42:03.280397",
+ "exception": false,
+ "start_time": "2021-11-07T11:42:02.969060",
+ "status": "completed"
+ },
+ "tags": []
+ },
+ "source": [
+ "## 总结\n",
+ "\n",
+ "总结起来,该方案主要有以下两方面值得我们学习:\n",
+ "\n",
+ "- 充分考虑到数据量小、特征少的数据情况,对时间和空间分别进行卷积操作,用GAP层对提取的信息进行降维,尽可能减少每一层的参数量、增加模型层数以提取更丰富的特征。\n",
+ "- 对问题背景有深刻地理解,考虑到不同时间尺度序列所携带的信息不同,用池化层变换时间尺度,并用RNN进行信息提取,综合三种不同时间尺度的序列信息得到最终的预测序列。\n",
+ "\n",
+ "该方案在构造模型时充分考虑了数据集情况和问题背景,并能灵活运用各种网络层来处理特定问题,这种模型构造思路要求对各种网络层的作用有较为深刻地理解。希望大家在学习该方案的时候不仅能够学习模型的构造方法,同时也能够在今后的其他任务中逐渐体会和掌握这种模型的构造思路。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 参考文献\n",
+ "\n",
+ "1. Top1思路分享:https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.6.561d482cp7CFlx&postId=210391"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.3"
+ },
+ "papermill": {
+ "default_parameters": {},
+ "duration": 571.585002,
+ "end_time": "2021-11-07T11:42:05.308202",
+ "environment_variables": {},
+ "exception": null,
+ "input_path": "__notebook__.ipynb",
+ "output_path": "__notebook__.ipynb",
+ "parameters": {},
+ "start_time": "2021-11-07T11:32:33.723200",
+ "version": "2.3.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}