Skip to content
ML_train_and_predict.ipynb 36.7 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train ML model to correct predictions of week 3-4 & 5-6\n",
    "\n",
    "This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synopsis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Method: `ML-based mean bias reduction`\n",
    "\n",
    "- calculate the ML-based bias from 2000-2019 deterministic ensemble mean forecast\n",
    "- remove that the ML-based bias from 2020 forecast deterministic ensemble mean forecast"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data used\n",
    "\n",
    "type: renku datasets\n",
    "\n",
    "Training-input for Machine Learning model:\n",
    "- hindcasts of models:\n",
    "    - ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`\n",
    "\n",
    "Forecast-input for Machine Learning model:\n",
    "- real-time 2020 forecasts of models:\n",
    "    - ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`\n",
    "\n",
    "Compare Machine Learning model forecast against against ground truth:\n",
    "- `CPC` observations:\n",
    "    - `hindcast-like-observations_biweekly_deterministic.zarr`\n",
    "    - `forecast-like-observations_2020_biweekly_deterministic.zarr`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Resources used\n",
Aaron Spring's avatar
Aaron Spring committed
    "for training, details in reproducibility\n",
    "\n",
    "- platform: renku\n",
    "- memory: 8 GB\n",
    "- processors: 2 CPU\n",
    "- storage required: 10 GB"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Safeguards\n",
    "\n",
    "All points have to be [x] checked. If not, your submission is invalid.\n",
    "\n",
    "Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.\n",
    "(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1) \n",
    "\n",
    "If the organizers suspect overfitting, your contribution can be disqualified.\n",
    "\n",
Aaron Spring's avatar
Aaron Spring committed
    "  - [x] We did not use 2020 observations in training (explicit overfitting and cheating)\n",
    "  - [x] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)\n",
    "  - [x] We provide RPSS scores for the training period with script `print_RPS_per_year`, see in section 6.3 `predict`.\n",
    "  - [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).\n",
    "  - [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.\n",
Aaron Spring's avatar
Aaron Spring committed
    "  - [x] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.\n",
    "  - [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics))."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Safeguards for Reproducibility\n",
    "Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize\n",
    "  - [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)\n",
    "  - [x] Code is well documented, readable and reproducible.\n",
    "  - [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Todos to improve template\n",
    "\n",
    "This is just a demo.\n",
    "\n",
    "- [ ] use multiple predictor variables and two predicted variables\n",
    "- [ ] for both `lead_time`s in one go\n",
    "- [ ] consider seasonality, for now all `forecast_time` months are mixed\n",
    "- [ ] make probabilistic predictions with `category` dim, for now works deterministic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/xarray/backends/cfgrib_.py:27: UserWarning: Failed to load cfgrib - most likely there is a problem accessing the ecCodes library. Try `import cfgrib` to get the full error message\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras.layers import Input, Dense, Flatten\n",
    "from tensorflow.keras.models import Sequential\n",
Aaron Spring's avatar
Aaron Spring committed
    "from tensorflow import keras\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import xarray as xr\n",
    "xr.set_options(display_style='text')\n",
    "import numpy as np\n",
    "\n",
    "from dask.utils import format_bytes\n",
Aaron Spring's avatar
Aaron Spring committed
    "import xskillscore as xs\n",
    "%load_ext tensorboard"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get training data\n",
    "\n",
    "preprocessing of input data may be done in separate notebook/script"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hindcast\n",
    "\n",
    "get weekly initialized hindcasts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "v='t2m'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# preprocessed as renku dataset\n",
    "!renku storage pull ../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/xarray/backends/plugins.py:61: RuntimeWarning: Engine 'cfgrib' loading failed:\n",
      "/opt/conda/lib/python3.8/site-packages/gribapi/_bindings.cpython-38-x86_64-linux-gnu.so: undefined symbol: codes_bufr_key_is_header\n",
      "  warnings.warn(f\"Engine {name!r} loading failed:\\n{ex}\", RuntimeWarning)\n"
     ]
    }
   ],
   "source": [
    "hind_2000_2019 = xr.open_zarr(\"../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr\", consolidated=True)"
Aaron Spring's avatar
Aaron Spring committed
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# preprocessed as renku dataset\n",
    "!renku storage pull ../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "fct_2020 = xr.open_zarr(\"../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr\", consolidated=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Observations\n",
Aaron Spring's avatar
Aaron Spring committed
    "categorized in terciles corresponding to hindcasts/forecasts"
Aaron Spring's avatar
Aaron Spring committed
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 12,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Frozen(SortedKeysDict({'category': 3, 'forecast_time': 1060, 'latitude': 121, 'lead_time': 2, 'longitude': 240}))"
      ]
     },
Aaron Spring's avatar
Aaron Spring committed
     "execution_count": 12,
Aaron Spring's avatar
Aaron Spring committed
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "obs_2000_2019_p = xr.open_dataset(f'../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr', engine='zarr')\n",
    "\n",
    "obs_2000_2019_p.sizes"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "obs_2000_2019_p2 = obs_2000_2019_p.assign_coords(category=[0,1,2])\n",
    "obs_2000_2019_p2 = (obs_2000_2019_p2 * obs_2000_2019_p2.category).sum('category', skipna=False).compute()\n",
Aaron Spring's avatar
Aaron Spring committed
    "\n",
Aaron Spring's avatar
Aaron Spring committed
    "#obs_2000_2019_p2.isel(forecast_time=[2,4]).t2m.plot(col='lead_time', row='forecast_time')"
Aaron Spring's avatar
Aaron Spring committed
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 14,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
Aaron Spring's avatar
Aaron Spring committed
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "obs_2020_p = xr.open_dataset(f'../data/forecast-like-observations_2020_biweekly_terciled.nc')\n",
    "obs_2020_p2 = obs_2020_p.assign_coords(category=[0,1,2])\n",
    "obs_2020_p2 = (obs_2020_p2 * obs_2020_p2.category).sum('category', skipna=False).compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ML model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "based on [Weatherbench](https://github.com/pangeo-data/WeatherBench/blob/master/quickstart.ipynb)"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
Aaron Spring's avatar
Aaron Spring committed
      "fatal: destination path 'WeatherBench' already exists and is not an empty directory.\n"
   "source": [
    "# run once only and dont commit\n",
    "!git clone https://github.com/pangeo-data/WeatherBench/"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(1, 'WeatherBench')\n",
    "from WeatherBench.src.train_nn import DataGenerator, PeriodicConv2D, create_predictions\n",
    "import tensorflow.keras as keras"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=32\n",
    "\n",
    "import numpy as np\n",
    "class DataGenerator(keras.utils.Sequence):\n",
Aaron Spring's avatar
Aaron Spring committed
    "    def __init__(self, fct, verif, lead_time, batch_size=bs, shuffle=True, load=True):\n",
    "        \"\"\"\n",
    "        Data generator for WeatherBench data.\n",
    "        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly\n",
    "\n",
    "        Args:\n",
    "            fct: forecasts from S2S models: xr.DataArray (xr.Dataset doesnt work properly)\n",
    "            verif: observations with same dimensionality (xr.Dataset doesnt work properly)\n",
    "            lead_time: Lead_time as in model\n",
    "            batch_size: Batch size\n",
    "            shuffle: bool. If True, data is shuffled.\n",
    "            load: bool. If True, datadet is loaded into RAM.\n",
    "            \n",
    "        Todo:\n",
    "        - use number in a better way, now uses only ensemble mean forecast\n",
    "        - dont use .sel(lead_time=lead_time) to train over all lead_time at once\n",
    "        - be sensitive with forecast_time, pool a few around the weekofyear given\n",
    "        - use more variables as predictors\n",
    "        - predict more variables\n",
    "        \"\"\"\n",
    "\n",
    "        if isinstance(fct, xr.Dataset):\n",
    "            print('convert fct to array')\n",
    "            fct = fct.to_array().transpose(...,'variable')\n",
    "            self.fct_dataset=True\n",
    "        else:\n",
    "            self.fct_dataset=False\n",
    "            \n",
    "        if isinstance(verif, xr.Dataset):\n",
    "            print('convert verif to array')\n",
    "            verif = verif.to_array().transpose(...,'variable')\n",
    "            self.verif_dataset=True\n",
    "        else:\n",
    "            self.verif_dataset=False\n",
    "        \n",
    "        #self.fct = fct\n",
    "        self.batch_size = batch_size\n",
    "        self.shuffle = shuffle\n",
    "        self.lead_time = lead_time\n",
    "\n",
    "        self.fct_data = fct.transpose('forecast_time', ...).sel(lead_time=lead_time)\n",
Aaron Spring's avatar
Aaron Spring committed
    "        self.fct_mean = self.fct_data.mean('forecast_time').compute()\n",
    "        self.fct_std = self.fct_data.std('forecast_time').compute()\n",
    "        \n",
    "        self.verif_data = verif.transpose('forecast_time', ...).sel(lead_time=lead_time)\n",
Aaron Spring's avatar
Aaron Spring committed
    "        #self.verif_mean = self.verif_data.mean('forecast_time').compute() if mean is None else mean\n",
    "        #self.verif_std = self.verif_data.std('forecast_time').compute() if std is None else std\n",
    "\n",
    "        # Normalize\n",
Aaron Spring's avatar
Aaron Spring committed
    "        self.fct_data = (self.fct_data - self.fct_mean) / self.fct_std\n",
Aaron Spring's avatar
Aaron Spring committed
    "        #self.verif_data = (self.verif_data - self.verif_mean) / self.verif_std\n",
    "        #self.verif_data = self.verif_data.astype('int32')\n",
    "        \n",
    "        self.n_samples = self.fct_data.forecast_time.size\n",
    "        self.forecast_time = self.fct_data.forecast_time\n",
    "\n",
    "        self.on_epoch_end()\n",
    "\n",
    "        # For some weird reason calling .load() earlier messes up the mean and std computations\n",
    "        if load:\n",
    "            # print('Loading data into RAM')\n",
    "            self.fct_data.load()\n",
    "\n",
    "    def __len__(self):\n",
    "        'Denotes the number of batches per epoch'\n",
    "        return int(np.ceil(self.n_samples / self.batch_size))\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        'Generate one batch of data'\n",
    "        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]\n",
    "        # got all nan if nans not masked\n",
Aaron Spring's avatar
Aaron Spring committed
    "        X = self.fct_data.isel(forecast_time=idxs).values\n",
    "        y = self.verif_data.isel(forecast_time=idxs).values\n",
    "        return X, y\n",
    "\n",
    "    def on_epoch_end(self):\n",
    "        'Updates indexes after each epoch'\n",
    "        self.idxs = np.arange(self.n_samples)\n",
    "        if self.shuffle == True:\n",
    "            np.random.shuffle(self.idxs)"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre>&lt;xarray.DataArray &#x27;lead_time&#x27; ()&gt;\n",
       "array(1209600000000000, dtype=&#x27;timedelta64[ns]&#x27;)\n",
       "Coordinates:\n",
       "    lead_time  timedelta64[ns] 14 days\n",
       "Attributes:\n",
       "    comment:  lead_time describes bi-weekly aggregates. The pd.Timedelta corr...</pre>"
      ],
      "text/plain": [
       "<xarray.DataArray 'lead_time' ()>\n",
       "array(1209600000000000, dtype='timedelta64[ns]')\n",
       "Coordinates:\n",
       "    lead_time  timedelta64[ns] 14 days\n",
       "Attributes:\n",
       "    comment:  lead_time describes bi-weekly aggregates. The pd.Timedelta corr..."
      ]
     },
Aaron Spring's avatar
Aaron Spring committed
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2 bi-weekly `lead_time`: week 3-4\n",
    "lead = hind_2000_2019.isel(lead_time=0).lead_time\n",
    "\n",
    "lead"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "mask = obs_2000_2019_p2.isel(forecast_time=0, lead_time=0,drop=True).notnull()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## data prep: train, valid, test\n",
    "\n",
    "[Use the hindcast period to split train and valid.](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets) Do not use the 2020 data for testing!"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 278,
   "metadata": {},
   "outputs": [],
   "source": [
    "# time is the forecast_time\n",
Aaron Spring's avatar
Aaron Spring committed
    "time_train_start,time_train_end='2000','2017' # train\n",
    "time_valid_start,time_valid_end='2018','2019' # valid\n",
    "time_test = '2020'                            # test"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 279,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
    "# Jan only\n",
Aaron Spring's avatar
Aaron Spring committed
    "tattr='month'\n",
    "tattr_label=1\n",
    "attr='seaon'\n",
    "attr_label='DJF'\n",
Aaron Spring's avatar
Aaron Spring committed
    "bs=16\n",
Aaron Spring's avatar
Aaron Spring committed
    "hind_2000_2019 = hind_2000_2019.sel(forecast_time=getattr(hind_2000_2019.forecast_time.dt, tattr)==tattr_label)\n",
    "obs_2000_2019_p2 = obs_2000_2019_p2.sel(forecast_time=getattr(obs_2000_2019_p2.forecast_time.dt, tattr)==tattr_label)"
Aaron Spring's avatar
Aaron Spring committed
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 280,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
    "dg_train = DataGenerator(\n",
    "    hind_2000_2019.mean('realization').sel(forecast_time=slice(time_train_start,time_train_end))[v],\n",
Aaron Spring's avatar
Aaron Spring committed
    "    obs_2000_2019_p2.sel(forecast_time=slice(time_train_start,time_train_end))[v],\n",
Aaron Spring's avatar
Aaron Spring committed
    "    lead_time=lead, batch_size=bs, shuffle=True, load=True)"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 281,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
    "dg_valid = DataGenerator(\n",
    "    hind_2000_2019.mean('realization').sel(forecast_time=slice(time_valid_start,time_valid_end))[v],\n",
Aaron Spring's avatar
Aaron Spring committed
    "    obs_2000_2019_p2.sel(forecast_time=slice(time_valid_start,time_valid_end))[v],\n",
    "    lead_time=lead, batch_size=bs, shuffle=False, load=True)"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 282,
   "metadata": {},
   "outputs": [],
   "source": [
    "# do not use, delete?\n",
Aaron Spring's avatar
Aaron Spring committed
    "dg_test = DataGenerator(\n",
    "    fct_2020.mean('realization').sel(forecast_time=time_test)[v],\n",
    "    obs_2020_p2.sel(forecast_time=time_test)[v],\n",
    "    lead_time=lead, batch_size=bs, shuffle=False, load=True)"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 283,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
Aaron Spring's avatar
Aaron Spring committed
       "((16, 121, 240), (16, 121, 240))"
Aaron Spring's avatar
Aaron Spring committed
     "execution_count": 283,
Aaron Spring's avatar
Aaron Spring committed
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y = dg_train[0]\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 284,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
Aaron Spring's avatar
Aaron Spring committed
       "((10, 121, 240), (10, 121, 240))"
Aaron Spring's avatar
Aaron Spring committed
      ]
     },
Aaron Spring's avatar
Aaron Spring committed
     "execution_count": 284,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y = dg_valid[0]\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 285,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "# short look into training data\n",
    "# any problem from normalizing?\n",
    "i=4\n",
Aaron Spring's avatar
Aaron Spring committed
    "#xr.DataArray(np.vstack([X[i],y[i]])).plot(yincrease=False, robust=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `fit`"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 286,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import Sequential\n",
Aaron Spring's avatar
Aaron Spring committed
    "from tensorflow.keras.layers import Convolution2D, MaxPooling2D, Flatten, Dense, Dropout"
Aaron Spring's avatar
Aaron Spring committed
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 314,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
Aaron Spring's avatar
Aaron Spring committed
      "Model: \"sequential_6\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
Aaron Spring's avatar
Aaron Spring committed
      "periodic_conv2d_17 (Periodic (None, 121, 240, 64)      5248      \n",
      "_________________________________________________________________\n",
      "dropout_11 (Dropout)         (None, 121, 240, 64)      0         \n",
      "_________________________________________________________________\n",
      "periodic_conv2d_18 (Periodic (None, 121, 240, 16)      25616     \n",
      "_________________________________________________________________\n",
Aaron Spring's avatar
Aaron Spring committed
      "dense_7 (Dense)              (None, 121, 240, 8)       136       \n",
Aaron Spring's avatar
Aaron Spring committed
      "_________________________________________________________________\n",
Aaron Spring's avatar
Aaron Spring committed
      "dense_8 (Dense)              (None, 121, 240, 3)       27        \n",
      "=================================================================\n",
Aaron Spring's avatar
Aaron Spring committed
      "Total params: 31,027\n",
      "Trainable params: 31,027\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "dr=.01\n",
Aaron Spring's avatar
Aaron Spring committed
    "cnn = keras.models.Sequential([\n",
Aaron Spring's avatar
Aaron Spring committed
    "    PeriodicConv2D(filters=64, kernel_size=9, conv_kwargs={'activation':'elu'},\n",
    "                   input_shape=(121, 240, 1)),\n",
    "    #Dropout(dr),\n",
    "    #PeriodicConv2D(filters=16, kernel_size=9, conv_kwargs={'activation':'relu'}),\n",
    "    Dropout(dr),\n",
Aaron Spring's avatar
Aaron Spring committed
    "    PeriodicConv2D(filters=16, kernel_size=5),\n",
Aaron Spring's avatar
Aaron Spring committed
    "    Dense(units=8),\n",
Aaron Spring's avatar
Aaron Spring committed
    "    Dense(units=3, activation='softmax')\n",
    "])\n",
    "cnn.summary()"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 315,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizers = ['adam','RMSprop', 'sgd', keras.optimizers.Adam(learning_rate=0.01)]\n",
    "losses = ['sparse_categorical_crossentropy', 'categorical_crossentropy']\n",
    "\n",
    "metrics=[keras.metrics.CategoricalAccuracy(),'accuracy']\n",
    "\n",
    "#cnn.compile(optimizer=optimizers[0], loss=losses[0], metrics=metrics[-1])"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 317,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "# https://stackoverflow.com/questions/34875944/how-to-write-a-custom-loss-function-in-tensorflow#37573411\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "def rps_loss(y_true, y_pred, sample_weight=None):\n",
    "    y_true = tf.one_hot(tf.cast(y_true,'int32'), depth=3)\n",
    "    Fc = tf.cumsum(y_pred, axis=-1)\n",
    "    Oc = tf.cumsum(y_true, axis=-1)\n",
    "    diff = tf.subtract(Fc, Oc)\n",
    "    rps = tf.reduce_sum(tf.square(diff), axis=-1)\n",
    "    # mask ocean\n",
    "    y_one = tf.reduce_sum(y_true, axis=-1)\n",
    "    rps = tf.multiply(rps, y_one)\n",
    "    # spatial mean # todo add weights\n",
    "    return rps\n",
    "\n",
    "\n",
    "#xr.DataArray(tf.reduce_mean(rps_loss(obs_2020_p2.isel(lead_time=1)[v], preds_test[v].isel(lead_time=1)),axis=0),dims=['latitude','longitude'],coords={'latitude':mask.latitude,'longitude':mask.longitude}).plot(yincrease=True)"
Aaron Spring's avatar
Aaron Spring committed
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 318,
   "metadata": {},
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "cnn.compile(optimizer='adam', loss=rps_loss, metrics='accuracy')"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 319,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 320,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_callback = keras.callbacks.TensorBoard(log_dir=\"./logs\")"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 321,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
Aaron Spring's avatar
Aaron Spring committed
      "Epoch 1/3\n",
      "6/6 [==============================] - ETA: 0s - loss: 0.0964 - accuracy: 0.0983WARNING:tensorflow:5 out of the last 16 calls to <function Model.make_test_function.<locals>.test_function at 0x7fd18e625940> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n",
      "6/6 [==============================] - 11s 1s/step - loss: 0.0965 - accuracy: 0.0980 - val_loss: 0.1008 - val_accuracy: 0.0815\n",
      "Epoch 2/3\n",
      "6/6 [==============================] - 8s 1s/step - loss: 0.0960 - accuracy: 0.0954 - val_loss: 0.1004 - val_accuracy: 0.0843\n",
      "Epoch 3/3\n",
      "6/6 [==============================] - 9s 1s/step - loss: 0.0943 - accuracy: 0.0982 - val_loss: 0.1015 - val_accuracy: 0.0842\n"
     ]
    }
   ],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "history = cnn.fit(dg_train, epochs=3, validation_data=dg_valid, callbacks=[tensorboard_callback])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 232,
   "metadata": {},
   "outputs": [],
   "source": [
    "#history.history['loss']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 233,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!jupyter labextension install jupyterlab_tensorboard"
Aaron Spring's avatar
Aaron Spring committed
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 234,
Aaron Spring's avatar
Aaron Spring committed
   "metadata": {},
   "outputs": [],
   "source": [
    "#%tensorboard --logdir logs/fit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `predict`\n",
    "\n",
    "Create predictions and print `mean(variable, lead_time, longitude, weighted latitude)` RPSS for all years as calculated by `skill_by_year`."
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 235,
   "metadata": {},
   "outputs": [],
   "source": [
    "# idea: build image classifier: show RPS on valid after each epoch\n",
    "# https://www.tensorflow.org/tensorboard/image_summaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 255,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _create_predictions(model, dg, lead):\n",
Aaron Spring's avatar
Aaron Spring committed
    "    preds = model.predict_proba(dg)\n",
    "    da = xr.DataArray(preds,\n",
    "                      dims=['forecast_time', 'latitude', 'longitude','category'],\n",
    "                      coords={'forecast_time': dg.fct_data.forecast_time, 'latitude': dg.fct_data.latitude,\n",
    "                                'longitude': dg.fct_data.longitude, 'category': ['below normal','near normal','above normal']},\n",
    "                      )\n",
    "    da = da.where(mask[v]) # is that needed?\n",
    "    da = da.assign_coords(lead_time=lead)\n",
    "    return da"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 256,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "#cnn.evaluate(dg_valid)"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 257,
   "metadata": {},
   "outputs": [],
Aaron Spring's avatar
Aaron Spring committed
   "source": [
    "#obs_2000_2019_p[v].isel(forecast_time=[2,20,25,-3,-2,-1]).sel(lead_time=lead).squeeze().plot(col='category', row='forecast_time')"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": null,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "preds = _create_predictions(cnn, dg_train, lead)\n",
    "#preds.isel(forecast_time=[2,20,25,-3,-2,-1]).squeeze().plot(col='category', row='forecast_time', vmin=.13,vmax=.53, cmap='RdBu')"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 274,
   "metadata": {},
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "from scripts import add_valid_time_from_forecast_reference_time_and_lead_time\n",
    "\n",
    "# this is not useful but results have expected dimensions\n",
    "# actually train for each lead_time\n",
    "\n",
    "def create_predictions(cnn, fct, obs, time):\n",
    "    preds_test=[]\n",
    "    for lead in fct.lead_time:\n",
    "        dg = DataGenerator(fct.mean('realization').sel(forecast_time=time)[v],\n",
    "                           obs.sel(forecast_time=time)[v],\n",
Aaron Spring's avatar
Aaron Spring committed
    "                           lead_time=lead, batch_size=bs, shuffle=False)\n",
    "        preds_test.append(_create_predictions(cnn, dg, lead))\n",
    "    preds_test = xr.concat(preds_test, 'lead_time')\n",
    "    preds_test['lead_time'] = fct.lead_time\n",
    "    # add valid_time coord\n",
    "    preds_test = add_valid_time_from_forecast_reference_time_and_lead_time(preds_test)\n",
    "    preds_test = preds_test.to_dataset(name=v)\n",
    "    # add fake var\n",
    "    preds_test['tp'] = preds_test['t2m']\n",
    "    return preds_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `predict` training period in-sample"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 275,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
Aaron Spring's avatar
Aaron Spring committed
    "from scripts import skill_by_year"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": null,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
    "import os\n",
    "if os.environ['HOME'] == '/home/jovyan':\n",
    "    import pandas as pd\n",
    "    # assume on renku with small memory\n",
    "    step = 2\n",
    "    skill_list = []\n",
    "    for year in np.arange(int(time_train_start), int(time_train_end) -1, step): # loop over years to consume less memory on renku\n",
Aaron Spring's avatar
Aaron Spring committed
    "        preds_is = create_predictions(cnn, hind_2000_2019, obs_2000_2019_p2, time=slice(str(year), str(year+step-1))).compute()\n",
    "        skill_list.append(skill_by_year(preds_is))\n",
    "    skill = pd.concat(skill_list)\n",
    "else: # with larger memory, simply do\n",
Aaron Spring's avatar
Aaron Spring committed
    "    preds_is = create_predictions(cnn, hind_2000_2019, obs_2000_2019_p2, time=slice(time_train_start, time_train_end))\n",
    "    skill = skill_by_year(preds_is)\n",
    "skill"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `predict` validation period out-of-sample"
   ]
  },
  {
   "cell_type": "code",
Aaron Spring's avatar
Aaron Spring committed
   "execution_count": 323,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RPSS</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>year</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2018</th>\n",
Aaron Spring's avatar
Aaron Spring committed
       "      <td>-2.197396</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2019</th>\n",
Aaron Spring's avatar
Aaron Spring committed
       "      <td>-0.333983</td>\n",