Skip to content
ML_train_and_predict.ipynb 67.1 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train ML model to correct predictions of week 3-4 & 5-6\n",
    "\n",
    "This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synopsis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Method: `ML-based mean bias reduction`\n",
    "\n",
    "- calculate the ML-based bias from 2000-2019 deterministic ensemble mean forecast\n",
    "- remove that the ML-based bias from 2020 forecast deterministic ensemble mean forecast"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data used\n",
    "\n",
    "type: renku datasets\n",
    "\n",
    "Training-input for Machine Learning model:\n",
    "- hindcasts of models:\n",
    "    - ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`\n",
    "\n",
    "Forecast-input for Machine Learning model:\n",
    "- real-time 2020 forecasts of models:\n",
    "    - ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`\n",
    "\n",
    "Compare Machine Learning model forecast against against ground truth:\n",
    "- `CPC` observations:\n",
    "    - `hindcast-like-observations_biweekly_deterministic.zarr`\n",
    "    - `forecast-like-observations_2020_biweekly_deterministic.zarr`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Resources used\n",
Aaron Spring's avatar
Aaron Spring committed
    "for training, details in reproducibility\n",
    "\n",
    "- platform: renku\n",
    "- memory: 8 GB\n",
    "- processors: 2 CPU\n",
    "- storage required: 10 GB"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Safeguards\n",
    "\n",
    "All points have to be [x] checked. If not, your submission is invalid.\n",
    "\n",
    "Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.\n",
    "(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1) \n",
    "\n",
    "If the organizers suspect overfitting, your contribution can be disqualified.\n",
    "\n",
Aaron Spring's avatar
Aaron Spring committed
    "  - [x] We did not use 2020 observations in training (explicit overfitting and cheating)\n",
    "  - [x] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)\n",
    "  - [x] We provide RPSS scores for the training period with script `print_RPS_per_year`, see in section 6.3 `predict`.\n",
    "  - [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).\n",
    "  - [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.\n",
Aaron Spring's avatar
Aaron Spring committed
    "  - [x] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.\n",
    "  - [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics))."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Safeguards for Reproducibility\n",
    "Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize\n",
    "  - [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)\n",
    "  - [x] Code is well documented, readable and reproducible.\n",
    "  - [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Todos to improve template\n",
    "\n",
    "This is just a demo.\n",
    "\n",
    "- [ ] use multiple predictor variables and two predicted variables\n",
    "- [ ] for both `lead_time`s in one go\n",
    "- [ ] consider seasonality, for now all `forecast_time` months are mixed\n",
    "- [ ] make probabilistic predictions with `category` dim, for now works deterministic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
    "from tensorflow.keras.layers import Input, Dense, Flatten\n",
    "from tensorflow.keras.models import Sequential\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import xarray as xr\n",
    "xr.set_options(display_style='text')\n",
    "import numpy as np\n",
    "\n",
    "from dask.utils import format_bytes\n",
    "import xskillscore as xs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get training data\n",
    "\n",
    "preprocessing of input data may be done in separate notebook/script"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hindcast\n",
    "\n",
    "get weekly initialized hindcasts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "v='t2m'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# preprocessed as renku dataset\n",
    "!renku storage pull ../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
    "hind_2000_2019 = xr.open_zarr(\"../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr\", consolidated=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# preprocessed as renku dataset\n",
    "!renku storage pull ../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "fct_2020 = xr.open_zarr(\"../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr\", consolidated=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Observations\n",
    "corresponding to hindcasts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# preprocessed as renku dataset\n",
    "!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_2000_2019 = xr.open_zarr(\"../data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr\", consolidated=True)#[v]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# preprocessed as renku dataset\n",
    "!renku storage pull ../data/forecast-like-observations_2020_biweekly_deterministic.zarr"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_2020 = xr.open_zarr(\"../data/forecast-like-observations_2020_biweekly_deterministic.zarr\", consolidated=True)#[v]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ML model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "based on [Weatherbench](https://github.com/pangeo-data/WeatherBench/blob/master/quickstart.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'WeatherBench'...\n",
      "remote: Enumerating objects: 718, done.\u001b[K\n",
      "remote: Counting objects: 100% (3/3), done.\u001b[K\n",
      "remote: Compressing objects: 100% (3/3), done.\u001b[K\n",
      "remote: Total 718 (delta 0), reused 0 (delta 0), pack-reused 715\u001b[K\n",
      "Receiving objects: 100% (718/718), 17.77 MiB | 5.87 MiB/s, done.\n",
      "Resolving deltas: 100% (424/424), done.\n"
     ]
    }
   ],
   "source": [
    "# run once only and dont commit\n",
    "!git clone https://github.com/pangeo-data/WeatherBench/"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(1, 'WeatherBench')\n",
    "from WeatherBench.src.train_nn import DataGenerator, PeriodicConv2D, create_predictions\n",
    "import tensorflow.keras as keras"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=32\n",
    "\n",
    "import numpy as np\n",
    "class DataGenerator(keras.utils.Sequence):\n",
    "    def __init__(self, fct, verif, lead_time, batch_size=bs, shuffle=True, load=True,\n",
    "                 mean=None, std=None):\n",
    "        \"\"\"\n",
    "        Data generator for WeatherBench data.\n",
    "        Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly\n",
    "\n",
    "        Args:\n",
    "            fct: forecasts from S2S models: xr.DataArray (xr.Dataset doesnt work properly)\n",
    "            verif: observations with same dimensionality (xr.Dataset doesnt work properly)\n",
    "            lead_time: Lead_time as in model\n",
    "            batch_size: Batch size\n",
    "            shuffle: bool. If True, data is shuffled.\n",
    "            load: bool. If True, datadet is loaded into RAM.\n",
    "            mean: If None, compute mean from data.\n",
    "            std: If None, compute standard deviation from data.\n",
    "            \n",
    "        Todo:\n",
    "        - use number in a better way, now uses only ensemble mean forecast\n",
    "        - dont use .sel(lead_time=lead_time) to train over all lead_time at once\n",
    "        - be sensitive with forecast_time, pool a few around the weekofyear given\n",
    "        - use more variables as predictors\n",
    "        - predict more variables\n",
    "        \"\"\"\n",
    "\n",
    "        if isinstance(fct, xr.Dataset):\n",
    "            print('convert fct to array')\n",
    "            fct = fct.to_array().transpose(...,'variable')\n",
    "            self.fct_dataset=True\n",
    "        else:\n",
    "            self.fct_dataset=False\n",
    "            \n",
    "        if isinstance(verif, xr.Dataset):\n",
    "            print('convert verif to array')\n",
    "            verif = verif.to_array().transpose(...,'variable')\n",
    "            self.verif_dataset=True\n",
    "        else:\n",
    "            self.verif_dataset=False\n",
    "        \n",
    "        #self.fct = fct\n",
    "        self.batch_size = batch_size\n",
    "        self.shuffle = shuffle\n",
    "        self.lead_time = lead_time\n",
    "\n",
    "        self.fct_data = fct.transpose('forecast_time', ...).sel(lead_time=lead_time)\n",
    "        self.fct_mean = self.fct_data.mean('forecast_time').compute() if mean is None else mean\n",
    "        self.fct_std = self.fct_data.std('forecast_time').compute() if std is None else std\n",
    "        \n",
    "        self.verif_data = verif.transpose('forecast_time', ...).sel(lead_time=lead_time)\n",
    "        self.verif_mean = self.verif_data.mean('forecast_time').compute() if mean is None else mean\n",
    "        self.verif_std = self.verif_data.std('forecast_time').compute() if std is None else std\n",
    "\n",
    "        # Normalize\n",
    "        self.fct_data = (self.fct_data - self.fct_mean) / self.fct_std\n",
    "        self.verif_data = (self.verif_data - self.verif_mean) / self.verif_std\n",
    "        \n",
    "        self.n_samples = self.fct_data.forecast_time.size\n",
    "        self.forecast_time = self.fct_data.forecast_time\n",
    "\n",
    "        self.on_epoch_end()\n",
    "\n",
    "        # For some weird reason calling .load() earlier messes up the mean and std computations\n",
    "        if load:\n",
    "            # print('Loading data into RAM')\n",
    "            self.fct_data.load()\n",
    "\n",
    "    def __len__(self):\n",
    "        'Denotes the number of batches per epoch'\n",
    "        return int(np.ceil(self.n_samples / self.batch_size))\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        'Generate one batch of data'\n",
    "        idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]\n",
    "        # got all nan if nans not masked\n",
    "        X = self.fct_data.isel(forecast_time=idxs).fillna(0.).values\n",
    "        y = self.verif_data.isel(forecast_time=idxs).fillna(0.).values\n",
    "        return X, y\n",
    "\n",
    "    def on_epoch_end(self):\n",
    "        'Updates indexes after each epoch'\n",
    "        self.idxs = np.arange(self.n_samples)\n",
    "        if self.shuffle == True:\n",
    "            np.random.shuffle(self.idxs)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre>&lt;xarray.DataArray &#x27;lead_time&#x27; ()&gt;\n",
       "array(1209600000000000, dtype=&#x27;timedelta64[ns]&#x27;)\n",
       "Coordinates:\n",
       "    lead_time  timedelta64[ns] 14 days\n",
       "Attributes:\n",
Aaron Spring's avatar
Aaron Spring committed
       "    aggregate:      The pd.Timedelta corresponds to the first day of a biweek...\n",
       "    description:    Forecast period is the time interval between the forecast...\n",
       "    long_name:      lead time\n",
       "    standard_name:  forecast_period\n",
       "    week34_t2m:     mean[14 days, 27 days]\n",
       "    week34_tp:      28 days minus 14 days\n",
       "    week56_t2m:     mean[28 days, 41 days]\n",
       "    week56_tp:      42 days minus 28 days</pre>"
      ],
      "text/plain": [
       "<xarray.DataArray 'lead_time' ()>\n",
       "array(1209600000000000, dtype='timedelta64[ns]')\n",
       "Coordinates:\n",
       "    lead_time  timedelta64[ns] 14 days\n",
       "Attributes:\n",
Aaron Spring's avatar
Aaron Spring committed
       "    aggregate:      The pd.Timedelta corresponds to the first day of a biweek...\n",
       "    description:    Forecast period is the time interval between the forecast...\n",
       "    long_name:      lead time\n",
       "    standard_name:  forecast_period\n",
       "    week34_t2m:     mean[14 days, 27 days]\n",
       "    week34_tp:      28 days minus 14 days\n",
       "    week56_t2m:     mean[28 days, 41 days]\n",
       "    week56_tp:      42 days minus 28 days"
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 2 bi-weekly `lead_time`: week 3-4\n",
    "lead = hind_2000_2019.isel(lead_time=0).lead_time\n",
    "\n",
    "lead"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mask, needed?\n",
    "hind_2000_2019 = hind_2000_2019.where(obs_2000_2019.isel(forecast_time=0, lead_time=0,drop=True).notnull())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## data prep: train, valid, test\n",
    "\n",
    "[Use the hindcast period to split train and valid.](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets) Do not use the 2020 data for testing!"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "# time is the forecast_time\n",
    "time_train_start,time_train_end='2000','2017' # train\n",
    "time_valid_start,time_valid_end='2018','2019' # valid\n",
    "time_test = '2020'                            # test"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n"
     ]
    }
   ],
   "source": [
    "dg_train = DataGenerator(\n",
    "    hind_2000_2019.mean('realization').sel(forecast_time=slice(time_train_start,time_train_end))[v],\n",
    "    obs_2000_2019.sel(forecast_time=slice(time_train_start,time_train_end))[v],\n",
    "    lead_time=lead, batch_size=bs, load=True)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n",
      "/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide\n",
      "  x = np.divide(x1, x2, out)\n"
     ]
    }
   ],
   "source": [
    "dg_valid = DataGenerator(\n",
    "    hind_2000_2019.mean('realization').sel(forecast_time=slice(time_valid_start,time_valid_end))[v],\n",
    "    obs_2000_2019.sel(forecast_time=slice(time_valid_start,time_valid_end))[v],\n",
    "    lead_time=lead, batch_size=bs, shuffle=False, load=True)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "# do not use, delete?\n",
    "dg_test = DataGenerator(\n",
    "    fct_2020.mean('realization').sel(forecast_time=time_test)[v],\n",
    "    obs_2020.sel(forecast_time=time_test)[v],\n",
    "    lead_time=lead, batch_size=bs, load=True, mean=dg_train.fct_mean, std=dg_train.fct_std, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((32, 121, 240), (32, 121, 240))"
      ]
     },
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y = dg_valid[0]\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
Aaron Spring's avatar
Aaron Spring committed
   "outputs": [],
   "source": [
    "# short look into training data: large biases\n",
    "# any problem from normalizing?\n",
Aaron Spring's avatar
Aaron Spring committed
    "# i=4\n",
    "# xr.DataArray(np.vstack([X[i],y[i]])).plot(yincrease=False, robust=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `fit`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:AutoGraph could not transform <bound method PeriodicPadding2D.call of <WeatherBench.src.train_nn.PeriodicPadding2D object at 0x7f457c77da90>> and will run it as-is.\n",
      "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
      "Cause: module 'gast' has no attribute 'Index'\n",
      "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
      "WARNING: AutoGraph could not transform <bound method PeriodicPadding2D.call of <WeatherBench.src.train_nn.PeriodicPadding2D object at 0x7f457c77da90>> and will run it as-is.\n",
      "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
      "Cause: module 'gast' has no attribute 'Index'\n",
      "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n"
     ]
    }
   ],
   "source": [
    "cnn = keras.models.Sequential([\n",
    "    PeriodicConv2D(filters=32, kernel_size=5, conv_kwargs={'activation':'relu'}, input_shape=(32, 64, 1)),\n",
    "    PeriodicConv2D(filters=1, kernel_size=5)\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "periodic_conv2d (PeriodicCon (None, 32, 64, 32)        832       \n",
      "_________________________________________________________________\n",
      "periodic_conv2d_1 (PeriodicC (None, 32, 64, 1)         801       \n",
      "=================================================================\n",
      "Total params: 1,633\n",
      "Trainable params: 1,633\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "cnn.summary()"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn.compile(keras.optimizers.Adam(1e-4), 'mse')"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30/30 [==============================] - 34s 1s/step - loss: 0.1241 - val_loss: 0.0690\n",
      "30/30 [==============================] - 25s 822ms/step - loss: 0.0681 - val_loss: 0.0543\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f457c7db1c0>"
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cnn.fit(dg_train, epochs=2, validation_data=dg_valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `predict`\n",
    "\n",
    "Create predictions and print `mean(variable, lead_time, longitude, weighted latitude)` RPSS for all years as calculated by `skill_by_year`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "from scripts import add_valid_time_from_forecast_reference_time_and_lead_time\n",
    "\n",
    "def _create_predictions(model, dg, lead):\n",
    "    \"\"\"Create non-iterative predictions\"\"\"\n",
    "    preds = model.predict(dg).squeeze()\n",
    "    # Unnormalize\n",
    "    preds = preds * dg.fct_std.values + dg.fct_mean.values\n",
    "    if dg.verif_dataset:\n",
    "        da = xr.DataArray(\n",
    "                    preds,\n",
    "                    dims=['forecast_time', 'latitude', 'longitude','variable'],\n",
    "                    coords={'forecast_time': dg.fct_data.forecast_time, 'latitude': dg.fct_data.latitude,\n",
    "                            'longitude': dg.fct_data.longitude},\n",
    "                ).to_dataset() # doesnt work yet\n",
    "    else:\n",
    "        da = xr.DataArray(\n",
    "                    preds,\n",
    "                    dims=['forecast_time', 'latitude', 'longitude'],\n",
    "                    coords={'forecast_time': dg.fct_data.forecast_time, 'latitude': dg.fct_data.latitude,\n",
    "                            'longitude': dg.fct_data.longitude},\n",
    "                )\n",
    "    da = da.assign_coords(lead_time=lead)\n",
    "    # da = add_valid_time_from_forecast_reference_time_and_lead_time(da)\n",
    "    return da"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optionally masking the ocean when making probabilistic\n",
    "mask = obs_2020.std(['lead_time','forecast_time']).notnull()"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts import make_probabilistic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "cache_path='../data'\n",
    "tercile_file = f'{cache_path}/hindcast-like-observations_2000-2019_biweekly_tercile-edges.nc'\n",
    "tercile_edges = xr.open_dataset(tercile_file)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is not useful but results have expected dimensions\n",
    "# actually train for each lead_time\n",
    "\n",
    "def create_predictions(cnn, fct, obs, time):\n",
    "    preds_test=[]\n",
    "    for lead in fct.lead_time:\n",
    "        dg = DataGenerator(fct.mean('realization').sel(forecast_time=time)[v],\n",
    "                           obs.sel(forecast_time=time)[v],\n",
    "                           lead_time=lead, batch_size=bs, mean=dg_train.fct_mean, std=dg_train.fct_std, shuffle=False)\n",
    "        preds_test.append(_create_predictions(cnn, dg, lead))\n",
    "    preds_test = xr.concat(preds_test, 'lead_time')\n",
    "    preds_test['lead_time'] = fct.lead_time\n",
    "    # add valid_time coord\n",
    "    preds_test = add_valid_time_from_forecast_reference_time_and_lead_time(preds_test)\n",
    "    preds_test = preds_test.to_dataset(name=v)\n",
    "    # add fake var\n",
    "    preds_test['tp'] = preds_test['t2m']\n",
    "    # make probabilistic\n",
    "    preds_test = make_probabilistic(preds_test.expand_dims('realization'), tercile_edges, mask=mask)\n",
    "    return preds_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `predict` training period in-sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!renku storage pull ../data/forecast-like-observations_2020_biweekly_terciled.nc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33m\u001b[1mWarning: \u001b[0mRun CLI commands only from project's root directory.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RPSS</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>year</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2000</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2001</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2002</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2003</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2004</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2005</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2006</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2007</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2008</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2009</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2010</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2011</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2012</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2013</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2014</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2015</th>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          RPSS\n",
       "year          \n",
       "2000 -1.228445\n",
       "2001 -1.382888\n",
       "2002 -1.453018\n",
       "2003 -1.398054\n",
       "2004 -1.410522\n",